docs/transformers/glu_variants/simple.html
This trains a simple transformer model for auto-regression. We try different variants for the position-wise feedforward network.
This is a simpler implementation that doesn't use labml.configs module. We decided to write a simpler implementation to make it easier for readers who are not familiar.
19importdataclasses2021importtorch22fromlabmlimportexperiment,lab,tracker,monit,logger23fromlabml.loggerimportText24fromlabml.utils.downloadimportdownload\_file25fromlabml\_nn.experiments.nlp\_autoregressionimporttranspose\_batch26fromlabml\_nn.optimizers.noamimportNoam27fromlabml\_nn.transformersimportEncoder,MultiHeadAttention28fromlabml\_nn.transformers.feed\_forwardimportFeedForward29fromlabml\_nn.transformers.modelsimportEmbeddingsWithPositionalEncoding,TransformerLayer30fromlabml\_nn.transformers.utilsimportsubsequent\_mask31fromtorchimportnn32fromtorch.utils.dataimportDataset,DataLoader
35classAutoregressiveModel(nn.Module):
40def\_\_init\_\_(self,src\_embed:nn.Module,encoder:Encoder,generator:nn.Module):41super().\_\_init\_\_()
Token embedding module
43self.src\_embed=src\_embed
Transformer based encoder
45self.encoder=encoder
Next token generation layer; this gives logits of the the next token
48self.generator=generator
This will be initialized on the first call
50self.src\_mask=None
52defforward(self,src:torch.Tensor):
Create subsequent mask, so that the transformer can only pay attention to past tokens.
54ifself.src\_maskisNoneorself.src\_mask.size(0)!=len(src):55self.src\_mask=subsequent\_mask(len(src)).to(src.device)
Embed the tokens (src ) and run it through the the transformer
57res=self.encoder(self.src\_embed(src),self.src\_mask)
Generate logits of the next token
59returnself.generator(res)
67d\_model:int=51268seq\_len:int=12869batch\_size:int=3270n\_layers:int=671n\_heads:int=872dropout:float=0.173d\_ff:int=204874glu\_variant:str='GLU'75epochs:int=576grad\_norm\_clip:float=0.5
79classTinyShakespeareDataset(Dataset):
84def\_\_init\_\_(self,seq\_len:int):
Location of the text file
86path=lab.get\_data\_path()/'tiny\_shakespeare.txt'
Download the file
88download\_file('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt',path)
Read the downloaded file
90withopen(str(path),'r')asf:91text=f.read()
Extract the characters
94chars=list(set(text))
Character to id (integer) map
96self.stoi={c:ifori,cinenumerate(chars)}
Id to character map
98self.itos={i:cfori,cinenumerate(chars)}
Length of a training sample
100self.seq\_len=seq\_len
Data in the form of a tensor of ids
102self.data=self.text\_to\_i(text)
Transform the text into a tensor of ids
104deftext\_to\_i(self,text:str):
108returntorch.tensor([self.stoi[c]forcintext],dtype=torch.long)
Number of samples in the dataset.
This will read the dataset seq_len times in a single epoch.
110def\_\_len\_\_(self):
116returnlen(self.data)-self.seq\_len-1
Return a sample
118def\_\_getitem\_\_(self,idx):
122returnself.data[idx:idx+self.seq\_len],self.data[idx+1:idx+self.seq\_len+1]
125classTrainer:
130def\_\_init\_\_(self,configs:Configs):
Get the device
132self.device=torch.device('cpu')133iftorch.cuda.is\_available():134self.device=torch.device('cuda:0')
Initialize the dataset
136self.dataset=TinyShakespeareDataset(configs.seq\_len)
Initialize the dataloader
138self.dataloader=DataLoader(self.dataset,139batch\_size=configs.batch\_size,140collate\_fn=transpose\_batch,141shuffle=True)
FFN with Gated Linear Unit FFNGLU(x)(x,W1,V,W2)=(σ(xW1)⊗xV)W2
145ifconfigs.glu\_variant=='GLU':146ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.Sigmoid(),True,False,False,False)
FFN with Bilinear hidden layer FFNBilinear(x)(x,W1,V,W2)=(xW1⊗xV)W2
149elifconfigs.glu\_variant=='Bilinear':150ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.Identity(),True,False,False,False)
FFN with ReLU gate FFNReGLU(x)(x,W1,V,W2)=(max(0,xW1)⊗xV)W2
153elifconfigs.glu\_variant=='ReGLU':154ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.ReLU(),True,False,False,False)
FFN with GELU gate FFNGEGLU(x)(x,W1,V,W2)=(GELU(xW1)⊗xV)W2
157elifconfigs.glu\_variant=='GEGLU':158ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.GELU(),True,False,False,False)
FFN with Swish gate FFNSwiGLU(x)(x,W1,V,W2)=(Swish1(xW1)⊗xV)W2 where Swishβ(x)=xσ(βx)
162elifconfigs.glu\_variant=='SwiGLU':163ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.SiLU(),True,False,False,False)
FFN with ReLU activation FFNReLU(x)(x,W1,W2,b1,b2)=ReLU1(xW1+b1)W2+b2
166elifconfigs.glu\_variant=='ReLU':167ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.ReLU())
FFN with ReLU activation FFNGELU(x)(x,W1,W2,b1,b2)=GELU1(xW1+b1)W2+b2
170elifconfigs.glu\_variant=='GELU':171ffn=FeedForward(configs.d\_model,configs.d\_ff,configs.dropout,nn.GELU())172else:173raiseValueError(f'Unknown variant {configs.glu\_variant}')
Number of different characters
176n\_chars=len(self.dataset.stoi)
Initialize Multi-Head Attention module
179mha=MultiHeadAttention(configs.n\_heads,configs.d\_model,configs.dropout)
Initialize the Transformer Block
181transformer\_layer=TransformerLayer(d\_model=configs.d\_model,self\_attn=mha,src\_attn=None,182feed\_forward=ffn,dropout\_prob=configs.dropout)
Initialize the model with an embedding layer (with fixed positional encoding) transformer encoder and a linear layer to generate logits.
188self.model=AutoregressiveModel(EmbeddingsWithPositionalEncoding(configs.d\_model,n\_chars),189Encoder(transformer\_layer,configs.n\_layers),190nn.Linear(configs.d\_model,n\_chars))
Move the model to the current device
193self.model.to(self.device)
Initialize Noam optimizer
196self.optimizer=Noam(self.model.parameters(),lr=1.0,warmup=2\_000,d\_model=configs.d\_model)
Cross-entropy loss
199self.loss\_func=nn.CrossEntropyLoss()
Number of training epochs; note that our dataset definition repeats the data seq_len times in a single epoch
202self.epochs=configs.epochs
Gradient clipping norm
204self.grad\_norm\_clip=configs.grad\_norm\_clip
Set tracker configurations
207tracker.set\_scalar("loss.\*",True)
209defsample(self):
Starting prompt
215prompt='It is'
Collect output for printing
217log=[(prompt,Text.subtle)]
Sample 25 tokens
219foriinmonit.iterate('Sample',25):
Tokenize the prompt
221data=self.dataset.text\_to\_i(prompt).unsqueeze(-1)222data=data.to(self.device)
Get the model output
224output=self.model(data)
Get the model prediction (greedy)
226output=output.argmax(dim=-1).squeeze()
Add the prediction to prompt
228prompt+=self.dataset.itos[output[-1].item()]
Add the prediction for logging
230log+=[(self.dataset.itos[output[-1].item()],Text.value)]
Print the sampled output
233logger.log(log)
235deftrain(self):
Loop for the given number of epochs
241for\_inmonit.loop(self.epochs):
Iterate over the minibatches
243fori,batchinmonit.enum('Train',self.dataloader):
Move data to the device
245data,target=batch[0].to(self.device),batch[1].to(self.device)
Set tracker step, as the number of characters trained on
248tracker.add\_global\_step(data.shape[0]\*data.shape[1])
Set model state to training
251self.model.train()
Evaluate the model
253output=self.model(data)
Calculate loss
256loss=self.loss\_func(output.view(-1,output.shape[-1]),target.view(-1))
Log the loss
258tracker.add("loss.train",loss)
Calculate gradients
261loss.backward()
Clip gradients
263torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)
Take optimizer step
265self.optimizer.step()
Log the model parameters and gradients
267if(i+1)%100==0:268tracker.add('model',self.model)
Clear the gradients
270self.optimizer.zero\_grad()
Generate a sample
273if(i+1)%100==0:274self.model.eval()275withtorch.no\_grad():276self.sample()
Save the tracked metrics
279if(i+1)%10==0:280tracker.save()
283defmain():
Create experiment
285experiment.create(name="glu\_variants")
Create configs
287configs=Configs()
Load configurations
289experiment.configs(dataclasses.asdict(configs))
Create trainer
292trainer=Trainer(configs)
Set models for training and loading
294experiment.add\_pytorch\_models({'model':trainer.model})
Start the experiment
297withexperiment.start():
Train the model
299trainer.train()300301302if\_\_name\_\_=='\_\_main\_\_':303main()