Back to Annotated Deep Learning Paper Implementations

Transformer for Stable Diffusion U-Net

docs/diffusion/stable_diffusion/model/unet_attention.html

latest11.2 KB
Original Source

homediffusionstable_diffusionmodel

View code on Github

#

Transformer for Stable Diffusion U-Net

This implements the transformer module used in U-Net that gives ϵcond​(xt​,c)

We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.

19fromtypingimportOptional2021importtorch22importtorch.nn.functionalasF23fromtorchimportnn

#

Spatial Transformer

26classSpatialTransformer(nn.Module):

#

  • channels is the number of channels in the feature map
  • n_heads is the number of attention heads
  • n_layers is the number of transformer layers
  • d_cond is the size of the conditional embedding
31def\_\_init\_\_(self,channels:int,n\_heads:int,n\_layers:int,d\_cond:int):

#

38super().\_\_init\_\_()

#

Initial group normalization

40self.norm=torch.nn.GroupNorm(num\_groups=32,num\_channels=channels,eps=1e-6,affine=True)

#

Initial 1×1 convolution

42self.proj\_in=nn.Conv2d(channels,channels,kernel\_size=1,stride=1,padding=0)

#

Transformer layers

45self.transformer\_blocks=nn.ModuleList(46[BasicTransformerBlock(channels,n\_heads,channels//n\_heads,d\_cond=d\_cond)for\_inrange(n\_layers)]47)

#

Final 1×1 convolution

50self.proj\_out=nn.Conv2d(channels,channels,kernel\_size=1,stride=1,padding=0)

#

  • x is the feature map of shape [batch_size, channels, height, width]
  • cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]
52defforward(self,x:torch.Tensor,cond:torch.Tensor):

#

Get shape [batch_size, channels, height, width]

58b,c,h,w=x.shape

#

For residual connection

60x\_in=x

#

Normalize

62x=self.norm(x)

#

Initial 1×1 convolution

64x=self.proj\_in(x)

#

Transpose and reshape from [batch_size, channels, height, width] to [batch_size, height * width, channels]

67x=x.permute(0,2,3,1).view(b,h\*w,c)

#

Apply the transformer layers

69forblockinself.transformer\_blocks:70x=block(x,cond)

#

Reshape and transpose from [batch_size, height * width, channels] to [batch_size, channels, height, width]

73x=x.view(b,h,w,c).permute(0,3,1,2)

#

Final 1×1 convolution

75x=self.proj\_out(x)

#

Add residual

77returnx+x\_in

#

Transformer Layer

80classBasicTransformerBlock(nn.Module):

#

  • d_model is the input embedding size
  • n_heads is the number of attention heads
  • d_head is the size of a attention head
  • d_cond is the size of the conditional embeddings
85def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_head:int,d\_cond:int):

#

92super().\_\_init\_\_()

#

Self-attention layer and pre-norm layer

94self.attn1=CrossAttention(d\_model,d\_model,n\_heads,d\_head)95self.norm1=nn.LayerNorm(d\_model)

#

Cross attention layer and pre-norm layer

97self.attn2=CrossAttention(d\_model,d\_cond,n\_heads,d\_head)98self.norm2=nn.LayerNorm(d\_model)

#

Feed-forward network and pre-norm layer

100self.ff=FeedForward(d\_model)101self.norm3=nn.LayerNorm(d\_model)

#

  • x are the input embeddings of shape [batch_size, height * width, d_model]
  • cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]
103defforward(self,x:torch.Tensor,cond:torch.Tensor):

#

Self attention

109x=self.attn1(self.norm1(x))+x

#

Cross-attention with conditioning

111x=self.attn2(self.norm2(x),cond=cond)+x

#

Feed-forward network

113x=self.ff(self.norm3(x))+x

#

115returnx

#

Cross Attention Layer

This falls-back to self-attention when conditional embeddings are not specified.

118classCrossAttention(nn.Module):

#

125use\_flash\_attention:bool=False

#

  • d_model is the input embedding size
  • n_heads is the number of attention heads
  • d_head is the size of a attention head
  • d_cond is the size of the conditional embeddings
  • is_inplace specifies whether to perform the attention softmax computation inplace to save memory
127def\_\_init\_\_(self,d\_model:int,d\_cond:int,n\_heads:int,d\_head:int,is\_inplace:bool=True):

#

136super().\_\_init\_\_()137138self.is\_inplace=is\_inplace139self.n\_heads=n\_heads140self.d\_head=d\_head

#

Attention scaling factor

143self.scale=d\_head\*\*-0.5

#

Query, key and value mappings

146d\_attn=d\_head\*n\_heads147self.to\_q=nn.Linear(d\_model,d\_attn,bias=False)148self.to\_k=nn.Linear(d\_cond,d\_attn,bias=False)149self.to\_v=nn.Linear(d\_cond,d\_attn,bias=False)

#

Final linear layer

152self.to\_out=nn.Sequential(nn.Linear(d\_attn,d\_model))

#

Setup flash attention. Flash attention is only used if it's installed and CrossAttention.use_flash_attention is set to True .

157try:

#

