Back to Annotated Deep Learning Paper Implementations

Gated Linear Units and Variants

docs/transformers/glu_variants/simple.html

latest11.2 KB
Original Source

hometransformersglu_variants

View code on Github

#

Gated Linear Units and Variants

This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.

This is a simpler implementation that doesn't use labml.configs module. We decided to write a simpler implementation to make it easier for readers who are not familiar.

19importdataclasses2021importtorch22fromlabmlimportexperiment,lab,tracker,monit,logger23fromlabml.loggerimportText24fromlabml.utils.downloadimportdownload\_file25fromlabml\_nn.experiments.nlp\_autoregressionimporttranspose\_batch26fromlabml\_nn.optimizers.noamimportNoam27fromlabml\_nn.transformersimportEncoder,MultiHeadAttention28fromlabml\_nn.transformers.feed\_forwardimportFeedForward29fromlabml\_nn.transformers.modelsimportEmbeddingsWithPositionalEncoding,TransformerLayer30fromlabml\_nn.transformers.utilsimportsubsequent\_mask31fromtorchimportnn32fromtorch.utils.dataimportDataset,DataLoader

#

Auto regressive model

35classAutoregressiveModel(nn.Module):

#

40def\_\_init\_\_(self,src\_embed:nn.Module,encoder:Encoder,generator:nn.Module):41super().\_\_init\_\_()

#

Token embedding module

43self.src\_embed=src\_embed

#

Transformer based encoder

45self.encoder=encoder

#

Next token generation layer; this gives logits of the the next token

48self.generator=generator

#

This will be initialized on the first call

50self.src\_mask=None

#

52defforward(self,src:torch.Tensor):

#

Create subsequent mask, so that the transformer can only pay attention to past tokens.

54ifself.src\_maskisNoneorself.src\_mask.size(0)!=len(src):55self.src\_mask=subsequent\_mask(len(src)).to(src.device)

#

Embed the tokens (src ) and run it through the the transformer

57res=self.encoder(self.src\_embed(src),self.src\_mask)

#

Generate logits of the next token

59returnself.generator(res)

#

Configurations

#

67d\_model:int=51268seq\_len:int=12869batch\_size:int=3270n\_layers:int=671n\_heads:int=872dropout:float=0.173d\_ff:int=204874glu\_variant:str='GLU'75epochs:int=576grad\_norm\_clip:float=0.5

#

Tiny Shakespeare Dataset

79classTinyShakespeareDataset(Dataset):

#

84def\_\_init\_\_(self,seq\_len:int):

#

Location of the text file

86path=lab.get\_data\_path()/'tiny\_shakespeare.txt'

#

Download the file

88download\_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt',path)

#

Read the downloaded file

90withopen(str(path),'r')asf:91text=f.read()

#

Extract the characters

94chars=list(set(text))

#

Character to id (integer) map

96self.stoi={c:ifori,cinenumerate(chars)}

#

Id to character map

98self.itos={i:cfori,cinenumerate(chars)}

#

Length of a training sample

100self.seq\_len=seq\_len

#

Data in the form of a tensor of ids

102self.data=self.text\_to\_i(text)

#

Transform the text into a tensor of ids

104deftext\_to\_i(self,text:str):

#

108returntorch.tensor([self.stoi[c]forcintext],dtype=torch.long)

#

Number of samples in the dataset.

This will read the dataset seq_len times in a single epoch.

110def\_\_len\_\_(self):

#

116returnlen(self.data)-self.seq\_len-1

#

Return a sample

118def\_\_getitem\_\_(self,idx):

#

122returnself.data[idx:idx+self.seq\_len],self.data[idx+1:idx+self.seq\_len+1]

#

Trainer

125classTrainer:

#

130def\_\_init\_\_(self,configs:Configs):

#

Get the device

132self.device=torch.device('cpu')133iftorch.cuda.is\_available():134self.device=torch.device('cuda:0')

#

Initialize the dataset

136self.dataset=TinyShakespeareDataset(configs.seq\_len)

#

Initialize the dataloader

138self.dataloader=DataLoader(self.dataset,139batch\_size=configs.batch\_size,140collate\_fn=transpose\_batch,141shuffle=True)

#

FFN with Gated Linear Unit FFNGLU​(x)(x,W1​,V,W2​)=(σ(xW1​)⊗xV)W2​

145ifconfigs.glu\_variant=='GLU':146ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.Sigmoid(),True,False,False,False)

#

FFN with Bilinear hidden layer FFNBilinear​(x)(x,W1​,V,W2​)=(xW1​⊗xV)W2​

149elifconfigs.glu\_variant=='Bilinear':150ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.Identity(),True,False,False,False)

#

FFN with ReLU gate FFNReGLU​(x)(x,W1​,V,W2​)=(max(0,xW1​)⊗xV)W2​

