Back to Annotated Deep Learning Paper Implementations

finetune.py

docs/neox/utils/finetune.html

latest2.1 KB
Original Source

homeneoxutils

View code on Github

#

1fromtypingimportList,Dict23importtorch4fromtorchimportnn56fromlabml\_nn.neox.modelimportTransformerLayer,NeoXModule

#

9classFineTuner:

#

10def\_\_init\_\_(self,layers:List[NeoXModule]):11self.layers=layers

#

13defget\_trainable\_params(self)-\>Dict[str,nn.Parameter]:14params={}15fori,layerinenumerate(self.layers):16params.update(self.get\_layer\_trainable\_params(layer,prefix=f'layer\_{i :02d}'))1718returnparams

#

20defget\_layer\_trainable\_params(self,layer:NeoXModule,prefix:str)-\>Dict[str,nn.Parameter]:21raiseNotImplementedError

#

23defset\_trainable\_params(self):24forlayerinself.layers:

#

Set requires_grad to False for the entire layer.

26layer.requires\_grad\_(False)

#

28forpinself.get\_trainable\_params().values():29p.requires\_grad\_(True)

#

31defstate\_dict(self):32return{n:p.data.cpu()forn,pinself.get\_trainable\_params().items()}

#

34defload\_state\_dict(self,state\_dict:Dict[str,torch.Tensor]):35params=self.get\_trainable\_params()36forn,pinparams.items():37p.data[:]=state\_dict[n].to(p.data.device)3839forninstate\_dict.keys():40assertninparams,n

#

43classFineTuneBiases(FineTuner):

#

44defget\_layer\_trainable\_params(self,layer:NeoXModule,prefix:str)-\>Dict[str,nn.Parameter]:45params={}4647ifisinstance(layer,TransformerLayer):

#

No need to train the mlp bias because we are adding it with attention output

49params[f'{prefix}.attention.output.bias']=layer.attention.output.bias50params[f'{prefix}.attention.qkv\_lin.bias']=layer.attention.qkv\_lin.bias51params[f'{prefix}.ffn.dense\_h\_h4.bias']=layer.ffn.dense\_h\_h4.bias52else:53pass5455returnparams

labml.ai