docs/transformers/switch/experiment.html
This is an annotated PyTorch experiment to train a switch transformer.
14importtorch15importtorch.nnasnn1617fromlabmlimportexperiment,tracker18fromlabml.configsimportoption19fromlabml\_nn.helpers.trainerimportBatchIndex20fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs
23classAutoregressiveModel(nn.Module):
28def\_\_init\_\_(self,n\_vocab:int,d\_model:int,transformer:nn.Module):29super().\_\_init\_\_()
Token embedding module
31self.src\_embed=nn.Embedding(n\_vocab,d\_model)
Transformer
33self.transformer=transformer
Final layer
35self.generator=nn.Linear(d\_model,n\_vocab)36self.mask=None
38defforward(self,x:torch.Tensor):
Initialize the subsequent mask
40ifself.maskisNoneorself.mask.size(0)!=len(x):41fromlabml\_nn.transformers.utilsimportsubsequent\_mask42self.mask=subsequent\_mask(len(x)).to(x.device)
Token embeddings
44x=self.src\_embed(x)
Run it through the transformer
46res,counts,route\_prob,n\_dropped,route\_prob\_max=self.transformer(x,self.mask)
Generate logits of the next token
48res=self.generator(res)
50returnres,counts,route\_prob,n\_dropped,route\_prob\_max
This extends NLPAutoRegressionConfigs.
The default configs can and will be over-ridden when we start the experiment
53classConfigs(NLPAutoRegressionConfigs):
62model:AutoregressiveModel63transformer:nn.Module
Token embedding size
66d\_model:int=128
Number of attention heads
68heads:int=4
Dropout probability
70dropout:float=0.0
Number of features in FFN hidden layer
72d\_ff:int=256
Number of transformer layers
74n\_layers:int=6
Number of experts
76n\_experts:int=4
Load balancing coefficient
78load\_balancing\_loss\_ceof=0.01
Whether to scale the chosen expert outputs by the routing probability
80is\_scale\_prob:bool=True
Whether to drop tokens
82drop\_tokens:bool=False
Capacity factor to determine capacity of each model
84capacity\_factor:float=1.0
86definit(self):87super().init()
Initialize tracking indicators
89tracker.set\_scalar("lb\_loss.\*",False)90tracker.set\_scalar("route.\*",False)91tracker.set\_scalar("dropped.\*",False)
93defstep(self,batch:any,batch\_idx:BatchIndex):
Move data to the device
99data,target=batch[0].to(self.device),batch[1].to(self.device)
Update global step (number of tokens processed) when in training mode
102ifself.mode.is\_train:103tracker.add\_global\_step(data.shape[0]\*data.shape[1])
Get model outputs.
106output,counts,route\_prob,n\_dropped,route\_prob\_max=self.model(data)
Calculate and cross entropy loss
109cross\_entropy\_loss=self.loss\_func(output,target)
Total number of tokens processed, T, in the current batch B
111total=counts.sum(dim=-1,keepdims=True)
Fraction of tokens routed to each expert fi=T1x∈B∑1{argmaxp(x),i} fi is the count of tokens where the argmax of p(x) is equal to i.
115route\_frac=counts/total
Mean routing probability Pi=T1x∈B∑pi(x)
118route\_prob=route\_prob/total
Load balancing loss L=Ni=1∑Nfi⋅Pi L is the loss for a single layer and here we are taking the sum of losses across all layers.
123load\_balancing\_loss=self.n\_experts\*(route\_frac\*route\_prob).sum()
Track stats
126tracker.add('dropped.',total.new\_tensor(n\_dropped)/total)127tracker.add('route.min.',route\_frac.min())128tracker.add('route.max.',route\_frac.max())129tracker.add('route.std.',route\_frac.std())130tracker.add('route.max\_prob.',route\_prob\_max)131tracker.add("loss.",cross\_entropy\_loss)132tracker.add("lb\_loss.",load\_balancing\_loss)
Combined loss. The load balancing loss is multiplied by a coefficient α which is set to something small like α=0.01.
137loss=cross\_entropy\_loss+self.load\_balancing\_loss\_ceof\*load\_balancing\_loss
Calculate and log accuracy
140self.accuracy(output,target)141self.accuracy.track()
Train the model
144ifself.mode.is\_train:
Calculate gradients
146loss.backward()
Clip gradients
148torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)
Take optimizer step
150self.optimizer.step()
Log the model parameters and gradients on last batch of every epoch
152ifbatch\_idx.is\_last:153tracker.add('model',self.model)
Clear the gradients
155self.optimizer.zero\_grad()
Save the tracked metrics
158tracker.save()
161@option(Configs.model)162defautoregressive\_model(c:Configs):
166m=AutoregressiveModel(c.n\_tokens,c.d\_model,c.transformer)167returnm.to(c.device)
170@option(Configs.transformer)171defswitch\_transformer(c:Configs):
175fromlabml\_nn.transformers.switchimportSwitchTransformer,SwitchTransformerLayer,SwitchFeedForward176fromlabml\_nn.transformersimportMultiHeadAttention177fromlabml\_nn.transformers.feed\_forwardimportFeedForward178179returnSwitchTransformer(180SwitchTransformerLayer(d\_model=c.d\_model,181attn=MultiHeadAttention(c.heads,c.d\_model,c.dropout),182feed\_forward=SwitchFeedForward(capacity\_factor=c.capacity\_factor,183drop\_tokens=c.drop\_tokens,184is\_scale\_prob=c.is\_scale\_prob,185n\_experts=c.n\_experts,186expert=FeedForward(c.d\_model,c.d\_ff,c.dropout),187d\_model=c.d\_model),188dropout\_prob=c.dropout),189c.n\_layers)
192defmain():
Create experiment
197experiment.create(name="switch\_transformer",comment='')
Create configs
199conf=Configs()
Load configurations
201experiment.configs(conf,
A dictionary of configurations to override
203{'tokenizer':'character',204'text':'tiny\_shakespeare',205'optimizer.learning\_rate':1.,206'optimizer.optimizer':'Noam',207'prompt':'It is',208'prompt\_separator':'',209210'transformer':'switch\_transformer',211'n\_experts':4,212213'drop\_tokens':True,214'capacity\_factor':1.2,215216'train\_loader':'shuffled\_train\_loader',217'valid\_loader':'shuffled\_valid\_loader',218219'seq\_len':64,220'epochs':128,221'batch\_size':32,222'inner\_iterations':25,223})
Set models for saving and loading
226experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
229withexperiment.start():
TrainValidConfigs.run
231conf.run()
235if\_\_name\_\_=='\_\_main\_\_':236main()