Back to Annotated Deep Learning Paper Implementations

Auto-regressive NLP model trainer

docs/experiments/nlp_autoregression.html

latest10.2 KB
Original Source

homeexperiments

View code on Github

#

Auto-regressive NLP model trainer

11fromtypingimportCallable1213importtorch14importtorch.nnasnn15fromlabmlimportlab,monit,logger,tracker16fromlabml.configsimportoption17fromlabml.loggerimportText18fromlabml\_nn.helpers.datasetsimportTextDataset,SequentialDataLoader,SequentialUnBatchedDataset,TextFileDataset19fromlabml\_nn.helpers.deviceimportDeviceConfigs20fromlabml\_nn.helpers.metricsimportAccuracy21fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex22fromlabml\_nn.optimizers.configsimportOptimizerConfigs23fromtorch.utils.dataimportDataLoader,RandomSampler

#

Cross entropy loss

26classCrossEntropyLoss(nn.Module):

#

31def\_\_init\_\_(self):32super().\_\_init\_\_()33self.loss=nn.CrossEntropyLoss()

#

35defforward(self,outputs,targets):36returnself.loss(outputs.view(-1,outputs.shape[-1]),targets.view(-1))

#

Trainer configurations

This has the basic configurations for NLP auto-regressive task training. All the properties are configurable.

39classNLPAutoRegressionConfigs(TrainValidConfigs):

#

Optimizer

50optimizer:torch.optim.Adam

#

Training device

52device:torch.device=DeviceConfigs()

#

Autoregressive model

55model:nn.Module

#

Text dataset

57text:TextDataset

#

Batch size

59batch\_size:int=16

#

Length of the sequence, or context size

61seq\_len:int=512

#

Number of token in vocabulary

63n\_tokens:int

#

Tokenizer

65tokenizer:Callable='character'

#

Text prompt to start sampling (for illustration)

68prompt:str

#

The token separator when sampling (blank for character level tokenization)

70prompt\_separator:str

#

Whether to periodically save models

73is\_save\_models=True

#

Loss function

76loss\_func=CrossEntropyLoss()

#

Accuracy function

78accuracy=Accuracy()

#

Model embedding size

80d\_model:int=512

#

Gradient clipping

82grad\_norm\_clip:float=1.0

#

Training data loader

85train\_loader:DataLoader='shuffled\_train\_loader'

#

Validation data loader

87valid\_loader:DataLoader='shuffled\_valid\_loader'

#

Data loaders shuffle with replacement

90dataloader\_shuffle\_with\_replacement:bool=False

#

Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

95is\_log\_model\_params\_grads:bool=False

#

Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

100is\_log\_model\_activations:bool=False

#

Initialization

102definit(self):

#

Set tracker configurations

107tracker.set\_scalar("accuracy.\*",True)108tracker.set\_scalar("loss.\*",True)109tracker.set\_text("sampled",False)

#

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

114self.state\_modules=[self.accuracy]

#

Override to calculate and log other metrics

116defother\_metrics(self,output:torch.Tensor,target:torch.Tensor):

#

118pass

#

Training or validation step

120defstep(self,batch:any,batch\_idx:BatchIndex):

#

Set training/eval mode

126self.model.train(self.mode.is\_train)

#

Move data to the device

129data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Update global step (number of tokens processed) when in training mode

132ifself.mode.is\_train:133tracker.add\_global\_step(data.shape[0]\*data.shape[1])

#

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

138output,\*\_=self.model(data)

#

Calculate and log loss

141loss=self.loss\_func(output,target)142tracker.add("loss.",loss)

#

Calculate and log accuracy

145self.accuracy(output,target)146self.accuracy.track()147148self.other\_metrics(output,target)

#

Train the model

151ifself.mode.is\_train:

#

Calculate gradients

153loss.backward()

#

Clip gradients

155torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Take optimizer step

157self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

159ifbatch\_idx.is\_lastandself.is\_log\_model\_params\_grads:160tracker.add('model',self.model)

#

Clear the gradients

162self.optimizer.zero\_grad()

#

Save the tracked metrics

165tracker.save()

#

Sampling function to generate samples periodically while training

167defsample(self):

#

Starting prompt

173prompt=self.prompt

#

Collect output for printing

