Back to Annotated Deep Learning Paper Implementations

Autoencoder for Stable Diffusion

docs/diffusion/stable_diffusion/model/autoencoder.html

latest14.5 KB
Original Source

homediffusionstable_diffusionmodel

View code on Github

#

Autoencoder for Stable Diffusion

This implements the auto-encoder model used to map between image space and latent space.

We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.

18fromtypingimportList1920importtorch21importtorch.nn.functionalasF22fromtorchimportnn

#

Autoencoder

This consists of the encoder and decoder modules.

25classAutoencoder(nn.Module):

#

  • encoder is the encoder
  • decoder is the decoder
  • emb_channels is the number of dimensions in the quantized embedding space
  • z_channels is the number of channels in the embedding space
32def\_\_init\_\_(self,encoder:'Encoder',decoder:'Decoder',emb\_channels:int,z\_channels:int):

#

39super().\_\_init\_\_()40self.encoder=encoder41self.decoder=decoder

#

Convolution to map from embedding space to quantized embedding space moments (mean and log variance)

44self.quant\_conv=nn.Conv2d(2\*z\_channels,2\*emb\_channels,1)

#

Convolution to map from quantized embedding space back to embedding space

47self.post\_quant\_conv=nn.Conv2d(emb\_channels,z\_channels,1)

#

Encode images to latent representation

  • img is the image tensor with shape [batch_size, img_channels, img_height, img_width]
49defencode(self,img:torch.Tensor)-\>'GaussianDistribution':

#

Get embeddings with shape [batch_size, z_channels * 2, z_height, z_height]

56z=self.encoder(img)

#

Get the moments in the quantized embedding space

58moments=self.quant\_conv(z)

#

Return the distribution

60returnGaussianDistribution(moments)

#

Decode images from latent representation

  • z is the latent representation with shape [batch_size, emb_channels, z_height, z_height]
62defdecode(self,z:torch.Tensor):

#

Map to embedding space from the quantized representation

69z=self.post\_quant\_conv(z)

#

Decode the image of shape [batch_size, channels, height, width]

71returnself.decoder(z)

#

Encoder module

74classEncoder(nn.Module):

#

  • channels is the number of channels in the first convolution layer
  • channel_multipliers are the multiplicative factors for the number of channels in the subsequent blocks
  • n_resnet_blocks is the number of resnet layers at each resolution
  • in_channels is the number of channels in the image
  • z_channels is the number of channels in the embedding space
79def\_\_init\_\_(self,\*,channels:int,channel\_multipliers:List[int],n\_resnet\_blocks:int,80in\_channels:int,z\_channels:int):

#

89super().\_\_init\_\_()

#

Number of blocks of different resolutions. The resolution is halved at the end each top level block

93n\_resolutions=len(channel\_multipliers)

#

Initial 3×3 convolution layer that maps the image to channels

96self.conv\_in=nn.Conv2d(in\_channels,channels,3,stride=1,padding=1)

#

Number of channels in each top level block

99channels\_list=[m\*channelsformin[1]+channel\_multipliers]

#

List of top-level blocks

102self.down=nn.ModuleList()

#

Create top-level blocks

104foriinrange(n\_resolutions):

#

Each top level block consists of multiple ResNet Blocks and down-sampling

106resnet\_blocks=nn.ModuleList()

#

Add ResNet Blocks

108for\_inrange(n\_resnet\_blocks):109resnet\_blocks.append(ResnetBlock(channels,channels\_list[i+1]))110channels=channels\_list[i+1]

#

Top-level block

112down=nn.Module()113down.block=resnet\_blocks

#

Down-sampling at the end of each top level block except the last

115ifi!=n\_resolutions-1:116down.downsample=DownSample(channels)117else:118down.downsample=nn.Identity()

#

120self.down.append(down)

#

Final ResNet blocks with attention

123self.mid=nn.Module()124self.mid.block\_1=ResnetBlock(channels,channels)125self.mid.attn\_1=AttnBlock(channels)126self.mid.block\_2=ResnetBlock(channels,channels)

#

Map to embedding space with a 3×3 convolution

129self.norm\_out=normalization(channels)130self.conv\_out=nn.Conv2d(channels,2\*z\_channels,3,stride=1,padding=1)

#

  • img is the image tensor with shape [batch_size, img_channels, img_height, img_width]
132defforward(self,img:torch.Tensor):

#

