docs/transformers/switch/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/switch/ init.py)
This is a miniature PyTorch implementation of the paper Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn't do model parallel distributed training. It does single GPU training, but we implement the concept of switching as described in the paper.
The Switch Transformer uses different parameters for each token by switching among parameters based on the token. Therefore, only a fraction of parameters are chosen for each token. So you can have more parameters but less computational cost.
The switching happens at the Position-wise Feedforward network (FFN) of each transformer block. Position-wise feedforward network consists of two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts), and we chose which one to use based on a router. The output is a set of probabilities for picking a FFN, and we pick the one with the highest probability and only evaluate that. So essentially the computational cost is the same as having a single FFN. In our implementation this doesn't parallelize well when you have many or large FFNs since it's all happening on a single GPU. In a distributed setup you would have each FFN (each very large) on a different device.
The paper introduces another loss term to balance load among the experts (FFNs) and discusses dropping tokens when routing is not balanced.
Here's the training code and a notebook for training a switch transformer on Tiny Shakespeare dataset.
39importtorch40fromtorchimportnn4142fromlabml\_nn.transformers.feed\_forwardimportFeedForward43fromlabml\_nn.transformers.mhaimportMultiHeadAttention44fromlabml\_nn.utilsimportclone\_module\_list
47classSwitchFeedForward(nn.Module):
capacity_factor is the capacity of each expert as a factor relative to ideally balanced loaddrop_tokens specifies whether to drop tokens if more tokens are routed to an expert than the capacityis_scale_prob specifies whether to multiply the input to the FFN by the routing probabilityn_experts is the number of expertsexpert is the expert layer, a FFN moduled_model is the number of features in a token embeddingd_ff is the number of features in the hidden layer of the FFNdropout is dropout probability in the FFN52def\_\_init\_\_(self,\*,53capacity\_factor:float,54drop\_tokens:bool,55is\_scale\_prob:bool,56n\_experts:int,57expert:FeedForward,58d\_model:int):
69super().\_\_init\_\_()7071self.capacity\_factor=capacity\_factor72self.is\_scale\_prob=is\_scale\_prob73self.n\_experts=n\_experts74self.drop\_tokens=drop\_tokens
make copies of the FFNs
77self.experts=clone\_module\_list(expert,n\_experts)
Routing layer and softmax
79self.switch=nn.Linear(d\_model,n\_experts)80self.softmax=nn.Softmax(dim=-1)
x is the input to the switching module with shape [seq_len, batch_size, d_model]82defforward(self,x:torch.Tensor):
Capture the shape to change shapes later
88seq\_len,batch\_size,d\_model=x.shape
Flatten the sequence and batch dimensions
90x=x.view(-1,d\_model)
Get routing probabilities for each of the tokens. pi(x)=∑jNeh(x)jeh(x)i where N is the number of experts n_experts and h(⋅) is the linear transformation of token embeddings.
96route\_prob=self.softmax(self.switch(x))
Get the maximum routing probabilities and the routes. We route to the expert with highest probability
100route\_prob\_max,routes=torch.max(route\_prob,dim=-1)
Get indexes of tokens going to each expert
103indexes\_list=[torch.eq(routes,i).nonzero(as\_tuple=True)[0]foriinrange(self.n\_experts)]
Initialize an empty tensor to store outputs
106final\_output=x.new\_zeros(x.shape)
Capacity of each expert. expertcapacity=numberofexpertstokensperbatch×capacityfactor
112capacity=int(self.capacity\_factor\*len(x)/self.n\_experts)
Number of tokens routed to each expert.
114counts=x.new\_tensor([len(indexes\_list[i])foriinrange(self.n\_experts)])
Initialize an empty list of dropped tokens
117dropped=[]
Only drop tokens if drop_tokens is True .
119ifself.drop\_tokens:
Drop tokens in each of the experts
121foriinrange(self.n\_experts):
Ignore if the expert is not over capacity
123iflen(indexes\_list[i])\<=capacity:124continue
Shuffle indexes before dropping
126indexes\_list[i]=indexes\_list[i][torch.randperm(len(indexes\_list[i]))]
Collect the tokens over capacity as dropped tokens
128dropped.append(indexes\_list[i][capacity:])
Keep only the tokens upto the capacity of the expert
130indexes\_list[i]=indexes\_list[i][:capacity]
Get outputs of the expert FFNs
133expert\_output=[self.experts[i](x[indexes\_list[i],:])foriinrange(self.n\_experts)]
Assign to final output
136foriinrange(self.n\_experts):137final\_output[indexes\_list[i],:]=expert\_output[i]
Pass through the dropped tokens
140ifdropped:141dropped=torch.cat(dropped)142final\_output[dropped,:]=x[dropped,:]143144ifself.is\_scale\_prob:
Multiply by the expert outputs by the probabilities y=pi(x)Ei(x)
146final\_output=final\_output\*route\_prob\_max.view(-1,1)147else:
Don't scale the values but multiply by p^p=1 so that the gradients flow (this is something we experimented with).
150final\_output=final\_output\*(route\_prob\_max/route\_prob\_max.detach()).view(-1,1)
Change the shape of the final output back to [seq_len, batch_size, d_model]
153final\_output=final\_output.view(seq\_len,batch\_size,d\_model)
Return
These are used for the load balancing loss and logging
164returnfinal\_output,counts,route\_prob.sum(0),len(dropped),route\_prob\_max
This is the same as normal transformer block with handling extra outputs of switch feedforward module.
167classSwitchTransformerLayer(nn.Module):
d_model is the token embedding sizeattn is the attention modulefeed_forward is the feed forward module (which is the switching module in this case)dropout_prob is the probability of dropping out after self attention and FFN175def\_\_init\_\_(self,\*,176d\_model:int,177attn:MultiHeadAttention,178feed\_forward:SwitchFeedForward,179dropout\_prob:float):
186super().\_\_init\_\_()187self.size=d\_model188self.attn=attn189self.feed\_forward=feed\_forward190self.dropout=nn.Dropout(dropout\_prob)191self.norm\_self\_attn=nn.LayerNorm([d\_model])192self.norm\_ff=nn.LayerNorm([d\_model])
194defforward(self,\*,195x:torch.Tensor,196mask:torch.Tensor):
Normalize the vectors before doing self attention
198z=self.norm\_self\_attn(x)
Run through self attention, i.e. keys and values are from self
200self\_attn=self.attn(query=z,key=z,value=z,mask=mask)
Add the self attention results
202x=x+self.dropout(self\_attn)
Normalize for feed-forward
205z=self.norm\_ff(x)
Pass through the switching feed-forward network
207ff,counts,route\_prob,n\_dropped,route\_prob\_max=self.feed\_forward(z)
Add the feed-forward results back
209x=x+self.dropout(ff)210211returnx,counts,route\_prob,n\_dropped,route\_prob\_max
214classSwitchTransformer(nn.Module):
219def\_\_init\_\_(self,layer:SwitchTransformerLayer,n\_layers:int):220super().\_\_init\_\_()
Make copies of the transformer layer
222self.layers=clone\_module\_list(layer,n\_layers)
Final normalization layer
224self.norm=nn.LayerNorm([layer.size])
226defforward(self,x:torch.Tensor,mask:torch.Tensor):
Run through each transformer layer
228counts,route\_prob,n\_dropped,route\_prob\_max=[],[],[],[]229forlayerinself.layers:230x,f,p,n\_d,p\_max=layer(x=x,mask=mask)231counts.append(f)232route\_prob.append(p)233n\_dropped.append(n\_d)234route\_prob\_max.append(p\_max)
Finally, normalize the vectors
236x=self.norm(x)
238returnx,torch.stack(counts),torch.stack(route\_prob),n\_dropped,torch.stack(route\_prob\_max)