Back to Annotated Deep Learning Paper Implementations

Generate images using stable diffusion with a prompt

docs/diffusion/stable_diffusion/scripts/text_to_image.html

latest4.9 KB
Original Source

homediffusionstable_diffusionscripts

View code on Github

#

Generate images using stable diffusion with a prompt

11importargparse12importos13frompathlibimportPath1415importtorch1617fromlabmlimportlab,monit18fromlabml\_nn.diffusion.stable\_diffusion.latent\_diffusionimportLatentDiffusion19fromlabml\_nn.diffusion.stable\_diffusion.sampler.ddimimportDDIMSampler20fromlabml\_nn.diffusion.stable\_diffusion.sampler.ddpmimportDDPMSampler21fromlabml\_nn.diffusion.stable\_diffusion.utilimportload\_model,save\_images,set\_seed

#

Text to image class

24classTxt2Img:

#

28model:LatentDiffusion

#

  • checkpoint_path is the path of the checkpoint
  • sampler_name is the name of the sampler
  • n_steps is the number of sampling steps
  • ddim_eta is the DDIM sampling η constant
30def\_\_init\_\_(self,\*,31checkpoint\_path:Path,32sampler\_name:str,33n\_steps:int=50,34ddim\_eta:float=0.0,35):

#

Load latent diffusion model

43self.model=load\_model(checkpoint\_path)

#

Get device

45self.device=torch.device("cuda:0")iftorch.cuda.is\_available()elsetorch.device("cpu")

#

Move the model to device

47self.model.to(self.device)

#

Initialize sampler

50ifsampler\_name=='ddim':51self.sampler=DDIMSampler(self.model,52n\_steps=n\_steps,53ddim\_eta=ddim\_eta)54elifsampler\_name=='ddpm':55self.sampler=DDPMSampler(self.model)

#

  • dest_path is the path to store the generated images
  • batch_size is the number of images to generate in a batch
  • prompt is the prompt to generate images with
  • h is the height of the image
  • w is the width of the image
  • uncond_scale is the unconditional guidance scale s. This is used for ϵθ​(xt​,c)=sϵcond​(xt​,c)+(s−1)ϵcond​(xt​,cu​)
[email protected]\_grad()58def\_\_call\_\_(self,\*,59dest\_path:str,60batch\_size:int=3,61prompt:str,62h:int=512,w:int=512,63uncond\_scale:float=7.5,64):

#

Number of channels in the image

75c=4

#

Image to latent space resolution reduction

77f=8

#

Make a batch of prompts

80prompts=batch\_size\*[prompt]

#

AMP auto casting

83withtorch.cuda.amp.autocast():

#

In unconditional scaling is not 1 get the embeddings for empty prompts (no conditioning).

85ifuncond\_scale!=1.0:86un\_cond=self.model.get\_text\_conditioning(batch\_size\*[""])87else:88un\_cond=None

#

Get the prompt embeddings

90cond=self.model.get\_text\_conditioning(prompts)

#

Sample in the latent space. x will be of shape [batch_size, c, h / f, w / f]

93x=self.sampler.sample(cond=cond,94shape=[batch\_size,c,h//f,w//f],95uncond\_scale=uncond\_scale,96uncond\_cond=un\_cond)

#

Decode the image from the autoencoder

98images=self.model.autoencoder\_decode(x)

#

Save images

101save\_images(images,dest\_path,'txt\_')

#

CLI

104defmain():

#

108parser=argparse.ArgumentParser()109110parser.add\_argument(111"--prompt",112type=str,113nargs="?",114default="a painting of a virus monster playing guitar",115help="the prompt to render"116)117118parser.add\_argument("--batch\_size",type=int,default=4,help="batch size")119120parser.add\_argument(121'--sampler',122dest='sampler\_name',123choices=['ddim','ddpm'],124default='ddim',125help=f'Set the sampler.',126)127128parser.add\_argument("--flash",action='store\_true',help="whether to use flash attention")129130parser.add\_argument("--steps",type=int,default=50,help="number of sampling steps")131132parser.add\_argument("--scale",type=float,default=7.5,133help="unconditional guidance scale: "134"eps = eps(x, empty) + scale \* (eps(x, cond) - eps(x, empty))")135136opt=parser.parse\_args()137138set\_seed(42)

#

Set flash attention

141fromlabml\_nn.diffusion.stable\_diffusion.model.unet\_attentionimportCrossAttention142CrossAttention.use\_flash\_attention=opt.flash

#

145txt2img=Txt2Img(checkpoint\_path=lab.get\_data\_path()/'stable-diffusion'/'sd-v1-4.ckpt',146sampler\_name=opt.sampler\_name,147n\_steps=opt.steps)148149withmonit.section('Generate'):150txt2img(dest\_path='outputs',151batch\_size=opt.batch\_size,152prompt=opt.prompt,153uncond\_scale=opt.scale)

#

157if\_\_name\_\_=="\_\_main\_\_":158main()

labml.ai