Back to Annotated Deep Learning Paper Implementations

Feedback Transformer

docs/transformers/feedback/index.html

latest18.8 KB
Original Source

hometransformersfeedback

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/feedback/ init.py)

#

Feedback Transformer

This is a PyTorch implementation of the paper Accessing Higher-level Representations in Sequential Transformers with Feedback Memory.

Normal transformers process tokens in parallel. Each transformer layer pays attention to the outputs of the previous layer. Feedback transformer pays attention to the output of all layers in previous steps. So this adds recurrence, and we need to process token-by-token. This slows down the training significantly (about 5X - 10X depending on the sequence length). However, when predicting Feedback Transformer is faster because you can predict the next token if you cache the memory vectors.

In order to speed up the training, the paper discusses starting with a short sequence length and gradually increasing it. They also discuss using a pretrained parallel transformer as the starting point.

The original feedback transformer doesn't keep the outputs of all layers. Instead it keeps weighted sum of the output of all layers. This reduces the memory used for caching during prediction. The first half of this file implements this.

The updated feedback transformer shares weights Wkl​ and Wvl​ used to calculate keys and values among the layers. We then calculate the keys and values for each step only once and keep them cached. The second half of this file implements this. We implemented a custom PyTorch function to improve performance.

Here's the training code and a notebook for training a feedback transformer on Tiny Shakespeare dataset.

42importmath43fromtypingimportOptional4445importtorch46fromtorchimportnn4748fromlabml\_nn.transformers.feed\_forwardimportFeedForward49fromlabml\_nn.transformers.mhaimportPrepareForMultiHeadAttention50fromlabml\_nn.utilsimportclone\_module\_list

#

Feedback Attention

This module computes recurrent attention similar to attention from original transformers paper.

Attention(Q,K,V)=seqsoftmax​(dk​​Q⊤K​)V

53classFeedbackAttention(nn.Module):

#

  • 'heads' is the number of attention heads
  • d_model is the number of features in the transformer
  • dropout_prob is the attention dropout probability
  • is_kv_precomputed is whether key, value tensors are already calculated
64def\_\_init\_\_(self,heads:int,d\_model:int,dropout\_prob:float=0.1,\*,65is\_kv\_precomputed:bool=False):

#

73super().\_\_init\_\_()

#

Number of features per head

76self.d\_k=d\_model//heads

#

78self.heads=heads

#

These transform the query multi-headed attention.

81self.query=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)

#

These transform the key and value for multi-headed attention.

83ifnotis\_kv\_precomputed:84self.key=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=False)85self.value=PrepareForMultiHeadAttention(d\_model,heads,self.d\_k,bias=True)

#

Keys and values are already calculated

87else:88self.key=None89self.value=None

#

Output layer

92self.output=nn.Linear(d\_model,d\_model)

#

Dropout

94self.dropout=nn.Dropout(dropout\_prob)

#

Scaling factor before the softmax

96self.scale=1/math.sqrt(self.d\_k)

#

Softmax for attention along the time dimension of key

99self.softmax=nn.Softmax(dim=0)

#

Number of relative positions

102self.P=2\*\*12

#

Relative positional embeddings for key relative to the query.

105self.key\_pos\_embeddings=nn.Parameter(torch.zeros((self.P,heads,self.d\_k)),requires\_grad=True)

#

Relative positional embedding bias for key relative to the query.

107self.key\_pos\_bias=nn.Parameter(torch.zeros((self.P,heads)),requires\_grad=True)

#

Positional embeddings for the query is independent of the position of the query

109self.query\_pos\_bias=nn.Parameter(torch.zeros((heads,self.d\_k)),requires\_grad=True)

#

We store attentions so that it can be used for logging, or other computations if needed

112self.attn=None

#

Get attention scores

We use relative positional encodings for attention, similar to relative multi-head attention form Transformer-XL paper.

Attention from current step's query to key in step j (relative to current step) is,

