Difference between neural net weight decay and learning rate

  • In the context of neural networks, what is the difference between the learning rate and weight decay?

  • mrig

    mrig Correct answer

    8 years ago

    The learning rate is a parameter that determines how much an updating step influences the current value of the weights. While weight decay is an additional term in the weight update rule that causes the weights to exponentially decay to zero, if no other update is scheduled.

    So let's say that we have a cost or error function $E(\mathbf{w})$ that we want to minimize. Gradient descent tells us to modify the weights $\mathbf{w}$ in the direction of steepest descent in $E$: \begin{equation} w_i \leftarrow w_i-\eta\frac{\partial E}{\partial w_i}, \end{equation} where $\eta$ is the learning rate, and if it's large you will have a correspondingly large modification of the weights $w_i$ (in general it shouldn't be too large, otherwise you'll overshoot the local minimum in your cost function).

    In order to effectively limit the number of free parameters in your model so as to avoid over-fitting, it is possible to regularize the cost function. An easy way to do that is by introducing a zero mean Gaussian prior over the weights, which is equivalent to changing the cost function to $\widetilde{E}(\mathbf{w})=E(\mathbf{w})+\frac{\lambda}{2}\mathbf{w}^2$. In practice this penalizes large weights and effectively limits the freedom in your model. The regularization parameter $\lambda$ determines how you trade off the original cost $E$ with the large weights penalization.

    Applying gradient descent to this new cost function we obtain: \begin{equation} w_i \leftarrow w_i-\eta\frac{\partial E}{\partial w_i}-\eta\lambda w_i. \end{equation} The new term $-\eta\lambda w_i$ coming from the regularization causes the weight to decay in proportion to its size.

    Thanks for the useful explanation. A question: in the "nnet" R package there is a parameter used in the training of the neural network called "decay". Do you know if decay corresponds to your lambda or to your eta*lambda?

    I would also add that weight decay is the same thing as L2 regularization for those who are familiar the the latter.

    @Sergei please no, stop spreading this misinformation! This is only true in the very special case of vanilla SGD. See the Fixing weight decay for Adam paper.

    To clarify: at time of writing, the PyTorch docs for Adam uses the term "weight decay" (parenthetically called "L2 penalty") to refer to what I think those authors call L2 regulation. If I understand correctly, this answer refers to SGD without momentum, where the two are equivalent.

License under CC-BY-SA with attribution


Content dated before 6/26/2020 9:53 AM