docs/neox/checkpoint.html
11frompathlibimportPath12fromtypingimportDict,Union,Tuple,Optional1314importtorch15fromtorchimportnn1617fromlabmlimportmonit,lab,logger18fromlabml.loggerimportText,inspect19fromlabml.utils.downloadimportdownload\_file
Parent url
22CHECKPOINTS\_URL='https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim\_weights/'2324\_CHECKPOINTS\_DOWNLOAD\_PATH:Optional[Path]=None
Download path
28defget\_checkpoints\_download\_path():29global\_CHECKPOINTS\_DOWNLOAD\_PATH3031if\_CHECKPOINTS\_DOWNLOAD\_PATHisnotNone:32return\_CHECKPOINTS\_DOWNLOAD\_PATH3334\_CHECKPOINTS\_DOWNLOAD\_PATH=lab.get\_data\_path()/'neox\_fast'/'slim\_weights'35ifnot\_CHECKPOINTS\_DOWNLOAD\_PATH.exists():36\_CHECKPOINTS\_DOWNLOAD\_PATH=lab.get\_data\_path()/'neox'/'slim\_weights'37inspect(neox\_checkpoint\_path=\_CHECKPOINTS\_DOWNLOAD\_PATH)3839return\_CHECKPOINTS\_DOWNLOAD\_PATH
Returns a list of files to be downloaded
42defget\_files\_to\_download(n\_layers:int=44):
48layers=(
Embedding layer
50[0]+
Transformer layers
52list(range(2,2+n\_layers))+
Final normalization layer and readout layer
54[47,48]55)5657return(
Vocabulary and configs
59['20B\_tokenizer.json','configs/20B.yml','latest']+
Layer checkpoints
61[f'global\_step150000/layer\_{i :02d}-model\_{p :02d}-model\_states.pt'foriinlayersforpinrange(2)]+
Empty states (not used)
63[f'global\_step150000/mp\_rank\_{i :02d}\_model\_states.pt'foriinrange(8)]64)
67defdownload(n\_layers:int=44):
Get files to download
73files=get\_files\_to\_download(n\_layers)
Iterate
76fori,finmonit.enum('Download All',files):
Log
78logger.log(['Downloading ',(f'{i + 1 :3d}/{len(files)}',Text.meta),': ',(f,Text.value)])
Download
80download\_file(CHECKPOINTS\_URL+f,get\_checkpoints\_download\_path()/f)
files pair of files to loadReturns the loaded parameter tensors
83defload\_checkpoint\_files(files:Tuple[str,str]):
90checkpoint\_path=get\_checkpoints\_download\_path()/'global\_step150000'91withmonit.section('Load checkpoint'):92data=[torch.load(checkpoint\_path/f)forfinfiles]9394returndata
param is the parameterkey is the name of the parameterp1 first partition dictionaryp2 second partition dictionary97defmerge\_params\_dim\_0(param:Union[nn.Parameter,torch.Tensor],key:str,p1:Dict[str,torch.Tensor],98p2:Dict[str,torch.Tensor]):
107w1,w2=p1[key],p2[key]108param.data[:w1.shape[0]]=w1109param.data[w1.shape[0]:]=w2
param is the parameterkey is the name of the parameterp1 first partition dictionaryp2 second partition dictionary112defmerge\_params\_dim\_1(param:Union[nn.Parameter,torch.Tensor],key:str,p1:Dict[str,torch.Tensor],113p2:Dict[str,torch.Tensor]):
122w1,w2=p1[key],p2[key]123param.data[:,:w1.shape[1]]=w1124param.data[:,w1.shape[1]:]=w2
This does a sanity check to make use both partitions are the same
param is the parameterkey is the name of the parameterp1 first partition dictionaryp2 second partition dictionary127defmerge\_params\_duplicate(param:Union[nn.Parameter,torch.Tensor],key:str,p1:Dict[str,torch.Tensor],128p2:Dict[str,torch.Tensor]):
139w1,w2=p1[key],p2[key]140141diff=sum((w1-w2)\*\*2).item()142assertdiff\<1e-4,f'The partitions do not match: {key}'143144param.data[:]=(w1+w2)/2.
param is the parameterkey is the name of the parameterp1 first partition dictionaryp2 second partition dictionary147defmerge\_params\_sum(param:Union[nn.Parameter,torch.Tensor],key:str,p1:Dict[str,torch.Tensor],148p2:Dict[str,torch.Tensor]):
157w1,w2=p1[key],p2[key]158159param.data[:]=w1+w2
163if\_\_name\_\_=='\_\_main\_\_':164download()