Data Leakage in Machine Learning

Photo by Franki Chamaki on Unsplash

I have come across a noticeable pattern on many data science websites such as Kaggle, where someone will share their machine learning model that claims 100% accuracy (across all metrics). While achieving scores that approach perfection can be possible, it is not very common in most scenarios.

As I began to look into these models further, I noticed that in almost every scenario, the person has either mistakenly or been naive to the fact that these scores were due to some form of data leakage.

What is Data Leakage?

In the real world, when models are put into production, the model predicts on data that it has never seen before, and since the test dataset’s purpose to mimic unseen data, caution should always be taken to make sure that the necessary steps and precautions are done to prevent any leakage between the two.

The simplest and silliest form of data leakage would be one where for a supervised learning problem, you provide your training set with access to your label. In this situation, regardless of what other features you may choose to include in your model, the leaked data from the label will result in a model with perfect accuracy.

Common Causes of Data Leakage

1. Mishandling missing values

To prevent data leakage in this case, make sure that all imputations are performed after the data is split into training and test sets. A pipeline with function transformers can be a neat way to do this.

2. Shuffling Time Series Data

To prevent running into such an issue, you can choose a cut-off point to break your data into training and test sets. This way, you eliminate the risk of mixing your time series data.

3. Over/Under Sampling Imbalanced Datasets

When oversampling, it is important that the data is again split before into train and test sets to prevent and duplicate data leakage from the test set to the training set.

Good Habits to Prevent Indirect Data Leakage

  1. Incorporate a pipeline
  2. Validate using K-fold cross-validation to detect inconsistencies using the validation dataset