docs/diffusion/ddpm/unet.html
This is a U-Net based model to predict noise ϵθ(xt,t).
U-Net is a gets it's name from the U shape in the model diagram. It processes a given image by progressively lowering (halving) the feature map resolution and then increasing the resolution. There are pass-through connection at each resolution.
This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) and also adds time-step embeddings t.
24importmath25fromtypingimportOptional,Tuple,Union,List2627importtorch28fromtorchimportnn
x⋅σ(x)
31classSwish(nn.Module):
38defforward(self,x):39returnx\*torch.sigmoid(x)
42classTimeEmbedding(nn.Module):
n_channels is the number of dimensions in the embedding47def\_\_init\_\_(self,n\_channels:int):
51super().\_\_init\_\_()52self.n\_channels=n\_channels
First linear layer
54self.lin1=nn.Linear(self.n\_channels//4,self.n\_channels)
Activation
56self.act=Swish()
Second linear layer
58self.lin2=nn.Linear(self.n\_channels,self.n\_channels)
60defforward(self,t:torch.Tensor):
Create sinusoidal position embeddings same as those from the transformer
PEt,i(1)PEt,i(2)=sin(10000d−1it)=cos(10000d−1it)
where d is half_dim
70half\_dim=self.n\_channels//871emb=math.log(10\_000)/(half\_dim-1)72emb=torch.exp(torch.arange(half\_dim,device=t.device)\*-emb)73emb=t[:,None]\*emb[None,:]74emb=torch.cat((emb.sin(),emb.cos()),dim=1)
Transform with the MLP
77emb=self.act(self.lin1(emb))78emb=self.lin2(emb)
81returnemb
A residual block has two convolution layers with group normalization. Each resolution is processed with two residual blocks.
84classResidualBlock(nn.Module):
in_channels is the number of input channelsout_channels is the number of input channelstime_channels is the number channels in the time step (t) embeddingsn_groups is the number of groups for group normalizationdropout is the dropout rate92def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,93n\_groups:int=32,dropout:float=0.1):
101super().\_\_init\_\_()
Group normalization and the first convolution layer
103self.norm1=nn.GroupNorm(n\_groups,in\_channels)104self.act1=Swish()105self.conv1=nn.Conv2d(in\_channels,out\_channels,kernel\_size=(3,3),padding=(1,1))
Group normalization and the second convolution layer
108self.norm2=nn.GroupNorm(n\_groups,out\_channels)109self.act2=Swish()110self.conv2=nn.Conv2d(out\_channels,out\_channels,kernel\_size=(3,3),padding=(1,1))
If the number of input channels is not equal to the number of output channels we have to project the shortcut connection
114ifin\_channels!=out\_channels:115self.shortcut=nn.Conv2d(in\_channels,out\_channels,kernel\_size=(1,1))116else:117self.shortcut=nn.Identity()
Linear layer for time embeddings
120self.time\_emb=nn.Linear(time\_channels,out\_channels)121self.time\_act=Swish()122123self.dropout=nn.Dropout(dropout)
x has shape [batch_size, in_channels, height, width]t has shape [batch_size, time_channels]125defforward(self,x:torch.Tensor,t:torch.Tensor):
First convolution layer
131h=self.conv1(self.act1(self.norm1(x)))
Add time embeddings
133h+=self.time\_emb(self.time\_act(t))[:,:,None,None]
Second convolution layer
135h=self.conv2(self.dropout(self.act2(self.norm2(h))))
Add the shortcut connection and return
138returnh+self.shortcut(x)
This is similar to transformer multi-head attention.
141classAttentionBlock(nn.Module):
n_channels is the number of channels in the inputn_heads is the number of heads in multi-head attentiond_k is the number of dimensions in each headn_groups is the number of groups for group normalization148def\_\_init\_\_(self,n\_channels:int,n\_heads:int=1,d\_k:int=None,n\_groups:int=32):
155super().\_\_init\_\_()
Default d_k
158ifd\_kisNone:159d\_k=n\_channels
Normalization layer
161self.norm=nn.GroupNorm(n\_groups,n\_channels)
Projections for query, key and values
163self.projection=nn.Linear(n\_channels,n\_heads\*d\_k\*3)
Linear layer for final transformation
165self.output=nn.Linear(n\_heads\*d\_k,n\_channels)
Scale for dot-product attention
167self.scale=d\_k\*\*-0.5
169self.n\_heads=n\_heads170self.d\_k=d\_k
x has shape [batch_size, in_channels, height, width]t has shape [batch_size, time_channels]172defforward(self,x:torch.Tensor,t:Optional[torch.Tensor]=None):
t is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock .
179\_=t
Get shape
181batch\_size,n\_channels,height,width=x.shape
Change x to shape [batch_size, seq, n_channels]
183x=x.view(batch\_size,n\_channels,-1).permute(0,2,1)
Get query, key, and values (concatenated) and shape it to [batch_size, seq, n_heads, 3 * d_k]
185qkv=self.projection(x).view(batch\_size,-1,self.n\_heads,3\*self.d\_k)
Split query, key, and values. Each of them will have shape [batch_size, seq, n_heads, d_k]
187q,k,v=torch.chunk(qkv,3,dim=-1)
Calculate scaled dot-product dkQK⊤
189attn=torch.einsum('bihd,bjhd-\>bijh',q,k)\*self.scale
Softmax along the sequence dimension seqsoftmax(dkQK⊤)
191attn=attn.softmax(dim=2)
Multiply by values
193res=torch.einsum('bijh,bjhd-\>bihd',attn,v)
Reshape to [batch_size, seq, n_heads * d_k]
195res=res.view(batch\_size,-1,self.n\_heads\*self.d\_k)
Transform to [batch_size, seq, n_channels]
197res=self.output(res)
Add skip connection
200res+=x
Change to shape [batch_size, in_channels, height, width]
203res=res.permute(0,2,1).view(batch\_size,n\_channels,height,width)
206returnres
This combines ResidualBlock and AttentionBlock . These are used in the first half of U-Net at each resolution.
209classDownBlock(nn.Module):
216def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,has\_attn:bool):217super().\_\_init\_\_()218self.res=ResidualBlock(in\_channels,out\_channels,time\_channels)219ifhas\_attn:220self.attn=AttentionBlock(out\_channels)221else:222self.attn=nn.Identity()
224defforward(self,x:torch.Tensor,t:torch.Tensor):225x=self.res(x,t)226x=self.attn(x)227returnx
This combines ResidualBlock and AttentionBlock . These are used in the second half of U-Net at each resolution.
230classUpBlock(nn.Module):
237def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,has\_attn:bool):238super().\_\_init\_\_()
The input has in_channels + out_channels because we concatenate the output of the same resolution from the first half of the U-Net
241self.res=ResidualBlock(in\_channels+out\_channels,out\_channels,time\_channels)242ifhas\_attn:243self.attn=AttentionBlock(out\_channels)244else:245self.attn=nn.Identity()
247defforward(self,x:torch.Tensor,t:torch.Tensor):248x=self.res(x,t)249x=self.attn(x)250returnx
It combines a ResidualBlock , AttentionBlock , followed by another ResidualBlock . This block is applied at the lowest resolution of the U-Net.
253classMiddleBlock(nn.Module):
261def\_\_init\_\_(self,n\_channels:int,time\_channels:int):262super().\_\_init\_\_()263self.res1=ResidualBlock(n\_channels,n\_channels,time\_channels)264self.attn=AttentionBlock(n\_channels)265self.res2=ResidualBlock(n\_channels,n\_channels,time\_channels)
267defforward(self,x:torch.Tensor,t:torch.Tensor):268x=self.res1(x,t)269x=self.attn(x)270x=self.res2(x,t)271returnx
274classUpsample(nn.Module):
279def\_\_init\_\_(self,n\_channels):280super().\_\_init\_\_()281self.conv=nn.ConvTranspose2d(n\_channels,n\_channels,(4,4),(2,2),(1,1))
283defforward(self,x:torch.Tensor,t:torch.Tensor):
t is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock .
286\_=t287returnself.conv(x)
290classDownsample(nn.Module):
295def\_\_init\_\_(self,n\_channels):296super().\_\_init\_\_()297self.conv=nn.Conv2d(n\_channels,n\_channels,(3,3),(2,2),(1,1))
299defforward(self,x:torch.Tensor,t:torch.Tensor):
t is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock .
302\_=t303returnself.conv(x)
306classUNet(nn.Module):
image_channels is the number of channels in the image. 3 for RGB.n_channels is number of channels in the initial feature map that we transform the image intoch_mults is the list of channel numbers at each resolution. The number of channels is ch_mults[i] * n_channelsis_attn is a list of booleans that indicate whether to use attention at each resolutionn_blocks is the number of UpDownBlocks at each resolution311def\_\_init\_\_(self,image\_channels:int=3,n\_channels:int=64,312ch\_mults:Union[Tuple[int,...],List[int]]=(1,2,2,4),313is\_attn:Union[Tuple[bool,...],List[bool]]=(False,False,True,True),314n\_blocks:int=2):
322super().\_\_init\_\_()
Number of resolutions
325n\_resolutions=len(ch\_mults)
Project image into feature map
328self.image\_proj=nn.Conv2d(image\_channels,n\_channels,kernel\_size=(3,3),padding=(1,1))
Time embedding layer. Time embedding has n_channels * 4 channels
331self.time\_emb=TimeEmbedding(n\_channels\*4)
334down=[]
Number of channels
336out\_channels=in\_channels=n\_channels
For each resolution
338foriinrange(n\_resolutions):
Number of output channels at this resolution
340out\_channels=in\_channels\*ch\_mults[i]
Add n_blocks
342for\_inrange(n\_blocks):343down.append(DownBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))344in\_channels=out\_channels
Down sample at all resolutions except the last
346ifi\<n\_resolutions-1:347down.append(Downsample(in\_channels))
Combine the set of modules
350self.down=nn.ModuleList(down)
Middle block
353self.middle=MiddleBlock(out\_channels,n\_channels\*4,)
356up=[]
Number of channels
358in\_channels=out\_channels
For each resolution
360foriinreversed(range(n\_resolutions)):
n_blocks at the same resolution
362out\_channels=in\_channels363for\_inrange(n\_blocks):364up.append(UpBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))
Final block to reduce the number of channels
366out\_channels=in\_channels//ch\_mults[i]367up.append(UpBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))368in\_channels=out\_channels
Up sample at all resolutions except last
370ifi\>0:371up.append(Upsample(in\_channels))
Combine the set of modules
374self.up=nn.ModuleList(up)
Final normalization and convolution layer
377self.norm=nn.GroupNorm(8,n\_channels)378self.act=Swish()379self.final=nn.Conv2d(in\_channels,image\_channels,kernel\_size=(3,3),padding=(1,1))
x has shape [batch_size, in_channels, height, width]t has shape [batch_size]381defforward(self,x:torch.Tensor,t:torch.Tensor):
Get time-step embeddings
388t=self.time\_emb(t)
Get image projection
391x=self.image\_proj(x)
h will store outputs at each resolution for skip connection
394h=[x]
First half of U-Net
396forminself.down:397x=m(x,t)398h.append(x)
Middle (bottom)
401x=self.middle(x,t)
Second half of U-Net
404forminself.up:405ifisinstance(m,Upsample):406x=m(x,t)407else:
Get the skip connection from first half of U-Net and concatenate
409s=h.pop()410x=torch.cat((x,s),dim=1)
412x=m(x,t)
Final normalization and convolution
415returnself.final(self.act(self.norm(x)))