Back to Annotated Deep Learning Paper Implementations

Switch Transformer

docs/transformers/switch/index.html

latest9.5 KB
Original Source

hometransformersswitch

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/switch/ init.py)

#

Switch Transformer

This is a miniature PyTorch implementation of the paper Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. Our implementation only has a few million parameters and doesn't do model parallel distributed training. It does single GPU training, but we implement the concept of switching as described in the paper.

The Switch Transformer uses different parameters for each token by switching among parameters based on the token. Therefore, only a fraction of parameters are chosen for each token. So you can have more parameters but less computational cost.

The switching happens at the Position-wise Feedforward network (FFN) of each transformer block. Position-wise feedforward network consists of two sequentially fully connected layers. In switch transformer we have multiple FFNs (multiple experts), and we chose which one to use based on a router. The output is a set of probabilities for picking a FFN, and we pick the one with the highest probability and only evaluate that. So essentially the computational cost is the same as having a single FFN. In our implementation this doesn't parallelize well when you have many or large FFNs since it's all happening on a single GPU. In a distributed setup you would have each FFN (each very large) on a different device.

The paper introduces another loss term to balance load among the experts (FFNs) and discusses dropping tokens when routing is not balanced.

Here's the training code and a notebook for training a switch transformer on Tiny Shakespeare dataset.

39importtorch40fromtorchimportnn4142fromlabml\_nn.transformers.feed\_forwardimportFeedForward43fromlabml\_nn.transformers.mhaimportMultiHeadAttention44fromlabml\_nn.utilsimportclone\_module\_list

#

Routing among multiple FFNs

47classSwitchFeedForward(nn.Module):

#

  • capacity_factor is the capacity of each expert as a factor relative to ideally balanced load
  • drop_tokens specifies whether to drop tokens if more tokens are routed to an expert than the capacity
  • is_scale_prob specifies whether to multiply the input to the FFN by the routing probability
  • n_experts is the number of experts
  • expert is the expert layer, a FFN module
  • d_model is the number of features in a token embedding
  • d_ff is the number of features in the hidden layer of the FFN
  • dropout is dropout probability in the FFN
52def\_\_init\_\_(self,\*,53capacity\_factor:float,54drop\_tokens:bool,55is\_scale\_prob:bool,56n\_experts:int,57expert:FeedForward,58d\_model:int):

#

69super().\_\_init\_\_()7071self.capacity\_factor=capacity\_factor72self.is\_scale\_prob=is\_scale\_prob73self.n\_experts=n\_experts74self.drop\_tokens=drop\_tokens

#

make copies of the FFNs

77self.experts=clone\_module\_list(expert,n\_experts)

#

Routing layer and softmax

79self.switch=nn.Linear(d\_model,n\_experts)80self.softmax=nn.Softmax(dim=-1)

#

  • x is the input to the switching module with shape [seq_len, batch_size, d_model]
82defforward(self,x:torch.Tensor):

#

Capture the shape to change shapes later

88seq\_len,batch\_size,d\_model=x.shape

#

Flatten the sequence and batch dimensions

90x=x.view(-1,d\_model)

#

Get routing probabilities for each of the tokens. pi​(x)=∑jN​eh(x)j​eh(x)i​​ where N is the number of experts n_experts and h(⋅) is the linear transformation of token embeddings.

96route\_prob=self.softmax(self.switch(x))

#

Get the maximum routing probabilities and the routes. We route to the expert with highest probability

100route\_prob\_max,routes=torch.max(route\_prob,dim=-1)

#

Get indexes of tokens going to each expert

103indexes\_list=[torch.eq(routes,i).nonzero(as\_tuple=True)[0]foriinrange(self.n\_experts)]

#

Initialize an empty tensor to store outputs

106final\_output=x.new\_zeros(x.shape)

#

Capacity of each expert. expertcapacity=numberofexpertstokensperbatch​×capacityfactor

112capacity=int(self.capacity\_factor\*len(x)/self.n\_experts)

#

Number of tokens routed to each expert.

114counts=x.new\_tensor([len(indexes\_list[i])foriinrange(self.n\_experts)])

#

Initialize an empty list of dropped tokens

117dropped=[]

#

Only drop tokens if drop_tokens is True .

119ifself.drop\_tokens:

#

Drop tokens in each of the experts

121foriinrange(self.n\_experts):

#

