docs/diffusion/stable_diffusion/model/autoencoder.html
homediffusionstable_diffusionmodel
This implements the auto-encoder model used to map between image space and latent space.
We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.
18fromtypingimportList1920importtorch21importtorch.nn.functionalasF22fromtorchimportnn
This consists of the encoder and decoder modules.
25classAutoencoder(nn.Module):
encoder is the encoderdecoder is the decoderemb_channels is the number of dimensions in the quantized embedding spacez_channels is the number of channels in the embedding space32def\_\_init\_\_(self,encoder:'Encoder',decoder:'Decoder',emb\_channels:int,z\_channels:int):
39super().\_\_init\_\_()40self.encoder=encoder41self.decoder=decoder
Convolution to map from embedding space to quantized embedding space moments (mean and log variance)
44self.quant\_conv=nn.Conv2d(2\*z\_channels,2\*emb\_channels,1)
Convolution to map from quantized embedding space back to embedding space
47self.post\_quant\_conv=nn.Conv2d(emb\_channels,z\_channels,1)
img is the image tensor with shape [batch_size, img_channels, img_height, img_width]49defencode(self,img:torch.Tensor)-\>'GaussianDistribution':
Get embeddings with shape [batch_size, z_channels * 2, z_height, z_height]
56z=self.encoder(img)
Get the moments in the quantized embedding space
58moments=self.quant\_conv(z)
Return the distribution
60returnGaussianDistribution(moments)
z is the latent representation with shape [batch_size, emb_channels, z_height, z_height]62defdecode(self,z:torch.Tensor):
Map to embedding space from the quantized representation
69z=self.post\_quant\_conv(z)
Decode the image of shape [batch_size, channels, height, width]
71returnself.decoder(z)
74classEncoder(nn.Module):
channels is the number of channels in the first convolution layerchannel_multipliers are the multiplicative factors for the number of channels in the subsequent blocksn_resnet_blocks is the number of resnet layers at each resolutionin_channels is the number of channels in the imagez_channels is the number of channels in the embedding space79def\_\_init\_\_(self,\*,channels:int,channel\_multipliers:List[int],n\_resnet\_blocks:int,80in\_channels:int,z\_channels:int):
89super().\_\_init\_\_()
Number of blocks of different resolutions. The resolution is halved at the end each top level block
93n\_resolutions=len(channel\_multipliers)
Initial 3×3 convolution layer that maps the image to channels
96self.conv\_in=nn.Conv2d(in\_channels,channels,3,stride=1,padding=1)
Number of channels in each top level block
99channels\_list=[m\*channelsformin[1]+channel\_multipliers]
List of top-level blocks
102self.down=nn.ModuleList()
Create top-level blocks
104foriinrange(n\_resolutions):
Each top level block consists of multiple ResNet Blocks and down-sampling
106resnet\_blocks=nn.ModuleList()
Add ResNet Blocks
108for\_inrange(n\_resnet\_blocks):109resnet\_blocks.append(ResnetBlock(channels,channels\_list[i+1]))110channels=channels\_list[i+1]
Top-level block
112down=nn.Module()113down.block=resnet\_blocks
Down-sampling at the end of each top level block except the last
115ifi!=n\_resolutions-1:116down.downsample=DownSample(channels)117else:118down.downsample=nn.Identity()
120self.down.append(down)
Final ResNet blocks with attention
123self.mid=nn.Module()124self.mid.block\_1=ResnetBlock(channels,channels)125self.mid.attn\_1=AttnBlock(channels)126self.mid.block\_2=ResnetBlock(channels,channels)
Map to embedding space with a 3×3 convolution
129self.norm\_out=normalization(channels)130self.conv\_out=nn.Conv2d(channels,2\*z\_channels,3,stride=1,padding=1)
img is the image tensor with shape [batch_size, img_channels, img_height, img_width]132defforward(self,img:torch.Tensor):
Map to channels with the initial convolution
138x=self.conv\_in(img)
Top-level blocks
141fordowninself.down:
ResNet Blocks
143forblockindown.block:144x=block(x)
Down-sampling
146x=down.downsample(x)
Final ResNet blocks with attention
149x=self.mid.block\_1(x)150x=self.mid.attn\_1(x)151x=self.mid.block\_2(x)
Normalize and map to embedding space
154x=self.norm\_out(x)155x=swish(x)156x=self.conv\_out(x)
159returnx
162classDecoder(nn.Module):
channels is the number of channels in the final convolution layerchannel_multipliers are the multiplicative factors for the number of channels in the previous blocks, in reverse ordern_resnet_blocks is the number of resnet layers at each resolutionout_channels is the number of channels in the imagez_channels is the number of channels in the embedding space167def\_\_init\_\_(self,\*,channels:int,channel\_multipliers:List[int],n\_resnet\_blocks:int,168out\_channels:int,z\_channels:int):
177super().\_\_init\_\_()
Number of blocks of different resolutions. The resolution is halved at the end each top level block
181num\_resolutions=len(channel\_multipliers)
Number of channels in each top level block, in the reverse order
184channels\_list=[m\*channelsforminchannel\_multipliers]
Number of channels in the top-level block
187channels=channels\_list[-1]
Initial 3×3 convolution layer that maps the embedding space to channels
190self.conv\_in=nn.Conv2d(z\_channels,channels,3,stride=1,padding=1)
ResNet blocks with attention
193self.mid=nn.Module()194self.mid.block\_1=ResnetBlock(channels,channels)195self.mid.attn\_1=AttnBlock(channels)196self.mid.block\_2=ResnetBlock(channels,channels)
List of top-level blocks
199self.up=nn.ModuleList()
Create top-level blocks
201foriinreversed(range(num\_resolutions)):
Each top level block consists of multiple ResNet Blocks and up-sampling
203resnet\_blocks=nn.ModuleList()
Add ResNet Blocks
205for\_inrange(n\_resnet\_blocks+1):206resnet\_blocks.append(ResnetBlock(channels,channels\_list[i]))207channels=channels\_list[i]
Top-level block
209up=nn.Module()210up.block=resnet\_blocks
Up-sampling at the end of each top level block except the first
212ifi!=0:213up.upsample=UpSample(channels)214else:215up.upsample=nn.Identity()
Prepend to be consistent with the checkpoint
217self.up.insert(0,up)
Map to image space with a 3×3 convolution
220self.norm\_out=normalization(channels)221self.conv\_out=nn.Conv2d(channels,out\_channels,3,stride=1,padding=1)
z is the embedding tensor with shape [batch_size, z_channels, z_height, z_height]223defforward(self,z:torch.Tensor):
Map to channels with the initial convolution
229h=self.conv\_in(z)
ResNet blocks with attention
232h=self.mid.block\_1(h)233h=self.mid.attn\_1(h)234h=self.mid.block\_2(h)
Top-level blocks
237forupinreversed(self.up):
ResNet Blocks
239forblockinup.block:240h=block(h)
Up-sampling
242h=up.upsample(h)
Normalize and map to image space
245h=self.norm\_out(h)246h=swish(h)247img=self.conv\_out(h)
250returnimg
253classGaussianDistribution:
parameters are the means and log of variances of the embedding of shape [batch_size, z_channels * 2, z_height, z_height]258def\_\_init\_\_(self,parameters:torch.Tensor):
Split mean and log of variance
264self.mean,log\_var=torch.chunk(parameters,2,dim=1)
Clamp the log of variances
266self.log\_var=torch.clamp(log\_var,-30.0,20.0)
Calculate standard deviation
268self.std=torch.exp(0.5\*self.log\_var)
270defsample(self):
Sample from the distribution
272returnself.mean+self.std\*torch.randn\_like(self.std)
275classAttnBlock(nn.Module):
channels is the number of channels280def\_\_init\_\_(self,channels:int):
284super().\_\_init\_\_()
Group normalization
286self.norm=normalization(channels)
Query, key and value mappings
288self.q=nn.Conv2d(channels,channels,1)289self.k=nn.Conv2d(channels,channels,1)290self.v=nn.Conv2d(channels,channels,1)
Final 1×1 convolution layer
292self.proj\_out=nn.Conv2d(channels,channels,1)
Attention scaling factor
294self.scale=channels\*\*-0.5
x is the tensor of shape [batch_size, channels, height, width]296defforward(self,x:torch.Tensor):
Normalize x
301x\_norm=self.norm(x)
Get query, key and vector embeddings
303q=self.q(x\_norm)304k=self.k(x\_norm)305v=self.v(x\_norm)
Reshape to query, key and vector embeedings from [batch_size, channels, height, width] to [batch_size, channels, height * width]
310b,c,h,w=q.shape311q=q.view(b,c,h\*w)312k=k.view(b,c,h\*w)313v=v.view(b,c,h\*w)
Compute seqsoftmax(dkeyQK⊤)
316attn=torch.einsum('bci,bcj-\>bij',q,k)\*self.scale317attn=F.softmax(attn,dim=2)
Compute seqsoftmax(dkeyQK⊤)V
320out=torch.einsum('bij,bcj-\>bci',attn,v)
Reshape back to [batch_size, channels, height, width]
323out=out.view(b,c,h,w)
Final 1×1 convolution layer
325out=self.proj\_out(out)
Add residual connection
328returnx+out
331classUpSample(nn.Module):
channels is the number of channels335def\_\_init\_\_(self,channels:int):
339super().\_\_init\_\_()
3×3 convolution mapping
341self.conv=nn.Conv2d(channels,channels,3,padding=1)
x is the input feature map with shape [batch_size, channels, height, width]343defforward(self,x:torch.Tensor):
Up-sample by a factor of 2
348x=F.interpolate(x,scale\_factor=2.0,mode="nearest")
Apply convolution
350returnself.conv(x)
353classDownSample(nn.Module):
channels is the number of channels357def\_\_init\_\_(self,channels:int):
361super().\_\_init\_\_()
3×3 convolution with stride length of 2 to down-sample by a factor of 2
363self.conv=nn.Conv2d(channels,channels,3,stride=2,padding=0)
x is the input feature map with shape [batch_size, channels, height, width]365defforward(self,x:torch.Tensor):
Add padding
370x=F.pad(x,(0,1,0,1),mode="constant",value=0)
Apply convolution
372returnself.conv(x)
375classResnetBlock(nn.Module):
in_channels is the number of channels in the inputout_channels is the number of channels in the output379def\_\_init\_\_(self,in\_channels:int,out\_channels:int):
384super().\_\_init\_\_()
First normalization and convolution layer
386self.norm1=normalization(in\_channels)387self.conv1=nn.Conv2d(in\_channels,out\_channels,3,stride=1,padding=1)
Second normalization and convolution layer
389self.norm2=normalization(out\_channels)390self.conv2=nn.Conv2d(out\_channels,out\_channels,3,stride=1,padding=1)
in_channels to out_channels mapping layer for residual connection
392ifin\_channels!=out\_channels:393self.nin\_shortcut=nn.Conv2d(in\_channels,out\_channels,1,stride=1,padding=0)394else:395self.nin\_shortcut=nn.Identity()
x is the input feature map with shape [batch_size, channels, height, width]397defforward(self,x:torch.Tensor):
402h=x
First normalization and convolution layer
405h=self.norm1(h)406h=swish(h)407h=self.conv1(h)
Second normalization and convolution layer
410h=self.norm2(h)411h=swish(h)412h=self.conv2(h)
Map and add residual
415returnself.nin\_shortcut(x)+h
x⋅σ(x)
418defswish(x:torch.Tensor):
424returnx\*torch.sigmoid(x)
This is a helper function, with fixed number of groups and eps .
427defnormalization(channels:int):
433returnnn.GroupNorm(num\_groups=32,num\_channels=channels,eps=1e-6)