On Layer Normalization in the Transformer Architecture
本文介绍了Pre-LN, 将归一化层放置在残差分支,以降低训练初始状态的训练梯度爆炸的现象。通过Post-LN架构进行训练刚需Warm-up(即通过初始降低学习率的方式进行训练), 本文提出的Pre-LN通过迁移LN层位置的方式降低了整体梯度的稳定性与相对大小。将模型从Warm-up 中解脱出来。
Layer Normalization的作用
LN(x)=γ⊙σ2+εx−μ+β
其中 γ,β 是可学习的参数。
整个 LN层的作用可视为一个归一化与一个仿射变换作用,内层归一化可表示为
Normal:x→σ2+εx−μ
归一化的 Normal(x)变为期望0, 方差1的标准向量
Var(Normal(x))E(Normal(x))=σ2+εσ2→1=0
可学习的参数 γ,β 能改变向量的整体期望与均值以增强 LN 层的调节能力。
Post-Layer Normalization 的 梯度爆炸与 Warm-up
本文对于Multi Head Attention 的梯度阶估计过程提出了一个简化计算的模型,再通过实验论证假设对于完整的MHA Residue Flow 也成立。
初始权重使用Xavier初始化,每个权重w满足
Var(w)=nin+nout2
Post-LN 层的一次前向传播的公式
{x~tpost=LN(xtpost+MHA(xtpost))xt+1post=LN(x~tpost+FFN(x~tpost))
Pre-LN 层的一次前向传播的公式
{x~tpre=xtpre+MHA(LN(xtpre))xt+1pre=x~tpre+FFN(LN(x~tpre))
记
JLN(x)=∂x∂LN(x)
为 LN层的Jacobian 矩阵
则Post-LN满足
dx~tpost=JLN(xtpost+MHA(xtpost))⋅(dxtpost+dMHA(xtpost))=JLN(xtpost+MHA(xtpost))⋅(I+JMHA(xtpost))⋅dxtpost
dxt+1post=JLN(x~tpost+FFN(x~tpost))⋅(I+JFFN(x~tpost))dx~tpost
∂xtpost∂xt+1post=JLN(x~tpost+FFN(x~tpost))⋅(I+JFFN(x~tpost))⋅JLN(xtpost+MHA(xtpost))⋅(I+JMHA(xtpost))
Pre-LN满足
dx~tpredxt+1pre=(I+JMHA(LN(xtpre))⋅JLN(xtpre))dxtpre=(I+JFFN(LN(x~tpre))⋅JLN(x~tpre))dx~tpre
∂xtpre∂xt+1pre=(I+JFFN(LN(x~tpre))⋅JLN(x~tpre))⋅(I+JMHA(LN(xtpre))⋅JLN(xtpre))
MHA贡献的梯度流动
基于本文关于MHA的假定,有WQ=WK=0, 因此单一Attention头的输出为
h=Softmax(dQKT)⋅V=Softmax(0)⋅X⋅WV=n1X⋅WV=n1j=1∑nxjwVj
MHA(X)=Concat(h1,⋯,hn)⋅WO
计算MHA的微分
dMHA(X)=dConcat(h1,⋯,hn)⋅WO+Concat(h1,⋯,hn)⋅dWO=Concat(dh1,⋯,dhn)⋅WO+Concat(h1,⋯,hn)⋅dWO=Concat(dX⋅WVi)⋅WO=dX⋅Concat(WVi)⋅WO=n1j=i∑n(dxj)⋅Concat(WVi)⋅WO:=n1j=i∑n(dxj)WV,l
其中 WV,l 是等效的随机矩阵
WV,l=Concat(WVi)⋅WO
对应Jacobian矩阵为
JMHA=n111T⊗WV,l
残差梯度流为
I+JMHA=I+n111T⊗WV,l
LN层Jacobian矩阵谱范数的阶估计
根据上文推导,需要计算LN层的Jacobian矩阵的大小。在此我们只考虑未仿射变换的归一化映射的梯度,因为仿射变换后只需要进行梯度的线性缩放。
LN(x)=σx−μ
取无偏向量
y=x(I−n11T1)=x1−n1∑xix2−n1∑xi⋮xn−n1∑xi
其中
∥y∥=n1i∑(xi−μ)2=n1i∑(xi2−2μxi+μ2)=n1i∑xi2−μ2
有O(∥y∥)=O(∥x∥)
因此
LN(x)=n1∑yj2y
∂yj∂LN(x)i=n∑yj2δi,j∑yj2−yi∑yj2yj=∥y∥n(δi,j−∥y∥2yiyj)
因此
JLN(x)=∂y∂LN(x)∂x∂y=∥y∥n(I−∥y∥2yiyj)(I−1T1)
∥JLN(x)∥=O(∥y∥n)=O(∥x∥n)
基于以上结果进行主定理的叙述
Definition 1.1: 随机变量的 (ε,δ)-Bounded
对于实随机变量 Z≥0, 如果Z满足
P(μZ−μ≤ε)≥1−δ
也即
P(μZ−μ≥ε)≤δ
其中ε>0,0<δ<1, 则称随机变量Z是 (ε−δ)-Bounded
这个结论和Chebyshev不等式的结构相似, Chebyshev不等式能说明对方差有界随机变量都是 (ε,ε2σ2)-Bounded的
整体损失函数梯度谱范数
Post-LN架构的损失函数定义为顶部第L层的交叉熵
L(xL+1,ipost)=−logsoftmaxyi(WembxL+1,ipost)=−log(P(Softmax(WembxL+1,ipost)∣yi))
Pre-LN架构尾部多一个LN块,损失函数为
L(xfinal,ipre)=−logsoftmaxyi(Wembxfinal,ipre)=−log(P(Softmax(Wembxfinal,ipre)∣yi))
其中
xfinal,ipre=LN(xL+1,ipre)
Theorem 1. 假设 ∥WL+1,ipost∥, ∥WL+1,ipre∥ 均为(ε,δ)-Bounded的。 则Post-LN与Pre-LN结构的梯度谱范数满足
⎩⎨⎧∂W2,L∂L~(xL+1post)F=O(dlnd)∂W2,L∂L~(xfinalpre)F=O(dLlnd)
其中 W2,L 是FFN中的参数矩阵
Proof:
由链式法则
∂W2,L∂L~(xL+1post)=∂xL+1post∂L~(xL+1post)(k=l∏L∂xkpost∂xk+1post)W2,L∂xlpost
∂xL+1post∂L~ 是有界的,因为 xL+1post 是 (ε,δ)-Bounded的
∂xL+1post∂L~=P(Softmax(WembxL+1post∣yi))⋅∂xL+1post∂P(Softmax(WembxL+1post∣yi))=O(1)
(此处略相关递推的阶估计,上文有相关Jacobian矩阵,只需进行估阶即可)关键在于
⎩⎨⎧Post-LN:Pre-LN:JLN(xL+1post)2=O(∥xL+1post∥2n)=O(1)JLN(xfinalpre)2=O(∥xfinalpre∥2n)=O(L1)
Theorem 1 的结论证明了:在初始化时刻,Post-LN 的梯度规模是常数阶,这意味着它与模型深度 L 无关,无法感知并抑制深层带来的不稳定因素;而 Pre-LN 的梯度规模具有 O(L1) 的衰减性,能随着模型深度的增加降低初始梯度强度,减弱了对 warmup 的依赖。