Back to Annotated Deep Learning Paper Implementations

RETRO model

docs/transformers/retro/model.html

latest21.0 KB
Original Source

hometransformersretro

View code on Github

#

RETRO model

This is the model definition for RETRO.

14importmath15fromtypingimportSet1617importtorch18fromtorchimportnn1920fromlabml.loggerimportinspect

#

RoPE embeddings

We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.

23classRotaryPositionalEmbeddings(nn.Module):

#

  • d is the number of features d
  • base is the constant used for calculating Θ
34def\_\_init\_\_(self,d:int,base:int=10\_000):

#

39super().\_\_init\_\_()

#

Θ=θi​=10000d2(i−1)​,i∈[1,2,...,2d​]

41self.theta=nn.Parameter(1./(base\*\*(torch.arange(0,d,2).float()/d)),requires\_grad=False)

#

  • x is the Tensor at the head of a key or a query with shape [batch_size, seq_len, n_heads, d]
43defforward(self,x:torch.Tensor):

#

Extract the shape

48batch\_size,seq\_len,n\_heads,d=x.shape

#

2d​

51d\_2=d//2

#

Create position indexes [0, 1, ..., seq_len - 1]

54seq\_idx=torch.arange(seq\_len,device=x.device).type\_as(self.theta)

#

Calculate the product of position index and θi​

57idx\_theta=torch.einsum('n,d-\>nd',seq\_idx,self.theta)

#

Concatenate so that for row m we have [mθ0​,mθ1​,...,mθ2d​​,mθ0,mθ1,...,mθ2d​​]

61idx\_theta2=torch.cat([idx\_theta,idx\_theta],dim=1)

#

Calculate [−x(2d​+1),−x(2d​+2),...,−x(d),x(1),x(2),...,−x(2d​)]

65neg\_half\_x=torch.cat([-x[:,:,:,d\_2:],x[:,:,:,:d\_2]],dim=-1)

#

Calculate

(xm(i)​cosmθi​−xm(i+2d​)​sinmθi​xm(i+2d​)​cosmθi​+xm(i)​sinmθi​​)​

for i∈1,2,...,2d​

77rx=(x\*idx\_theta2.cos()[None,:,None,:])+(neg\_half\_x\*idx\_theta2.sin()[None,:,None,:])

#

80returnrx

#

Self-Attention Layer ATTN

This applies causal and non-causal multi-headed self-attention.

83classSelfAttention(nn.Module):

#

  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
  • is_causal indicates whether this is causal attention (masked)
90def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int,is\_causal:bool):

#

97super().\_\_init\_\_()9899self.is\_causal=is\_causal100self.n\_heads=n\_heads101self.d\_k=d\_k

#

To scale attentions before softmax by dk​​1​

104self.scale=1/math.sqrt(self.d\_k)

#

Linear layers for query, key and value heads.

107self.query=nn.Linear(d\_model,n\_heads\*d\_k)108self.key=nn.Linear(d\_model,n\_heads\*d\_k)109self.value=nn.Linear(d\_model,n\_heads\*d\_k)

#

Pre-norm layer. The paper uses RMSNorm instead.

112self.norm=nn.LayerNorm(d\_model)

#

Softmax for attention probabilities

115self.softmax=nn.Softmax(dim=-1)

#

Rotary positional embeddings

118self.rotary\_pe=RotaryPositionalEmbeddings(self.d\_k)

#

Final linear layer

121self.output=nn.Linear(n\_heads\*d\_k,d\_model)

#

Mask the attention layer for causal attention

  • attn is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len]
123defmask\_attention(self,attn:torch.Tensor):

#

No masking for non-causal attention

131ifnotself.is\_causal:132returnattn

#

Create a triangular mask

135mask=torch.tril(attn.new\_ones(attn.shape[-2:]))

#

Filter by the mask

137returnattn.masked\_fill(mask==0,float('-inf'))

#

  • h is the transformer embeddings of shape [batch_size, seq_len, d_model]
139defforward(self,h:torch.Tensor):

#

Residual connection

145h\_res=h

#

Pre-normalization

148h=self.norm(h)

#

Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k]

152mh\_shape=(\*h.shape[:-1],self.n\_heads,self.d\_k)153q=self.query(h).view(mh\_shape)154k=self.key(h).view(mh\_shape)155v=self.value(h).view(mh\_shape)

#

Apply rotary positional embeddings

158q=self.rotary\_pe(q)159k=self.rotary\_pe(k)

#

Calculate attentions

162attn=torch.einsum('bihd,bjhd-\>bhij',q,k)

#

Scale it by dk​​1​

164attn=attn\*self.scale

#

Apply masks if it's causal attention

167attn=self.mask\_attention(attn)

