What should I do when my neural network doesn't learn?

  • I'm training a neural network but the training loss doesn't decrease. How can I fix this?

    I'm not asking about overfitting or regularization. I'm asking about how to solve the problem where my network's performance doesn't improve on the training set.

    This question is intentionally general so that other questions about how to train a neural network can be closed as a duplicate of this one, with the attitude that "if you give a man a fish you feed him for a day, but if you teach a man to fish, you can feed him for the rest of his life." See this Meta thread for a discussion: What's the best way to answer "my neural network doesn't work, please fix" questions?

    If your neural network does not generalize well, see: What should I do when my neural network doesn't generalize well?

    Here's case where the NN could not progress. https://youtu.be/iakFfOmanJU?t=144

    Ivanov's blog "Reasons why your Neural Network is not working", especially sections II, III, and IV, could be helpful.

  • Unit Testing Is Your Friend

    There's a saying among writers that "All writing is re-writing" -- that is, the greater part of writing is revising. For programmers (or at least data scientists) the expression could be re-phrased as "All coding is debugging."

    Any time you're writing code, you need to verify that it works as intended. The best method I've ever found for verifying correctness is to break your code into small segments, and verify that each segment works. This can be done by comparing the segment output to what you know to be the correct answer. This is called unit testing. Writing good unit tests is a key piece of becoming a good statistician/data scientist/machine learning expert/neural network practitioner. There is simply no substitute.

    You have to check that your code is free of bugs before you can tune network performance! Otherwise, you might as well be re-arranging deck chairs on the RMS Titanic.

    There are two features of neural networks that make verification even more important than for other types of machine learning or statistical models.

    1. Neural networks are not "off-the-shelf" algorithms in the way that random forest or logistic regression are. Even for simple, feed-forward networks, the onus is largely on the user to make numerous decisions about how the network is configured, connected, initialized and optimized. This means writing code, and writing code means debugging.

    2. Even when a neural network code executes without raising an exception, the network can still have bugs! These bugs might even be the insidious kind for which the network will train, but get stuck at a sub-optimal solution, or the resulting network does not have the desired architecture. (This is an example of the difference between a syntactic and semantic error.)

    This Medium post, "How to unit test machine learning code," by Chase Roberts discusses unit-testing for machine learning models in more detail. I borrowed this example of buggy code from the article:

    def make_convnet(input_image):
        net = slim.conv2d(input_image, 32, [11, 11], scope="conv1_11x11")
        net = slim.conv2d(input_image, 64, [5, 5], scope="conv2_5x5")
        net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')
        net = slim.conv2d(input_image, 64, [5, 5], scope="conv3_5x5")
        net = slim.conv2d(input_image, 128, [3, 3], scope="conv4_3x3")
        net = slim.max_pool2d(net, [2, 2], scope='pool2')
        net = slim.conv2d(input_image, 128, [3, 3], scope="conv5_3x3")
        net = slim.max_pool2d(net, [2, 2], scope='pool3')
        net = slim.conv2d(input_image, 32, [1, 1], scope="conv6_1x1")
        return net

    Do you see the error? Many of the different operations are not actually used because previous results are over-written with new variables. Using this block of code in a network will still train and the weights will update and the loss might even decrease -- but the code definitely isn't doing what was intended. (The author is also inconsistent about using single- or double-quotes but that's purely stylistic.)

    The most common programming errors pertaining to neural networks are

    • Variables are created but never used (usually because of copy-paste errors);
    • Expressions for gradient updates are incorrect;
    • Weight updates are not applied;
    • Loss functions are not measured on the correct scale (for example, cross-entropy loss can be expressed in terms of probability or logits)
    • The loss is not appropriate for the task (for example, using categorical cross-entropy loss for a regression task).

    Crawl Before You Walk; Walk Before You Run

    Wide and deep neural networks, and neural networks with exotic wiring, are the Hot Thing right now in machine learning. But these networks didn't spring fully-formed into existence; their designers built up to them from smaller units. First, build a small network with a single hidden layer and verify that it works correctly. Then incrementally add additional model complexity, and verify that each of those works as well.

    • Too few neurons in a layer can restrict the representation that the network learns, causing under-fitting. Too many neurons can cause over-fitting because the network will "memorize" the training data.

      Even if you can prove that there is, mathematically, only a small number of neurons necessary to model a problem, it is often the case that having "a few more" neurons makes it easier for the optimizer to find a "good" configuration. (But I don't think anyone fully understands why this is the case.) I provide an example of this in the context of the XOR problem here: Aren't my iterations needed to train NN for XOR with MSE < 0.001 too high?.

    • Choosing the number of hidden layers lets the network learn an abstraction from the raw data. Deep learning is all the rage these days, and networks with a large number of layers have shown impressive results. But adding too many hidden layers can make risk overfitting or make it very hard to optimize the network.

    • Choosing a clever network wiring can do a lot of the work for you. Is your data source amenable to specialized network architectures? Convolutional neural networks can achieve impressive results on "structured" data sources, image or audio data. Recurrent neural networks can do well on sequential data types, such as natural language or time series data. Residual connections can improve deep feed-forward networks.

    Neural Network Training Is Like Lock Picking

    To achieve state of the art, or even merely good, results, you have to have to have set up all of the parts configured to work well together. Setting up a neural network configuration that actually learns is a lot like picking a lock: all of the pieces have to be lined up just right. Just as it is not sufficient to have a single tumbler in the right place, neither is it sufficient to have only the architecture, or only the optimizer, set up correctly.

    Tuning configuration choices is not really as simple as saying that one kind of configuration choice (e.g. learning rate) is more or less important than another (e.g. number of units), since all of these choices interact with all of the other choices, so one choice can do well in combination with another choice made elsewhere.

    This is a non-exhaustive list of the configuration options which are not also regularization options or numerical optimization options.

    All of these topics are active areas of research.

    Non-convex optimization is hard

    The objective function of a neural network is only convex when there are no hidden units, all activations are linear, and the design matrix is full-rank -- because this configuration is identically an ordinary regression problem.

    In all other cases, the optimization problem is non-convex, and non-convex optimization is hard. The challenges of training neural networks are well-known (see: Why is it hard to train deep neural networks?). Additionally, neural networks have a very large number of parameters, which restricts us to solely first-order methods (see: Why is Newton's method not widely used in machine learning?). This is a very active area of research.

    • Setting the learning rate too large will cause the optimization to diverge, because you will leap from one side of the "canyon" to the other. Setting this too small will prevent you from making any real progress, and possibly allow the noise inherent in SGD to overwhelm your gradient estimates. See:

    • Gradient clipping re-scales the norm of the gradient if it's above some threshold. I used to think that this was a set-and-forget parameter, typically at 1.0, but I found that I could make an LSTM language model dramatically better by setting it to 0.25. I don't know why that is.

    • Learning rate scheduling can decrease the learning rate over the course of training. In my experience, trying to use scheduling is a lot like regex: it replaces one problem ("How do I get learning to continue after a certain epoch?") with two problems ("How do I get learning to continue after a certain epoch?" and "How do I choose a good schedule?"). Other people insist that scheduling is essential. I'll let you decide.

    • Choosing a good minibatch size can influence the learning process indirectly, since a larger mini-batch will tend to have a smaller variance () than a smaller mini-batch. You want the mini-batch to be large enough to be informative about the direction of the gradient, but small enough that SGD can regularize your network.

    • There are a number of variants on stochastic gradient descent which use momentum, adaptive learning rates, Nesterov updates and so on to improve upon vanilla SGD. Designing a better optimizer is very much an active area of research. Some examples:

    • When it first came out, the Adam optimizer generated a lot of interest. But some recent research has found that SGD with momentum can out-perform adaptive gradient methods for neural networks. "The Marginal Value of Adaptive Gradient Methods in Machine Learning" by Ashia C. Wilson, Rebecca Roelofs, Mitchell Stern, Nathan Srebro, Benjamin Recht

    • But on the other hand, this very recent paper proposes a new adaptive learning-rate optimizer which supposedly closes the gap between adaptive-rate methods and SGD with momentum. "Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks" by Jinghui Chen, Quanquan Gu

      Adaptive gradient methods, which adopt historical gradient information to automatically adjust the learning rate, have been observed to generalize worse than stochastic gradient descent (SGD) with momentum in training deep neural networks. This leaves how to close the generalization gap of adaptive gradient methods an open problem. In this work, we show that adaptive gradient methods such as Adam, Amsgrad, are sometimes "over adapted". We design a new algorithm, called Partially adaptive momentum estimation method (Padam), which unifies the Adam/Amsgrad with SGD to achieve the best from both worlds. Experiments on standard benchmarks show that Padam can maintain fast convergence rate as Adam/Amsgrad while generalizing as well as SGD in training deep neural networks. These results would suggest practitioners pick up adaptive gradient methods once again for faster training of deep neural networks.


    The scale of the data can make a big difference on training.


    Choosing and tuning network regularization is a key part of building a model that generalizes well (that is, a model that is not overfit to the training data). However, at the time that your network is struggling to decrease the loss on the training data -- when the network is not learning -- regularization can obscure what the problem is.

    When my network doesn't learn, I turn off all regularization and verify that the non-regularized network works correctly. Then I add each regularization piece back, and verify that each of those works along the way.

    This tactic can pinpoint where some regularization might be poorly set. Some examples are

    Keep a Logbook of Experiments

    When I set up a neural network, I don't hard-code any parameter settings. Instead, I do that in a configuration file (e.g., JSON) that is read and used to populate network configuration details at runtime. I keep all of these configuration files. If I make any parameter modification, I make a new configuration file. Finally, I append as comments all of the per-epoch losses for training and validation.

    The reason that I'm so obsessive about retaining old results is that this makes it very easy to go back and review previous experiments. It also hedges against mistakenly repeating the same dead-end experiment. Psychologically, it also lets you look back and observe "Well, the project might not be where I want it to be today, but I am making progress compared to where I was $k$ weeks ago."

    As an example, I wanted to learn about LSTM language models, so I decided to make a Twitter bot that writes new tweets in response to other Twitter users. I worked on this in my free time, between grad school and my job. It took about a year, and I iterated over about 150 different models before getting to a model that did what I wanted: generate new English-language text that (sort of) makes sense. (One key sticking point, and part of the reason that it took so many attempts, is that it was not sufficient to simply get a low out-of-sample loss, since early low-loss models had managed to memorize the training data, so it was just reproducing germane blocks of text verbatim in reply to prompts -- it took some tweaking to make the model more spontaneous and still have low loss.)

    Lots of good advice there. It's interesting how many of your comments are similar to comments I have made (or have seen others make) in relation to debugging estimation of parameters or predictions for complex models with MCMC sampling schemes. (For example, the code may seem to work when it's not correctly implemented.)

    @Glen_b I don’t think coding best practices receive enough emphasis in most stats/machine learning curricula which is why I emphasized that point so heavily. I’ve seen a number of NN posts where OP left a comment like “oh I found a bug now it works.”

    I teach a programming for data science course in python, and we actually do functions and unit testing on the first day, as primary concepts. Fighting the good fight.

    +1 for "All coding is debugging". I am amazed how many posters on SO seem to think that coding is a simple exercise requiring little effort; who expect their code to work correctly the first time they run it; and who seem to be unable to proceed when it doesn't. The funny thing is that they're half right: coding *is* easy - but programming is hard.

    It is really nice answer. I knew a good part of this stuff, what stood out for me is **Keep a Logbook of Experiments**- it is a really good suggestion. Very intuitive but not very obvious ways to keep track of experiments. Mine was very messy. Thank you @sycorax-says-reinstate-monica. Btw, if I may ask, what do you suggest is the best to keep the log of experiments- JSON or YAML or any other way. I mean which one do you use or find it most convenient and readable. Please do tell. Thanks :)

    @zeal I use `datascientist` (as in `pip install datascientist`), but the "log book" can be as simple as a table that tracks the model's performance and the hyper-parameters you used and the saved model checkpoint (if you retrain model artifacts).

License under CC-BY-SA with attribution

Content dated before 6/26/2020 9:53 AM