Aj​​=Q⊤Kj​=linq​(Xq+Pq​)⊤link​(Xjk​+Pj​)=(Q+UQ)⊤(Kj​+UjK​)=AQ⊤Kj​​+BQ⊤UjK​​+CUQ⊤Kj​​+DUQ⊤UjK​​​

where Q,Kj​, are linear transformations of original embeddings Xq,Xjk​ and UQ,UjK​ are linear transformations of positional encodings Pq​,Pj​.

We replace term D with Sj​.

114defget\_scores(self,query:torch.Tensor,key:torch.Tensor):

#

UjK​

142key\_pos\_emb=self.key\_pos\_embeddings[-key.shape[0]:]

#

UQ

144query\_pos\_bias=self.query\_pos\_bias[None,:,:]

#

Sj​

146key\_pos\_bias=self.key\_pos\_bias[-key.shape[0]:]

#

AQ⊤Kj​​+CUQ⊤Kj​​

149ac=torch.einsum('bhd,jbhd-\>jbh',query+query\_pos\_bias,key)

#

BQ⊤UjK​​+DSj​​

151bd=torch.einsum('bhd,jhd-\>jbh',query,key\_pos\_emb)+key\_pos\_bias[:,None,:]

#

Aj​

154returnac+bd

#

  • query has shape [batch_size, d_model]
  • key and value has shape [seq_len, batch_size, d_model]
156defforward(self,\*,157query:torch.Tensor,158key:torch.Tensor,159value:torch.Tensor):

#

Prepare query , key and value for attention computation key and value will then have shape [seq_len, batch_size, heads, d_k] and query will have shape [batch_size, heads, d_k]

168query=self.query(query)169ifself.key:170key=self.key(key)171ifself.value:172value=self.value(value)

#

Compute attention scores. Results in a tensor of shape [seq_len, batch_size, heads]

176scores=self.get\_scores(query,key)

#

Scale scores dk​​1​

179scores\*=self.scale

#

Softmax

182attn=self.softmax(scores)

#

Apply dropout

185attn=self.dropout(attn)

#

Multiply by the values

188x=torch.einsum("jbh,jbhd-\>bhd",attn,value)

#

Concatenate multiple heads

191x=x.reshape(x.shape[0],-1)

#

Output layer

194returnself.output(x)

#

Feedback Transformer Layer

This implements a single transformer layer in the feedback transformer.

197classFeedbackTransformerLayer(nn.Module):

#

  • d_model is the number of features in the transformer
  • attn is the feedback attention module
  • feed_forward is the position-wise feed forward layer
  • dropout_prob is the dropout probability for dropout layers after attention and feed-forward
204def\_\_init\_\_(self,\*,205d\_model:int,206attn:FeedbackAttention,207feed\_forward:FeedForward,208dropout\_prob:float):

#

215super().\_\_init\_\_()

#

Transformer size dmodel​

217self.size=d\_model

#

219self.attn=attn220self.feed\_forward=feed\_forward221self.dropout=nn.Dropout(dropout\_prob)

#

Normalization layers

224self.norm\_self\_attn=nn.LayerNorm([d\_model])225self.norm\_ff=nn.LayerNorm([d\_model])

#

227defforward(self,\*,228x:torch.Tensor,229key:Optional[torch.Tensor],230value:Optional[torch.Tensor]):

#

If there is memory

232ifkeyisnotNone:

#

Normalize the vectors before doing self attention

234z=self.norm\_self\_attn(x)

#

Run through self attention, i.e. keys and values are from self

236self\_attn=self.attn(query=z,key=key,value=value)

#

Add the self attention results

238x=x+self.dropout(self\_attn)

#

Normalize for feed-forward

241z=self.norm\_ff(x)

#

Pass through the feed-forward network

243ff=self.feed\_forward(z)

#

Add the feed-forward results back

245x=x+self.dropout(ff)

#

248returnx

#

Feedback Transformer Module

251classFeedbackTransformer(nn.Module):

#

  • layer is the feedback transformer layer, which we clone for each layer
  • n_layers is the number of layers in the transformer
256def\_\_init\_\_(self,layer:FeedbackTransformerLayer,n\_layers:int):

#

