Back to Annotated Deep Learning Paper Implementations

Hierarchical Transformers Are More Efficient Language Models

docs/transformers/hour_glass/index.html

latest9.7 KB
Original Source

hometransformershour_glass

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/hour_glass/ init.py)

#

Hierarchical Transformers Are More Efficient Language Models

This is a PyTorch implementation of the paper Hierarchical Transformers Are More Efficient Language Models.

This paper introduces a hierarchical transformer architecture to handle long sequences efficiently. The first half of the transformer layers down-sample tokens and the second half up-samples with direct skip connections between layers of the same resolution. This is a little similar to U-Net for vision tasks.

They try different up-sampling and down-sampling techniques and build a model with the best performing up and down-sampling techniques which they call the hourglass model.

Here we have implemented the simplest up-sampling and down-sampling techniques for simplicity. We will consider adding more complex (and better performing) implementations later.

Here is the training code for the hourglass model.

28fromtypingimportList2930importtorch31fromtorchimportnn3233fromlabml\_nn.transformersimportMultiHeadAttention,TransformerLayer34fromlabml\_nn.transformers.feed\_forwardimportFeedForward35fromlabml\_nn.transformers.utilsimportsubsequent\_mask

#

Hourglass model

This model recursively adds layers to the middle while shortening the sequence by down-sampling. The shortened sequence processed by another hourglass model is sandwiched between two normal transformer layers. (A transformer layer has a self-attention layer and a position-wise feed-forward layer).

38classHourGlass(nn.Module):

#

48def\_\_init\_\_(self,n\_heads:int,d\_model:int,dropout:float,d\_ff:int,shortening\_factors:List[int]):

#

56super().\_\_init\_\_()

#

The transformer layer before down-sampling

59self.pre=TransformerLayer(d\_model=d\_model,

#

Multi-head attention layer

61self\_attn=MultiHeadAttention(n\_heads,d\_model,dropout),

#

Position wise feed-forward layers

63feed\_forward=FeedForward(d\_model,d\_ff,dropout),

#

65dropout\_prob=dropout)

#

Auto-regressive mask

67self.mask=AutoregressiveMask()

#

The shortening factor k (or the down-sampling rate)

70k=shortening\_factors[0]

#

We shift the tokens to the right by k−1 steps to make sure information doesn't leak from the future tokens to past tokens as a result of down-sampling and up-sampling

75self.shift\_right=ShiftRight(k-1)

#

Shortening or the down-sampling layer. We use the simplest form - average pooling. The paper shows that attention based down sampling works best, which we haven't implemented yet.

78self.shortening=AvgPoolShortening(k)

#

If there are no more shortening (middle of the hourglass)

81iflen(shortening\_factors)==1:

#

The center layer is another transformer layer

83self.shortened=TransformerLayer(d\_model=d\_model,84self\_attn=MultiHeadAttention(n\_heads,d\_model,dropout),85feed\_forward=FeedForward(d\_model,d\_ff,dropout),86dropout\_prob=dropout)

#

Autoregressive mask

88self.mask\_short=AutoregressiveMask()89self.hour\_glass=None90else:

#

Insert another hourglass model recursively

92self.hour\_glass=HourGlass(n\_heads,d\_model,dropout,d\_ff,shortening\_factors[1:])

#

Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling works better.

96self.up\_sampling=NaiveUpSampling(k)

#

The final transformer layer after up-sampling

99self.post=TransformerLayer(d\_model=d\_model,100self\_attn=MultiHeadAttention(n\_heads,d\_model,dropout),101feed\_forward=FeedForward(d\_model,d\_ff,dropout),102dropout\_prob=dropout)

#

104defforward(self,x:torch.Tensor):

#

Initial transformer layer x←PreVanillaLayers(x)

107x=self.pre(x=x,mask=self.mask(x))

#

Shifting and shortening x′←Shortening(ShiftRight(x,k−1),k)

110x\_short=self.shortening(self.shift\_right(x))

#

If we are at the center of the hourglass, if EMPTY(shorten_factors)then

114ifself.hour\_glassisNone:

#

Center transformer layer x′←ShortenedLayers(x′)

117x\_short=self.shortened(x=x\_short,mask=self.mask\_short(x\_short))

#

