How to choose a predictive model after k-fold cross-validation?

  • I am wondering how to choose a predictive model after doing K-fold cross-validation.

    This may be awkwardly phrased, so let me explain in more detail: whenever I run K-fold cross-validation, I use K subsets of the training data, and end up with K different models.

    I would like to know how to pick one of the K models, so that I can present it to someone and say "this is the best model that we can produce."

    Is it OK to pick any one of the K models? Or is there some kind of best practice that is involved, such as picking the model that achieves the median test error?

    You will need to repeat 5-fold CV 100 times and average the results to get sufficient precision. And the answer from @bogdanovist is spot on. You can get the same precision of accuracy estimate from the bootstrap with fewer model fits.

    @Frank Harrell, why do you say 100 repetitions is necessary (I usually use 10 reps on 10 fold), is this a rule of thumb as the OP didn't give any specifics?

    For 10-fold cv it is best to do $\geq 50$ repeats. More repeats will be needed with 5-fold. These are rules of thumb. A single 10-fold cv will given an unstable answer, i.e., repeat the 10 splits and you get enough of a different answer to worry.

    Almost an exact duplicate: http://stats.stackexchange.com/questions/11602 with lots of worthy answers. Perhaps these threads should be merged but I am not sure in which direction. Both have accepted answers that are very good. But the other one is older and has more views/upvotes so it might make sense to merge this one into that one.

  • Bogdanovist

    Bogdanovist Correct answer

    8 years ago

    I think that you are missing something still in your understanding of the purpose of cross-validation.

    Let's get some terminology straight, generally when we say 'a model' we refer to a particular method for describing how some input data relates to what we are trying to predict. We don't generally refer to particular instances of that method as different models. So you might say 'I have a linear regression model' but you wouldn't call two different sets of the trained coefficients different models. At least not in the context of model selection.

    So, when you do K-fold cross validation, you are testing how well your model is able to get trained by some data and then predict data it hasn't seen. We use cross validation for this because if you train using all the data you have, you have none left for testing. You could do this once, say by using 80% of the data to train and 20% to test, but what if the 20% you happened to pick to test happens to contain a bunch of points that are particularly easy (or particularly hard) to predict? We will not have come up with the best estimate possible of the models ability to learn and predict.

    We want to use all of the data. So to continue the above example of an 80/20 split, we would do 5-fold cross validation by training the model 5 times on 80% of the data and testing on 20%. We ensure that each data point ends up in the 20% test set exactly once. We've therefore used every data point we have to contribute to an understanding of how well our model performs the task of learning from some data and predicting some new data.

    But the purpose of cross-validation is not to come up with our final model. We don't use these 5 instances of our trained model to do any real prediction. For that we want to use all the data we have to come up with the best model possible. The purpose of cross-validation is model checking, not model building.

    Now, say we have two models, say a linear regression model and a neural network. How can we say which model is better? We can do K-fold cross-validation and see which one proves better at predicting the test set points. But once we have used cross-validation to select the better performing model, we train that model (whether it be the linear regression or the neural network) on all the data. We don't use the actual model instances we trained during cross-validation for our final predictive model.

    Note that there is a technique called bootstrap aggregation (usually shortened to 'bagging') that does in a way use model instances produced in a way similar to cross-validation to build up an ensemble model, but that is an advanced technique beyond the scope of your question here.

    I agree with this point entirely and thought about using all of the data. That said, if we trained our final model using the entire data set then wouldn't this result in overfitting and thereby sabotage future predictions?

    No! Overfitting has to do with model complexity, it has nothing to do with the amount of data used to train the model. Model complexity has to do with the method the model uses, not the values its parameters take. For instance whether to include x^2 co-efficients as well as x co-efficients in a regression model.

    @Bogdanovist: I rather say that overfitting has to do with having too few training cases for too complex a model. So it (also) has to do with numbers of training cases. But having more training cases will reduce the risk of overfitting (for constant model complexity).

    @Bogdanovist I totally agree with your points! Can you recommend some paper (or book chapter) that explains this ?

    I believe you've given me the missing piece of the puzzle! Thanks so much!

    @Bogdanovist `For that we want to use all the data we have to come up with the best model possible.` - When doing a grid search with K-fold cross validation, does this mean you would use the best params found by grid search and fit a model on the entire training data, and then evaluate generalization performance using the test set?

    @arun, if you've used k-fold cross validation and selected the best model with the best parameters & hyper-parameters, then after fitting the final model over the training set, you don't need to again check for performance using a test set. This is because you've already checked how the model with specified parameters behaved on unseen data.

    what test metric would you use to compare two different models, e.g., regression and a neural network? If the regression minimizes least square error, then won't it presumably do better then the neural network on the test data if you use a squared error test metric?

    In my opinion, "For that we want to use all the data we have to come up with the best model possible. " is exactly why we would have a overfitting problem.

    If we train on whole Data then what do we report as test error or performance error then ? Do we report the one we obtained on cross validation as the performance measure of the selected model ?

    I believe the question raised by @arun is meaningful, while I disagree with the answer by: cross validation helps us choosing among a set of models and it should be performed on a training set. To evaluate whether the best model we have chosen is predictive, this should be tested on an independent set, the test set, on which cross validation was not performed. See, for example: http://scikit-learn.org/stable/modules/cross_validation.html

    @Bogdanovist Thanks for the useful answer. In some application having the validation set is crucial in model training for example for earlier stopping. In such cases, the final training (on the whole data) may not be done properly with the lack of a validation set. This is the case for example in training deep neural network or GBM where early stopping is applied. So my question is that how to train the model on the whole data while we wouldn't have any criteria to know where to stop the training (which can lead to over-fitting)?

    @Bogdanovist Secondly, I have seen cases where the prediction on a test set of the trained model on each fold is averaged together to get the final predictions, based on your answer, why doesn't this make sense?

    I have the same wonder with @M.Reza. If we using the whole data set to re-train model, it means we lack the valid set. In some cases, like gradient boosting, without valid set, the model will seriously be overfitting. Any suggestions ? Thanks

    Bogdanovist 's answer is correct for the general case; the CV average error is the best estimate of the model trained on all data. There's no reason to think think result on a validation set is a better estimate, and in fact it may be less accurate for smaller datasets due to high variance. The issues raised by @M.Reza and Catbuilts are the exception to the rule, only because the logic of early stopping methods requires an independent valid set.

    @ Bogdanovist - Can we consider different hyper parameter settings as different models. E.g. NN with different hidden layers as different models or Decision tree with different heights as different models?

    Then what do we generally mean by the term *hypothesis*?

License under CC-BY-SA with attribution


Content dated before 6/26/2020 9:53 AM