Back to Annotated Deep Learning Paper Implementations

Distilling the Knowledge in a Neural Network

docs/distillation/index.html

latest8.8 KB
Original Source

homedistillation

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/distillation/ init.py)

#

Distilling the Knowledge in a Neural Network

This is a PyTorch implementation/tutorial of the paper Distilling the Knowledge in a Neural Network.

It's a way of training a small network using the knowledge in a trained larger network; i.e. distilling the knowledge from the large network.

A large model with regularization or an ensemble of models (using dropout) generalizes better than a small model when trained directly on the data and labels. However, a small model can be trained to generalize better with help of a large model. Smaller models are better in production: faster, less compute, less memory.

The output probabilities of a trained model give more information than the labels because it assigns non-zero probabilities to incorrect classes as well. These probabilities tell us that a sample has a chance of belonging to certain classes. For instance, when classifying digits, when given an image of digit 7, a generalized model will give a high probability to 7 and a small but non-zero probability to 2, while assigning almost zero probability to other digits. Distillation uses this information to train a small model better.

Soft Targets

The probabilities are usually computed with a softmax operation,

qi​=∑j​exp(zj​)exp(zi​)​

where qi​ is the probability for class i and zi​ is the logit.

We train the small model to minimize the Cross entropy or KL Divergence between its output probability distribution and the large network's output probability distribution (soft targets).

One of the problems here is that the probabilities assigned to incorrect classes by the large network are often very small and don't contribute to the loss. So they soften the probabilities by applying a temperature T,

qi​=∑j​exp(Tzj​​)exp(Tzi​​)​

where higher values for T will produce softer probabilities.

Training

Paper suggests adding a second loss term for predicting the actual labels when training the small model. We calculate the composite loss as the weighted sum of the two loss terms: soft targets and actual labels.

The dataset for distillation is called the transfer set, and the paper suggests using the same training data.

Our experiment

We train on CIFAR-10 dataset. We train a large model that has 14,728,266 parameters with dropout and it gives an accuracy of 85% on the validation set. A small model with 437,034 parameters gives an accuracy of 80%.

We then train the small model with distillation from the large model, and it gives an accuracy of 82%; a 2% increase in the accuracy.

72importtorch73importtorch.nn.functional74fromtorchimportnn7576fromlabmlimportexperiment,tracker77fromlabml.configsimportoption78fromlabml\_nn.helpers.trainerimportBatchIndex79fromlabml\_nn.distillation.largeimportLargeModel80fromlabml\_nn.distillation.smallimportSmallModel81fromlabml\_nn.experiments.cifar10importCIFAR10Configs

#

Configurations

This extends from CIFAR10Configs which defines all the dataset related configurations, optimizer, and a training loop.

84classConfigs(CIFAR10Configs):

#

The small model

92model:SmallModel

#

The large model

94large:LargeModel

#

KL Divergence loss for soft targets

96kl\_div\_loss=nn.KLDivLoss(log\_target=True)

#

Cross entropy loss for true label loss

98loss\_func=nn.CrossEntropyLoss()

#

Temperature, T

100temperature:float=5.

#

Weight for soft targets loss.

The gradients produced by soft targets get scaled by T21​. To compensate for this the paper suggests scaling the soft targets loss by a factor of T2

106soft\_targets\_weight:float=100.

#

Weight for true label cross entropy loss

108label\_loss\_weight:float=0.5

#

Training/validation step

We define a custom training/validation step to include the distillation

110defstep(self,batch:any,batch\_idx:BatchIndex):

#

Training/Evaluation mode for the small model

118self.model.train(self.mode.is\_train)

#

Large model in evaluation mode

120self.large.eval()

#

Move data to the device

123data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Update global step (number of samples processed) when in training mode

126ifself.mode.is\_train:127tracker.add\_global\_step(len(data))

#

Get the output logits, vi​, from the large model

130withtorch.no\_grad():131large\_logits=self.large(data)

#

Get the output logits, zi​, from the small model

134output=self.model(data)

#

Soft targets pi​=∑j​exp(Tvj​​)exp(Tvi​​)​

138soft\_targets=nn.functional.log\_softmax(large\_logits/self.temperature,dim=-1)

#

Temperature adjusted probabilities of the small model qi​=∑j​exp(Tzj​​)exp(Tzi​​)​

141soft\_prob=nn.functional.log\_softmax(output/self.temperature,dim=-1)

#

Calculate the soft targets loss

144soft\_targets\_loss=self.kl\_div\_loss(soft\_prob,soft\_targets)

#

Calculate the true label loss

146label\_loss=self.loss\_func(output,target)

#

Weighted sum of the two losses

148loss=self.soft\_targets\_weight\*soft\_targets\_loss+self.label\_loss\_weight\*label\_loss

#

Log the losses

150tracker.add({"loss.kl\_div.":soft\_targets\_loss,151"loss.nll":label\_loss,152"loss.":loss})

#

Calculate and log accuracy

155self.accuracy(output,target)156self.accuracy.track()

#

Train the model

159ifself.mode.is\_train:

#

Calculate gradients

161loss.backward()

#

Take optimizer step

163self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

165ifbatch\_idx.is\_last:166tracker.add('model',self.model)

#

Clear the gradients

168self.optimizer.zero\_grad()

#

Save the tracked metrics

171tracker.save()

#

Create large model

174@option(Configs.large)175def\_large\_model(c:Configs):

#

179returnLargeModel().to(c.device)

#

Create small model

182@option(Configs.model)183def\_small\_student\_model(c:Configs):

#

187returnSmallModel().to(c.device)

#

Load trained large model

190defget\_saved\_model(run\_uuid:str,checkpoint:int):

#

195fromlabml\_nn.distillation.largeimportConfigsasLargeConfigs

#

In evaluation mode (no recording)

198experiment.evaluate()

#

Initialize configs of the large model training experiment

200conf=LargeConfigs()

#

Load saved configs

202experiment.configs(conf,experiment.load\_configs(run\_uuid))

#

Set models for saving/loading

204experiment.add\_pytorch\_models({'model':conf.model})

#

Set which run and checkpoint to load

206experiment.load(run\_uuid,checkpoint)

#

Start the experiment - this will load the model, and prepare everything

208experiment.start()

#

Return the model

211returnconf.model

#

Train a small model with distillation

214defmain(run\_uuid:str,checkpoint:int):

#

Load saved model

219large\_model=get\_saved\_model(run\_uuid,checkpoint)

#

Create experiment

221experiment.create(name='distillation',comment='cifar10')

#

Create configurations

223conf=Configs()

#

Set the loaded large model

225conf.large=large\_model

#

Load configurations

227experiment.configs(conf,{228'optimizer.optimizer':'Adam',229'optimizer.learning\_rate':2.5e-4,230'model':'\_small\_student\_model',231})

#

Set model for saving/loading

233experiment.add\_pytorch\_models({'model':conf.model})

#

Start experiment from scratch

235experiment.load(None,None)

#

Start the experiment and run the training loop

237withexperiment.start():238conf.run()

#

242if\_\_name\_\_=='\_\_main\_\_':243main('d46cd53edaec11eb93c38d6538aee7d6',1\_000\_000)

labml.ai