262super().\_\_init\_\_()

#

Make copies of the transformer layer

264self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

266self.norm=nn.LayerNorm([layer.size])

#

Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.

269self.weights=nn.Parameter(torch.ones(n\_layers+1),requires\_grad=True)

#

Softmax for weights before taking the weighted sum

271self.softmax=nn.Softmax(0)

#

  • x_seq is the input with shape [seq_len, batch_size, d_model]
273defforward(self,x\_seq:torch.Tensor):

#

Split the input to a list along the sequence axis

279x\_seq=torch.unbind(x\_seq,dim=0)

#

List to store the outputs

281res=[]

#

List to store the memory vectors

283mem=[]

#

For each input step

285forxinx\_seq:

#

List to store layer outputs

287layer\_outputs=[x]

#

If there is memory, stack them into a vector

290mem\_tensor=torch.stack(mem)ifmemelseNone

#

Run through each layer

293forlayerinself.layers:

#

Get layer output

295x=layer(x=x,key=mem\_tensor,value=mem\_tensor)

#

Append them to the list of layer outputs

297layer\_outputs.append(x)

#

Stack the layer outputs to a tensor

300layer\_outputs=torch.stack(layer\_outputs)

#

Calculate the memory vector as a weighted sum of layer outputs

302mem.append(torch.einsum('lbd,l-\>bd',layer\_outputs,self.softmax(self.weights)))

#

Append the output to results

304res.append(x)

#

Stack the output tensors

307res=torch.stack(res)

#

Normalize the output

309returnself.norm(res)

#

Shared keys and values among layers

#

Stack Function implementation

We implement a custom function instead of appending to a python list and then doing torch.stack . This greatly improves the performance over calling torch.stack at each step along the sequence. Everytime torch.stack is called, it creates a new tensor, while this method and the accompanying class Stack share memory for each step.

316classStackFunction(torch.autograd.Function):

#

  • ctx is the context of the function (which lets us cache stuff)
  • memory is the shared memory tensor where we stack and store the values of each step (keys & values)
  • memory_grad is the shared memory tensor to store and accumulate gradients of each step
  • last is the last value stacked
  • n is the number of steps (i.e. size of the stack)

This returns the stacked tensor for steps upto n .

328@staticmethod329defforward(ctx,memory,memory\_grad,last,n):

#

Cache accumulated gradients

341ctx.\_mem\_grad=memory\_grad

#

Cache the size of the stack

343ctx.\_n=n

#

Return the stack

345returnmemory[:n+1]

#

  • grad_output is the gradient with respect to the output of about forward function

This accumulates the gradients in the shared memory tensor and return the gradients with respect to the last result in the stack.

347@staticmethod348defbackward(ctx,grad\_output):

#

Get the current size of the stack

356n=ctx.\_n

#

Get the accumulated gradients

358memory\_grad=ctx.\_mem\_grad

#

Add the gradients

360memory\_grad[:n+1]+=grad\_output

#

Return the gradients w.r.t to last value in the stack

362returnNone,None,memory\_grad[n],None

#

Stack Module

This uses the stack function defined above, and does the necessary initializations.

365classStack:

#

  • max_len is the maximum size of the stack
372def\_\_init\_\_(self,max\_len:int):

#

376self.max\_len=max\_len377self.memory=None378self.memory\_grad=None379self.last=None380self.n=-1381self.last\_get\_n=-1

#

  • n is the size of the stack
  • value is the tensor that needs to be added to the stack
383defappend(self,n:int,value:torch.Tensor):

#

You need to get (use) the stack after adding a value. Otherwise this implementation fails

391assertn==0orself.last\_get\_n==n-1,f"{n}, {self.last\_get\_n}"

#

Do this without gradients

394withtorch.no\_grad():

#

Initialize the shared memory tensor to keep the stack

396ifself.memoryisNoneorself.memory.shape[1:]!=value.shape:

#

This should only happen when the stack is empty

398assertn==0

#

Create a tensor for the stack

400self.memory=value.new\_zeros(self.max\_len,\*value.shape,requires\_grad=False)

#