Map to channels with the initial convolution

138x=self.conv\_in(img)

#

Top-level blocks

141fordowninself.down:

#

ResNet Blocks

143forblockindown.block:144x=block(x)

#

Down-sampling

146x=down.downsample(x)

#

Final ResNet blocks with attention

149x=self.mid.block\_1(x)150x=self.mid.attn\_1(x)151x=self.mid.block\_2(x)

#

Normalize and map to embedding space

154x=self.norm\_out(x)155x=swish(x)156x=self.conv\_out(x)

#

159returnx

#

Decoder module

162classDecoder(nn.Module):

#

  • channels is the number of channels in the final convolution layer
  • channel_multipliers are the multiplicative factors for the number of channels in the previous blocks, in reverse order
  • n_resnet_blocks is the number of resnet layers at each resolution
  • out_channels is the number of channels in the image
  • z_channels is the number of channels in the embedding space
167def\_\_init\_\_(self,\*,channels:int,channel\_multipliers:List[int],n\_resnet\_blocks:int,168out\_channels:int,z\_channels:int):

#

177super().\_\_init\_\_()

#

Number of blocks of different resolutions. The resolution is halved at the end each top level block

181num\_resolutions=len(channel\_multipliers)

#

Number of channels in each top level block, in the reverse order

184channels\_list=[m\*channelsforminchannel\_multipliers]

#

Number of channels in the top-level block

187channels=channels\_list[-1]

#

Initial 3×3 convolution layer that maps the embedding space to channels

190self.conv\_in=nn.Conv2d(z\_channels,channels,3,stride=1,padding=1)

#

ResNet blocks with attention

193self.mid=nn.Module()194self.mid.block\_1=ResnetBlock(channels,channels)195self.mid.attn\_1=AttnBlock(channels)196self.mid.block\_2=ResnetBlock(channels,channels)

#

List of top-level blocks

199self.up=nn.ModuleList()

#

Create top-level blocks

201foriinreversed(range(num\_resolutions)):

#

Each top level block consists of multiple ResNet Blocks and up-sampling

203resnet\_blocks=nn.ModuleList()

#

Add ResNet Blocks

205for\_inrange(n\_resnet\_blocks+1):206resnet\_blocks.append(ResnetBlock(channels,channels\_list[i]))207channels=channels\_list[i]

#

Top-level block

209up=nn.Module()210up.block=resnet\_blocks

#

Up-sampling at the end of each top level block except the first

212ifi!=0:213up.upsample=UpSample(channels)214else:215up.upsample=nn.Identity()

#

Prepend to be consistent with the checkpoint

217self.up.insert(0,up)

#

Map to image space with a 3×3 convolution

220self.norm\_out=normalization(channels)221self.conv\_out=nn.Conv2d(channels,out\_channels,3,stride=1,padding=1)

#

  • z is the embedding tensor with shape [batch_size, z_channels, z_height, z_height]
223defforward(self,z:torch.Tensor):

#

Map to channels with the initial convolution

229h=self.conv\_in(z)

#

ResNet blocks with attention

232h=self.mid.block\_1(h)233h=self.mid.attn\_1(h)234h=self.mid.block\_2(h)

#

Top-level blocks

237forupinreversed(self.up):

#

ResNet Blocks

239forblockinup.block:240h=block(h)

#

Up-sampling

242h=up.upsample(h)

#

Normalize and map to image space

245h=self.norm\_out(h)246h=swish(h)247img=self.conv\_out(h)

#

250returnimg

#

Gaussian Distribution

253classGaussianDistribution:

#

  • parameters are the means and log of variances of the embedding of shape [batch_size, z_channels * 2, z_height, z_height]
258def\_\_init\_\_(self,parameters:torch.Tensor):

#

Split mean and log of variance

264self.mean,log\_var=torch.chunk(parameters,2,dim=1)

#

Clamp the log of variances

266self.log\_var=torch.clamp(log\_var,-30.0,20.0)

#

Calculate standard deviation

268self.std=torch.exp(0.5\*self.log\_var)

#

270defsample(self):

#

Sample from the distribution

272returnself.mean+self.std\*torch.randn\_like(self.std)

#

Attention block

275classAttnBlock(nn.Module):

#

  • channels is the number of channels
280def\_\_init\_\_(self,channels:int):

#

284super().\_\_init\_\_()

#

Group normalization

286self.norm=normalization(channels)

