Back to Annotated Deep Learning Paper Implementations

U-Net for Stable Diffusion

docs/diffusion/stable_diffusion/model/unet.html

latest11.7 KB
Original Source

homediffusionstable_diffusionmodel

View code on Github

#

U-Net for Stable Diffusion

This implements the U-Net that gives ϵcond​(xt​,c)

We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.

18importmath19fromtypingimportList2021importnumpyasnp22importtorch23importtorch.nnasnn24importtorch.nn.functionalasF2526fromlabml\_nn.diffusion.stable\_diffusion.model.unet\_attentionimportSpatialTransformer

#

U-Net model

29classUNetModel(nn.Module):

#

  • in_channels is the number of channels in the input feature map
  • out_channels is the number of channels in the output feature map
  • channels is the base channel count for the model
  • n_res_blocks number of residual blocks at each level
  • attention_levels are the levels at which attention should be performed
  • channel_multipliers are the multiplicative factors for number of channels for each level
  • n_heads is the number of attention heads in the transformers
  • tf_layers is the number of transformer layers in the transformers
  • d_cond is the size of the conditional embedding in the transformers
34def\_\_init\_\_(35self,\*,36in\_channels:int,37out\_channels:int,38channels:int,39n\_res\_blocks:int,40attention\_levels:List[int],41channel\_multipliers:List[int],42n\_heads:int,43tf\_layers:int=1,44d\_cond:int=768):

#

56super().\_\_init\_\_()57self.channels=channels

#

Number of levels

60levels=len(channel\_multipliers)

#

Size time embeddings

62d\_time\_emb=channels\*463self.time\_embed=nn.Sequential(64nn.Linear(channels,d\_time\_emb),65nn.SiLU(),66nn.Linear(d\_time\_emb,d\_time\_emb),67)

#

Input half of the U-Net

70self.input\_blocks=nn.ModuleList()

#

Initial 3×3 convolution that maps the input to channels . The blocks are wrapped in TimestepEmbedSequential module because different modules have different forward function signatures; for example, convolution only accepts the feature map and residual blocks accept the feature map and time embedding. TimestepEmbedSequential calls them accordingly.

77self.input\_blocks.append(TimestepEmbedSequential(78nn.Conv2d(in\_channels,channels,3,padding=1)))

#

Number of channels at each block in the input half of U-Net

80input\_block\_channels=[channels]

#

Number of channels at each level

82channels\_list=[channels\*mforminchannel\_multipliers]

#

Prepare levels

84foriinrange(levels):

#

Add the residual blocks and attentions

86for\_inrange(n\_res\_blocks):

#

Residual block maps from previous number of channels to the number of channels in the current level

89layers=[ResBlock(channels,d\_time\_emb,out\_channels=channels\_list[i])]90channels=channels\_list[i]

#

Add transformer

92ifiinattention\_levels:93layers.append(SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond))

#

Add them to the input half of the U-Net and keep track of the number of channels of its output

96self.input\_blocks.append(TimestepEmbedSequential(\*layers))97input\_block\_channels.append(channels)

#

Down sample at all levels except last

99ifi!=levels-1:100self.input\_blocks.append(TimestepEmbedSequential(DownSample(channels)))101input\_block\_channels.append(channels)

#

The middle of the U-Net

104self.middle\_block=TimestepEmbedSequential(105ResBlock(channels,d\_time\_emb),106SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond),107ResBlock(channels,d\_time\_emb),108)

#

Second half of the U-Net

111self.output\_blocks=nn.ModuleList([])

#

Prepare levels in reverse order

113foriinreversed(range(levels)):

#

Add the residual blocks and attentions

115forjinrange(n\_res\_blocks+1):

#

Residual block maps from previous number of channels plus the skip connections from the input half of U-Net to the number of channels in the current level.

119layers=[ResBlock(channels+input\_block\_channels.pop(),d\_time\_emb,out\_channels=channels\_list[i])]120channels=channels\_list[i]

#

Add transformer

122ifiinattention\_levels:123layers.append(SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond))

#

Up-sample at every level after last residual block except the last one. Note that we are iterating in reverse; i.e. i == 0 is the last.

127ifi!=0andj==n\_res\_blocks:128layers.append(UpSample(channels))

#

Add to the output half of the U-Net

130self.output\_blocks.append(TimestepEmbedSequential(\*layers))

#

Final normalization and 3×3 convolution

133self.out=nn.Sequential(134normalization(channels),135nn.SiLU(),136nn.Conv2d(channels,out\_channels,3,padding=1),137)

#

Create sinusoidal time step embeddings

  • time_steps are the time steps of shape [batch_size]
  • max_period controls the minimum frequency of the embeddings.
139deftime\_step\_embedding(self,time\_steps:torch.Tensor,max\_period:int=10000):

#

2c​; half the channels are sin and the other half is cos,

147half=self.channels//2

#

10000c2i​1​

149frequencies=torch.exp(150-math.log(max\_period)\*torch.arange(start=0,end=half,dtype=torch.float32)/half151).to(device=time\_steps.device)

#

10000c2i​t​

153args=time\_steps[:,None].float()\*frequencies[None]

#

cos(10000c2i​t​) and sin(10000c2i​t​)

155returntorch.cat([torch.cos(args),torch.sin(args)],dim=-1)

