docs/transformers/retro/model.html
This is the model definition for RETRO.
14importmath15fromtypingimportSet1617importtorch18fromtorchimportnn1920fromlabml.loggerimportinspect
We use rotary position embeddings in self-attention layers. We assume the positional information gets embedded in embeddings and therefore not use them in causal attention. Non-causal self-attention needs explicit positional information because it cannot infer it.
23classRotaryPositionalEmbeddings(nn.Module):
d is the number of features dbase is the constant used for calculating Θ34def\_\_init\_\_(self,d:int,base:int=10\_000):
39super().\_\_init\_\_()
Θ=θi=10000d2(i−1),i∈[1,2,...,2d]
41self.theta=nn.Parameter(1./(base\*\*(torch.arange(0,d,2).float()/d)),requires\_grad=False)
x is the Tensor at the head of a key or a query with shape [batch_size, seq_len, n_heads, d]43defforward(self,x:torch.Tensor):
Extract the shape
48batch\_size,seq\_len,n\_heads,d=x.shape
2d
51d\_2=d//2
Create position indexes [0, 1, ..., seq_len - 1]
54seq\_idx=torch.arange(seq\_len,device=x.device).type\_as(self.theta)
Calculate the product of position index and θi
57idx\_theta=torch.einsum('n,d-\>nd',seq\_idx,self.theta)
Concatenate so that for row m we have [mθ0,mθ1,...,mθ2d,mθ0,mθ1,...,mθ2d]
61idx\_theta2=torch.cat([idx\_theta,idx\_theta],dim=1)
Calculate [−x(2d+1),−x(2d+2),...,−x(d),x(1),x(2),...,−x(2d)]
65neg\_half\_x=torch.cat([-x[:,:,:,d\_2:],x[:,:,:,:d\_2]],dim=-1)
Calculate
(xm(i)cosmθi−xm(i+2d)sinmθixm(i+2d)cosmθi+xm(i)sinmθi)
for i∈1,2,...,2d
77rx=(x\*idx\_theta2.cos()[None,:,None,:])+(neg\_half\_x\*idx\_theta2.sin()[None,:,None,:])
80returnrx
This applies causal and non-causal multi-headed self-attention.
83classSelfAttention(nn.Module):
d_model is the number of features in transformer embeddingsn_heads is the number of attention headsd_k is the number of features per headis_causal indicates whether this is causal attention (masked)90def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int,is\_causal:bool):
97super().\_\_init\_\_()9899self.is\_causal=is\_causal100self.n\_heads=n\_heads101self.d\_k=d\_k
To scale attentions before softmax by dk1
104self.scale=1/math.sqrt(self.d\_k)
Linear layers for query, key and value heads.
107self.query=nn.Linear(d\_model,n\_heads\*d\_k)108self.key=nn.Linear(d\_model,n\_heads\*d\_k)109self.value=nn.Linear(d\_model,n\_heads\*d\_k)
Pre-norm layer. The paper uses RMSNorm instead.
112self.norm=nn.LayerNorm(d\_model)
Softmax for attention probabilities
115self.softmax=nn.Softmax(dim=-1)
Rotary positional embeddings
118self.rotary\_pe=RotaryPositionalEmbeddings(self.d\_k)
Final linear layer
121self.output=nn.Linear(n\_heads\*d\_k,d\_model)
attn is the attention matrix of shape [batch_size, n_heads, seq_len, seq_len]123defmask\_attention(self,attn:torch.Tensor):
No masking for non-causal attention
131ifnotself.is\_causal:132returnattn
Create a triangular mask
135mask=torch.tril(attn.new\_ones(attn.shape[-2:]))
Filter by the mask
137returnattn.masked\_fill(mask==0,float('-inf'))
h is the transformer embeddings of shape [batch_size, seq_len, d_model]139defforward(self,h:torch.Tensor):
Residual connection
145h\_res=h
Pre-normalization
148h=self.norm(h)
Get query, key, and values and split them in to heads. These will have shapes [batch_size, seq_len, n_heads, d_k]
152mh\_shape=(\*h.shape[:-1],self.n\_heads,self.d\_k)153q=self.query(h).view(mh\_shape)154k=self.key(h).view(mh\_shape)155v=self.value(h).view(mh\_shape)
Apply rotary positional embeddings
158q=self.rotary\_pe(q)159k=self.rotary\_pe(k)
Calculate attentions
162attn=torch.einsum('bihd,bjhd-\>bhij',q,k)
Scale it by dk1
164attn=attn\*self.scale
Apply masks if it's causal attention
167attn=self.mask\_attention(attn)
Calculate attention probabilities
170attn=self.softmax(attn)
Get values
173h=torch.einsum("bhij,bjhd-\>bihd",attn,v)
Change from shape [batch_size, seq_len, n_heads, d_k] to [batch_size, seq_len, n_heads * d_k]
177h=h.reshape(\*h.shape[:-2],-1)
Apply final linear layer. The result will have shape [batch_size, seq_len, d_model]
181h=self.output(h)
Add the residual connection
184returnh+h\_res
This is similar to the self-attention layer defined above, except that it gets keys and values from a different set of embeddings than the queries.
This is used in the encoder to encode the retrieved chunks based on the input chunks.
We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
187classCrossAttention(nn.Module):
d_model is the number of features in transformer embeddingsn_heads is the number of attention headsd_k is the number of features per head201def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int):
207super().\_\_init\_\_()208209self.n\_heads=n\_heads210self.d\_k=d\_k
To scale attentions before softmax by dk1
213self.scale=1/math.sqrt(self.d\_k)
Linear layers for query, key and value heads.
216self.query=nn.Linear(d\_model,n\_heads\*d\_k)217self.key=nn.Linear(d\_model,n\_heads\*d\_k)218self.value=nn.Linear(d\_model,n\_heads\*d\_k)
Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
221self.norm=nn.LayerNorm(d\_model)
Softmax for attention probabilities
224self.softmax=nn.Softmax(dim=-1)
Final linear layer
227self.output=nn.Linear(n\_heads\*d\_k,d\_model)
e are the retrieved nearest neighbor chunk embeddings with shape [batch_size, chunks, neighbors, neighbor_len, d_model]h are the input chunks from which the nearest neighbors were retrieved with shape [batch_size, chunks, chunk_len, d_model] . This is already normalized.229defforward(self,e:torch.Tensor,h:torch.Tensor):
Residual connection
238e\_res=e
Normalize retrieved chunks
241e=self.norm(e)
Get query from the retrieved chunks
244q=self.query(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)
Get keys and values from the input chunks
246k=self.key(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)247v=self.value(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)
Calculate attention scores for all chunks. Each retrieved neighbor will pay attention to the original chunk that retrieved it. This will have shape [batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]
252attn=torch.einsum('bcnihd,bcjhd-\>bcnhij',q,k)
Scale attention scores
254attn=attn\*self.scale
Calculate softmax across the last dimension
257attn=self.softmax(attn)
Gather values
260e=torch.einsum("bcnhij,bcjhd-\>bcnihd",attn,v)
Change from shape [batch_size, chunks, neighbors, neighbor_len, n_heads, d_k] to [batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]
264e=e.reshape(\*e.shape[:-2],-1)
Apply final linear layer. The result will have shape [batch_size, chunks, neighbors, neighbor_len, d_model]
268e=self.output(e)
Add residual connection
271returne+e\_res
This is similar to the cross-attention layer defined above.
This is used in the decoder to pay attention to the retrieved neighbor chunks.
We do not use any explicit positional embeddings here. We assume that the model can represent positional information in the embeddings implicitly.
274classChunkedCrossAttention(nn.Module):
d_model is the number of features in transformer embeddingsn_heads is the number of attention headsd_k is the number of features per headchunk_len is the length of a chunk286def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_k:int,chunk\_len:int):
294super().\_\_init\_\_()295296self.chunk\_len=chunk\_len297self.n\_heads=n\_heads298self.d\_k=d\_k
To scale attentions before softmax by dk1
301self.scale=1/math.sqrt(self.d\_k)
Linear layers for query, key and value heads.
304self.query=nn.Linear(d\_model,n\_heads\*d\_k)305self.key=nn.Linear(d\_model,n\_heads\*d\_k)306self.value=nn.Linear(d\_model,n\_heads\*d\_k)
Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
309self.norm=nn.LayerNorm(d\_model)
Softmax for attention probabilities
312self.softmax=nn.Softmax(dim=-1)
Final linear layer
315self.output=nn.Linear(n\_heads\*d\_k,d\_model)
h are the input embeddings of shape [batch_size, seq_len, d_model]``e are the retrieved nearest neighbors of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
317defforward(self,h:torch.Tensor,e:torch.Tensor):
Get shape
324batch\_size,chunks,neighbors,neighbor\_len,d\_model=e.shape
No attention if there are no chunks (for short inputs when sampling)
327ifchunks==0:328returnh
Residual connection
331h\_res=h
Remove the first chunk_len - 1 embeddings. The input pays attention to neighbors retrieved and encoded using the past tokens only; so that there is no information leakage. That is the retrieved neighbors from the first chunks will have information from the first chunk. So by shifting the sequence to the left by chunk_len - 1 we make sure that information only flows to the right.
339h=h[:,self.chunk\_len-1:]
Pre-norm
341h=self.norm(h)
Append empty embeddings to the end to be able to split the input into chunks
343ifh.shape[1]\<chunks\*self.chunk\_len:344h=torch.cat((h,h.new\_zeros(batch\_size,chunks\*self.chunk\_len-h.shape[1],d\_model)),dim=1)
Reshape the input into chunks.
346h=h.reshape(batch\_size,chunks,self.chunk\_len,d\_model)
Get query from the input
349q=self.query(h).view(\*h.shape[:-1],self.n\_heads,self.d\_k)
Get keys and values from the retrieved neighbors
351k=self.key(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)352v=self.value(e).view(\*e.shape[:-1],self.n\_heads,self.d\_k)
Calculate attention scores for input chunks. Each chunk will pay attention to neighbors retrieved by the previous chunk. This will have shape [batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]
357attn=torch.einsum('bcihd,bcnjhd-\>bchinj',q,k)
Scale attention scores
359attn=attn\*self.scale
Apply softmax over the last two dimensions neighbors, neighbor_len
362attn=self.softmax(attn.view(\*attn.shape[:-2],-1)).view(attn.shape)
Gather values
365h=torch.einsum("bchinj,bcnjhd-\>bcihd",attn,v)
Change from shape [batch_size, chunks, chunk_len, n_heads, d_k] to [batch_size, chunks * chunk_len, n_heads * d_k]
369h=h.reshape(batch\_size,chunks\*self.chunk\_len,-1)
Apply final linear layer. The result will have shape [batch_size, chunks * chunk_len, d_model]
373h=self.output(h)
Append chunk_len - 1 zero embedding to the left; i.e. right shift it back
376h=torch.cat((h.new\_zeros(batch\_size,self.chunk\_len-1,d\_model),h),dim=1)
Truncate and add the residual connection
379returnh[:,:h\_res.shape[1]]+h\_res
This consists of two linear layers and an activation in the middle.
382classFeedForward(nn.Module):
d_model is the number of features in transformer embeddingsd_ff is the number features in the hidden layer389def\_\_init\_\_(self,d\_model:int,d\_ff:int):
395super().\_\_init\_\_()
The two linear layers
398self.lin1=nn.Linear(d\_model,d\_ff)399self.lin2=nn.Linear(d\_ff,d\_model)
ReLU Activation
402self.act=nn.ReLU()
Pre-norm layer
405self.norm=nn.LayerNorm(d\_model)
h are the embeddings of shape [batch_size, seq_len, d_model]
407defforward(self,h:torch.Tensor):
Residual
413h\_res=h
Pre-norm
415h=self.norm(h)
First linear layer
417h=self.lin1(h)
Activation
419h=self.act(h)
Second linear layer
421h=self.lin2(h)
Add the residual connection
424returnh+h\_res
This module encodes the retrieved nearest neighbors
427classNearestNeighborEncoder(nn.Module):
chunk_len is the length of a chunkn_layer is the number of layers in the encoder Lencca_layers are the layers with cross attention Pencd_model is the number of features in embeddingsn_heads is the number of heads in attention layersd_k is the size of attention headsd_ff is the size of the feed-forward networks hidden layers434def\_\_init\_\_(self,chunk\_len:int,n\_layers:int,ca\_layers:Set[int],435d\_model:int,n\_heads:int,d\_k:int,d\_ff:int):
446super().\_\_init\_\_()447self.ca\_layers=ca\_layers448self.chunk\_len=chunk\_len
Cross-attention layers
450self.ca=nn.ModuleList([CrossAttention(d\_model,n\_heads,d\_k)for\_inrange(len(ca\_layers))])
Bi-directional self attention layers
452self.attn=nn.ModuleList([SelfAttention(d\_model,n\_heads,d\_k,is\_causal=False)for\_inrange(n\_layers)])
Feed forward layers
454self.ffw=nn.ModuleList([FeedForward(d\_model,d\_ff)for\_inrange(n\_layers)])
Pre-normalization layer for H
457self.norm\_h=nn.LayerNorm(d\_model)
e are token embeddings of the retrieved nearest neighbors, EMB(RET(Cu)1≤u≤l) of shape [batch_size, chunks, neighbors, neighbor_len, d_model]
h is are the input token embeddings, H of shape [batch_size, seq_len, d_model]
The chunks u∈[1,l] and neighbors j∈[1,k] are processed in parallel.
459defforward(self,e:torch.Tensor,h:torch.Tensor):
Get shape
472batch\_size,chunks,neighbors,neighbor\_len,d\_model=e.shape
(Hu)u∈[1,l]←SPLIT(H)
475h\_split=h[:,:self.chunk\_len\*chunks,:].reshape(batch\_size,chunks,self.chunk\_len,d\_model)
Pre-norm
478h\_split=self.norm\_h(h\_split)
Keep the index of the cross attention layer
481p\_ca=0
For all layers p′∈[1,Lenc]
483forpinrange(len(self.attn)):
Bi-directional self attention Euj←ATTNenc(Euj)
486e=self.attn[p](e.view(-1,neighbor\_len,d\_model)).view(e.shape)
Cross attention if p′∈Penc
489ifpinself.ca\_layers:
Euj←CAenc(Euj,Hu)
491e=self.ca[p\_ca](e,h\_split)
Incremnt the cross attention index
493p\_ca+=1
Feed forward layer Euj←FFWenc(Euj)
496e=self.ffw[p](e)
return E
499returne
This is the Retro decoder
502classRetroModel(nn.Module):
v_vocab is the number of tokens in the vocabularyd_model is the number of features in embeddingsn_layers is the number of layers in the decoder Lca_layers are the layers with cross attention Pchunk_len is the length of a chunkn_heads is the number of heads in attention layersd_k is the size of attention headsd_ff is the size of the feed-forward networks hidden layersencoder is the nearest neighbor encoder509def\_\_init\_\_(self,n\_vocab:int,d\_model:int,n\_layers:int,ca\_layers:Set[int],chunk\_len:int,510n\_heads:int,d\_k:int,d\_ff:int,encoder:NearestNeighborEncoder):
522super().\_\_init\_\_()523524self.ca\_layers=ca\_layers525self.encoder=encoder
Token embedding layer
528self.emb=nn.Embedding(n\_vocab,d\_model)
Chunked cross attention layers CCA
530self.cca=nn.ModuleList(531[ChunkedCrossAttention(d\_model,n\_heads,d\_k,chunk\_len)for\_inrange(len(ca\_layers))])
Attention layers ATTN
533self.attn=nn.ModuleList([SelfAttention(d\_model,n\_heads,d\_k,is\_causal=True)for\_inrange(n\_layers)])
Feed forward layers FFW
535self.ffw=nn.ModuleList([FeedForward(d\_model,d\_ff)for\_inrange(n\_layers)])
Readout layer READ
537self.read=nn.Linear(d\_model,n\_vocab)
Pre-normalization layer for nearest neighbor embeddings from ENCODER(RET(Cu)1≤u≤l,H)
541self.norm\_e=nn.LayerNorm(d\_model)
x is the input sequence, X of shape [batch_size, seq_len]ret are the retrieved neighbors RET(Cu)1≤u≤l of shape [batch_size, chunks, neighbors, neighbor_len]543defforward(self,x:torch.Tensor,ret:torch.Tensor):
Get input embeddings H←EMB(X)
552h=self.emb(x)
Embeddings of the retrieved neighbors Euj=EMBenc(RET(Cu)j).
We use same embeddings for both input and neighbors
558ret\_emb=self.emb(ret)
Keep index of the chunked cross attention layer
561p\_ca=0
For all layers p∈[1,L]
563forpinrange(len(self.attn)):
Causal self attention H←ATTN(H)
565h=self.attn[p](h)
Get encoder embeddings before the first CCA layer, when p=min(P)
569ifself.ca\_layersandp==min(self.ca\_layers):
E=ENCODER(RET(Cu)1≤u≤l,H)
We passed the embeddings of RET(Cu)1≤u≤l to encoder.
573e=self.encoder(ret\_emb,h)
Normalize encoder embeddings
575e=self.norm\_e(e)
Chunked-cross attention if p∈P
578ifpinself.ca\_layers:
H←CCA(H,E)
580h=self.cca[p\_ca](h,e)
Increment chunked cross-attention index
582p\_ca+=1
H←FFW(H)
585h=self.ffw[p](h)
O←READ(H)
588returnself.read(h)
591def\_test():
595chunk\_len=4596d\_model=8597d\_ff=32598n\_heads=2599d\_k=4600601device=torch.device('cuda:0')602603m=RetroModel(5,d\_model,6,{2,5},chunk\_len,n\_heads,d\_k,d\_ff,604encoder=NearestNeighborEncoder(chunk\_len,2,{1},d\_model,n\_heads,d\_k,d\_ff))605606m.to(device)607x=[1,2,4,4,0,1,2,3,4,3]608ret=[609[[0,0,0,0,0,0],[1,1,1,1,1,1]],610[[0,0,0,0,0,0],[1,1,1,1,1,1]],611]612res=m(torch.tensor([x]\*10).to(device),torch.tensor([ret]\*10).to(device))613614inspect(res)
618if\_\_name\_\_=='\_\_main\_\_':619\_test()