Back to Annotated Deep Learning Paper Implementations

U-Net model for Denoising Diffusion Probabilistic Models (DDPM)

docs/diffusion/ddpm/unet.html

latest14.4 KB
Original Source

homediffusionddpm

View code on Github

#

U-Net model for Denoising Diffusion Probabilistic Models (DDPM)

This is a U-Net based model to predict noise ϵθ​(xt​,t).

U-Net is a gets it's name from the U shape in the model diagram. It processes a given image by progressively lowering (halving) the feature map resolution and then increasing the resolution. There are pass-through connection at each resolution.

This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention) and also adds time-step embeddings t.

24importmath25fromtypingimportOptional,Tuple,Union,List2627importtorch28fromtorchimportnn

#

Swish activation function

x⋅σ(x)

31classSwish(nn.Module):

#

38defforward(self,x):39returnx\*torch.sigmoid(x)

#

Embeddings for t

42classTimeEmbedding(nn.Module):

#

  • n_channels is the number of dimensions in the embedding
47def\_\_init\_\_(self,n\_channels:int):

#

51super().\_\_init\_\_()52self.n\_channels=n\_channels

#

First linear layer

54self.lin1=nn.Linear(self.n\_channels//4,self.n\_channels)

#

Activation

56self.act=Swish()

#

Second linear layer

58self.lin2=nn.Linear(self.n\_channels,self.n\_channels)

#

60defforward(self,t:torch.Tensor):

#

Create sinusoidal position embeddings same as those from the transformer

PEt,i(1)​PEt,i(2)​​=sin(10000d−1i​t​)=cos(10000d−1i​t​)​

where d is half_dim

70half\_dim=self.n\_channels//871emb=math.log(10\_000)/(half\_dim-1)72emb=torch.exp(torch.arange(half\_dim,device=t.device)\*-emb)73emb=t[:,None]\*emb[None,:]74emb=torch.cat((emb.sin(),emb.cos()),dim=1)

#

Transform with the MLP

77emb=self.act(self.lin1(emb))78emb=self.lin2(emb)

#

81returnemb

#

Residual block

A residual block has two convolution layers with group normalization. Each resolution is processed with two residual blocks.

84classResidualBlock(nn.Module):

#

  • in_channels is the number of input channels
  • out_channels is the number of input channels
  • time_channels is the number channels in the time step (t) embeddings
  • n_groups is the number of groups for group normalization
  • dropout is the dropout rate
92def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,93n\_groups:int=32,dropout:float=0.1):

#

101super().\_\_init\_\_()

#

Group normalization and the first convolution layer

103self.norm1=nn.GroupNorm(n\_groups,in\_channels)104self.act1=Swish()105self.conv1=nn.Conv2d(in\_channels,out\_channels,kernel\_size=(3,3),padding=(1,1))

#

Group normalization and the second convolution layer

108self.norm2=nn.GroupNorm(n\_groups,out\_channels)109self.act2=Swish()110self.conv2=nn.Conv2d(out\_channels,out\_channels,kernel\_size=(3,3),padding=(1,1))

#

If the number of input channels is not equal to the number of output channels we have to project the shortcut connection

114ifin\_channels!=out\_channels:115self.shortcut=nn.Conv2d(in\_channels,out\_channels,kernel\_size=(1,1))116else:117self.shortcut=nn.Identity()

#

Linear layer for time embeddings

120self.time\_emb=nn.Linear(time\_channels,out\_channels)121self.time\_act=Swish()122123self.dropout=nn.Dropout(dropout)

#

  • x has shape [batch_size, in_channels, height, width]
  • t has shape [batch_size, time_channels]
125defforward(self,x:torch.Tensor,t:torch.Tensor):

#

First convolution layer

131h=self.conv1(self.act1(self.norm1(x)))

#

Add time embeddings

133h+=self.time\_emb(self.time\_act(t))[:,:,None,None]

#

Second convolution layer

135h=self.conv2(self.dropout(self.act2(self.norm2(h))))

#

Add the shortcut connection and return

138returnh+self.shortcut(x)

#

Attention block

This is similar to transformer multi-head attention.

141classAttentionBlock(nn.Module):

#

  • n_channels is the number of channels in the input
  • n_heads is the number of heads in multi-head attention
  • d_k is the number of dimensions in each head
  • n_groups is the number of groups for group normalization
148def\_\_init\_\_(self,n\_channels:int,n\_heads:int=1,d\_k:int=None,n\_groups:int=32):