#

Calculate attention probabilities

170attn=self.softmax(attn)

#

Get values

173h=torch.einsum("bhij,bjhd-\>bihd",attn,v)

#

Change from shape [batch_size, seq_len, n_heads, d_k] to [batch_size, seq_len, n_heads * d_k]

177h=h.reshape(\*h.shape[:-2],-1)

#

Apply final linear layer. The result will have shape [batch_size, seq_len, d_model]

181h=self.output(h)

#

Add the residual connection

184returnh+h\_res

#

Cross-Attention Layer CA

This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.

This is used in the encoder to encode the retrieved chunks based on the input chunks.

We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.

187classCrossAttention(nn.Module):

#

  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
201def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int):

#

207super().\_\_init\_\_()208209self.n\_heads=n\_heads210self.d\_k=d\_k

#

To scale attentions before softmax by dk​​1​

213self.scale=1/math.sqrt(self.d\_k)

#

Linear layers for query, key and value heads.

216self.query=nn.Linear(d\_model,n\_heads\*d\_k)217self.key=nn.Linear(d\_model,n\_heads\*d\_k)218self.value=nn.Linear(d\_model,n\_heads\*d\_k)

#

Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.

221self.norm=nn.LayerNorm(d\_model)

#

Softmax for attention probabilities

224self.softmax=nn.Softmax(dim=-1)

#

Final linear layer

227self.output=nn.Linear(n\_heads\*d\_k,d\_model)

#

  • e are the retrieved nearest neighbor chunk embeddings with shape [batch_size, chunks, neighbors, neighbor_len, d_model]
  • h are the input chunks from which the nearest neighbors were retrieved with shape [batch_size, chunks, chunk_len, d_model] . This is already normalized.
229defforward(self,e:torch.Tensor,h:torch.Tensor):

#

Residual connection

238e\_res=e

#

Normalize retrieved chunks

241e=self.norm(e)

#

Get query from the retrieved chunks

244q=self.query(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)

#

Get keys and values from the input chunks

246k=self.key(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)247v=self.value(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)

#

Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]

252attn=torch.einsum('bcnihd,bcjhd-\>bcnhij',q,k)

#

Scale attention scores

254attn=attn\*self.scale

#

Calculate softmax across the last dimension

257attn=self.softmax(attn)

#

Gather values

260e=torch.einsum("bcnhij,bcjhd-\>bcnihd",attn,v)

#

Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k] to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]

264e=e.reshape(\*e.shape[:-2],-1)

#

Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model]

268e=self.output(e)

#

Add residual connection

271returne+e\_res

#

Chunked Cross-Attention Layer CCA

This is similar to the cross-attention layer defined above.

This is used in the decoder to pay attention to the retrieved neighbor chunks.

We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.

274classChunkedCrossAttention(nn.Module):

#

  • d_model is the number of features in transformer embeddings
  • n_heads is the number of attention heads
  • d_k is the number of features per head
  • chunk_len is the length of a chunk
286def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int,chunk\_len:int):

#

294super().\_\_init\_\_()295296self.chunk\_len=chunk\_len297self.n\_heads=n\_heads298self.d\_k=d\_k

#

To scale attentions before softmax by dk​​1​

301self.scale=1/math.sqrt(self.d\_k)

#

Linear layers for query, key and value heads.

304self.query=nn.Linear(d\_model,n\_heads\*d\_k)305self.key=nn.Linear(d\_model,n\_heads\*d\_k)306self.value=nn.Linear(d\_model,n\_heads\*d\_k)

#

Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.

309self.norm=nn.LayerNorm(d\_model)

#

Softmax for attention probabilities

312self.softmax=nn.Softmax(dim=-1)

#

Final linear layer

315self.output=nn.Linear(n\_heads\*d\_k,d\_model)

#

h are the input embeddings of shape [batch_size, seq_len, d_model]``e are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model]

317defforward(self,h:torch.Tensor,e:torch.Tensor):

#

Get shape

324batch\_size,chunks,neighbors,neighbor\_len,d\_model=e.shape

#

No attention if there are no chunks (for short inputs when sampling)

327ifchunks==0:328returnh

#

Residual connection

331h\_res=h

#

Remove the first chunk_len - 1 embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1 we make sure that information only flows to the right.

339h=h[:,self.chunk\_len-1:]

#

Pre-norm

341h=self.norm(h)

#

Append empty embeddings to the end to be able to split the input into chunks

343ifh.shape[1]\<chunks\*self.chunk\_len:344h=torch.cat((h,h.new\_zeros(batch\_size,chunks\*self.chunk\_len-h.shape[1],d\_model)),dim=1)

#

Reshape the input into chunks.