else

119else:

#

x′←HOURGLASS(x,shorten_factors)

121x\_short=self.hour\_glass(x\_short)

#

Up-sample the shortened sequence and add a skip connection x←x+Upsampling(x,x′,k)

125x=x+self.up\_sampling(x,x\_short)

#

Final transformer layer x←PostVanillaLayers(x)

128x=self.post(x=x,mask=self.mask(x))

#

131returnx

#

Shift right operation

This shifts the sequence to the right by the given number of steps

134classShiftRight(nn.Module):

#

  • shift is the number of steps to shift by
141def\_\_init\_\_(self,shift:int):

#

145super().\_\_init\_\_()

#

cannot be negative

147assertshift\>=0

#

149self.shift=shift

#

  • x is a tensor of shape [seq_len, ...]
151defforward(self,x:torch.Tensor):

#

If the shift is 0 return the original

156ifself.shift==0:157returnx

#

Zeros to be appended to the left

159prefix=x.new\_zeros([self.shift,\*x.shape[1:]])

#

Concatenate the zeros and truncate the right

161returntorch.cat([prefix,x[:-self.shift]])

#

Average pool shortening

This down-samples by a given factor with average pooling

164classAvgPoolShortening(nn.Module):

#

  • k is the shortening factor
171def\_\_init\_\_(self,k:int):

#

175super().\_\_init\_\_()

#

Average pooling layer

177self.pool=nn.AvgPool1d(k,ceil\_mode=True)

#

  • x is of shape [seq_len, batch_size, d_model]
179defforward(self,x:torch.Tensor):

#

Pooling layer accepts shape [batch_size, d_model, seq_len] so we permute axes.

185returnself.pool(x.permute(1,2,0)).permute(2,0,1)

#

Naive up-sampling

This up-samples by repeating

188classNaiveUpSampling(nn.Module):

#

  • k is the shortening factor
195def\_\_init\_\_(self,k:int):

#

199super().\_\_init\_\_()200self.k=k

#

  • x is the tensor with embeddings before down-sampling
  • x_short is the tensor of higher density (to be up-sampled) representations
202defforward(self,x:torch.Tensor,x\_short:torch.Tensor):

#

Repeat across the sequence dimension

208expanded=torch.repeat\_interleave(x\_short,self.k,dim=0)

#

Truncate the extra embeddings at the end

210expanded=expanded[:x.shape[0]]

#

213returnexpanded

#

Generate auto-regressive mask

216classAutoregressiveMask(nn.Module):

#

221def\_\_init\_\_(self):222super().\_\_init\_\_()223self.mask=None

#

225defforward(self,x:torch.Tensor):

#

Create a mask if we haven't created or sizes have changed

227ifself.maskisNoneorself.mask.size(0)!=len(x):

#

Subsequent mask, will mask out tokens from seeing future tokens

229self.mask=subsequent\_mask(len(x)).to(x.device)

#

232returnself.mask

#

🚧 Linear pooling for down-sampling

This concatenates the consecutive tokens embeddings that need to be merged and do a linear transformation to map it to the size of a single token embedding.

235classLinearPoolingShortening(nn.Module):

#

243def\_\_init\_\_(self):244super().\_\_init\_\_()245raiseNotImplementedError

#

🚧 Down-sampling with attention

x′x′​=S(x)+Attention(Q=S(x),K=x,V=x)=x′+FFN(x′)​

where S(x) is average pooling or linear pooling.

248classAttentionBasedShortening(nn.Module):

#

260def\_\_init\_\_(self):261super().\_\_init\_\_()262raiseNotImplementedError

#

🚧 Linear projection for up-sampling

Make a linear projection of dense token embeddings to a size of dmodel​k.

265classLinearUpSampling(nn.Module):

#

272def\_\_init\_\_(self):273super().\_\_init\_\_()274raiseNotImplementedError

#

🚧 Attention based up-sampling

xx​=U(x,x′)+Attention(Q=U(x,x′),K=x′,V=x′)=x+FFN(x)​

where U(x,x′)=x+LinearUpsampling(x′)

277classAttentionBasedUpSampling(nn.Module):

#

289def\_\_init\_\_(self):290super().\_\_init\_\_()291raiseNotImplementedError

labml.ai