175log=[(prompt,Text.subtle)]

#

Sample 25 tokens

177foriinmonit.iterate('Sample',25):

#

Tokenize the prompt

179data=self.text.text\_to\_i(prompt).unsqueeze(-1)180data=data.to(self.device)

#

Get the model output

182output,\*\_=self.model(data)

#

Get the model prediction (greedy)

184output=output.argmax(dim=-1).squeeze()

#

Add the prediction to prompt

186prompt+=self.prompt\_separator+self.text.itos[output[-1]]

#

Add the prediction for logging

188log+=[(self.prompt\_separator+self.text.itos[output[-1]],Text.value)]189190tracker.add({'sampled':prompt})

#

Print the sampled output

192logger.log(log)

#

Default optimizer configurations

195@option(NLPAutoRegressionConfigs.optimizer)196def\_optimizer(c:NLPAutoRegressionConfigs):

#

201optimizer=OptimizerConfigs()202optimizer.parameters=c.model.parameters()203optimizer.optimizer='Adam'204optimizer.d\_model=c.d\_model205206returnoptimizer

#

Get number of tokens

209@option(NLPAutoRegressionConfigs.n\_tokens)210def\_n\_tokens(c:NLPAutoRegressionConfigs):

#

214returnc.text.n\_tokens

#

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

217@option(NLPAutoRegressionConfigs.tokenizer)218defbasic\_english():

#

232fromtorchtext.dataimportget\_tokenizer233returnget\_tokenizer('basic\_english')

#

Character level tokenizer

236defcharacter\_tokenizer(x:str):

#

240returnlist(x)

#

Character level tokenizer configuration

243@option(NLPAutoRegressionConfigs.tokenizer)244defcharacter():

#

248returncharacter\_tokenizer

#

Tiny Shakespeare dataset

It will download from the url if not present

251@option(NLPAutoRegressionConfigs.text)252deftiny\_shakespeare(c:NLPAutoRegressionConfigs):

#

258returnTextFileDataset(259lab.get\_data\_path()/'tiny\_shakespeare.txt',260c.tokenizer,261url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

#

Sequential training data loader

264@option(NLPAutoRegressionConfigs.train\_loader)265defsequential\_train\_loader(c:NLPAutoRegressionConfigs):

#

269returnSequentialDataLoader(text=c.text.train,270dataset=c.text,271batch\_size=c.batch\_size,272seq\_len=c.seq\_len)

#

Sequential validation data loader

275@option(NLPAutoRegressionConfigs.valid\_loader)276defsequential\_valid\_loader(c:NLPAutoRegressionConfigs):

#

280returnSequentialDataLoader(text=c.text.valid,281dataset=c.text,282batch\_size=c.batch\_size,283seq\_len=c.seq\_len)

#

Transpose batch

DataLoader collects the batches on the first dimension. We need to transpose it to be sequence first.

286deftranspose\_batch(batch):

#

294transposed\_data=list(zip(\*batch))

#

Stack the batch along the second dimension dim=1

296src=torch.stack(transposed\_data[0],dim=1)297tgt=torch.stack(transposed\_data[1],dim=1)298299returnsrc,tgt

#

Shuffled training data loader

302@option(NLPAutoRegressionConfigs.train\_loader)303defshuffled\_train\_loader(c:NLPAutoRegressionConfigs):

#

307dataset=SequentialUnBatchedDataset(text=c.text.train,308dataset=c.text,309seq\_len=c.seq\_len)310sampler=RandomSampler(dataset,replacement=c.dataloader\_shuffle\_with\_replacement)311312returnDataLoader(dataset,313batch\_size=c.batch\_size,314collate\_fn=transpose\_batch,315sampler=sampler)

#

Shuffled validation data loader

318@option(NLPAutoRegressionConfigs.valid\_loader)319defshuffled\_valid\_loader(c:NLPAutoRegressionConfigs):

#

323dataset=SequentialUnBatchedDataset(text=c.text.valid,324dataset=c.text,325seq\_len=c.seq\_len)326sampler=RandomSampler(dataset,replacement=c.dataloader\_shuffle\_with\_replacement)327328returnDataLoader(dataset,329batch\_size=c.batch\_size,330collate\_fn=transpose\_batch,331sampler=sampler)

labml.ai