346h=h.reshape(batch\_size,chunks,self.chunk\_len,d\_model)

#

Get query from the input

349q=self.query(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)

#

Get keys and values from the retrieved neighbors

351k=self.key(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)352v=self.value(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)

#

Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]

357attn=torch.einsum('bcihd,bcnjhd-\>bchinj',q,k)

#

Scale attention scores

359attn=attn\*self.scale

#

Apply softmax over the last two dimensions neighbors, neighbor_len

362attn=self.softmax(attn.view(\*attn.shape[:-2],-1)).view(attn.shape)

#

Gather values

365h=torch.einsum("bchinj,bcnjhd-\>bcihd",attn,v)

#

Change from shape [batch_size, chunks, chunk_len, n_heads, d_k] to [batch_size, chunks * chunk_len, n_heads * d_k]

369h=h.reshape(batch\_size,chunks\*self.chunk\_len,-1)

#

Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model]

373h=self.output(h)

#

Append chunk_len - 1 zero embedding to the left; i.e. right shift it back

376h=torch.cat((h.new\_zeros(batch\_size,self.chunk\_len-1,d\_model),h),dim=1)

#

Truncate and add the residual connection

379returnh[:,:h\_res.shape[1]]+h\_res

#

Position-wise Feed Forward Layer FFW

This consists of two linear layers and an activation in the middle.

382classFeedForward(nn.Module):

#

  • d_model is the number of features in transformer embeddings
  • d_ff is the number features in the hidden layer
389def\_\_init\_\_(self,d\_model:int,d\_ff:int):

#

395super().\_\_init\_\_()

#

The two linear layers

398self.lin1=nn.Linear(d\_model,d\_ff)399self.lin2=nn.Linear(d\_ff,d\_model)

#

ReLU Activation

402self.act=nn.ReLU()

#

Pre-norm layer

405self.norm=nn.LayerNorm(d\_model)

#

h are the embeddings of shape [batch_size, seq_len, d_model]

407defforward(self,h:torch.Tensor):

#

Residual

413h\_res=h

#

Pre-norm

415h=self.norm(h)

#

First linear layer

417h=self.lin1(h)

#

Activation

419h=self.act(h)

#

Second linear layer

421h=self.lin2(h)

#

Add the residual connection

424returnh+h\_res

#

Nearest Neighbor Encoder ENCODER(RET(Cu​)1≤u≤l​,H)

This module encodes the retrieved nearest neighbors

427classNearestNeighborEncoder(nn.Module):

#

  • chunk_len is the length of a chunk
  • n_layer is the number of layers in the encoder Lenc​
  • ca_layers are the layers with cross attention Penc​
  • d_model is the number of features in embeddings
  • n_heads is the number of heads in attention layers
  • d_k is the size of attention heads
  • d_ff is the size of the feed-forward networks hidden layers
434def\_\_init\_\_(self,chunk\_len:int,n\_layers:int,ca\_layers:Set[int],435d\_model:int,n\_heads:int,d\_k:int,d\_ff:int):

#

446super().\_\_init\_\_()447self.ca\_layers=ca\_layers448self.chunk\_len=chunk\_len

#

Cross-attention layers

450self.ca=nn.ModuleList([CrossAttention(d\_model,n\_heads,d\_k)for\_inrange(len(ca\_layers))])

#

Bi-directional self attention layers

452self.attn=nn.ModuleList([SelfAttention(d\_model,n\_heads,d\_k,is\_causal=False)for\_inrange(n\_layers)])

#

Feed forward layers

454self.ffw=nn.ModuleList([FeedForward(d\_model,d\_ff)for\_inrange(n\_layers)])

#

Pre-normalization layer for H

457self.norm\_h=nn.LayerNorm(d\_model)

#

  • e are token embeddings of the retrieved nearest neighbors, EMB(RET(Cu​)1≤u≤l​) of shape [batch_size, chunks, neighbors, neighbor_len, d_model]

  • h is are the input token embeddings, H of shape [batch_size, seq_len, d_model]

The chunks u∈[1,l] and neighbors j∈[1,k] are processed in parallel.

459defforward(self,e:torch.Tensor,h:torch.Tensor):

#

Get shape

472batch\_size,chunks,neighbors,neighbor\_len,d\_model=e.shape

#

(Hu​)u∈[1,l]​←SPLIT(H)

475h\_split=h[:,:self.chunk\_len\*chunks,:].reshape(batch\_size,chunks,self.chunk\_len,d\_model)

#

Pre-norm

478h\_split=self.norm\_h(h\_split)

#

Keep the index of the cross attention layer

481p\_ca=0

#

For all layers p′∈[1,Lenc​]

483forpinrange(len(self.attn)):

#

Bi-directional self attention Euj​←ATTNenc​(Euj​)

