docs/transformers/retro/bert_embeddings.html
This is the code to get BERT embeddings of chunks for RETRO model.
13fromtypingimportList1415importtorch16fromtransformersimportBertTokenizer,BertModel1718fromlabmlimportlab,monit
For a given chunk of text N this class generates BERT embeddings BERT(N). BERT(N) is the average of BERT embeddings of all the tokens in N.
21classBERTChunkEmbeddings:
29def\_\_init\_\_(self,device:torch.device):30self.device=device
Load the BERT tokenizer from HuggingFace
33withmonit.section('Load BERT tokenizer'):34self.tokenizer=BertTokenizer.from\_pretrained('bert-base-uncased',35cache\_dir=str(36lab.get\_data\_path()/'cache'/'bert-tokenizer'))
Load the BERT model from HuggingFace
39withmonit.section('Load BERT model'):40self.model=BertModel.from\_pretrained("bert-base-uncased",41cache\_dir=str(lab.get\_data\_path()/'cache'/'bert-model'))
Move the model to device
44self.model.to(device)
In this implementation, we do not make chunks with a fixed number of tokens. One of the reasons is that this implementation uses character-level tokens and BERT uses its sub-word tokenizer.
So this method will truncate the text to make sure there are no partial tokens.
For instance, a chunk could be like s a popular programming la , with partial words (partial sub-word tokens) on the ends. We strip them off to get better BERT embeddings. As mentioned earlier this is not necessary if we broke chunks after tokenizing.
46@staticmethod47def\_trim\_chunk(chunk:str):
Strip whitespace
61stripped=chunk.strip()
Break words
63parts=stripped.split()
Remove first and last pieces
65stripped=stripped[len(parts[0]):-len(parts[-1])]
Remove whitespace
68stripped=stripped.strip()
If empty return original string
71ifnotstripped:72returnchunk
Otherwise, return the stripped string
74else:75returnstripped
77def\_\_call\_\_(self,chunks:List[str]):
We don't need to compute gradients
83withtorch.no\_grad():
Trim the chunks
85trimmed\_chunks=[self.\_trim\_chunk(c)forcinchunks]
Tokenize the chunks with BERT tokenizer
88tokens=self.tokenizer(trimmed\_chunks,return\_tensors='pt',add\_special\_tokens=False,padding=True)
Move token ids, attention mask and token types to the device
91input\_ids=tokens['input\_ids'].to(self.device)92attention\_mask=tokens['attention\_mask'].to(self.device)93token\_type\_ids=tokens['token\_type\_ids'].to(self.device)
Evaluate the model
95output=self.model(input\_ids=input\_ids,96attention\_mask=attention\_mask,97token\_type\_ids=token\_type\_ids)
Get the token embeddings
100state=output['last\_hidden\_state']
Calculate the average token embeddings. Note that the attention mask is 0 if the token is empty padded. We get empty tokens because the chunks are of different lengths.
104emb=(state\*attention\_mask[:,:,None]).sum(dim=1)/attention\_mask[:,:,None].sum(dim=1)
107returnemb
110def\_test():
114fromlabml.loggerimportinspect
Initialize
117device=torch.device('cuda:0')118bert=BERTChunkEmbeddings(device)
Sample
121text=["Replace me by any text you'd like.",122"Second sentence"]
Check BERT tokenizer
125encoded\_input=bert.tokenizer(text,return\_tensors='pt',add\_special\_tokens=False,padding=True)126127inspect(encoded\_input,\_expand=True)
Check BERT model outputs
130output=bert.model(input\_ids=encoded\_input['input\_ids'].to(device),131attention\_mask=encoded\_input['attention\_mask'].to(device),132token\_type\_ids=encoded\_input['token\_type\_ids'].to(device))133134inspect({'last\_hidden\_state':output['last\_hidden\_state'],135'pooler\_output':output['pooler\_output']},136\_expand=True)
Check recreating text from token ids
139inspect(bert.tokenizer.convert\_ids\_to\_tokens(encoded\_input['input\_ids'][0]),\_n=-1)140inspect(bert.tokenizer.convert\_ids\_to\_tokens(encoded\_input['input\_ids'][1]),\_n=-1)
Get chunk embeddings
143inspect(bert(text))
147if\_\_name\_\_=='\_\_main\_\_':148\_test()