#

155super().\_\_init\_\_()

#

Default d_k

158ifd\_kisNone:159d\_k=n\_channels

#

Normalization layer

161self.norm=nn.GroupNorm(n\_groups,n\_channels)

#

Projections for query, key and values

163self.projection=nn.Linear(n\_channels,n\_heads\*d\_k\*3)

#

Linear layer for final transformation

165self.output=nn.Linear(n\_heads\*d\_k,n\_channels)

#

Scale for dot-product attention

167self.scale=d\_k\*\*-0.5

#

169self.n\_heads=n\_heads170self.d\_k=d\_k

#

  • x has shape [batch_size, in_channels, height, width]
  • t has shape [batch_size, time_channels]
172defforward(self,x:torch.Tensor,t:Optional[torch.Tensor]=None):

#

t is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock .

179\_=t

#

Get shape

181batch\_size,n\_channels,height,width=x.shape

#

Change x to shape [batch_size, seq, n_channels]

183x=x.view(batch\_size,n\_channels,-1).permute(0,2,1)

#

Get query, key, and values (concatenated) and shape it to [batch_size, seq, n_heads, 3 * d_k]

185qkv=self.projection(x).view(batch\_size,-1,self.n\_heads,3\*self.d\_k)

#

Split query, key, and values. Each of them will have shape [batch_size, seq, n_heads, d_k]

187q,k,v=torch.chunk(qkv,3,dim=-1)

#

Calculate scaled dot-product dk​​QK⊤​

189attn=torch.einsum('bihd,bjhd-\>bijh',q,k)\*self.scale

#

Softmax along the sequence dimension seqsoftmax​(dk​​QK⊤​)

191attn=attn.softmax(dim=2)

#

Multiply by values

193res=torch.einsum('bijh,bjhd-\>bihd',attn,v)

#

Reshape to [batch_size, seq, n_heads * d_k]

195res=res.view(batch\_size,-1,self.n\_heads\*self.d\_k)

#

Transform to [batch_size, seq, n_channels]

197res=self.output(res)

#

Add skip connection

200res+=x

#

Change to shape [batch_size, in_channels, height, width]

203res=res.permute(0,2,1).view(batch\_size,n\_channels,height,width)

#

206returnres

#

Down block

This combines ResidualBlock and AttentionBlock . These are used in the first half of U-Net at each resolution.

209classDownBlock(nn.Module):

#

216def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,has\_attn:bool):217super().\_\_init\_\_()218self.res=ResidualBlock(in\_channels,out\_channels,time\_channels)219ifhas\_attn:220self.attn=AttentionBlock(out\_channels)221else:222self.attn=nn.Identity()

#

224defforward(self,x:torch.Tensor,t:torch.Tensor):225x=self.res(x,t)226x=self.attn(x)227returnx

#

Up block

This combines ResidualBlock and AttentionBlock . These are used in the second half of U-Net at each resolution.

230classUpBlock(nn.Module):

#

237def\_\_init\_\_(self,in\_channels:int,out\_channels:int,time\_channels:int,has\_attn:bool):238super().\_\_init\_\_()

#

The input has in_channels + out_channels because we concatenate the output of the same resolution from the first half of the U-Net

241self.res=ResidualBlock(in\_channels+out\_channels,out\_channels,time\_channels)242ifhas\_attn:243self.attn=AttentionBlock(out\_channels)244else:245self.attn=nn.Identity()

#

247defforward(self,x:torch.Tensor,t:torch.Tensor):248x=self.res(x,t)249x=self.attn(x)250returnx

#

Middle block

It combines a ResidualBlock , AttentionBlock , followed by another ResidualBlock . This block is applied at the lowest resolution of the U-Net.

253classMiddleBlock(nn.Module):

#

261def\_\_init\_\_(self,n\_channels:int,time\_channels:int):262super().\_\_init\_\_()263self.res1=ResidualBlock(n\_channels,n\_channels,time\_channels)264self.attn=AttentionBlock(n\_channels)265self.res2=ResidualBlock(n\_channels,n\_channels,time\_channels)

#

267defforward(self,x:torch.Tensor,t:torch.Tensor):268x=self.res1(x,t)269x=self.attn(x)270x=self.res2(x,t)271returnx

#

Scale up the feature map by 2×

274classUpsample(nn.Module):

#