153elifconfigs.glu\_variant=='ReGLU':154ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.ReLU(),True,False,False,False)

#

FFN with GELU gate FFNGEGLU​(x)(x,W1​,V,W2​)=(GELU(xW1​)⊗xV)W2​

157elifconfigs.glu\_variant=='GEGLU':158ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.GELU(),True,False,False,False)

#

FFN with Swish gate FFNSwiGLU​(x)(x,W1​,V,W2​)=(Swish1​(xW1​)⊗xV)W2​ where Swishβ​(x)=xσ(βx)

162elifconfigs.glu\_variant=='SwiGLU':163ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.SiLU(),True,False,False,False)

#

FFN with ReLU activation FFNReLU​(x)(x,W1​,W2​,b1​,b2​)=ReLU1​(xW1​+b1​)W2​+b2​

166elifconfigs.glu\_variant=='ReLU':167ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.ReLU())

#

FFN with ReLU activation FFNGELU​(x)(x,W1​,W2​,b1​,b2​)=GELU1​(xW1​+b1​)W2​+b2​

170elifconfigs.glu\_variant=='GELU':171ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.GELU())172else:173raiseValueError(f'Unknown variant {configs.glu\_variant}')

#

Number of different characters

176n\_chars=len(self.dataset.stoi)

#

Initialize Multi-Head Attention module

179mha=MultiHeadAttention(configs.n\_heads,configs.d\_model,configs.dropout)

#

Initialize the Transformer Block

181transformer\_layer=TransformerLayer(d\_model=configs.d\_model,self\_attn=mha,src\_attn=None,182feed\_forward=ffn,dropout\_prob=configs.dropout)

#

Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.

188self.model=AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d\_model,n\_chars),189Encoder(transformer\_layer,configs.n\_layers),190nn.Linear(configs.d\_model,n\_chars))

#

Move the model to the current device

193self.model.to(self.device)

#

Initialize Noam optimizer

196self.optimizer=Noam(self.model.parameters(),lr=1.0,warmup=2\_000,d\_model=configs.d\_model)

#

Cross-entropy loss

199self.loss\_func=nn.CrossEntropyLoss()

#

Number of training epochs; note that our dataset definition repeats the data seq_len times in a single epoch

202self.epochs=configs.epochs

#

Gradient clipping norm

204self.grad\_norm\_clip=configs.grad\_norm\_clip

#

Set tracker configurations

207tracker.set\_scalar("loss.\*",True)

#

Sampling function to generate samples periodically while training

209defsample(self):

#

Starting prompt

215prompt='It is'

#

Collect output for printing

217log=[(prompt,Text.subtle)]

#

Sample 25 tokens

219foriinmonit.iterate('Sample',25):

#

Tokenize the prompt

221data=self.dataset.text\_to\_i(prompt).unsqueeze(-1)222data=data.to(self.device)

#

Get the model output

224output=self.model(data)

#

Get the model prediction (greedy)

226output=output.argmax(dim=-1).squeeze()

#

Add the prediction to prompt

228prompt+=self.dataset.itos[output[-1].item()]

#

Add the prediction for logging

230log+=[(self.dataset.itos[output[-1].item()],Text.value)]

#

Print the sampled output

233logger.log(log)

#

Train the model

235deftrain(self):

#

Loop for the given number of epochs

241for\_inmonit.loop(self.epochs):

#

Iterate over the minibatches

243fori,batchinmonit.enum('Train',self.dataloader):

#

Move data to the device

245data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Set tracker step, as the number of characters trained on

248tracker.add\_global\_step(data.shape[0]\*data.shape[1])

#

Set model state to training

251self.model.train()

#

Evaluate the model

253output=self.model(data)

#

Calculate loss

256loss=self.loss\_func(output.view(-1,output.shape[-1]),target.view(-1))

#

Log the loss

258tracker.add("loss.train",loss)

#

Calculate gradients

261loss.backward()

#

Clip gradients

263torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Take optimizer step

265self.optimizer.step()

#

Log the model parameters and gradients

267if(i+1)%100==0:268tracker.add('model',self.model)

#

Clear the gradients

270self.optimizer.zero\_grad()

#

Generate a sample

273if(i+1)%100==0:274self.model.eval()275withtorch.no\_grad():276self.sample()

#

Save the tracked metrics

279if(i+1)%10==0:280tracker.save()

#

283defmain():

#

Create experiment

285experiment.create(name="glu\_variants")

#

Create configs

287configs=Configs()

#

Load configurations

289experiment.configs(dataclasses.asdict(configs))

#

Create trainer

292trainer=Trainer(configs)

#

Set models for training and loading

294experiment.add\_pytorch\_models({'model':trainer.model})

#

Start the experiment

297withexperiment.start():

#

Train the model

299trainer.train()300301302if\_\_name\_\_=='\_\_main\_\_':303main()

labml.ai