486e=self.attn[p](e.view(-1,neighbor\_len,d\_model)).view(e.shape)

#

Cross attention if p′∈Penc​

489ifpinself.ca\_layers:

#

Euj​←CAenc​(Euj​,Hu​)

491e=self.ca[p\_ca](e,h\_split)

#

Incremnt the cross attention index

493p\_ca+=1

#

Feed forward layer Euj​←FFWenc​(Euj​)

496e=self.ffw[p](e)

#

return E

499returne

#

Retro Model

This is the Retro decoder

502classRetroModel(nn.Module):

#

  • v_vocab is the number of tokens in the vocabulary
  • d_model is the number of features in embeddings
  • n_layers is the number of layers in the decoder L
  • ca_layers are the layers with cross attention P
  • chunk_len is the length of a chunk
  • n_heads is the number of heads in attention layers
  • d_k is the size of attention heads
  • d_ff is the size of the feed-forward networks hidden layers
  • encoder is the nearest neighbor encoder
509def\_\_init\_\_(self,n\_vocab:int,d\_model:int,n\_layers:int,ca\_layers:Set[int],chunk\_len:int,510n\_heads:int,d\_k:int,d\_ff:int,encoder:NearestNeighborEncoder):

#

522super().\_\_init\_\_()523524self.ca\_layers=ca\_layers525self.encoder=encoder

#

Token embedding layer

528self.emb=nn.Embedding(n\_vocab,d\_model)

#

Chunked cross attention layers CCA

530self.cca=nn.ModuleList(531[ChunkedCrossAttention(d\_model,n\_heads,d\_k,chunk\_len)for\_inrange(len(ca\_layers))])

#

Attention layers ATTN

533self.attn=nn.ModuleList([SelfAttention(d\_model,n\_heads,d\_k,is\_causal=True)for\_inrange(n\_layers)])

#

Feed forward layers FFW

535self.ffw=nn.ModuleList([FeedForward(d\_model,d\_ff)for\_inrange(n\_layers)])

#

Readout layer READ

537self.read=nn.Linear(d\_model,n\_vocab)

#

Pre-normalization layer for nearest neighbor embeddings from ENCODER(RET(Cu​)1≤u≤l​,H)

541self.norm\_e=nn.LayerNorm(d\_model)

#

  • x is the input sequence, X of shape [batch_size, seq_len]
  • ret are the retrieved neighbors RET(Cu​)1≤u≤l​ of shape [batch_size, chunks, neighbors, neighbor_len]
543defforward(self,x:torch.Tensor,ret:torch.Tensor):

#

Get input embeddings H←EMB(X)

552h=self.emb(x)

#

Embeddings of the retrieved neighbors Euj​=EMBenc​(RET(Cu​)j).

We use same embeddings for both input and neighbors

558ret\_emb=self.emb(ret)

#

Keep index of the chunked cross attention layer

561p\_ca=0

#

For all layers p∈[1,L]

563forpinrange(len(self.attn)):

#

Causal self attention H←ATTN(H)

565h=self.attn[p](h)

#

Get encoder embeddings before the first CCA layer, when p=min(P)

569ifself.ca\_layersandp==min(self.ca\_layers):

#

E=ENCODER(RET(Cu​)1≤u≤l​,H)

We passed the embeddings of RET(Cu​)1≤u≤l​ to encoder.

573e=self.encoder(ret\_emb,h)

#

Normalize encoder embeddings

575e=self.norm\_e(e)

#

Chunked-cross attention if p∈P

578ifpinself.ca\_layers:

#

H←CCA(H,E)

580h=self.cca[p\_ca](h,e)

#

Increment chunked cross-attention index

582p\_ca+=1

#

H←FFW(H)

585h=self.ffw[p](h)

#

O←READ(H)

588returnself.read(h)

#

Test the model with fake data

591def\_test():

#

595chunk\_len=4596d\_model=8597d\_ff=32598n\_heads=2599d\_k=4600601device=torch.device('cuda:0')602603m=RetroModel(5,d\_model,6,{2,5},chunk\_len,n\_heads,d\_k,d\_ff,604encoder=NearestNeighborEncoder(chunk\_len,2,{1},d\_model,n\_heads,d\_k,d\_ff))605606m.to(device)607x=[1,2,4,4,0,1,2,3,4,3]608ret=[609[[0,0,0,0,0,0],[1,1,1,1,1,1]],610[[0,0,0,0,0,0],[1,1,1,1,1,1]],611]612res=m(torch.tensor([x]\*10).to(device),torch.tensor([ret]\*10).to(device))613614inspect(res)

#

618if\_\_name\_\_=='\_\_main\_\_':619\_test()

labml.ai