279def\_\_init\_\_(self,n\_channels):280super().\_\_init\_\_()281self.conv=nn.ConvTranspose2d(n\_channels,n\_channels,(4,4),(2,2),(1,1))

#

283defforward(self,x:torch.Tensor,t:torch.Tensor):

#

t is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock .

286\_=t287returnself.conv(x)

#

Scale down the feature map by 21​×

290classDownsample(nn.Module):

#

295def\_\_init\_\_(self,n\_channels):296super().\_\_init\_\_()297self.conv=nn.Conv2d(n\_channels,n\_channels,(3,3),(2,2),(1,1))

#

299defforward(self,x:torch.Tensor,t:torch.Tensor):

#

t is not used, but it's kept in the arguments because for the attention layer function signature to match with ResidualBlock .

302\_=t303returnself.conv(x)

#

U-Net

306classUNet(nn.Module):

#

  • image_channels is the number of channels in the image. 3 for RGB.
  • n_channels is number of channels in the initial feature map that we transform the image into
  • ch_mults is the list of channel numbers at each resolution. The number of channels is ch_mults[i] * n_channels
  • is_attn is a list of booleans that indicate whether to use attention at each resolution
  • n_blocks is the number of UpDownBlocks at each resolution
311def\_\_init\_\_(self,image\_channels:int=3,n\_channels:int=64,312ch\_mults:Union[Tuple[int,...],List[int]]=(1,2,2,4),313is\_attn:Union[Tuple[bool,...],List[bool]]=(False,False,True,True),314n\_blocks:int=2):

#

322super().\_\_init\_\_()

#

Number of resolutions

325n\_resolutions=len(ch\_mults)

#

Project image into feature map

328self.image\_proj=nn.Conv2d(image\_channels,n\_channels,kernel\_size=(3,3),padding=(1,1))

#

Time embedding layer. Time embedding has n_channels * 4 channels

331self.time\_emb=TimeEmbedding(n\_channels\*4)

#

First half of U-Net - decreasing resolution

334down=[]

#

Number of channels

336out\_channels=in\_channels=n\_channels

#

For each resolution

338foriinrange(n\_resolutions):

#

Number of output channels at this resolution

340out\_channels=in\_channels\*ch\_mults[i]

#

Add n_blocks

342for\_inrange(n\_blocks):343down.append(DownBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))344in\_channels=out\_channels

#

Down sample at all resolutions except the last

346ifi\<n\_resolutions-1:347down.append(Downsample(in\_channels))

#

Combine the set of modules

350self.down=nn.ModuleList(down)

#

Middle block

353self.middle=MiddleBlock(out\_channels,n\_channels\*4,)

#

Second half of U-Net - increasing resolution

356up=[]

#

Number of channels

358in\_channels=out\_channels

#

For each resolution

360foriinreversed(range(n\_resolutions)):

#

n_blocks at the same resolution

362out\_channels=in\_channels363for\_inrange(n\_blocks):364up.append(UpBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))

#

Final block to reduce the number of channels

366out\_channels=in\_channels//ch\_mults[i]367up.append(UpBlock(in\_channels,out\_channels,n\_channels\*4,is\_attn[i]))368in\_channels=out\_channels

#

Up sample at all resolutions except last

370ifi\>0:371up.append(Upsample(in\_channels))

#

Combine the set of modules

374self.up=nn.ModuleList(up)

#

Final normalization and convolution layer

377self.norm=nn.GroupNorm(8,n\_channels)378self.act=Swish()379self.final=nn.Conv2d(in\_channels,image\_channels,kernel\_size=(3,3),padding=(1,1))

#

  • x has shape [batch_size, in_channels, height, width]
  • t has shape [batch_size]
381defforward(self,x:torch.Tensor,t:torch.Tensor):

#

Get time-step embeddings

388t=self.time\_emb(t)

#

Get image projection

391x=self.image\_proj(x)

#

h will store outputs at each resolution for skip connection

394h=[x]

#

First half of U-Net

396forminself.down:397x=m(x,t)398h.append(x)

#

Middle (bottom)

401x=self.middle(x,t)

#

Second half of U-Net

404forminself.up:405ifisinstance(m,Upsample):406x=m(x,t)407else:

#

Get the skip connection from first half of U-Net and concatenate

409s=h.pop()410x=torch.cat((x,s),dim=1)

#

412x=m(x,t)

#

Final normalization and convolution

415returnself.final(self.act(self.norm(x)))

labml.ai