公司新闻

Adam和AdamW

最近在很多关于大语言模型计算与推理的资源消耗的文章中看到使用AdamW进行语言模型训练时需要的参数量是模型参数量的4倍,因此借机梳理一下Adam与AdamW的计算流程。

首先是Adam[1],给定在迭代步数 t 时模型的参数 \	heta_t 与梯度 g_t ,Adam的计算公式如下:

m_t=\\beta_1m_{t-1}+(1-\\beta_1)g_t\	ag{1}

v_t=\\beta_2v_{t-1}+(1-\\beta_2)g_t^2\	ag{2}

\\hat{m_t}=\\frac{1}{1-\\beta_1^t}m_t\	ag{3}

\\hat{v}_t=\\frac{1}{1-\\beta_2^t}v_t\	ag{4}

\	heta_{t+1}=\	heta_t-\\frac{\\eta}{\\sqrt{\\hat{v}_t}+\\epsilon}\\hat{m}_t\	ag{5} 式(1)用于计算梯度的一阶指数滑动平均,式(2)用于计算梯度的二阶项的指数滑动平均,式(3)与(4)对计算得到的指数滑动平均值进行消偏。式(5)为Adam的更新公式,其可以拆成两部分理解:动量更新与自适应学习率。其中 \\hat{m}_t 对应动量更新部分,这部分比较好理解,使用历史梯度的指数加权平均来更新参数,使得更新的过程相对稳定,防止跳变。第二部分是 \\frac{\\eta}{\\sqrt{v_t}+\\epsilon} 带来的自适应学习率效应,可以直观的理解成,如果当前所处的区域比较平坦(梯度的二阶项很小)则我们可以用较大的学习率来更新,快速走出鞍点,如果当前所处的区域比较陡峭(梯度的二阶项很大),则为了防止梯度爆炸等不稳定的情况发生,我们需要用较小的学习率谨慎地更新。

AdamW[2]相对与Adam的改动十分简单,其将权重衰减项从梯度的计算中拿出来直接加在了最后的权重更新步骤上(图1,式12)。其提出的动机在于:原先Adam的实现中如果采用了L2权重衰减,则相应的权重衰减项会被直接加在loss里,从而导致动量的一阶与二阶滑动平均均考虑了该权重衰减项(图1. 式6),而这影响了Adam的优化效果,而将权重衰减与梯度的计算进行解耦能够显著提升Adam的效果。目前,AdamW现在已经成为transformer训练中的默认优化器了。

图1. AdamW与Adam的区别[1]

从上述的计算步骤中可以看出,Adam和AdamW在反向传播时需要维护的变量分别为原始参数 \	heta_t ,梯度 g_t ,动量 m_t 与二阶动量 v_t ,因此其训练时的显存占用为参数量的4倍。

[1]Kingma, Diederik P., and Jimmy Ba. "Adam: A method for stochastic optimization."arXiv preprint arXiv:1412.6980(2014).

[2]Loshchilov, Ilya, and Frank Hutter. "Decoupled weight decay regularization."arXiv preprint arXiv:1711.05101(2017).

平台注册入口