overfitting in statistics and machine learning (part one)

Overfitting is a common risk when designing statistical and machine-learning models. Here I give a brief demonstration of overfitting in action, using simple regression models. A later post will more rigorously address how to quantify and avoid overfitting.

We start by sampling data from the process

equation

using the R code:

set.seed(100)
x <- seq(0, 1, 0.1)
noise <- rnorm(length(x), 0, 0.2)
y <- x + noise

Then we produce a linear model of the sampled data:

df <- data.frame(x=x, y=y)
order.1.lm <- lm(y ~ x, data=df)
plot(x, y, main="Simple Regression Model on Simulated Data")
abline(0, 1, col="magenta")
abline(order.1.lm$coefficients, col="blue")
legend("topleft", c("simulated data", "y=x", "regression line"), fill=c("black", "magenta", "blue"))

Plotting this result, we see that the regression line and the true line defined by the process largely agree:

simple_regression

Qualitatively, we can argue from this analysis that the 1st order regression model “learned” the underlying process y=x.

Suppose however that we are not satisfied with this outcome, and suspect we can obtain a better model by adding terms x3 and x4 to the regression model:

order.3_4.lm <- lm(y ~ x + I(x^3) + I(x^4), data=df)
x_pred <- seq(-1, 1.1, 0.01)
df_pred <- data.frame(x=x_pred)
prediction <- predict(order.3_4.lm, df_pred)
plot(x, y, main="Overfit Model on Simulated Data")
abline(0, 1, col="magenta")
lines(x_pred, prediction, col="blue")
legend("topleft", c("simulated data", "y=x", "regression curve"), fill=c("black", "magenta", "blue"))

The resulting plot (below) shows how the model has “learned” some of the features of the sample, rather than remain generalized for the underlying process y=x. We see in the plot that between x=0.6 and x=1.0 the model chooses values lower than the correct y=x line, due to points in sampled data falling below this line. Similarly, between x=0.1 and x=0.6, the model chooses values higher than the correct y=x line due to points in the sampled data that fall above the line:

overfit_model

Consequently, this model is overfit since it models the noise in the sample rather than just the underlying process y=x.

Overfit models generalize poorly. If we select 11 different data points and plot them against the overfit model curve, we see that the model does not reflect the data points well:

set.seed(400)
plot(x, y, type='n', main="Poor Generalization of 4th-Order Model")
xg <- seq(0, 1, 0.1)
noise <- rnorm(length(xg), 0, 0.2)
yg <- xg + noise
points(xg, yg)
abline(0, 1, col="magenta")
lines(x_pred, prediction, col="blue")
legend("topleft", c("simulated data", "y=x", "regression curve"), fill=c("black", "magenta", "blue"))

poorfit

For points x={0.3, 0.4, 0.5}, the y values fall substantially below x=y while the model predicts they would fall above x=y. Similarly, the y values at x={0.7, 0.8} fall on opposite sides of x=y than the model predicts they should.

Post Author: badassdatascience

1 thought on “overfitting in statistics and machine learning (part one)

    the humble sum of the squared errors |

    (November 29, 2013 - 8:29 pm)

    […] with the lowest SSE (or MSE/RMSE), after you have prevented overfitting. Overfitting is described here on this blog, and a future post will detail how to prevent […]

Leave a Reply

Your email address will not be published.