Back to Annotated Deep Learning Paper Implementations

Masked Language Model (MLM) Experiment

docs/transformers/mlm/experiment.html

latest10.2 KB
Original Source

hometransformersmlm

View code on Github

#

Masked Language Model (MLM) Experiment

This is an annotated PyTorch experiment to train a Masked Language Model.

11fromtypingimportList1213importtorch14fromtorchimportnn1516fromlabmlimportexperiment,tracker,logger17fromlabml.configsimportoption18fromlabml.loggerimportText19fromlabml\_nn.helpers.metricsimportAccuracy20fromlabml\_nn.helpers.trainerimportBatchIndex21fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs22fromlabml\_nn.transformersimportEncoder,Generator23fromlabml\_nn.transformersimportTransformerConfigs24fromlabml\_nn.transformers.mlmimportMLM

#

Transformer based model for MLM

27classTransformerMLM(nn.Module):

#

32def\_\_init\_\_(self,\*,encoder:Encoder,src\_embed:nn.Module,generator:Generator):

#

39super().\_\_init\_\_()40self.generator=generator41self.src\_embed=src\_embed42self.encoder=encoder

#

44defforward(self,x:torch.Tensor):

#

Get the token embeddings with positional encodings

46x=self.src\_embed(x)

#

Transformer encoder

48x=self.encoder(x,None)

#

Logits for the output

50y=self.generator(x)

#

Return results (second value is for state, since our trainer is used with RNNs also)

54returny,None

#

Configurations

This inherits from NLPAutoRegressionConfigs because it has the data pipeline implementations that we reuse here. We have implemented a custom training step form MLM.

57classConfigs(NLPAutoRegressionConfigs):

#

MLM model

68model:TransformerMLM

#

Transformer

70transformer:TransformerConfigs

#

Number of tokens

73n\_tokens:int='n\_tokens\_mlm'

#

Tokens that shouldn't be masked

75no\_mask\_tokens:List[int]=[]

#

Probability of masking a token

77masking\_prob:float=0.15

#

Probability of replacing the mask with a random token

79randomize\_prob:float=0.1

#

Probability of replacing the mask with original token

81no\_change\_prob:float=0.1

#

Masked Language Model (MLM) class to generate the mask

83mlm:MLM

#

[MASK] token

86mask\_token:int

#

[PADDING] token

88padding\_token:int

#

Prompt to sample

91prompt:str=[92"We are accounted poor citizens, the patricians good.",93"What authority surfeits on would relieve us: if they",94"would yield us but the superfluity, while it were",95"wholesome, we might guess they relieved us humanely;",96"but they think we are too dear: the leanness that",97"afflicts us, the object of our misery, is as an",98"inventory to particularise their abundance; our",99"sufferance is a gain to them Let us revenge this with",100"our pikes, ere we become rakes: for the gods know I",101"speak this in hunger for bread, not in thirst for revenge.",102]

#

Initialization

104definit(self):

#

[MASK] token

110self.mask\_token=self.n\_tokens-1

#

[PAD] token

112self.padding\_token=self.n\_tokens-2

#

Masked Language Model (MLM) class to generate the mask

115self.mlm=MLM(padding\_token=self.padding\_token,116mask\_token=self.mask\_token,117no\_mask\_tokens=self.no\_mask\_tokens,118n\_tokens=self.n\_tokens,119masking\_prob=self.masking\_prob,120randomize\_prob=self.randomize\_prob,121no\_change\_prob=self.no\_change\_prob)

#

Accuracy metric (ignore the labels equal to [PAD] )

124self.accuracy=Accuracy(ignore\_index=self.padding\_token)

#

Cross entropy loss (ignore the labels equal to [PAD] )

126self.loss\_func=nn.CrossEntropyLoss(ignore\_index=self.padding\_token)

#

128super().init()

#

Training or validation step

130defstep(self,batch:any,batch\_idx:BatchIndex):

#

Move the input to the device

136data=batch[0].to(self.device)

#

Update global step (number of tokens processed) when in training mode

139ifself.mode.is\_train:140tracker.add\_global\_step(data.shape[0]\*data.shape[1])

#

