docs/transformers/hour_glass/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/hour_glass/ init.py)
This is a PyTorch implementation of the paper Hierarchical Transformers Are More Efficient Language Models.
This paper introduces a hierarchical transformer architecture to handle long sequences efficiently. The first half of the transformer layers down-sample tokens and the second half up-samples with direct skip connections between layers of the same resolution. This is a little similar to U-Net for vision tasks.
They try different up-sampling and down-sampling techniques and build a model with the best performing up and down-sampling techniques which they call the hourglass model.
Here we have implemented the simplest up-sampling and down-sampling techniques for simplicity. We will consider adding more complex (and better performing) implementations later.
Here is the training code for the hourglass model.
28fromtypingimportList2930importtorch31fromtorchimportnn3233fromlabml\_nn.transformersimportMultiHeadAttention,TransformerLayer34fromlabml\_nn.transformers.feed\_forwardimportFeedForward35fromlabml\_nn.transformers.utilsimportsubsequent\_mask
This model recursively adds layers to the middle while shortening the sequence by down-sampling. The shortened sequence processed by another hourglass model is sandwiched between two normal transformer layers. (A transformer layer has a self-attention layer and a position-wise feed-forward layer).
38classHourGlass(nn.Module):
n_heads is the number of heads in multi-head attention layersd_model is the size of the token embeddingsdropout is the dropout probabilityd_ff is the dimensionality of the hidden layer in position-wise feed-forward layersshortening_factors is the list of shortening factors48def\_\_init\_\_(self,n\_heads:int,d\_model:int,dropout:float,d\_ff:int,shortening\_factors:List[int]):
56super().\_\_init\_\_()
The transformer layer before down-sampling
59self.pre=TransformerLayer(d\_model=d\_model,
61self\_attn=MultiHeadAttention(n\_heads,d\_model,dropout),
Position wise feed-forward layers
63feed\_forward=FeedForward(d\_model,d\_ff,dropout),
65dropout\_prob=dropout)
Auto-regressive mask
67self.mask=AutoregressiveMask()
The shortening factor k (or the down-sampling rate)
70k=shortening\_factors[0]
We shift the tokens to the right by k−1 steps to make sure information doesn't leak from the future tokens to past tokens as a result of down-sampling and up-sampling
75self.shift\_right=ShiftRight(k-1)
Shortening or the down-sampling layer. We use the simplest form - average pooling. The paper shows that attention based down sampling works best, which we haven't implemented yet.
78self.shortening=AvgPoolShortening(k)
If there are no more shortening (middle of the hourglass)
81iflen(shortening\_factors)==1:
The center layer is another transformer layer
83self.shortened=TransformerLayer(d\_model=d\_model,84self\_attn=MultiHeadAttention(n\_heads,d\_model,dropout),85feed\_forward=FeedForward(d\_model,d\_ff,dropout),86dropout\_prob=dropout)
Autoregressive mask
88self.mask\_short=AutoregressiveMask()89self.hour\_glass=None90else:
Insert another hourglass model recursively
92self.hour\_glass=HourGlass(n\_heads,d\_model,dropout,d\_ff,shortening\_factors[1:])
Up-sampling layer. We use naive up-sampling for simplicity and the paper shows attention based up sampling works better.
96self.up\_sampling=NaiveUpSampling(k)
The final transformer layer after up-sampling
99self.post=TransformerLayer(d\_model=d\_model,100self\_attn=MultiHeadAttention(n\_heads,d\_model,dropout),101feed\_forward=FeedForward(d\_model,d\_ff,dropout),102dropout\_prob=dropout)
104defforward(self,x:torch.Tensor):
Initial transformer layer x←PreVanillaLayers(x)
107x=self.pre(x=x,mask=self.mask(x))
Shifting and shortening x′←Shortening(ShiftRight(x,k−1),k)
110x\_short=self.shortening(self.shift\_right(x))
If we are at the center of the hourglass, if EMPTY(shorten_factors)then
114ifself.hour\_glassisNone:
Center transformer layer x′←ShortenedLayers(x′)
117x\_short=self.shortened(x=x\_short,mask=self.mask\_short(x\_short))
else
119else:
x′←HOURGLASS(x,shorten_factors)
121x\_short=self.hour\_glass(x\_short)
Up-sample the shortened sequence and add a skip connection x←x+Upsampling(x,x′,k)
125x=x+self.up\_sampling(x,x\_short)
Final transformer layer x←PostVanillaLayers(x)
128x=self.post(x=x,mask=self.mask(x))
131returnx
This shifts the sequence to the right by the given number of steps
134classShiftRight(nn.Module):
shift is the number of steps to shift by141def\_\_init\_\_(self,shift:int):
145super().\_\_init\_\_()
cannot be negative
147assertshift\>=0
149self.shift=shift
x is a tensor of shape [seq_len, ...]151defforward(self,x:torch.Tensor):
If the shift is 0 return the original
156ifself.shift==0:157returnx
Zeros to be appended to the left
159prefix=x.new\_zeros([self.shift,\*x.shape[1:]])
Concatenate the zeros and truncate the right
161returntorch.cat([prefix,x[:-self.shift]])
This down-samples by a given factor with average pooling
164classAvgPoolShortening(nn.Module):
k is the shortening factor171def\_\_init\_\_(self,k:int):
175super().\_\_init\_\_()
Average pooling layer
177self.pool=nn.AvgPool1d(k,ceil\_mode=True)
x is of shape [seq_len, batch_size, d_model]179defforward(self,x:torch.Tensor):
Pooling layer accepts shape [batch_size, d_model, seq_len] so we permute axes.
185returnself.pool(x.permute(1,2,0)).permute(2,0,1)
This up-samples by repeating
188classNaiveUpSampling(nn.Module):
k is the shortening factor195def\_\_init\_\_(self,k:int):
199super().\_\_init\_\_()200self.k=k
x is the tensor with embeddings before down-samplingx_short is the tensor of higher density (to be up-sampled) representations202defforward(self,x:torch.Tensor,x\_short:torch.Tensor):
Repeat across the sequence dimension
208expanded=torch.repeat\_interleave(x\_short,self.k,dim=0)
Truncate the extra embeddings at the end
210expanded=expanded[:x.shape[0]]
213returnexpanded
216classAutoregressiveMask(nn.Module):
221def\_\_init\_\_(self):222super().\_\_init\_\_()223self.mask=None
225defforward(self,x:torch.Tensor):
Create a mask if we haven't created or sizes have changed
227ifself.maskisNoneorself.mask.size(0)!=len(x):
Subsequent mask, will mask out tokens from seeing future tokens
229self.mask=subsequent\_mask(len(x)).to(x.device)
232returnself.mask
This concatenates the consecutive tokens embeddings that need to be merged and do a linear transformation to map it to the size of a single token embedding.
235classLinearPoolingShortening(nn.Module):
243def\_\_init\_\_(self):244super().\_\_init\_\_()245raiseNotImplementedError
x′x′=S(x)+Attention(Q=S(x),K=x,V=x)=x′+FFN(x′)
where S(x) is average pooling or linear pooling.
248classAttentionBasedShortening(nn.Module):
260def\_\_init\_\_(self):261super().\_\_init\_\_()262raiseNotImplementedError
Make a linear projection of dense token embeddings to a size of dmodelk.
265classLinearUpSampling(nn.Module):
272def\_\_init\_\_(self):273super().\_\_init\_\_()274raiseNotImplementedError
xx=U(x,x′)+Attention(Q=U(x,x′),K=x′,V=x′)=x+FFN(x)
where U(x,x′)=x+LinearUpsampling(x′)
277classAttentionBasedUpSampling(nn.Module):
289def\_\_init\_\_(self):290super().\_\_init\_\_()291raiseNotImplementedError