#

  • x is the input feature map of shape [batch_size, channels, width, height]
  • time_steps are the time steps of shape [batch_size]
  • cond conditioning of shape [batch_size, n_cond, d_cond]
157defforward(self,x:torch.Tensor,time\_steps:torch.Tensor,cond:torch.Tensor):

#

To store the input half outputs for skip connections

164x\_input\_block=[]

#

Get time step embeddings

167t\_emb=self.time\_step\_embedding(time\_steps)168t\_emb=self.time\_embed(t\_emb)

#

Input half of the U-Net

171formoduleinself.input\_blocks:172x=module(x,t\_emb,cond)173x\_input\_block.append(x)

#

Middle of the U-Net

175x=self.middle\_block(x,t\_emb,cond)

#

Output half of the U-Net

177formoduleinself.output\_blocks:178x=torch.cat([x,x\_input\_block.pop()],dim=1)179x=module(x,t\_emb,cond)

#

Final normalization and 3×3 convolution

182returnself.out(x)

#

Sequential block for modules with different inputs

This sequential module can compose of different modules such as ResBlock , nn.Conv and SpatialTransformer and calls them with the matching signatures

185classTimestepEmbedSequential(nn.Sequential):

#

193defforward(self,x,t\_emb,cond=None):194forlayerinself:195ifisinstance(layer,ResBlock):196x=layer(x,t\_emb)197elifisinstance(layer,SpatialTransformer):198x=layer(x,cond)199else:200x=layer(x)201returnx

#

Up-sampling layer

204classUpSample(nn.Module):

#

  • channels is the number of channels
209def\_\_init\_\_(self,channels:int):

#

213super().\_\_init\_\_()

#

3×3 convolution mapping

215self.conv=nn.Conv2d(channels,channels,3,padding=1)

#

  • x is the input feature map with shape [batch_size, channels, height, width]
217defforward(self,x:torch.Tensor):

#

Up-sample by a factor of 2

222x=F.interpolate(x,scale\_factor=2,mode="nearest")

#

Apply convolution

224returnself.conv(x)

#

Down-sampling layer

227classDownSample(nn.Module):

#

  • channels is the number of channels
232def\_\_init\_\_(self,channels:int):

#

236super().\_\_init\_\_()

#

3×3 convolution with stride length of 2 to down-sample by a factor of 2

238self.op=nn.Conv2d(channels,channels,3,stride=2,padding=1)

#

  • x is the input feature map with shape [batch_size, channels, height, width]
240defforward(self,x:torch.Tensor):

#

Apply convolution

245returnself.op(x)

#

ResNet Block

248classResBlock(nn.Module):

#

  • channels the number of input channels
  • d_t_emb the size of timestep embeddings
  • out_channels is the number of out channels. defaults to `channels.
253def\_\_init\_\_(self,channels:int,d\_t\_emb:int,\*,out\_channels=None):

#

259super().\_\_init\_\_()

#

out_channels not specified

261ifout\_channelsisNone:262out\_channels=channels

#

First normalization and convolution

265self.in\_layers=nn.Sequential(266normalization(channels),267nn.SiLU(),268nn.Conv2d(channels,out\_channels,3,padding=1),269)

#

Time step embeddings

272self.emb\_layers=nn.Sequential(273nn.SiLU(),274nn.Linear(d\_t\_emb,out\_channels),275)

#

Final convolution layer

277self.out\_layers=nn.Sequential(278normalization(out\_channels),279nn.SiLU(),280nn.Dropout(0.),281nn.Conv2d(out\_channels,out\_channels,3,padding=1)282)

#

channels to out_channels mapping layer for residual connection

285ifout\_channels==channels:286self.skip\_connection=nn.Identity()287else:288self.skip\_connection=nn.Conv2d(channels,out\_channels,1)

#

  • x is the input feature map with shape [batch_size, channels, height, width]
  • t_emb is the time step embeddings of shape [batch_size, d_t_emb]
290defforward(self,x:torch.Tensor,t\_emb:torch.Tensor):

#

Initial convolution

296h=self.in\_layers(x)

#

Time step embeddings

298t\_emb=self.emb\_layers(t\_emb).type(h.dtype)

#

Add time step embeddings

300h=h+t\_emb[:,:,None,None]

#

Final convolution

302h=self.out\_layers(h)

#

Add skip connection

304returnself.skip\_connection(x)+h

#

Group normalization with float32 casting

307classGroupNorm32(nn.GroupNorm):

#

312defforward(self,x):313returnsuper().forward(x.float()).type(x.dtype)

#

Group normalization

This is a helper function, with fixed number of groups..

316defnormalization(channels):

#

322returnGroupNorm32(32,channels)

#

Test sinusoidal time step embeddings

325def\_test\_time\_embeddings():

#

329importmatplotlib.pyplotasplt330331plt.figure(figsize=(15,5))332m=UNetModel(in\_channels=1,out\_channels=1,channels=320,n\_res\_blocks=1,attention\_levels=[],333channel\_multipliers=[],334n\_heads=1,tf\_layers=1,d\_cond=1)335te=m.time\_step\_embedding(torch.arange(0,1000))336plt.plot(np.arange(1000),te[:,[50,100,190,260]].numpy())337plt.legend(["dim %d"%pforpin[50,100,190,260]])338plt.title("Time embeddings")339plt.show()

#

343if\_\_name\_\_=='\_\_main\_\_':344\_test\_time\_embeddings()

labml.ai