docs/neox/utils/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/neox/utils/ init.py)
15importtyping16fromtypingimportList,Optional1718importtorch1920fromlabmlimportlogger21fromlabml.loggerimportText22fromlabml\_nn.neox.tokenizerimportget\_tokenizer2324iftyping.TYPE\_CHECKING:25fromtokenizersimportTokenizer
Tokenizer singleton
28\_TOKENIZER:Optional['Tokenizer']=None
text is the text to tokenizeReturns the token ids
31defget\_tokens(text:str)-\>List[int]:
38global\_TOKENIZER39if\_TOKENIZERisNone:40\_TOKENIZER=get\_tokenizer()41return\_TOKENIZER.encode\_batch([text])[0].ids
Pretty prints target tokens along side outputs from the model(s).
ids are the target token idsxs are the model(s) outputs44defprint\_token\_outputs(ids:List[int],\*xs:torch.Tensor):
53ids=ids+[-1]54xs=[[-1]+x[0].max(dim=-1)[1].tolist()forxinxs]5556print\_tokens(ids,xs)
Pretty prints tokens for comparison
target are the target token idsothers are the sampled outputs from the model(s)59defprint\_tokens(target:List[int],others:List[List[int]]):
Load tokenizer
70global\_TOKENIZER71if\_TOKENIZERisNone:72\_TOKENIZER=get\_tokenizer()
Convert the tokens to list of strings
75text=[]76foriinrange(len(target)):77tokens=[\_TOKENIZER.decode([target[i]])iftarget[i]!=-1else'---']78forjinrange(len(others)):79tokens.append(\_TOKENIZER.decode([others[j][i]])ifothers[j][i]!=-1else'---')8081text.append(tokens)
Stats
84correct=[0for\_inothers]85total=0
Iterate through tokens
88foriinrange(len(target)):89parts=[(f'{i}: ',Text.meta)]90parts+=[('"',Text.subtle),(text[i][0],Text.subtle),('"',Text.subtle),'\t']
Empty target
93iftarget[i]==-1:94forjinrange(len(others)):95parts+=[('"',Text.subtle),(text[i][j+1],Text.subtle),('"',Text.subtle),'\t']9697logger.log(parts)98continue
Number of tokens
101total+=1
Other outputs
104forjinrange(len(others)):105correct[j]+=1ifothers[j][i]==target[i]else0106107parts+=[('"',Text.subtle),108(text[i][j+1],Text.successifothers[j][i]==target[i]elseText.danger),109('"',Text.subtle),'\t']110111logger.log(parts)
Stats
114parts=[(f'{total}',Text.highlight),'\t']115forjinrange(len(others)):116parts+=[(f'{correct[j]}',Text.value),'\t']117logger.log(parts)
Split the n_layers into n_chunks . This is used for pipeline parallel training.
n_layers is the number of layersn_chunks is the number of chunksReturns returns a list with the number of layers for each chunk
120defbalance\_layers\_simple(n\_layers:int,n\_chunks:int):
130balance=[]131foriinrange(n\_chunks):132balance.append((n\_layers-sum(balance))//(n\_chunks-i))133134returnlist(reversed(balance))