docs/experiments/nlp_classification.html
11fromcollectionsimportCounter12fromtypingimportCallable1314importtorchtext15importtorchtext.vocab16fromtorchtext.vocabimportVocab1718importtorch19fromlabmlimportlab,tracker,monit20fromlabml.configsimportoption21fromlabml\_nn.helpers.deviceimportDeviceConfigs22fromlabml\_nn.helpers.metricsimportAccuracy23fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex24fromlabml\_nn.optimizers.configsimportOptimizerConfigs25fromtorchimportnn26fromtorch.utils.dataimportDataLoader
This has the basic configurations for NLP classification task training. All the properties are configurable.
29classNLPClassificationConfigs(TrainValidConfigs):
Optimizer
40optimizer:torch.optim.Adam
Training device
42device:torch.device=DeviceConfigs()
Autoregressive model
45model:nn.Module
Batch size
47batch\_size:int=16
Length of the sequence, or context size
49seq\_len:int=512
Vocabulary
51vocab:Vocab='ag\_news'
Number of token in vocabulary
53n\_tokens:int
Number of classes
55n\_classes:int='ag\_news'
Tokenizer
57tokenizer:Callable='character'
Whether to periodically save models
60is\_save\_models=True
Loss function
63loss\_func=nn.CrossEntropyLoss()
Accuracy function
65accuracy=Accuracy()
Model embedding size
67d\_model:int=512
Gradient clipping
69grad\_norm\_clip:float=1.0
Training data loader
72train\_loader:DataLoader='ag\_news'
Validation data loader
74valid\_loader:DataLoader='ag\_news'
Whether to log model parameters and gradients (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
79is\_log\_model\_params\_grads:bool=False
Whether to log model activations (once per epoch). These are summarized stats per layer, but it could still lead to many indicators for very deep networks.
84is\_log\_model\_activations:bool=False
86definit(self):
Set tracker configurations
91tracker.set\_scalar("accuracy.\*",True)92tracker.set\_scalar("loss.\*",True)
Add accuracy as a state module. The name is probably confusing, since it's meant to store states between training and validation for RNNs. This will keep the accuracy metric stats separate for training and validation.
97self.state\_modules=[self.accuracy]
99defstep(self,batch:any,batch\_idx:BatchIndex):
Move data to the device
105data,target=batch[0].to(self.device),batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
108ifself.mode.is\_train:109tracker.add\_global\_step(data.shape[1])
Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜
114output,\*\_=self.model(data)
Calculate and log loss
117loss=self.loss\_func(output,target)118tracker.add("loss.",loss)
Calculate and log accuracy
121self.accuracy(output,target)122self.accuracy.track()
Train the model
125ifself.mode.is\_train:
Calculate gradients
127loss.backward()
Clip gradients
129torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)
Take optimizer step
131self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
133ifbatch\_idx.is\_lastandself.is\_log\_model\_params\_grads:134tracker.add('model',self.model)
Clear the gradients
136self.optimizer.zero\_grad()
Save the tracked metrics
139tracker.save()
142@option(NLPClassificationConfigs.optimizer)143def\_optimizer(c:NLPClassificationConfigs):
148optimizer=OptimizerConfigs()149optimizer.parameters=c.model.parameters()150optimizer.optimizer='Adam'151optimizer.d\_model=c.d\_model152153returnoptimizer
We use character level tokenizer in this experiment. You can switch by setting,
'tokenizer': 'basic_english',
in the configurations dictionary when starting the experiment.
156@option(NLPClassificationConfigs.tokenizer)157defbasic\_english():
171fromtorchtext.dataimportget\_tokenizer172returnget\_tokenizer('basic\_english')
175defcharacter\_tokenizer(x:str):
179returnlist(x)
Character level tokenizer configuration
182@option(NLPClassificationConfigs.tokenizer)183defcharacter():
187returncharacter\_tokenizer
Get number of tokens
190@option(NLPClassificationConfigs.n\_tokens)191def\_n\_tokens(c:NLPClassificationConfigs):
195returnlen(c.vocab)+2
198classCollateFunc:
tokenizer is the tokenizer functionvocab is the vocabularyseq_len is the length of the sequencepadding_token is the token used for padding when the seq_len is larger than the text lengthclassifier_token is the [CLS] token which we set at end of the input203def\_\_init\_\_(self,tokenizer,vocab:Vocab,seq\_len:int,padding\_token:int,classifier\_token:int):
211self.classifier\_token=classifier\_token212self.padding\_token=padding\_token213self.seq\_len=seq\_len214self.vocab=vocab215self.tokenizer=tokenizer
batch is the batch of data collected by the DataLoader217def\_\_call\_\_(self,batch):
Input data tensor, initialized with padding_token
223data=torch.full((self.seq\_len,len(batch)),self.padding\_token,dtype=torch.long)
Empty labels tensor
225labels=torch.zeros(len(batch),dtype=torch.long)
Loop through the samples
228for(i,(\_label,\_text))inenumerate(batch):
Set the label
230labels[i]=int(\_label)-1
Tokenize the input text
232\_text=[self.vocab[token]fortokeninself.tokenizer(\_text)]
Truncate upto seq_len
234\_text=\_text[:self.seq\_len]
Transpose and add to data
236data[:len(\_text),i]=data.new\_tensor(\_text)
Set the final token in the sequence to [CLS]
239data[-1,:]=self.classifier\_token
242returndata,labels
This loads the AG News dataset and the set the values for n_classes , vocab , train_loader , and valid_loader .
245@option([NLPClassificationConfigs.n\_classes,246NLPClassificationConfigs.vocab,247NLPClassificationConfigs.train\_loader,248NLPClassificationConfigs.valid\_loader])249defag\_news(c:NLPClassificationConfigs):
Get training and validation datasets
258train,valid=torchtext.datasets.AG\_NEWS(root=str(lab.get\_data\_path()/'ag\_news'),split=('train','test'))
Load data to memory
261withmonit.section('Load data'):262fromlabml\_nn.utilsimportMapStyleDataset
Create map-style datasets
265train,valid=MapStyleDataset(train),MapStyleDataset(valid)
Get tokenizer
268tokenizer=c.tokenizer
Create a counter
271counter=Counter()
Collect tokens from training dataset
273for(label,line)intrain:274counter.update(tokenizer(line))
Collect tokens from validation dataset
276for(label,line)invalid:277counter.update(tokenizer(line))
Create vocabulary
279vocab=torchtext.vocab.vocab(counter,min\_freq=1)
Create training data loader
282train\_loader=DataLoader(train,batch\_size=c.batch\_size,shuffle=True,283collate\_fn=CollateFunc(tokenizer,vocab,c.seq\_len,len(vocab),len(vocab)+1))
Create validation data loader
285valid\_loader=DataLoader(valid,batch\_size=c.batch\_size,shuffle=True,286collate\_fn=CollateFunc(tokenizer,vocab,c.seq\_len,len(vocab),len(vocab)+1))
Return n_classes , vocab , train_loader , and valid_loader
289return4,vocab,train\_loader,valid\_loader