Back to Annotated Deep Learning Paper Implementations

Switch Transformer Experiment

docs/transformers/switch/experiment.html

latest7.9 KB
Original Source

hometransformersswitch

View code on Github

#

Switch Transformer Experiment

This is an annotated PyTorch experiment to train a switch transformer.

14importtorch15importtorch.nnasnn1617fromlabmlimportexperiment,tracker18fromlabml.configsimportoption19fromlabml\_nn.helpers.trainerimportBatchIndex20fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs

#

Auto regressive model

23classAutoregressiveModel(nn.Module):

#

28def\_\_init\_\_(self,n\_vocab:int,d\_model:int,transformer:nn.Module):29super().\_\_init\_\_()

#

Token embedding module

31self.src\_embed=nn.Embedding(n\_vocab,d\_model)

#

Transformer

33self.transformer=transformer

#

Final layer

35self.generator=nn.Linear(d\_model,n\_vocab)36self.mask=None

#

38defforward(self,x:torch.Tensor):

#

Initialize the subsequent mask

40ifself.maskisNoneorself.mask.size(0)!=len(x):41fromlabml\_nn.transformers.utilsimportsubsequent\_mask42self.mask=subsequent\_mask(len(x)).to(x.device)

#

Token embeddings

44x=self.src\_embed(x)

#

Run it through the transformer

46res,counts,route\_prob,n\_dropped,route\_prob\_max=self.transformer(x,self.mask)

#

Generate logits of the next token

48res=self.generator(res)

#

50returnres,counts,route\_prob,n\_dropped,route\_prob\_max

#

Configurations

This extends NLPAutoRegressionConfigs.

The default configs can and will be over-ridden when we start the experiment

53classConfigs(NLPAutoRegressionConfigs):

#

62model:AutoregressiveModel63transformer:nn.Module

#

Token embedding size

66d\_model:int=128

#

Number of attention heads

68heads:int=4

#

Dropout probability

70dropout:float=0.0

#

Number of features in FFN hidden layer

72d\_ff:int=256

#

Number of transformer layers

74n\_layers:int=6

#

Number of experts

76n\_experts:int=4

#

Load balancing coefficient

78load\_balancing\_loss\_ceof=0.01

#

Whether to scale the chosen expert outputs by the routing probability

80is\_scale\_prob:bool=True

#

Whether to drop tokens

82drop\_tokens:bool=False

#

Capacity factor to determine capacity of each model

84capacity\_factor:float=1.0

#

86definit(self):87super().init()

#

Initialize tracking indicators

89tracker.set\_scalar("lb\_loss.\*",False)90tracker.set\_scalar("route.\*",False)91tracker.set\_scalar("dropped.\*",False)

#

Training or validation step

93defstep(self,batch:any,batch\_idx:BatchIndex):

#

Move data to the device

99data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Update global step (number of tokens processed) when in training mode

102ifself.mode.is\_train:103tracker.add\_global\_step(data.shape[0]\*data.shape[1])

#

Get model outputs.

106output,counts,route\_prob,n\_dropped,route\_prob\_max=self.model(data)

#

Calculate and cross entropy loss

109cross\_entropy\_loss=self.loss\_func(output,target)

#

Total number of tokens processed, T, in the current batch B

111total=counts.sum(dim=-1,keepdims=True)

#

Fraction of tokens routed to each expert fi​=T1​x∈B∑​1{argmaxp(x),i} fi​ is the count of tokens where the argmax of p(x) is equal to i.

115route\_frac=counts/total

#

Mean routing probability Pi​=T1​x∈B∑​pi​(x)

118route\_prob=route\_prob/total

#

Load balancing loss L=Ni=1∑N​fi​⋅Pi​ L is the loss for a single layer and here we are taking the sum of losses across all layers.

123load\_balancing\_loss=self.n\_experts\*(route\_frac\*route\_prob).sum()

#

Track stats

126tracker.add('dropped.',total.new\_tensor(n\_dropped)/total)127tracker.add('route.min.',route\_frac.min())128tracker.add('route.max.',route\_frac.max())129tracker.add('route.std.',route\_frac.std())130tracker.add('route.max\_prob.',route\_prob\_max)131tracker.add("loss.",cross\_entropy\_loss)132tracker.add("lb\_loss.",load\_balancing\_loss)

#

Combined loss. The load balancing loss is multiplied by a coefficient α which is set to something small like α=0.01.

137loss=cross\_entropy\_loss+self.load\_balancing\_loss\_ceof\*load\_balancing\_loss

#

Calculate and log accuracy

140self.accuracy(output,target)141self.accuracy.track()

#

Train the model

144ifself.mode.is\_train:

#

Calculate gradients

146loss.backward()

#

Clip gradients

148torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Take optimizer step

150self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

152ifbatch\_idx.is\_last:153tracker.add('model',self.model)

#

Clear the gradients

155self.optimizer.zero\_grad()

#

Save the tracked metrics

158tracker.save()

#

Initialize the auto-regressive model

161@option(Configs.model)162defautoregressive\_model(c:Configs):

#

166m=AutoregressiveModel(c.n\_tokens,c.d\_model,c.transformer)167returnm.to(c.device)

#

Initialize the switch transformer

170@option(Configs.transformer)171defswitch\_transformer(c:Configs):

#

175fromlabml\_nn.transformers.switchimportSwitchTransformer,SwitchTransformerLayer,SwitchFeedForward176fromlabml\_nn.transformersimportMultiHeadAttention177fromlabml\_nn.transformers.feed\_forwardimportFeedForward178179returnSwitchTransformer(180SwitchTransformerLayer(d\_model=c.d\_model,181attn=MultiHeadAttention(c.heads,c.d\_model,c.dropout),182feed\_forward=SwitchFeedForward(capacity\_factor=c.capacity\_factor,183drop\_tokens=c.drop\_tokens,184is\_scale\_prob=c.is\_scale\_prob,185n\_experts=c.n\_experts,186expert=FeedForward(c.d\_model,c.d\_ff,c.dropout),187d\_model=c.d\_model),188dropout\_prob=c.dropout),189c.n\_layers)

#

Run the experiment

192defmain():

#

Create experiment

197experiment.create(name="switch\_transformer",comment='')

#

Create configs

199conf=Configs()

#

Load configurations

201experiment.configs(conf,

#

A dictionary of configurations to override

203{'tokenizer':'character',204'text':'tiny\_shakespeare',205'optimizer.learning\_rate':1.,206'optimizer.optimizer':'Noam',207'prompt':'It is',208'prompt\_separator':'',209210'transformer':'switch\_transformer',211'n\_experts':4,212213'drop\_tokens':True,214'capacity\_factor':1.2,215216'train\_loader':'shuffled\_train\_loader',217'valid\_loader':'shuffled\_valid\_loader',218219'seq\_len':64,220'epochs':128,221'batch\_size':32,222'inner\_iterations':25,223})

#

Set models for saving and loading

226experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment

229withexperiment.start():

#

TrainValidConfigs.run

231conf.run()

#

235if\_\_name\_\_=='\_\_main\_\_':236main()

labml.ai