Back to Annotated Deep Learning Paper Implementations

Transformer Auto-Regression Experiment with Sophia-G optimizer

docs/transformers/basic/with_sophia.html

latest5.3 KB
Original Source

hometransformersbasic

View code on Github

#

Transformer Auto-Regression Experiment with Sophia-G optimizer

This trains a simple transformer introduced in Attention Is All You Need on an NLP auto-regression task (with Tiny Shakespeare dataset) with Sophia-G optimizer.

13importtorch1415fromlabmlimportexperiment,tracker16fromlabml\_nn.helpers.trainerimportBatchIndex17fromlabml\_nn.optimizers.sophiaimportSophia18fromlabml\_nn.transformers.basic.autoregressive\_experimentimportConfigsasTransformerAutoRegressionConfigs

#

Configurations

This inherits from Configs

21classConfigs(TransformerAutoRegressionConfigs):

#

28hess\_interval:int=102930optimizer:Sophia

#

Training or validation step with Gauss-Newton-Bartlett (GNB) Hessian diagonal estimator

32defstep(self,batch:any,batch\_idx:BatchIndex):

#

Set training/eval mode

38self.model.train(self.mode.is\_train)

#

Move data to the device

41data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Estimate the Hessian diagonal every k steps

44ifisinstance(self.optimizer,Sophia)andself.mode.is\_trainandbatch\_idx.idx%self.hess\_interval==0:

#

Get model outputs

46output,\*\_=self.model(data)

#

Create a categorical distribution from logits

49samp\_dist=torch.distributions.Categorical(logits=output)

#

Sample y^​

51y\_sample=samp\_dist.sample()

#

Calculate and log loss

54loss=self.loss\_func(output,y\_sample)55tracker.add("loss.hess.",loss)

#

Calculate gradients

58loss.backward()

#

Clip gradients

60torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Update EMA Hessian diagonal

h^t​ht​​=B⋅∇θ​L^(θ)⊙∇θ​L^(θ)=β2​ht−k​+(1−β2​)h^t​​

67self.optimizer.update\_hessian(data.numel())

#

Clear the gradients

69self.optimizer.zero\_grad()70else:

#

Move data to the device

72data,target=batch[0].to(self.device),batch[1].to(self.device)

#

Update global step (number of tokens processed) when in training mode

75ifself.mode.is\_train:76tracker.add\_global\_step(data.shape[0]\*data.shape[1])

#

Get model outputs. It's returning a tuple for states when using RNNs. This is not implemented yet. 😜

81output,\*\_=self.model(data)

#

Calculate and log loss

84loss=self.loss\_func(output,target)85tracker.add("loss.",loss)

#

Calculate and log accuracy

88self.accuracy(output,target)89self.accuracy.track()9091self.other\_metrics(output,target)

#

Train the model

94ifself.mode.is\_train:

#

Calculate gradients

96loss.backward()

#

Clip gradients

98torch.nn.utils.clip\_grad\_norm\_(self.model.parameters(),max\_norm=self.grad\_norm\_clip)

#

Take optimizer step

100self.optimizer.step()

#

Log the model parameters and gradients on last batch of every epoch

102ifbatch\_idx.is\_lastandself.is\_log\_model\_params\_grads:103tracker.add('model',self.model)

#

Clear the gradients

105self.optimizer.zero\_grad()

#

Save the tracked metrics

108tracker.save()

#

111defmain():

#

Create experiment

113experiment.create(name="transformer")

#

Create configs

115conf=Configs()

#

Override configurations

117experiment.configs(conf,{

#

Use character level tokenizer

119'tokenizer':'character',

#

Prompt separator is blank

121'prompt\_separator':'',

#

Starting prompt for sampling

123'prompt':'It is ',

#

Use Tiny Shakespeare dataset

125'text':'tiny\_shakespeare',

#

Use a context size of 256

128'seq\_len':512,

#

Train for 32 epochs

130'epochs':32,

#

Batch size 32

132'batch\_size':16,

#

Switch between training and validation for 10 times per epoch

135'inner\_iterations':10,

#

Model size

138'd\_model':256,139'transformer.n\_heads':16,140'transformer.ffn.d\_ff':1024,

#

Use Sophia optimizer

143'optimizer.optimizer':'Sophia',144'optimizer.learning\_rate':3e-4,145'optimizer.rho':0.03,146})

#

Set models for saving and loading

149experiment.add\_pytorch\_models({'model':conf.model})

#

Start the experiment

152withexperiment.start():

#

Run training

154conf.run()

#

158if\_\_name\_\_=='\_\_main\_\_':159main()

labml.ai