docs/diffusion/stable_diffusion/model/unet.html
homediffusionstable_diffusionmodel
This implements the U-Net that gives ϵcond(xt,c)
We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.
18importmath19fromtypingimportList2021importnumpyasnp22importtorch23importtorch.nnasnn24importtorch.nn.functionalasF2526fromlabml\_nn.diffusion.stable\_diffusion.model.unet\_attentionimportSpatialTransformer
29classUNetModel(nn.Module):
in_channels is the number of channels in the input feature mapout_channels is the number of channels in the output feature mapchannels is the base channel count for the modeln_res_blocks number of residual blocks at each levelattention_levels are the levels at which attention should be performedchannel_multipliers are the multiplicative factors for number of channels for each leveln_heads is the number of attention heads in the transformerstf_layers is the number of transformer layers in the transformersd_cond is the size of the conditional embedding in the transformers34def\_\_init\_\_(35self,\*,36in\_channels:int,37out\_channels:int,38channels:int,39n\_res\_blocks:int,40attention\_levels:List[int],41channel\_multipliers:List[int],42n\_heads:int,43tf\_layers:int=1,44d\_cond:int=768):
56super().\_\_init\_\_()57self.channels=channels
Number of levels
60levels=len(channel\_multipliers)
Size time embeddings
62d\_time\_emb=channels\*463self.time\_embed=nn.Sequential(64nn.Linear(channels,d\_time\_emb),65nn.SiLU(),66nn.Linear(d\_time\_emb,d\_time\_emb),67)
Input half of the U-Net
70self.input\_blocks=nn.ModuleList()
Initial 3×3 convolution that maps the input to channels . The blocks are wrapped in TimestepEmbedSequential module because different modules have different forward function signatures; for example, convolution only accepts the feature map and residual blocks accept the feature map and time embedding. TimestepEmbedSequential calls them accordingly.
77self.input\_blocks.append(TimestepEmbedSequential(78nn.Conv2d(in\_channels,channels,3,padding=1)))
Number of channels at each block in the input half of U-Net
80input\_block\_channels=[channels]
Number of channels at each level
82channels\_list=[channels\*mforminchannel\_multipliers]
Prepare levels
84foriinrange(levels):
Add the residual blocks and attentions
86for\_inrange(n\_res\_blocks):
Residual block maps from previous number of channels to the number of channels in the current level
89layers=[ResBlock(channels,d\_time\_emb,out\_channels=channels\_list[i])]90channels=channels\_list[i]
Add transformer
92ifiinattention\_levels:93layers.append(SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond))
Add them to the input half of the U-Net and keep track of the number of channels of its output
96self.input\_blocks.append(TimestepEmbedSequential(\*layers))97input\_block\_channels.append(channels)
Down sample at all levels except last
99ifi!=levels-1:100self.input\_blocks.append(TimestepEmbedSequential(DownSample(channels)))101input\_block\_channels.append(channels)
The middle of the U-Net
104self.middle\_block=TimestepEmbedSequential(105ResBlock(channels,d\_time\_emb),106SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond),107ResBlock(channels,d\_time\_emb),108)
Second half of the U-Net
111self.output\_blocks=nn.ModuleList([])
Prepare levels in reverse order
113foriinreversed(range(levels)):
Add the residual blocks and attentions
115forjinrange(n\_res\_blocks+1):
Residual block maps from previous number of channels plus the skip connections from the input half of U-Net to the number of channels in the current level.
119layers=[ResBlock(channels+input\_block\_channels.pop(),d\_time\_emb,out\_channels=channels\_list[i])]120channels=channels\_list[i]
Add transformer
122ifiinattention\_levels:123layers.append(SpatialTransformer(channels,n\_heads,tf\_layers,d\_cond))
Up-sample at every level after last residual block except the last one. Note that we are iterating in reverse; i.e. i == 0 is the last.
127ifi!=0andj==n\_res\_blocks:128layers.append(UpSample(channels))
Add to the output half of the U-Net
130self.output\_blocks.append(TimestepEmbedSequential(\*layers))
Final normalization and 3×3 convolution
133self.out=nn.Sequential(134normalization(channels),135nn.SiLU(),136nn.Conv2d(channels,out\_channels,3,padding=1),137)
time_steps are the time steps of shape [batch_size]max_period controls the minimum frequency of the embeddings.139deftime\_step\_embedding(self,time\_steps:torch.Tensor,max\_period:int=10000):
2c; half the channels are sin and the other half is cos,
147half=self.channels//2
10000c2i1
149frequencies=torch.exp(150-math.log(max\_period)\*torch.arange(start=0,end=half,dtype=torch.float32)/half151).to(device=time\_steps.device)
10000c2it
153args=time\_steps[:,None].float()\*frequencies[None]
cos(10000c2it) and sin(10000c2it)
155returntorch.cat([torch.cos(args),torch.sin(args)],dim=-1)
x is the input feature map of shape [batch_size, channels, width, height]time_steps are the time steps of shape [batch_size]cond conditioning of shape [batch_size, n_cond, d_cond]157defforward(self,x:torch.Tensor,time\_steps:torch.Tensor,cond:torch.Tensor):
To store the input half outputs for skip connections
164x\_input\_block=[]
Get time step embeddings
167t\_emb=self.time\_step\_embedding(time\_steps)168t\_emb=self.time\_embed(t\_emb)
Input half of the U-Net
171formoduleinself.input\_blocks:172x=module(x,t\_emb,cond)173x\_input\_block.append(x)
Middle of the U-Net
175x=self.middle\_block(x,t\_emb,cond)
Output half of the U-Net
177formoduleinself.output\_blocks:178x=torch.cat([x,x\_input\_block.pop()],dim=1)179x=module(x,t\_emb,cond)
Final normalization and 3×3 convolution
182returnself.out(x)
This sequential module can compose of different modules such as ResBlock , nn.Conv and SpatialTransformer and calls them with the matching signatures
185classTimestepEmbedSequential(nn.Sequential):
193defforward(self,x,t\_emb,cond=None):194forlayerinself:195ifisinstance(layer,ResBlock):196x=layer(x,t\_emb)197elifisinstance(layer,SpatialTransformer):198x=layer(x,cond)199else:200x=layer(x)201returnx
204classUpSample(nn.Module):
channels is the number of channels209def\_\_init\_\_(self,channels:int):
213super().\_\_init\_\_()
3×3 convolution mapping
215self.conv=nn.Conv2d(channels,channels,3,padding=1)
x is the input feature map with shape [batch_size, channels, height, width]217defforward(self,x:torch.Tensor):
Up-sample by a factor of 2
222x=F.interpolate(x,scale\_factor=2,mode="nearest")
Apply convolution
224returnself.conv(x)
227classDownSample(nn.Module):
channels is the number of channels232def\_\_init\_\_(self,channels:int):
236super().\_\_init\_\_()
3×3 convolution with stride length of 2 to down-sample by a factor of 2
238self.op=nn.Conv2d(channels,channels,3,stride=2,padding=1)
x is the input feature map with shape [batch_size, channels, height, width]240defforward(self,x:torch.Tensor):
Apply convolution
245returnself.op(x)
248classResBlock(nn.Module):
channels the number of input channelsd_t_emb the size of timestep embeddingsout_channels is the number of out channels. defaults to `channels.253def\_\_init\_\_(self,channels:int,d\_t\_emb:int,\*,out\_channels=None):
259super().\_\_init\_\_()
out_channels not specified
261ifout\_channelsisNone:262out\_channels=channels
First normalization and convolution
265self.in\_layers=nn.Sequential(266normalization(channels),267nn.SiLU(),268nn.Conv2d(channels,out\_channels,3,padding=1),269)
Time step embeddings
272self.emb\_layers=nn.Sequential(273nn.SiLU(),274nn.Linear(d\_t\_emb,out\_channels),275)
Final convolution layer
277self.out\_layers=nn.Sequential(278normalization(out\_channels),279nn.SiLU(),280nn.Dropout(0.),281nn.Conv2d(out\_channels,out\_channels,3,padding=1)282)
channels to out_channels mapping layer for residual connection
285ifout\_channels==channels:286self.skip\_connection=nn.Identity()287else:288self.skip\_connection=nn.Conv2d(channels,out\_channels,1)
x is the input feature map with shape [batch_size, channels, height, width]t_emb is the time step embeddings of shape [batch_size, d_t_emb]290defforward(self,x:torch.Tensor,t\_emb:torch.Tensor):
Initial convolution
296h=self.in\_layers(x)
Time step embeddings
298t\_emb=self.emb\_layers(t\_emb).type(h.dtype)
Add time step embeddings
300h=h+t\_emb[:,:,None,None]
Final convolution
302h=self.out\_layers(h)
Add skip connection
304returnself.skip\_connection(x)+h
307classGroupNorm32(nn.GroupNorm):
312defforward(self,x):313returnsuper().forward(x.float()).type(x.dtype)
This is a helper function, with fixed number of groups..
316defnormalization(channels):
322returnGroupNorm32(32,channels)
Test sinusoidal time step embeddings
325def\_test\_time\_embeddings():
329importmatplotlib.pyplotasplt330331plt.figure(figsize=(15,5))332m=UNetModel(in\_channels=1,out\_channels=1,channels=320,n\_res\_blocks=1,attention\_levels=[],333channel\_multipliers=[],334n\_heads=1,tf\_layers=1,d\_cond=1)335te=m.time\_step\_embedding(torch.arange(0,1000))336plt.plot(np.arange(1000),te[:,[50,100,190,260]].numpy())337plt.legend(["dim %d"%pforpin[50,100,190,260]])338plt.title("Time embeddings")339plt.show()
343if\_\_name\_\_=='\_\_main\_\_':344\_test\_time\_embeddings()