docs/transformers/mlm/experiment.html
This is an annotated PyTorch experiment to train a Masked Language Model.
11fromtypingimportList1213importtorch14fromtorchimportnn1516fromlabmlimportexperiment,tracker,logger17fromlabml.configsimportoption18fromlabml.loggerimportText19fromlabml\_nn.helpers.metricsimportAccuracy20fromlabml\_nn.helpers.trainerimportBatchIndex21fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs22fromlabml\_nn.transformersimportEncoder,Generator23fromlabml\_nn.transformersimportTransformerConfigs24fromlabml\_nn.transformers.mlmimportMLM
27classTransformerMLM(nn.Module):
encoder is the transformer Encodersrc_embed is the token embedding module (with positional encodings)generator is the final fully connected layer that gives the logits.32def\_\_init\_\_(self,\*,encoder:Encoder,src\_embed:nn.Module,generator:Generator):
39super().\_\_init\_\_()40self.generator=generator41self.src\_embed=src\_embed42self.encoder=encoder
44defforward(self,x:torch.Tensor):
Get the token embeddings with positional encodings
46x=self.src\_embed(x)
Transformer encoder
48x=self.encoder(x,None)
Logits for the output
50y=self.generator(x)
Return results (second value is for state, since our trainer is used with RNNs also)
54returny,None
This inherits from NLPAutoRegressionConfigs because it has the data pipeline implementations that we reuse here. We have implemented a custom training step form MLM.
57classConfigs(NLPAutoRegressionConfigs):
MLM model
68model:TransformerMLM
Transformer
70transformer:TransformerConfigs
Number of tokens
73n\_tokens:int='n\_tokens\_mlm'
Tokens that shouldn't be masked
75no\_mask\_tokens:List[int]=[]
Probability of masking a token
77masking\_prob:float=0.15
Probability of replacing the mask with a random token
79randomize\_prob:float=0.1
Probability of replacing the mask with original token
81no\_change\_prob:float=0.1
Masked Language Model (MLM) class to generate the mask
83mlm:MLM
[MASK] token
86mask\_token:int
[PADDING] token
88padding\_token:int
Prompt to sample
91prompt:str=[92"We are accounted poor citizens, the patricians good.",93"What authority surfeits on would relieve us: if they",94"would yield us but the superfluity, while it were",95"wholesome, we might guess they relieved us humanely;",96"but they think we are too dear: the leanness that",97"afflicts us, the object of our misery, is as an",98"inventory to particularise their abundance; our",99"sufferance is a gain to them Let us revenge this with",100"our pikes, ere we become rakes: for the gods know I",101"speak this in hunger for bread, not in thirst for revenge.",102]
104definit(self):
[MASK] token
110self.mask\_token=self.n\_tokens-1
[PAD] token
112self.padding\_token=self.n\_tokens-2
Masked Language Model (MLM) class to generate the mask
115self.mlm=MLM(padding\_token=self.padding\_token,116mask\_token=self.mask\_token,117no\_mask\_tokens=self.no\_mask\_tokens,118n\_tokens=self.n\_tokens,119masking\_prob=self.masking\_prob,120randomize\_prob=self.randomize\_prob,121no\_change\_prob=self.no\_change\_prob)
Accuracy metric (ignore the labels equal to [PAD] )
124self.accuracy=Accuracy(ignore\_index=self.padding\_token)
Cross entropy loss (ignore the labels equal to [PAD] )
126self.loss\_func=nn.CrossEntropyLoss(ignore\_index=self.padding\_token)
128super().init()
130defstep(self,batch:any,batch\_idx:BatchIndex):
Move the input to the device
136data=batch[0].to(self.device)
Update global step (number of tokens processed) when in training mode
139ifself.mode.is\_train:140tracker.add\_global\_step(data.shape[0]\*data.shape[1])
Get the masked input and labels
143withtorch.no\_grad():144data,labels=self.mlm(data)
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet.
149output,\*\_=self.model(data)
Calculate and log the loss
152loss=self.loss\_func(output.view(-1,output.shape[-1]),labels.view(-1))153tracker.add("loss.",loss)
Calculate and log accuracy
156self.accuracy(output,labels)157self.accuracy.track()
Train the model
160ifself.mode.is\_train:
Calculate gradients
162loss.backward()
Clip gradients
164torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)
Take optimizer step
166self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
168ifbatch\_idx.is\_last:169tracker.add('model',self.model)
Clear the gradients
171self.optimizer.zero\_grad()
Save the tracked metrics
174tracker.save()
[email protected]\_grad()177defsample(self):
Empty tensor for data filled with [PAD] .
183data=torch.full((self.seq\_len,len(self.prompt)),self.padding\_token,dtype=torch.long)
Add the prompts one by one
185fori,pinenumerate(self.prompt):
Get token indexes
187d=self.text.text\_to\_i(p)
Add to the tensor
189s=min(self.seq\_len,len(d))190data[:s,i]=d[:s]
Move the tensor to current device
192data=data.to(self.device)
Get masked input and labels
195data,labels=self.mlm(data)
Get model outputs
197output,\*\_=self.model(data)
Print the samples generated
200forjinrange(data.shape[1]):
Collect output from printing
202log=[]
For each token
204foriinrange(len(data)):
If the label is not [PAD]
206iflabels[i,j]!=self.padding\_token:
Get the prediction
208t=output[i,j].argmax().item()
If it's a printable character
210ift\<len(self.text.itos):
Correct prediction
212ift==labels[i,j]:213log.append((self.text.itos[t],Text.value))
Incorrect prediction
215else:216log.append((self.text.itos[t],Text.danger))
If it's not a printable character
218else:219log.append(('\*',Text.danger))
If the label is [PAD] (unmasked) print the original.
221elifdata[i,j]\<len(self.text.itos):222log.append((self.text.itos[data[i,j]],Text.subtle))
225logger.log(log)
Number of tokens including [PAD] and [MASK]
228@option(Configs.n\_tokens)229defn\_tokens\_mlm(c:Configs):
233returnc.text.n\_tokens+2
236@option(Configs.transformer)237def\_transformer\_configs(c:Configs):
We use our configurable transformer implementation
244conf=TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
246conf.n\_src\_vocab=c.n\_tokens247conf.n\_tgt\_vocab=c.n\_tokens
Embedding size
249conf.d\_model=c.d\_model
252returnconf
Create classification model
255@option(Configs.model)256def\_model(c:Configs):
260m=TransformerMLM(encoder=c.transformer.encoder,261src\_embed=c.transformer.src\_embed,262generator=c.transformer.generator).to(c.device)263264returnm
267defmain():
Create experiment
269experiment.create(name="mlm")
Create configs
271conf=Configs()
Override configurations
273experiment.configs(conf,{
Batch size
275'batch\_size':64,
Sequence length of 32. We use a short sequence length to train faster. Otherwise it takes forever to train.
278'seq\_len':32,
Train for 1024 epochs.
281'epochs':1024,
Switch between training and validation for 1 times per epoch
284'inner\_iterations':1,
Transformer configurations (same as defaults)
287'd\_model':128,288'transformer.ffn.d\_ff':256,289'transformer.n\_heads':8,290'transformer.n\_layers':6,
Use Noam optimizer
293'optimizer.optimizer':'Noam',294'optimizer.learning\_rate':1.,295})
Set models for saving and loading
298experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
301withexperiment.start():
Run training
303conf.run()
307if\_\_name\_\_=='\_\_main\_\_':308main()