Back to Annotated Deep Learning Paper Implementations

BERT Embeddings of chunks of text

docs/transformers/retro/bert_embeddings.html

latest5.0 KB
Original Source

hometransformersretro

View code on Github

#

BERT Embeddings of chunks of text

This is the code to get BERT embeddings of chunks for RETRO model.

13fromtypingimportList1415importtorch16fromtransformersimportBertTokenizer,BertModel1718fromlabmlimportlab,monit

#

BERT Embeddings

For a given chunk of text N this class generates BERT embeddings BERT(N). BERT(N) is the average of BERT embeddings of all the tokens in N.

21classBERTChunkEmbeddings:

#

29def\_\_init\_\_(self,device:torch.device):30self.device=device

#

Load the BERT tokenizer from HuggingFace

33withmonit.section('Load BERT tokenizer'):34self.tokenizer=BertTokenizer.from\_pretrained('bert-base-uncased',35cache\_dir=str(36lab.get\_data\_path()/'cache'/'bert-tokenizer'))

#

Load the BERT model from HuggingFace

39withmonit.section('Load BERT model'):40self.model=BertModel.from\_pretrained("bert-base-uncased",41cache\_dir=str(lab.get\_data\_path()/'cache'/'bert-model'))

#

Move the model to device

44self.model.to(device)

#

In this implementation, we do not make chunks with a fixed number of tokens. One of the reasons is that this implementation uses character-level tokens and BERT uses its sub-word tokenizer.

So this method will truncate the text to make sure there are no partial tokens.

For instance, a chunk could be like s a popular programming la , with partial words (partial sub-word tokens) on the ends. We strip them off to get better BERT embeddings. As mentioned earlier this is not necessary if we broke chunks after tokenizing.

46@staticmethod47def\_trim\_chunk(chunk:str):

#

Strip whitespace

61stripped=chunk.strip()

#

Break words

63parts=stripped.split()

#

Remove first and last pieces

65stripped=stripped[len(parts[0]):-len(parts[-1])]

#

Remove whitespace

68stripped=stripped.strip()

#

If empty return original string

71ifnotstripped:72returnchunk

#

Otherwise, return the stripped string

74else:75returnstripped

#

Get BERT(N) for a list of chunks.

77def\_\_call\_\_(self,chunks:List[str]):

#

We don't need to compute gradients

83withtorch.no\_grad():

#

Trim the chunks

85trimmed\_chunks=[self.\_trim\_chunk(c)forcinchunks]

#

Tokenize the chunks with BERT tokenizer

88tokens=self.tokenizer(trimmed\_chunks,return\_tensors='pt',add\_special\_tokens=False,padding=True)

#

Move token ids, attention mask and token types to the device

91input\_ids=tokens['input\_ids'].to(self.device)92attention\_mask=tokens['attention\_mask'].to(self.device)93token\_type\_ids=tokens['token\_type\_ids'].to(self.device)

#

Evaluate the model

95output=self.model(input\_ids=input\_ids,96attention\_mask=attention\_mask,97token\_type\_ids=token\_type\_ids)

#

Get the token embeddings

100state=output['last\_hidden\_state']

#

Calculate the average token embeddings. Note that the attention mask is 0 if the token is empty padded. We get empty tokens because the chunks are of different lengths.

104emb=(state\*attention\_mask[:,:,None]).sum(dim=1)/attention\_mask[:,:,None].sum(dim=1)

#

107returnemb

#

Code to test BERT embeddings

110def\_test():

#

114fromlabml.loggerimportinspect

#

Initialize

117device=torch.device('cuda:0')118bert=BERTChunkEmbeddings(device)

#

Sample

121text=["Replace me by any text you'd like.",122"Second sentence"]

#

Check BERT tokenizer

125encoded\_input=bert.tokenizer(text,return\_tensors='pt',add\_special\_tokens=False,padding=True)126127inspect(encoded\_input,\_expand=True)

#

Check BERT model outputs

130output=bert.model(input\_ids=encoded\_input['input\_ids'].to(device),131attention\_mask=encoded\_input['attention\_mask'].to(device),132token\_type\_ids=encoded\_input['token\_type\_ids'].to(device))133134inspect({'last\_hidden\_state':output['last\_hidden\_state'],135'pooler\_output':output['pooler\_output']},136\_expand=True)

#

Check recreating text from token ids

139inspect(bert.tokenizer.convert\_ids\_to\_tokens(encoded\_input['input\_ids'][0]),\_n=-1)140inspect(bert.tokenizer.convert\_ids\_to\_tokens(encoded\_input['input\_ids'][1]),\_n=-1)

#

Get chunk embeddings

143inspect(bert(text))

#

147if\_\_name\_\_=='\_\_main\_\_':148\_test()

labml.ai