docs/neox/utils/finetune.html
1fromtypingimportList,Dict23importtorch4fromtorchimportnn56fromlabml\_nn.neox.modelimportTransformerLayer,NeoXModule
9classFineTuner:
10def\_\_init\_\_(self,layers:List[NeoXModule]):11self.layers=layers
13defget\_trainable\_params(self)-\>Dict[str,nn.Parameter]:14params={}15fori,layerinenumerate(self.layers):16params.update(self.get\_layer\_trainable\_params(layer,prefix=f'layer\_{i :02d}'))1718returnparams
20defget\_layer\_trainable\_params(self,layer:NeoXModule,prefix:str)-\>Dict[str,nn.Parameter]:21raiseNotImplementedError
23defset\_trainable\_params(self):24forlayerinself.layers:
Set requires_grad to False for the entire layer.
26layer.requires\_grad\_(False)
28forpinself.get\_trainable\_params().values():29p.requires\_grad\_(True)
31defstate\_dict(self):32return{n:p.data.cpu()forn,pinself.get\_trainable\_params().items()}
34defload\_state\_dict(self,state\_dict:Dict[str,torch.Tensor]):35params=self.get\_trainable\_params()36forn,pinparams.items():37p.data[:]=state\_dict[n].to(p.data.device)3839forninstate\_dict.keys():40assertninparams,n
43classFineTuneBiases(FineTuner):
44defget\_layer\_trainable\_params(self,layer:NeoXModule,prefix:str)-\>Dict[str,nn.Parameter]:45params={}4647ifisinstance(layer,TransformerLayer):
No need to train the mlp bias because we are adding it with attention output
49params[f'{prefix}.attention.output.bias']=layer.attention.output.bias50params[f'{prefix}.attention.qkv\_lin.bias']=layer.attention.qkv\_lin.bias51params[f'{prefix}.ffn.dense\_h\_h4.bias']=layer.ffn.dense\_h\_h4.bias52else:53pass5455returnparams