Ignore if the expert is not over capacity

123iflen(indexes\_list[i])\<=capacity:124continue

#

Shuffle indexes before dropping

126indexes\_list[i]=indexes\_list[i][torch.randperm(len(indexes\_list[i]))]

#

Collect the tokens over capacity as dropped tokens

128dropped.append(indexes\_list[i][capacity:])

#

Keep only the tokens upto the capacity of the expert

130indexes\_list[i]=indexes\_list[i][:capacity]

#

Get outputs of the expert FFNs

133expert\_output=[self.experts[i](x[indexes\_list[i],:])foriinrange(self.n\_experts)]

#

Assign to final output

136foriinrange(self.n\_experts):137final\_output[indexes\_list[i],:]=expert\_output[i]

#

Pass through the dropped tokens

140ifdropped:141dropped=torch.cat(dropped)142final\_output[dropped,:]=x[dropped,:]143144ifself.is\_scale\_prob:

#

Multiply by the expert outputs by the probabilities y=pi​(x)Ei​(x)

146final\_output=final\_output\*route\_prob\_max.view(-1,1)147else:

#

Don't scale the values but multiply by p^​p​=1 so that the gradients flow (this is something we experimented with).

150final\_output=final\_output\*(route\_prob\_max/route\_prob\_max.detach()).view(-1,1)

#

Change the shape of the final output back to [seq_len, batch_size, d_model]

153final\_output=final\_output.view(seq\_len,batch\_size,d\_model)

#

Return

  • the final output
  • number of tokens routed to each expert
  • sum of probabilities for each expert
  • number of tokens dropped.
  • routing probabilities of the selected experts

These are used for the load balancing loss and logging

164returnfinal\_output,counts,route\_prob.sum(0),len(dropped),route\_prob\_max

#

Switch Transformer Block

This is the same as normal transformer block with handling extra outputs of switch feedforward module.

167classSwitchTransformerLayer(nn.Module):

#

  • d_model is the token embedding size
  • attn is the attention module
  • feed_forward is the feed forward module (which is the switching module in this case)
  • dropout_prob is the probability of dropping out after self attention and FFN
175def\_\_init\_\_(self,\*,176d\_model:int,177attn:MultiHeadAttention,178feed\_forward:SwitchFeedForward,179dropout\_prob:float):

#

186super().\_\_init\_\_()187self.size=d\_model188self.attn=attn189self.feed\_forward=feed\_forward190self.dropout=nn.Dropout(dropout\_prob)191self.norm\_self\_attn=nn.LayerNorm([d\_model])192self.norm\_ff=nn.LayerNorm([d\_model])

#

194defforward(self,\*,195x:torch.Tensor,196mask:torch.Tensor):

#

Normalize the vectors before doing self attention

198z=self.norm\_self\_attn(x)

#

Run through self attention, i.e. keys and values are from self

200self\_attn=self.attn(query=z,key=z,value=z,mask=mask)

#

Add the self attention results

202x=x+self.dropout(self\_attn)

#

Normalize for feed-forward

205z=self.norm\_ff(x)

#

Pass through the switching feed-forward network

207ff,counts,route\_prob,n\_dropped,route\_prob\_max=self.feed\_forward(z)

#

Add the feed-forward results back

209x=x+self.dropout(ff)210211returnx,counts,route\_prob,n\_dropped,route\_prob\_max

#

Switch Transformer

214classSwitchTransformer(nn.Module):

#

219def\_\_init\_\_(self,layer:SwitchTransformerLayer,n\_layers:int):220super().\_\_init\_\_()

#

Make copies of the transformer layer

222self.layers=clone\_module\_list(layer,n\_layers)

#

Final normalization layer

224self.norm=nn.LayerNorm([layer.size])

#

226defforward(self,x:torch.Tensor,mask:torch.Tensor):

#

Run through each transformer layer

228counts,route\_prob,n\_dropped,route\_prob\_max=[],[],[],[]229forlayerinself.layers:230x,f,p,n\_d,p\_max=layer(x=x,mask=mask)231counts.append(f)232route\_prob.append(p)233n\_dropped.append(n\_d)234route\_prob\_max.append(p\_max)

#

Finally, normalize the vectors

236x=self.norm(x)

#

238returnx,torch.stack(counts),torch.stack(route\_prob),n\_dropped,torch.stack(route\_prob\_max)

labml.ai