#

Query, key and value mappings

288self.q=nn.Conv2d(channels,channels,1)289self.k=nn.Conv2d(channels,channels,1)290self.v=nn.Conv2d(channels,channels,1)

#

Final 1×1 convolution layer

292self.proj\_out=nn.Conv2d(channels,channels,1)

#

Attention scaling factor

294self.scale=channels\*\*-0.5

#

  • x is the tensor of shape [batch_size, channels, height, width]
296defforward(self,x:torch.Tensor):

#

Normalize x

301x\_norm=self.norm(x)

#

Get query, key and vector embeddings

303q=self.q(x\_norm)304k=self.k(x\_norm)305v=self.v(x\_norm)

#

Reshape to query, key and vector embeedings from [batch_size, channels, height, width] to [batch_size, channels, height * width]

310b,c,h,w=q.shape311q=q.view(b,c,h\*w)312k=k.view(b,c,h\*w)313v=v.view(b,c,h\*w)

#

Compute seqsoftmax​(dkey​​QK⊤​)

316attn=torch.einsum('bci,bcj-\>bij',q,k)\*self.scale317attn=F.softmax(attn,dim=2)

#

Compute seqsoftmax​(dkey​​QK⊤​)V

320out=torch.einsum('bij,bcj-\>bci',attn,v)

#

Reshape back to [batch_size, channels, height, width]

323out=out.view(b,c,h,w)

#

Final 1×1 convolution layer

325out=self.proj\_out(out)

#

Add residual connection

328returnx+out

#

Up-sampling layer

331classUpSample(nn.Module):

#

  • channels is the number of channels
335def\_\_init\_\_(self,channels:int):

#

339super().\_\_init\_\_()

#

3×3 convolution mapping

341self.conv=nn.Conv2d(channels,channels,3,padding=1)

#

  • x is the input feature map with shape [batch_size, channels, height, width]
343defforward(self,x:torch.Tensor):

#

Up-sample by a factor of 2

348x=F.interpolate(x,scale\_factor=2.0,mode="nearest")

#

Apply convolution

350returnself.conv(x)

#

Down-sampling layer

353classDownSample(nn.Module):

#

  • channels is the number of channels
357def\_\_init\_\_(self,channels:int):

#

361super().\_\_init\_\_()

#

3×3 convolution with stride length of 2 to down-sample by a factor of 2

363self.conv=nn.Conv2d(channels,channels,3,stride=2,padding=0)

#

  • x is the input feature map with shape [batch_size, channels, height, width]
365defforward(self,x:torch.Tensor):

#

Add padding

370x=F.pad(x,(0,1,0,1),mode="constant",value=0)

#

Apply convolution

372returnself.conv(x)

#

ResNet Block

375classResnetBlock(nn.Module):

#

  • in_channels is the number of channels in the input
  • out_channels is the number of channels in the output
379def\_\_init\_\_(self,in\_channels:int,out\_channels:int):

#

384super().\_\_init\_\_()

#

First normalization and convolution layer

386self.norm1=normalization(in\_channels)387self.conv1=nn.Conv2d(in\_channels,out\_channels,3,stride=1,padding=1)

#

Second normalization and convolution layer

389self.norm2=normalization(out\_channels)390self.conv2=nn.Conv2d(out\_channels,out\_channels,3,stride=1,padding=1)

#

in_channels to out_channels mapping layer for residual connection

392ifin\_channels!=out\_channels:393self.nin\_shortcut=nn.Conv2d(in\_channels,out\_channels,1,stride=1,padding=0)394else:395self.nin\_shortcut=nn.Identity()

#

  • x is the input feature map with shape [batch_size, channels, height, width]
397defforward(self,x:torch.Tensor):

#

402h=x

#

First normalization and convolution layer

405h=self.norm1(h)406h=swish(h)407h=self.conv1(h)

#

Second normalization and convolution layer

410h=self.norm2(h)411h=swish(h)412h=self.conv2(h)

#

Map and add residual

415returnself.nin\_shortcut(x)+h

#

Swish activation

x⋅σ(x)

418defswish(x:torch.Tensor):

#

424returnx\*torch.sigmoid(x)

#

Group normalization

This is a helper function, with fixed number of groups and eps .

427defnormalization(channels:int):

#

433returnnn.GroupNorm(num\_groups=32,num\_channels=channels,eps=1e-6)

labml.ai