Back to Annotated Deep Learning Paper Implementations

CLIP Text Embedder

docs/diffusion/stable_diffusion/model/clip_embedder.html

latest1.7 KB
Original Source

homediffusionstable_diffusionmodel

View code on Github

#

CLIP Text Embedder

This is used to get prompt embeddings for stable diffusion. It uses HuggingFace Transformers CLIP model.

14fromtypingimportList1516fromtorchimportnn17fromtransformersimportCLIPTokenizer,CLIPTextModel

#

CLIP Text Embedder

20classCLIPTextEmbedder(nn.Module):

#

  • version is the model version
  • device is the device
  • max_length is the max length of the tokenized prompt
25def\_\_init\_\_(self,version:str="openai/clip-vit-large-patch14",device="cuda:0",max\_length:int=77):

#

31super().\_\_init\_\_()

#

Load the tokenizer

33self.tokenizer=CLIPTokenizer.from\_pretrained(version)

#

Load the CLIP transformer

35self.transformer=CLIPTextModel.from\_pretrained(version).eval()3637self.device=device38self.max\_length=max\_length

#

  • prompts are the list of prompts to embed
40defforward(self,prompts:List[str]):

#

Tokenize the prompts

45batch\_encoding=self.tokenizer(prompts,truncation=True,max\_length=self.max\_length,return\_length=True,46return\_overflowing\_tokens=False,padding="max\_length",return\_tensors="pt")

#

Get token ids

48tokens=batch\_encoding["input\_ids"].to(self.device)

#

Get CLIP embeddings

50returnself.transformer(input\_ids=tokens).last\_hidden\_state

labml.ai