Back to Annotated Deep Learning Paper Implementations

NLP model trainer for classification

docs/experiments/nlp_classification.html

latest9.3 KB
Original Source

homeexperiments

View code on Github

#

NLP model trainer for classification

11fromcollectionsimportCounter12fromtypingimportCallable1314importtorchtext15importtorchtext.vocab16fromtorchtext.vocabimportVocab1718importtorch19fromlabmlimportlab,tracker,monit20fromlabml.configsimportoption21fromlabml\_nn.helpers.deviceimportDeviceConfigs22fromlabml\_nn.helpers.metricsimportAccuracy23fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex24fromlabml\_nn.optimizers.configsimportOptimizerConfigs25fromtorchimportnn26fromtorch.utils.dataimportDataLoader

#

Trainer configurations

This has the basic configurations for NLP classification task training. All the properties are configurable.

29classNLPClassificationConfigs(TrainValidConfigs):

#

Optimizer

40optimizer:torch.optim.Adam

#

Training device

42device:torch.device=DeviceConfigs()

#

Autoregressive model

45model:nn.Module

#

Batch size

47batch\_size:int=16

#

Length of the sequence, or context size

49seq\_len:int=512

#

Vocabulary

51vocab:Vocab='ag\_news'

#

Number of token in vocabulary

53n\_tokens:int

#

Number of classes

55n\_classes:int='ag\_news'

#

Tokenizer

57tokenizer:Callable='character'

#

Whether to periodically save models

60is\_save\_models=True

#

Loss function

63loss\_func=nn.CrossEntropyLoss()

#

Accuracy function

65accuracy=Accuracy()

#

Model embedding size

67d\_model:int=512

#

Gradient clipping

69grad\_norm\_clip:float=1.0

#

Training data loader

72train\_loader:DataLoader='ag\_news'

#

Validation data loader

74valid\_loader:DataLoader='ag\_news'

#

Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

79is\_log\_model\_params\_grads:bool=False

#

Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.

84is\_log\_model\_activations:bool=False

#

Initialization

86definit(self):

#

Set tracker configurations

91tracker.set\_scalar("accuracy.\*",True)92tracker.set\_scalar("loss.\*",True)

#

Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.

97self.state\_modules=[self.accuracy]

#

Training or validation step

99defstep(self,batch:any,batch\_idx:BatchIndex):

#

Move data to the device

105data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Update global step (number of tokens processed) when in training mode

108ifself.mode.is\_train:109tracker.add\_global\_step(data.shape[1])

#

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

114output,\*\_=self.model(data)

#

Calculate and log loss

117loss=self.loss\_func(output,target)118tracker.add("loss.",loss)

#

Calculate and log accuracy

121self.accuracy(output,target)122self.accuracy.track()

#

Train the model

125ifself.mode.is\_train:

#

Calculate gradients

127loss.backward()

#

Clip gradients

129torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Take optimizer step

131self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

133ifbatch\_idx.is\_lastandself.is\_log\_model\_params\_grads:134tracker.add('model',self.model)

#

Clear the gradients

136self.optimizer.zero\_grad()

#

Save the tracked metrics

139tracker.save()

#

Default optimizer configurations

142@option(NLPClassificationConfigs.optimizer)143def\_optimizer(c:NLPClassificationConfigs):

#

148optimizer=OptimizerConfigs()149optimizer.parameters=c.model.parameters()150optimizer.optimizer='Adam'151optimizer.d\_model=c.d\_model152153returnoptimizer

#

Basic english tokenizer

We use character level tokenizer in this experiment. You can switch by setting,

'tokenizer': 'basic_english',

in the configurations dictionary when starting the experiment.

156@option(NLPClassificationConfigs.tokenizer)157defbasic\_english():

#

171fromtorchtext.dataimportget\_tokenizer172returnget\_tokenizer('basic\_english')

#

Character level tokenizer

175defcharacter\_tokenizer(x:str):

#

179returnlist(x)

#

Character level tokenizer configuration

182@option(NLPClassificationConfigs.tokenizer)183defcharacter():

#

187returncharacter\_tokenizer

#

Get number of tokens

190@option(NLPClassificationConfigs.n\_tokens)191def\_n\_tokens(c:NLPClassificationConfigs):

#

195returnlen(c.vocab)+2

#

Function to load data into batches

198classCollateFunc:

#

  • tokenizer is the tokenizer function
  • vocab is the vocabulary
  • seq_len is the length of the sequence
  • padding_token is the token used for padding when the seq_len is larger than the text length
  • classifier_token is the [CLS] token which we set at end of the input
203def\_\_init\_\_(self,tokenizer,vocab:Vocab,seq\_len:int,padding\_token:int,classifier\_token:int):

#

211self.classifier\_token=classifier\_token212self.padding\_token=padding\_token213self.seq\_len=seq\_len214self.vocab=vocab215self.tokenizer=tokenizer

#

  • batch is the batch of data collected by the DataLoader
217def\_\_call\_\_(self,batch):

#

Input data tensor, initialized with padding_token

223data=torch.full((self.seq\_len,len(batch)),self.padding\_token,dtype=torch.long)

#

Empty labels tensor

225labels=torch.zeros(len(batch),dtype=torch.long)

#

Loop through the samples

228for(i,(\_label,\_text))inenumerate(batch):

#

Set the label

230labels[i]=int(\_label)-1

#

Tokenize the input text

232\_text=[self.vocab[token]fortokeninself.tokenizer(\_text)]

#

Truncate upto seq_len

234\_text=\_text[:self.seq\_len]

#

Transpose and add to data

236data[:len(\_text),i]=data.new\_tensor(\_text)

#

Set the final token in the sequence to [CLS]

239data[-1,:]=self.classifier\_token

#

242returndata,labels

#

AG News dataset

This loads the AG News dataset and the set the values for n_classes , vocab , train_loader , and valid_loader .

245@option([NLPClassificationConfigs.n\_classes,246NLPClassificationConfigs.vocab,247NLPClassificationConfigs.train\_loader,248NLPClassificationConfigs.valid\_loader])249defag\_news(c:NLPClassificationConfigs):

#

Get training and validation datasets

258train,valid=torchtext.datasets.AG\_NEWS(root=str(lab.get\_data\_path()/'ag\_news'),split=('train','test'))

#

Load data to memory

261withmonit.section('Load data'):262fromlabml\_nn.utilsimportMapStyleDataset

#

Create map-style datasets

265train,valid=MapStyleDataset(train),MapStyleDataset(valid)

#

Get tokenizer

268tokenizer=c.tokenizer

#

Create a counter

271counter=Counter()

#

Collect tokens from training dataset

273for(label,line)intrain:274counter.update(tokenizer(line))

#

Collect tokens from validation dataset

276for(label,line)invalid:277counter.update(tokenizer(line))

#

Create vocabulary

279vocab=torchtext.vocab.vocab(counter,min\_freq=1)

#

Create training data loader

282train\_loader=DataLoader(train,batch\_size=c.batch\_size,shuffle=True,283collate\_fn=CollateFunc(tokenizer,vocab,c.seq\_len,len(vocab),len(vocab)+1))

#

Create validation data loader

285valid\_loader=DataLoader(valid,batch\_size=c.batch\_size,shuffle=True,286collate\_fn=CollateFunc(tokenizer,vocab,c.seq\_len,len(vocab),len(vocab)+1))

#

Return n_classes , vocab , train_loader , and valid_loader

289return4,vocab,train\_loader,valid\_loader

labml.ai