docs/transformers/gpt/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/gpt/ init.py)
This is a tutorial/implementation of OpenAI GPT architecture in PyTorch. We got a bunch of implementation details from minGPT by @karpathy. This implementation also uses character tiny shakespeare dataset.
GPT model is essentially a standard transformer with a few tweaks. GPT-2 and especially GPT-3 models are quite large and won't fit on a single GPU and will need model parallelism. This implementation doesn't even use data parallelism and is intended to be more of a tutorial.
Main differences of this compared to a simple autoregressive transformer are the parameter initialization, weight decay, and learning rate schedule. For the transformer we reuse the existing labml/nn transformer implementation.
Here's a notebook for training a GPT model on Tiny Shakespeare dataset.
34importtorch35fromtorchimportnn3637fromlabmlimportexperiment38fromlabml.configsimportoption39fromlabml\_nn.experiments.nlp\_autoregressionimportNLPAutoRegressionConfigs40fromlabml\_nn.optimizers.configsimportOptimizerConfigs41fromlabml\_nn.transformersimportTransformerConfigs,Encoder42fromlabml\_nn.transformers.utilsimportsubsequent\_mask
This consists of a token embedding layer, transformer encoder, and a final linear layer that gives token logits.
45classGPT(nn.Module):
encoder is the transformer Encodersrc_embed is the token embedding module (with positional encodings)generator is the final fully connected layer that gives the logits.53def\_\_init\_\_(self,encoder:Encoder,src\_embed:nn.Module,generator:nn.Module):
60super().\_\_init\_\_()61self.src\_embed=src\_embed62self.encoder=encoder63self.generator=generator
The mask will be initialized on the first call
66self.mask=None
68defforward(self,x:torch.Tensor):
Create subsequent mask if mask is not initialized or if the size of the mask is different
71ifself.maskisNoneorself.mask.size(0)!=len(x):
Subsequent mask, will mask out tokens from seeing future tokens
73self.mask=subsequent\_mask(len(x)).to(x.device)
Get the token embeddings with positional encodings
75x=self.src\_embed(x)
Transformer encoder
77x=self.encoder(x,self.mask)
Get logits
79x=self.generator(x)
Return results (second value is for state, since our trainer is used with RNNs also)
83returnx,None
This inherits from NLPAutoRegressionConfigs
86classConfigs(NLPAutoRegressionConfigs):
GPT model
95model:GPT
Transformer
97transformer:TransformerConfigs
Weight decay
99weight\_decay:float=0.1
Number of tokens for wamup
101warmup\_steps:int=128\*128\*20
Custom optimizer
104optimizer='transformer\_optimizer'
107@option(Configs.transformer,'GPT')108def\_transformer\_configs(c:Configs):
We use our configurable transformer implementation
115conf=TransformerConfigs()
Set the vocabulary sizes for embeddings and generating logits
117conf.n\_src\_vocab=c.n\_tokens118conf.n\_tgt\_vocab=c.n\_tokens
GPT uses GELU activation for position wise feedforward
120conf.ffn.activation='GELU'
123returnconf
Weights of linear layers and embedding layers are initialized to N(0,0.02) instead of the default Xavier initialzation.
126def\_init\_weights(module):
135ifnotisinstance(module,(nn.Linear,nn.Embedding)):136return137138module.weight.data.normal\_(mean=0.0,std=0.02)
Initialize biases to 0
141ifisinstance(module,nn.Linear)andmodule.biasisnotNone:142module.bias.data.zero\_()
Create GPT model and initialize weights
145@option(Configs.model)146def\_model(c:Configs):
150m=GPT(c.transformer.encoder,151c.transformer.src\_embed,152c.transformer.generator).to(c.device)
Apply custom weight initialization
155m.apply(\_init\_weights)156157returnm
This code is taken from minGPT. This applies weight decay only to weights of linear layers.
160@option(NLPAutoRegressionConfigs.optimizer)161deftransformer\_optimizer(c:NLPAutoRegressionConfigs):
Collect names of parameters to apply weight decay
169decay=set()170formn,minc.model.named\_modules():171forpn,pinm.named\_parameters():172fpn=f'{mn}.{pn}'ifmnelsepn# full param name173174iffpn.endswith('weight')andisinstance(m,nn.Linear):175decay.add(fpn)
Get all the parameters
178param\_dict={pn:pforpn,pinc.model.named\_parameters()}
Parameters that are not decayed
180no\_decay=set(param\_dict.keys())-decay
create the pytorch optimizer object
183opt\_groups=[184{"params":[param\_dict[pn]forpninsorted(list(decay))],"weight\_decay":c.weight\_decay},185{"params":[param\_dict[pn]forpninsorted(list(no\_decay))],"weight\_decay":0.0},186]
Create a configurable optimizer, so that we can change these simply by passing a config dictionary.
191optimizer=OptimizerConfigs()
Set parameter groups for optimization.
194optimizer.parameters=opt\_groups
Use cosine decay optimizer. This is what GPT uses.
197optimizer.optimizer='AdamWarmupCosineDecay'
Set model embedding size, required if we use Noam optimizer which has an exponential decay.
200optimizer.d\_model=c.d\_model
Set default weight decay. This is not required since we set the weight decay in the parameter groups.
203optimizer.weight\_decay=c.weight\_decay
GPT uses a maximum learning rate of 6×10−4.
205optimizer.learning\_rate=6e-4
β1=0.9,β2=0.95
207optimizer.betas=(0.9,0.95)
ϵ=10−8
209optimizer.eps=1e-8
Weight decay is decoupled from gradients
211optimizer.weight\_decouple=True
Total number of optimization steps for learning rate cosine decay
213optimizer.total\_steps=c.epochs\*len(c.text.train)//(c.batch\_size\*c.seq\_len)
Number of warmup optimization steps
215optimizer.warmup=c.warmup\_steps//(c.batch\_size\*c.seq\_len)216217returnoptimizer
220defmain():
Create experiment
222experiment.create(name="gpt")
Create configs
224conf=Configs()
Override configurations
226experiment.configs(conf,{
Use character level tokenizer
228'tokenizer':'character',
Prompt separator is blank
230'prompt\_separator':'',
Starting prompt for sampling
232'prompt':'It is ',
Use Tiny Shakespeare dataset
234'text':'tiny\_shakespeare',
Use a context size of 128
237'seq\_len':128,
Train for 32 epochs
239'epochs':32,
Batch size 128
241'batch\_size':128,
Switch between training and validation for 10 times per epoch
244'inner\_iterations':10,
Transformer configurations
247'transformer.d\_model':512,248'transformer.ffn.d\_ff':2048,249'transformer.n\_heads':8,250'transformer.n\_layers':6251})
Set models for saving and loading
254experiment.add\_pytorch\_models({'model':conf.model})
Start the experiment
257withexperiment.start():
Run training
259conf.run()
263if\_\_name\_\_=='\_\_main\_\_':264main()