You can install flash attention by cloning their Github repo, https://github.com/HazyResearch/flash-attention and then running python setup.py install

161fromflash\_attn.flash\_attentionimportFlashAttention162self.flash=FlashAttention()

#

Set the scale for scaled dot-product attention.

164self.flash.softmax\_scale=self.scale

#

Set to None if it's not installed

166exceptImportError:167self.flash=None

#

  • x are the input embeddings of shape [batch_size, height * width, d_model]
  • cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]
169defforward(self,x:torch.Tensor,cond:Optional[torch.Tensor]=None):

#

If cond is None we perform self attention

176has\_cond=condisnotNone177ifnothas\_cond:178cond=x

#

Get query, key and value vectors

181q=self.to\_q(x)182k=self.to\_k(cond)183v=self.to\_v(cond)

#

Use flash attention if it's available and the head size is less than or equal to 128

186ifCrossAttention.use\_flash\_attentionandself.flashisnotNoneandnothas\_condandself.d\_head\<=128:187returnself.flash\_attention(q,k,v)

#

Otherwise, fallback to normal attention

189else:190returnself.normal\_attention(q,k,v)

#

Flash Attention

  • q are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
  • k are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
  • v are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
192defflash\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):

#

Get batch size and number of elements along sequence axis (width * height )

202batch\_size,seq\_len,\_=q.shape

#

Stack q , k , v vectors for flash attention, to get a single tensor of shape [batch_size, seq_len, 3, n_heads * d_head]

206qkv=torch.stack((q,k,v),dim=2)

#

Split the heads

208qkv=qkv.view(batch\_size,seq\_len,3,self.n\_heads,self.d\_head)

#

Flash attention works for head sizes 32 , 64 and 128 , so we have to pad the heads to fit this size.

212ifself.d\_head\<=32:213pad=32-self.d\_head214elifself.d\_head\<=64:215pad=64-self.d\_head216elifself.d\_head\<=128:217pad=128-self.d\_head218else:219raiseValueError(f'Head size ${self.d\_head} too large for Flash Attention')

#

Pad the heads

222ifpad:223qkv=torch.cat((qkv,qkv.new\_zeros(batch\_size,seq\_len,3,self.n\_heads,pad)),dim=-1)

#

Compute attention seqsoftmax​(dkey​​QK⊤​)V This gives a tensor of shape [batch_size, seq_len, n_heads, d_padded]

228out,\_=self.flash(qkv)

#

Truncate the extra head size

230out=out[:,:,:,:self.d\_head]

#

Reshape to [batch_size, seq_len, n_heads * d_head]

232out=out.reshape(batch\_size,seq\_len,self.n\_heads\*self.d\_head)

#

Map to [batch_size, height * width, d_model] with a linear layer

235returnself.to\_out(out)

#

Normal Attention

  • q are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
  • k are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
  • v are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]
237defnormal\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):

#

Split them to heads of shape [batch_size, seq_len, n_heads, d_head]

247q=q.view(\*q.shape[:2],self.n\_heads,-1)248k=k.view(\*k.shape[:2],self.n\_heads,-1)249v=v.view(\*v.shape[:2],self.n\_heads,-1)

#

Calculate attention dkey​​QK⊤​

252attn=torch.einsum('bihd,bjhd-\>bhij',q,k)\*self.scale

#

Compute softmax seqsoftmax​(dkey​​QK⊤​)

256ifself.is\_inplace:257half=attn.shape[0]//2258attn[half:]=attn[half:].softmax(dim=-1)259attn[:half]=attn[:half].softmax(dim=-1)260else:261attn=attn.softmax(dim=-1)

#

Compute attention output seqsoftmax​(dkey​​QK⊤​)V

265out=torch.einsum('bhij,bjhd-\>bihd',attn,v)

#

Reshape to [batch_size, height * width, n_heads * d_head]

267out=out.reshape(\*out.shape[:2],-1)

#

Map to [batch_size, height * width, d_model] with a linear layer

269returnself.to\_out(out)

#

Feed-Forward Network

272classFeedForward(nn.Module):

#

  • d_model is the input embedding size
  • d_mult is multiplicative factor for the hidden layer size
277def\_\_init\_\_(self,d\_model:int,d\_mult:int=4):

#

282super().\_\_init\_\_()283self.net=nn.Sequential(284GeGLU(d\_model,d\_model\*d\_mult),285nn.Dropout(0.),286nn.Linear(d\_model\*d\_mult,d\_model)287)

#

289defforward(self,x:torch.Tensor):290returnself.net(x)

#

GeGLU Activation

GeGLU(x)=(xW+b)∗GELU(xV+c)

293classGeGLU(nn.Module):

#

300def\_\_init\_\_(self,d\_in:int,d\_out:int):301super().\_\_init\_\_()

#

Combined linear projections xW+b and xV+c

303self.proj=nn.Linear(d\_in,d\_out\*2)

#

305defforward(self,x:torch.Tensor):

#

Get xW+b and xV+c

307x,gate=self.proj(x).chunk(2,dim=-1)

#

GeGLU(x)=(xW+b)∗GELU(xV+c)

309returnx\*F.gelu(gate)

labml.ai