k-fold cross validation example in R
We’ll work with 173 simulated observations on variables \(y\), \(x\), and \(z\). The two models we’re considering are \[y_{t} = \alpha_{x} + \beta_{x} x_{t} + \varepsilon_{xt}\] \[y_{t} = \alpha_{z} + \beta_{z} z_{t} + \varepsilon_{zt}\]
We’re going to use k-fold cross validation with five folds to determine which of the above two models is the preferred specification. First generate the data:
set.seed(460)
y <- rnorm(173)
x <- rnorm(173)
z <- rnorm(173)
One of the more challenging parts of this for an inexperienced programmer is splitting the sample. This is complicated by the possibility that the total number of observations is not divisible by the number of folds. In this example, we have 173 observations and 5 folds. Three folds will have 35 observations and two will have 34.
Investing in understanding recursive functions pays off for this problem, because it’s nearly trivial if you make use of recursion:
obs.numbers <- function(n, k) {
n.min <- n %/% k
n.max <- n.min + 1
aux <- function(indexValues, extra, result=list()) {
if (length(indexValues) > 0) {
if (extra > 0) {
result[[length(result)+1]] <- indexValues[1:n.max]
Recall(indexValues[-(1:n.max)], extra-1, result)
} else {
result[[length(result)+1]] <- indexValues[1:n.min]
Recall(indexValues[-(1:n.min)], 0, result)
}
} else {
return(result)
}
}
return(aux(1:n, n %% k))
}
n.min
is the minimum number of observations in every fold. %/%
does integer division.
Here’s a little explanation if you’re not familiar with recursive functions. aux
is a recursive function that takes the vector of all observation index numbers (1 through 173 in this example). Each time through, it pulls out the observation numbers associated with the next fold and removes them from indexValues
. extra
is the number of folds that get an extra observation.
- If
indexValues
has elements andextra
is positive, the next fold consists of the firstn.max
elements ofindexValues
. Save those values as an element inresult
and remove them fromindexValues
. Decrementextra
. - If
indexValues
has elements andextra
is zero, the next fold consists of the firstn.min
elements ofindexValues
. Save those values as an element inresult
and remove them fromindexValues
. - If
indexValues
has no elements, we’ve assigned all observation numbers to a fold, so returnresult
.
You can test the obs.numbers
function. For instance, if we have 10 observations and 3 folds:
[[1]]
[1] 1 2 3 4
[[2]]
[1] 5 6 7
[[3]]
[1] 8 9 10
As expected, the first fold is observations 1 through 4, the second is 5 through 7, and the last is 8 through 10.
Here’s a function that does the MSE calculation, where you pass in the fold number for the validation sample and the output of the call to obs.numbers
:
cv.mse <- function(fold, obs) {
y.tilde <- y[-obs[[fold]]]
x.tilde <- x[-obs[[fold]]]
z.tilde <- z[-obs[[fold]]]
y.fold <- y[obs[[fold]]]
x.fold <- x[obs[[fold]]]
z.fold <- z[obs[[fold]]]
fit1 <- lm(y ~ x)
pred1 <- predict(fit1, data.frame(x = x.fold))
mse1 <- mean((y.fold - pred1)^2)
fit2 <- lm(y ~ z)
pred2 <- predict(fit2, data.frame(z = z.fold))
mse2 <- mean((y.fold - pred2)^2)
return(list(with.x=mse1, with.y=mse2))
}
We’re done with the hard parts. Call cv.mse
for each of the five folds and save the output in a list:
mse.values <- lapply(1:5, cv.mse, obs=obs.numbers(173,5))
This is a utility function that will be added to tstools at some point. We need it to pull out the the individual pieces of mse.values
:
listToVector <- function(obj, name) {
return(sapply(obj, function(z) { z[[name]] }))
}
The output we’re after:
print(mean(listToVector(mse.values, "with.x")))
print(mean(listToVector(mse.values, "with.z")))
Since the MSE for the model with x is (slightly) lower, that’s the model chosen by five-fold cross validation.