docs/diffusion/stable_diffusion/scripts/text_to_image.html
homediffusionstable_diffusionscripts
11importargparse12importos13frompathlibimportPath1415importtorch1617fromlabmlimportlab,monit18fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion19fromlabml\_nn.diffusion.stable\_diffusion.sampler.ddimimportDDIMSampler20fromlabml\_nn.diffusion.stable\_diffusion.sampler.ddpmimportDDPMSampler21fromlabml\_nn.diffusion.stable\_diffusion.utilimportload\_model,save\_images,set\_seed
24classTxt2Img:
28model:LatentDiffusion
checkpoint_path is the path of the checkpointsampler_name is the name of the samplern_steps is the number of sampling stepsddim_eta is the DDIM sampling η constant30def\_\_init\_\_(self,\*,31checkpoint\_path:Path,32sampler\_name:str,33n\_steps:int=50,34ddim\_eta:float=0.0,35):
43self.model=load\_model(checkpoint\_path)
Get device
45self.device=torch.device("cuda:0")iftorch.cuda.is\_available()elsetorch.device("cpu")
Move the model to device
47self.model.to(self.device)
Initialize sampler
50ifsampler\_name=='ddim':51self.sampler=DDIMSampler(self.model,52n\_steps=n\_steps,53ddim\_eta=ddim\_eta)54elifsampler\_name=='ddpm':55self.sampler=DDPMSampler(self.model)
dest_path is the path to store the generated imagesbatch_size is the number of images to generate in a batchprompt is the prompt to generate images withh is the height of the imagew is the width of the imageuncond_scale is the unconditional guidance scale s. This is used for ϵθ(xt,c)=sϵcond(xt,c)+(s−1)ϵcond(xt,cu)[email protected]\_grad()58def\_\_call\_\_(self,\*,59dest\_path:str,60batch\_size:int=3,61prompt:str,62h:int=512,w:int=512,63uncond\_scale:float=7.5,64):
Number of channels in the image
75c=4
Image to latent space resolution reduction
77f=8
Make a batch of prompts
80prompts=batch\_size\*[prompt]
AMP auto casting
83withtorch.cuda.amp.autocast():
In unconditional scaling is not 1 get the embeddings for empty prompts (no conditioning).
85ifuncond\_scale!=1.0:86un\_cond=self.model.get\_text\_conditioning(batch\_size\*[""])87else:88un\_cond=None
Get the prompt embeddings
90cond=self.model.get\_text\_conditioning(prompts)
Sample in the latent space. x will be of shape [batch_size, c, h / f, w / f]
93x=self.sampler.sample(cond=cond,94shape=[batch\_size,c,h//f,w//f],95uncond\_scale=uncond\_scale,96uncond\_cond=un\_cond)
Decode the image from the autoencoder
98images=self.model.autoencoder\_decode(x)
Save images
101save\_images(images,dest\_path,'txt\_')
104defmain():
108parser=argparse.ArgumentParser()109110parser.add\_argument(111"--prompt",112type=str,113nargs="?",114default="a painting of a virus monster playing guitar",115help="the prompt to render"116)117118parser.add\_argument("--batch\_size",type=int,default=4,help="batch size")119120parser.add\_argument(121'--sampler',122dest='sampler\_name',123choices=['ddim','ddpm'],124default='ddim',125help=f'Set the sampler.',126)127128parser.add\_argument("--flash",action='store\_true',help="whether to use flash attention")129130parser.add\_argument("--steps",type=int,default=50,help="number of sampling steps")131132parser.add\_argument("--scale",type=float,default=7.5,133help="unconditional guidance scale: "134"eps = eps(x, empty) + scale \* (eps(x, cond) - eps(x, empty))")135136opt=parser.parse\_args()137138set\_seed(42)
Set flash attention
141fromlabml\_nn.diffusion.stable\_diffusion.model.unet\_attentionimportCrossAttention142CrossAttention.use\_flash\_attention=opt.flash
145txt2img=Txt2Img(checkpoint\_path=lab.get\_data\_path()/'stable-diffusion'/'sd-v1-4.ckpt',146sampler\_name=opt.sampler\_name,147n\_steps=opt.steps)148149withmonit.section('Generate'):150txt2img(dest\_path='outputs',151batch\_size=opt.batch\_size,152prompt=opt.prompt,153uncond\_scale=opt.scale)
157if\_\_name\_\_=="\_\_main\_\_":158main()