Back to Annotated Deep Learning Paper Implementations

DeepNorm

docs/normalization/deep_norm/index.html

latest6.0 KB
Original Source

homenormalizationdeep_norm

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/normalization/deep_norm/ init.py)

#

DeepNorm

This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.

The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.

The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.

Weight Initializations

Usually, the weights are initialized with Xavier or Kaiming initializations. This paper scales (sets the gain) the weights by a constant β depending on the size of the transformer.

DeepNorm suggests scaling the weights of the two linear transforms in the Feed-Forward Network, the value projection transform, and the output projection transform of the attention layer. Weights of these transforms are scaled by (has a gain equal to) β.

The scaling is implemented in the

Normalization Function

xl+1​=LN(αxl​+Gl​(xl​,θl​))

where α is a constant that depends on the depth of the transformer, LN is Layer Normalization, and Gl​(xl​,θl​) is the function of the l-th transformer sub-layer (FFN or attention).

This function is used to replace Post-LayerNorm.

α and β constants

TypeEncoder onlyDecoder onlyEnc-Dec​Enc-α(2N)41​−0.81(N4M)161​​Enc-β(8N)−41​−0.87(N4M)−161​​Dec-α−(2M)41​(3M)41​​Dec-β−(8M)−41​(12M)−41​​​​

Where N is the number of layers in the encoder and M is the number of layers in the decoder.

Refer to the paper for derivation.

Here is an experiment implementation that uses DeepNorm.

73fromtypingimportUnion,List7475importtorch76fromtorchimportnn,Size7778fromlabml\_nn.normalization.layer\_normimportLayerNorm79fromlabml\_nn.transformersimportMultiHeadAttention80fromlabml\_nn.transformers.feed\_forwardimportFeedForward81fromlabml\_nn.transformers.utilsimportsubsequent\_mask

#

DeepNorm Normalization

xl+1​=LN(αxl​+Gl​(xl​,θl​))

84classDeepNorm(nn.Module):

#

  • alpha is α
  • normalized_shape is the shape for LayerNorm LN
  • eps is ϵ for LayerNorm
  • elementwise_affine is a flag indicating whether to do an elementwise transformation in LayerNorm
91def\_\_init\_\_(self,alpha:float,normalized\_shape:Union[int,List[int],Size],\*,92eps:float=1e-5,93elementwise\_affine:bool=True):

#

100super().\_\_init\_\_()101102self.alpha=alpha

#

Initialize LN

104self.layer\_norm=LayerNorm(normalized\_shape,eps=eps,elementwise\_affine=elementwise\_affine)

#

  • x is the output from the previous layer xl​
  • gx is the output of the current sub-layer Gl​(xl​,θl​)
106defforward(self,x:torch.Tensor,gx:torch.Tensor):

#

xl+1​=LN(αxl​+Gl​(xl​,θl​))

112returnself.layer\_norm(x+self.alpha\*gx)

#

Transformer Decoder Layer with DeepNorm

This implements a transformer decoder layer with DeepNorm. Encoder layers will have a similar form.

115classDeepNormTransformerLayer(nn.Module):

#

  • d_model is the token embedding size
  • self_attn is the self attention module
  • feed_forward is the feed forward module
  • deep_norm_alpha is α coefficient in DeepNorm
  • deep_norm_beta is β constant for scaling weights initialization
122def\_\_init\_\_(self,\*,123d\_model:int,124self\_attn:MultiHeadAttention,125feed\_forward:FeedForward,126deep\_norm\_alpha:float,127deep\_norm\_beta:float,128):

#

136super().\_\_init\_\_()137138self.self\_attn=self\_attn139self.feed\_forward=feed\_forward

#

DeepNorms after attention and feed forward network

141self.self\_attn\_norm=DeepNorm(deep\_norm\_alpha,[d\_model])142self.feed\_forward\_norm=DeepNorm(deep\_norm\_alpha,[d\_model])

#

Scale weights after initialization

145withtorch.no\_grad():

#

Feed forward network linear transformations

147feed\_forward.layer1.weight\*=deep\_norm\_beta148feed\_forward.layer2.weight\*=deep\_norm\_beta

#

Attention value projection

151self\_attn.value.linear.weight\*=deep\_norm\_beta

#

Attention output project

153self\_attn.output.weight\*=deep\_norm\_beta

#

The mask will be initialized on the first call

156self.mask=None

#

  • x are the embeddings of shape [seq_len, batch_size, d_model]
158defforward(self,x:torch.Tensor):

#

Create causal mask

163ifself.maskisNoneorself.mask.size(0)!=len(x):

#

Subsequent mask, will mask out tokens from seeing future tokens

165self.mask=subsequent\_mask(len(x)).to(x.device)

#

Run through self attention, i.e. keys and values are from self

168x=self.self\_attn\_norm(x,self.self\_attn(query=x,key=x,value=x,mask=self.mask))

#

Pass through the feed-forward network

170x=self.feed\_forward\_norm(x,self.feed\_forward(x))

#

173returnx

labml.ai