docs/distillation/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/distillation/ init.py)
This is a PyTorch implementation/tutorial of the paper Distilling the Knowledge in a Neural Network.
It's a way of training a small network using the knowledge in a trained larger network; i.e. distilling the knowledge from the large network.
A large model with regularization or an ensemble of models (using dropout) generalizes better than a small model when trained directly on the data and labels. However, a small model can be trained to generalize better with help of a large model. Smaller models are better in production: faster, less compute, less memory.
The output probabilities of a trained model give more information than the labels because it assigns non-zero probabilities to incorrect classes as well. These probabilities tell us that a sample has a chance of belonging to certain classes. For instance, when classifying digits, when given an image of digit 7, a generalized model will give a high probability to 7 and a small but non-zero probability to 2, while assigning almost zero probability to other digits. Distillation uses this information to train a small model better.
The probabilities are usually computed with a softmax operation,
qi=∑jexp(zj)exp(zi)
where qi is the probability for class i and zi is the logit.
We train the small model to minimize the Cross entropy or KL Divergence between its output probability distribution and the large network's output probability distribution (soft targets).
One of the problems here is that the probabilities assigned to incorrect classes by the large network are often very small and don't contribute to the loss. So they soften the probabilities by applying a temperature T,
qi=∑jexp(Tzj)exp(Tzi)
where higher values for T will produce softer probabilities.
Paper suggests adding a second loss term for predicting the actual labels when training the small model. We calculate the composite loss as the weighted sum of the two loss terms: soft targets and actual labels.
The dataset for distillation is called the transfer set, and the paper suggests using the same training data.
We train on CIFAR-10 dataset. We train a large model that has 14,728,266 parameters with dropout and it gives an accuracy of 85% on the validation set. A small model with 437,034 parameters gives an accuracy of 80%.
We then train the small model with distillation from the large model, and it gives an accuracy of 82%; a 2% increase in the accuracy.
72importtorch73importtorch.nn.functional74fromtorchimportnn7576fromlabmlimportexperiment,tracker77fromlabml.configsimportoption78fromlabml\_nn.helpers.trainerimportBatchIndex79fromlabml\_nn.distillation.largeimportLargeModel80fromlabml\_nn.distillation.smallimportSmallModel81fromlabml\_nn.experiments.cifar10importCIFAR10Configs
This extends from CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.
84classConfigs(CIFAR10Configs):
The small model
92model:SmallModel
The large model
94large:LargeModel
KL Divergence loss for soft targets
96kl\_div\_loss=nn.KLDivLoss(log\_target=True)
Cross entropy loss for true label loss
98loss\_func=nn.CrossEntropyLoss()
Temperature, T
100temperature:float=5.
Weight for soft targets loss.
The gradients produced by soft targets get scaled by T21. To compensate for this the paper suggests scaling the soft targets loss by a factor of T2
106soft\_targets\_weight:float=100.
Weight for true label cross entropy loss
108label\_loss\_weight:float=0.5
We define a custom training/validation step to include the distillation
110defstep(self,batch:any,batch\_idx:BatchIndex):
Training/Evaluation mode for the small model
118self.model.train(self.mode.is\_train)
Large model in evaluation mode
120self.large.eval()
Move data to the device
123data,target=batch[0].to(self.device),batch[1].to(self.device)
Update global step (number of samples processed) when in training mode
126ifself.mode.is\_train:127tracker.add\_global\_step(len(data))
Get the output logits, vi, from the large model
130withtorch.no\_grad():131large\_logits=self.large(data)
Get the output logits, zi, from the small model
134output=self.model(data)
Soft targets pi=∑jexp(Tvj)exp(Tvi)
138soft\_targets=nn.functional.log\_softmax(large\_logits/self.temperature,dim=-1)
Temperature adjusted probabilities of the small model qi=∑jexp(Tzj)exp(Tzi)
141soft\_prob=nn.functional.log\_softmax(output/self.temperature,dim=-1)
Calculate the soft targets loss
144soft\_targets\_loss=self.kl\_div\_loss(soft\_prob,soft\_targets)
Calculate the true label loss
146label\_loss=self.loss\_func(output,target)
Weighted sum of the two losses
148loss=self.soft\_targets\_weight\*soft\_targets\_loss+self.label\_loss\_weight\*label\_loss
Log the losses
150tracker.add({"loss.kl\_div.":soft\_targets\_loss,151"loss.nll":label\_loss,152"loss.":loss})
Calculate and log accuracy
155self.accuracy(output,target)156self.accuracy.track()
Train the model
159ifself.mode.is\_train:
Calculate gradients
161loss.backward()
Take optimizer step
163self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
165ifbatch\_idx.is\_last:166tracker.add('model',self.model)
Clear the gradients
168self.optimizer.zero\_grad()
Save the tracked metrics
171tracker.save()
174@option(Configs.large)175def\_large\_model(c:Configs):
179returnLargeModel().to(c.device)
182@option(Configs.model)183def\_small\_student\_model(c:Configs):
187returnSmallModel().to(c.device)
190defget\_saved\_model(run\_uuid:str,checkpoint:int):
195fromlabml\_nn.distillation.largeimportConfigsasLargeConfigs
In evaluation mode (no recording)
198experiment.evaluate()
Initialize configs of the large model training experiment
200conf=LargeConfigs()
Load saved configs
202experiment.configs(conf,experiment.load\_configs(run\_uuid))
Set models for saving/loading
204experiment.add\_pytorch\_models({'model':conf.model})
Set which run and checkpoint to load
206experiment.load(run\_uuid,checkpoint)
Start the experiment - this will load the model, and prepare everything
208experiment.start()
Return the model
211returnconf.model
Train a small model with distillation
214defmain(run\_uuid:str,checkpoint:int):
Load saved model
219large\_model=get\_saved\_model(run\_uuid,checkpoint)
Create experiment
221experiment.create(name='distillation',comment='cifar10')
Create configurations
223conf=Configs()
Set the loaded large model
225conf.large=large\_model
Load configurations
227experiment.configs(conf,{228'optimizer.optimizer':'Adam',229'optimizer.learning\_rate':2.5e-4,230'model':'\_small\_student\_model',231})
Set model for saving/loading
233experiment.add\_pytorch\_models({'model':conf.model})
Start experiment from scratch
235experiment.load(None,None)
Start the experiment and run the training loop
237withexperiment.start():238conf.run()
242if\_\_name\_\_=='\_\_main\_\_':243main('d46cd53edaec11eb93c38d6538aee7d6',1\_000\_000)