Create a tensor to accumulate the gradients

402self.memory\_grad=value.new\_zeros(self.memory.shape,requires\_grad=False)

#

The memory is already initialized but we are resetting the stack.

This could have been another function like reset , but we found this easier to use.

407elifn==0:

#

Reset accumulated gradients

409self.memory\_grad.fill\_(0.)

#

Set the value in the correct position of the stack

412self.memory.data[n]=value.detach()

#

Keep track of the stack (for debugging)

414self.n=n

#

Keep track of the last value added to the stack. We need this to be passed on to StackFunction in order to get the gradients propagated backwards.

419self.last=value

#

Returns the stack

421defget(self):

#

Keep track of the size of the stack when it was used. This is used for a sanity check in append .

428self.last\_get\_n=self.n

#

Take it all through StackFunction so that StackFunction.backwards is called by PyTorch during backpropagation.

431returnStackFunction.apply(self.memory,self.memory\_grad,self.last,self.n)

#

To release memory

433deffree(self):

#

438self.memory=None439self.memory\_grad=None440self.last=None

#

Updated Feedback Transformer Module

This is the updated feedback transformer module that caches the keys and values.

443classFeedbackTransformerKV(nn.Module):

#

  • layer is the feedback transformer layer, which we clone for each layer
  • n_layers is the number of layers in the transformer
  • d_model is the number of features in the transformer
  • 'heads' is the number of attention heads
450def\_\_init\_\_(self,layer:FeedbackTransformerLayer,n\_layers:int,d\_model:int,heads:int):

#

458super().\_\_init\_\_()

#

Make copies of the transformer layer

460self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

462self.norm=nn.LayerNorm([layer.size])

#

Memory vectors are computed as a weighted sum of representations of each layer. This is the weights parameter for that.

465self.weights=nn.Parameter(torch.ones(n\_layers+1),requires\_grad=True)

#

Softmax for weights before taking the weighted sum

467self.softmax=nn.Softmax(0)

#

Number of features in a head

470d\_k=d\_model//heads

#

Module to transform embeddings (memory) to get keys

472self.key=PrepareForMultiHeadAttention(d\_model,heads,d\_k,bias=False)

#

Module to transform embeddings (memory) to get keys

474self.value=PrepareForMultiHeadAttention(d\_model,heads,d\_k,bias=False)

#

Memory for stacked keys

477self.mem\_key=Stack(512)

#

Memory for stacked values

479self.mem\_value=Stack(512)

#

  • x_seq is the input with shape [seq_len, batch_size, d_model]
481defforward(self,x\_seq:torch.Tensor):

#

Split the input to a list along the sequence axis

487x\_seq=torch.unbind(x\_seq,dim=0)

#

List to store the outputs

489res=[]

#

For each input step

491forstep,xinenumerate(x\_seq):

#

List to store layer outputs

493layer\_outputs=[x]

#

Stack of keys and values

496key\_tensor=None497value\_tensor=None

#

Get the keys and values tensors if we are beyond the initial step

499ifstep\>0:500key\_tensor=self.mem\_key.get()501value\_tensor=self.mem\_value.get()

#

Run through each layer

504forlayerinself.layers:

#

Get layer output

506x=layer(x=x,key=key\_tensor,value=value\_tensor)

#

Append them to the list of layer outputs

508layer\_outputs.append(x)

#

Stack the layer outputs to a tensor

511layer\_outputs=torch.stack(layer\_outputs)

#

Calculate the memory vector as a weighted sum of layer outputs

513mem=torch.einsum('lbd,l-\>bd',layer\_outputs,self.softmax(self.weights))

#

Calculate the keys from memory and add it to the stack

515self.mem\_key.append(step,self.key(mem))

#

Calculate the values from memory and add it to the stack

517self.mem\_value.append(step,self.value(mem))

#

Append the output to results

519res.append(x)

#

Stack the output tensors

522res=torch.stack(res)

#

Normalize the output

524returnself.norm(res)

#

526deffree(self):527self.mem\_key.free()528self.mem\_value.free()

labml.ai