docs/diffusion/stable_diffusion/model/clip_embedder.html
homediffusionstable_diffusionmodel
This is used to get prompt embeddings for stable diffusion. It uses HuggingFace Transformers CLIP model.
14fromtypingimportList1516fromtorchimportnn17fromtransformersimportCLIPTokenizer,CLIPTextModel
20classCLIPTextEmbedder(nn.Module):
version is the model versiondevice is the devicemax_length is the max length of the tokenized prompt25def\_\_init\_\_(self,version:str="openai/clip-vit-large-patch14",device="cuda:0",max\_length:int=77):
31super().\_\_init\_\_()
Load the tokenizer
33self.tokenizer=CLIPTokenizer.from\_pretrained(version)
Load the CLIP transformer
35self.transformer=CLIPTextModel.from\_pretrained(version).eval()3637self.device=device38self.max\_length=max\_length
prompts are the list of prompts to embed40defforward(self,prompts:List[str]):
Tokenize the prompts
45batch\_encoding=self.tokenizer(prompts,truncation=True,max\_length=self.max\_length,return\_length=True,46return\_overflowing\_tokens=False,padding="max\_length",return\_tensors="pt")
Get token ids
48tokens=batch\_encoding["input\_ids"].to(self.device)
Get CLIP embeddings
50returnself.transformer(input\_ids=tokens).last\_hidden\_state