docs/normalization/deep_norm/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/normalization/deep_norm/ init.py)
This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.
The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.
The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.
Usually, the weights are initialized with Xavier or Kaiming initializations. This paper scales (sets the gain) the weights by a constant β depending on the size of the transformer.
DeepNorm suggests scaling the weights of the two linear transforms in the Feed-Forward Network, the value projection transform, and the output projection transform of the attention layer. Weights of these transforms are scaled by (has a gain equal to) β.
The scaling is implemented in the
xl+1=LN(αxl+Gl(xl,θl))
where α is a constant that depends on the depth of the transformer, LN is Layer Normalization, and Gl(xl,θl) is the function of the l-th transformer sub-layer (FFN or attention).
This function is used to replace Post-LayerNorm.
TypeEncoder onlyDecoder onlyEnc-DecEnc-α(2N)41−0.81(N4M)161Enc-β(8N)−41−0.87(N4M)−161Dec-α−(2M)41(3M)41Dec-β−(8M)−41(12M)−41
Where N is the number of layers in the encoder and M is the number of layers in the decoder.
Refer to the paper for derivation.
Here is an experiment implementation that uses DeepNorm.
73fromtypingimportUnion,List7475importtorch76fromtorchimportnn,Size7778fromlabml\_nn.normalization.layer\_normimportLayerNorm79fromlabml\_nn.transformersimportMultiHeadAttention80fromlabml\_nn.transformers.feed\_forwardimportFeedForward81fromlabml\_nn.transformers.utilsimportsubsequent\_mask
xl+1=LN(αxl+Gl(xl,θl))
84classDeepNorm(nn.Module):
alpha is αnormalized_shape is the shape for LayerNorm LNeps is ϵ for LayerNormelementwise_affine is a flag indicating whether to do an elementwise transformation in LayerNorm91def\_\_init\_\_(self,alpha:float,normalized\_shape:Union[int,List[int],Size],\*,92eps:float=1e-5,93elementwise\_affine:bool=True):
100super().\_\_init\_\_()101102self.alpha=alpha
Initialize LN
104self.layer\_norm=LayerNorm(normalized\_shape,eps=eps,elementwise\_affine=elementwise\_affine)
x is the output from the previous layer xlgx is the output of the current sub-layer Gl(xl,θl)106defforward(self,x:torch.Tensor,gx:torch.Tensor):
xl+1=LN(αxl+Gl(xl,θl))
112returnself.layer\_norm(x+self.alpha\*gx)
This implements a transformer decoder layer with DeepNorm. Encoder layers will have a similar form.
115classDeepNormTransformerLayer(nn.Module):
d_model is the token embedding sizeself_attn is the self attention modulefeed_forward is the feed forward moduledeep_norm_alpha is α coefficient in DeepNormdeep_norm_beta is β constant for scaling weights initialization122def\_\_init\_\_(self,\*,123d\_model:int,124self\_attn:MultiHeadAttention,125feed\_forward:FeedForward,126deep\_norm\_alpha:float,127deep\_norm\_beta:float,128):
136super().\_\_init\_\_()137138self.self\_attn=self\_attn139self.feed\_forward=feed\_forward
DeepNorms after attention and feed forward network
141self.self\_attn\_norm=DeepNorm(deep\_norm\_alpha,[d\_model])142self.feed\_forward\_norm=DeepNorm(deep\_norm\_alpha,[d\_model])
Scale weights after initialization
145withtorch.no\_grad():
Feed forward network linear transformations
147feed\_forward.layer1.weight\*=deep\_norm\_beta148feed\_forward.layer2.weight\*=deep\_norm\_beta
Attention value projection
151self\_attn.value.linear.weight\*=deep\_norm\_beta
Attention output project
153self\_attn.output.weight\*=deep\_norm\_beta
The mask will be initialized on the first call
156self.mask=None
x are the embeddings of shape [seq_len, batch_size, d_model]158defforward(self,x:torch.Tensor):
Create causal mask
163ifself.maskisNoneorself.mask.size(0)!=len(x):
Subsequent mask, will mask out tokens from seeing future tokens
165self.mask=subsequent\_mask(len(x)).to(x.device)
Run through self attention, i.e. keys and values are from self
168x=self.self\_attn\_norm(x,self.self\_attn(query=x,key=x,value=x,mask=self.mask))
Pass through the feed-forward network
170x=self.feed\_forward\_norm(x,self.feed\_forward(x))
173returnx