Get the masked input and labels

143withtorch.no\_grad():144data,labels=self.mlm(data)

#

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet.

149output,\*\_=self.model(data)

#

Calculate and log the loss

152loss=self.loss\_func(output.view(-1,output.shape[-1]),labels.view(-1))153tracker.add("loss.",loss)

#

Calculate and log accuracy

156self.accuracy(output,labels)157self.accuracy.track()

#

Train the model

160ifself.mode.is\_train:

#

Calculate gradients

162loss.backward()

#

Clip gradients

164torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Take optimizer step

166self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

168ifbatch\_idx.is\_last:169tracker.add('model',self.model)

#

Clear the gradients

171self.optimizer.zero\_grad()

#

Save the tracked metrics

174tracker.save()

#

Sampling function to generate samples periodically while training

[email protected]\_grad()177defsample(self):

#

Empty tensor for data filled with [PAD] .

183data=torch.full((self.seq\_len,len(self.prompt)),self.padding\_token,dtype=torch.long)

#

Add the prompts one by one

185fori,pinenumerate(self.prompt):

#

Get token indexes

187d=self.text.text\_to\_i(p)

#

Add to the tensor

189s=min(self.seq\_len,len(d))190data[:s,i]=d[:s]

#

Move the tensor to current device

192data=data.to(self.device)

#

Get masked input and labels

195data,labels=self.mlm(data)

#

Get model outputs

197output,\*\_=self.model(data)

#

Print the samples generated

200forjinrange(data.shape[1]):

#

Collect output from printing

202log=[]

#

For each token

204foriinrange(len(data)):

#

If the label is not [PAD]

206iflabels[i,j]!=self.padding\_token:

#

Get the prediction

208t=output[i,j].argmax().item()

#

If it's a printable character

210ift\<len(self.text.itos):

#

Correct prediction

212ift==labels[i,j]:213log.append((self.text.itos[t],Text.value))

#

Incorrect prediction

215else:216log.append((self.text.itos[t],Text.danger))

#

If it's not a printable character

218else:219log.append(('\*',Text.danger))

#

If the label is [PAD] (unmasked) print the original.

221elifdata[i,j]\<len(self.text.itos):222log.append((self.text.itos[data[i,j]],Text.subtle))

#

Print

225logger.log(log)

#

Number of tokens including [PAD] and [MASK]

228@option(Configs.n\_tokens)229defn\_tokens\_mlm(c:Configs):

#

233returnc.text.n\_tokens+2

#

Transformer configurations

236@option(Configs.transformer)237def\_transformer\_configs(c:Configs):

#

We use our configurable transformer implementation

244conf=TransformerConfigs()

#

Set the vocabulary sizes for embeddings and generating logits

246conf.n\_src\_vocab=c.n\_tokens247conf.n\_tgt\_vocab=c.n\_tokens

#

Embedding size

249conf.d\_model=c.d\_model

#

252returnconf

#

Create classification model

255@option(Configs.model)256def\_model(c:Configs):

#

260m=TransformerMLM(encoder=c.transformer.encoder,261src\_embed=c.transformer.src\_embed,262generator=c.transformer.generator).to(c.device)263264returnm

#

267defmain():

#

Create experiment

269experiment.create(name="mlm")

#

Create configs

271conf=Configs()

#

Override configurations

273experiment.configs(conf,{

#

Batch size

275'batch\_size':64,

#

Sequence length of 32. We use a short sequence length to train faster. Otherwise it takes forever to train.

278'seq\_len':32,

#

Train for 1024 epochs.

281'epochs':1024,

#

Switch between training and validation for 1 times per epoch

284'inner\_iterations':1,

#

Transformer configurations (same as defaults)

287'd\_model':128,288'transformer.ffn.d\_ff':256,289'transformer.n\_heads':8,290'transformer.n\_layers':6,

#

Use Noam optimizer

293'optimizer.optimizer':'Noam',294'optimizer.learning\_rate':1.,295})

#

Set models for saving and loading

298experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment

301withexperiment.start():

#

Run training

303conf.run()

#

307if\_\_name\_\_=='\_\_main\_\_':308main()

labml.ai