docs/diffusion/stable_diffusion/model/unet_attention.html
homediffusionstable_diffusionmodel
This implements the transformer module used in U-Net that gives ϵcond(xt,c)
We have kept to the model definition and naming unchanged from CompVis/stable-diffusion so that we can load the checkpoints directly.
19fromtypingimportOptional2021importtorch22importtorch.nn.functionalasF23fromtorchimportnn
26classSpatialTransformer(nn.Module):
channels is the number of channels in the feature mapn_heads is the number of attention headsn_layers is the number of transformer layersd_cond is the size of the conditional embedding31def\_\_init\_\_(self,channels:int,n\_heads:int,n\_layers:int,d\_cond:int):
38super().\_\_init\_\_()
Initial group normalization
40self.norm=torch.nn.GroupNorm(num\_groups=32,num\_channels=channels,eps=1e-6,affine=True)
Initial 1×1 convolution
42self.proj\_in=nn.Conv2d(channels,channels,kernel\_size=1,stride=1,padding=0)
Transformer layers
45self.transformer\_blocks=nn.ModuleList(46[BasicTransformerBlock(channels,n\_heads,channels//n\_heads,d\_cond=d\_cond)for\_inrange(n\_layers)]47)
Final 1×1 convolution
50self.proj\_out=nn.Conv2d(channels,channels,kernel\_size=1,stride=1,padding=0)
x is the feature map of shape [batch_size, channels, height, width]cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]52defforward(self,x:torch.Tensor,cond:torch.Tensor):
Get shape [batch_size, channels, height, width]
58b,c,h,w=x.shape
For residual connection
60x\_in=x
Normalize
62x=self.norm(x)
Initial 1×1 convolution
64x=self.proj\_in(x)
Transpose and reshape from [batch_size, channels, height, width] to [batch_size, height * width, channels]
67x=x.permute(0,2,3,1).view(b,h\*w,c)
Apply the transformer layers
69forblockinself.transformer\_blocks:70x=block(x,cond)
Reshape and transpose from [batch_size, height * width, channels] to [batch_size, channels, height, width]
73x=x.view(b,h,w,c).permute(0,3,1,2)
Final 1×1 convolution
75x=self.proj\_out(x)
Add residual
77returnx+x\_in
80classBasicTransformerBlock(nn.Module):
d_model is the input embedding sizen_heads is the number of attention headsd_head is the size of a attention headd_cond is the size of the conditional embeddings85def\_\_init\_\_(self,d\_model:int,n\_heads:int,d\_head:int,d\_cond:int):
92super().\_\_init\_\_()
Self-attention layer and pre-norm layer
94self.attn1=CrossAttention(d\_model,d\_model,n\_heads,d\_head)95self.norm1=nn.LayerNorm(d\_model)
Cross attention layer and pre-norm layer
97self.attn2=CrossAttention(d\_model,d\_cond,n\_heads,d\_head)98self.norm2=nn.LayerNorm(d\_model)
Feed-forward network and pre-norm layer
100self.ff=FeedForward(d\_model)101self.norm3=nn.LayerNorm(d\_model)
x are the input embeddings of shape [batch_size, height * width, d_model]cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]103defforward(self,x:torch.Tensor,cond:torch.Tensor):
Self attention
109x=self.attn1(self.norm1(x))+x
Cross-attention with conditioning
111x=self.attn2(self.norm2(x),cond=cond)+x
Feed-forward network
113x=self.ff(self.norm3(x))+x
115returnx
This falls-back to self-attention when conditional embeddings are not specified.
118classCrossAttention(nn.Module):
125use\_flash\_attention:bool=False
d_model is the input embedding sizen_heads is the number of attention headsd_head is the size of a attention headd_cond is the size of the conditional embeddingsis_inplace specifies whether to perform the attention softmax computation inplace to save memory127def\_\_init\_\_(self,d\_model:int,d\_cond:int,n\_heads:int,d\_head:int,is\_inplace:bool=True):
136super().\_\_init\_\_()137138self.is\_inplace=is\_inplace139self.n\_heads=n\_heads140self.d\_head=d\_head
Attention scaling factor
143self.scale=d\_head\*\*-0.5
Query, key and value mappings
146d\_attn=d\_head\*n\_heads147self.to\_q=nn.Linear(d\_model,d\_attn,bias=False)148self.to\_k=nn.Linear(d\_cond,d\_attn,bias=False)149self.to\_v=nn.Linear(d\_cond,d\_attn,bias=False)
Final linear layer
152self.to\_out=nn.Sequential(nn.Linear(d\_attn,d\_model))
Setup flash attention. Flash attention is only used if it's installed and CrossAttention.use_flash_attention is set to True .
157try:
You can install flash attention by cloning their Github repo, https://github.com/HazyResearch/flash-attention and then running python setup.py install
161fromflash\_attn.flash\_attentionimportFlashAttention162self.flash=FlashAttention()
Set the scale for scaled dot-product attention.
164self.flash.softmax\_scale=self.scale
Set to None if it's not installed
166exceptImportError:167self.flash=None
x are the input embeddings of shape [batch_size, height * width, d_model]cond is the conditional embeddings of shape [batch_size, n_cond, d_cond]169defforward(self,x:torch.Tensor,cond:Optional[torch.Tensor]=None):
If cond is None we perform self attention
176has\_cond=condisnotNone177ifnothas\_cond:178cond=x
Get query, key and value vectors
181q=self.to\_q(x)182k=self.to\_k(cond)183v=self.to\_v(cond)
Use flash attention if it's available and the head size is less than or equal to 128
186ifCrossAttention.use\_flash\_attentionandself.flashisnotNoneandnothas\_condandself.d\_head\<=128:187returnself.flash\_attention(q,k,v)
Otherwise, fallback to normal attention
189else:190returnself.normal\_attention(q,k,v)
q are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]k are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]v are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]192defflash\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):
Get batch size and number of elements along sequence axis (width * height )
202batch\_size,seq\_len,\_=q.shape
Stack q , k , v vectors for flash attention, to get a single tensor of shape [batch_size, seq_len, 3, n_heads * d_head]
206qkv=torch.stack((q,k,v),dim=2)
Split the heads
208qkv=qkv.view(batch\_size,seq\_len,3,self.n\_heads,self.d\_head)
Flash attention works for head sizes 32 , 64 and 128 , so we have to pad the heads to fit this size.
212ifself.d\_head\<=32:213pad=32-self.d\_head214elifself.d\_head\<=64:215pad=64-self.d\_head216elifself.d\_head\<=128:217pad=128-self.d\_head218else:219raiseValueError(f'Head size ${self.d\_head} too large for Flash Attention')
Pad the heads
222ifpad:223qkv=torch.cat((qkv,qkv.new\_zeros(batch\_size,seq\_len,3,self.n\_heads,pad)),dim=-1)
Compute attention seqsoftmax(dkeyQK⊤)V This gives a tensor of shape [batch_size, seq_len, n_heads, d_padded]
228out,\_=self.flash(qkv)
Truncate the extra head size
230out=out[:,:,:,:self.d\_head]
Reshape to [batch_size, seq_len, n_heads * d_head]
232out=out.reshape(batch\_size,seq\_len,self.n\_heads\*self.d\_head)
Map to [batch_size, height * width, d_model] with a linear layer
235returnself.to\_out(out)
q are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]k are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]v are the query vectors before splitting heads, of shape [batch_size, seq, d_attn]237defnormal\_attention(self,q:torch.Tensor,k:torch.Tensor,v:torch.Tensor):
Split them to heads of shape [batch_size, seq_len, n_heads, d_head]
247q=q.view(\*q.shape[:2],self.n\_heads,-1)248k=k.view(\*k.shape[:2],self.n\_heads,-1)249v=v.view(\*v.shape[:2],self.n\_heads,-1)
Calculate attention dkeyQK⊤
252attn=torch.einsum('bihd,bjhd-\>bhij',q,k)\*self.scale
Compute softmax seqsoftmax(dkeyQK⊤)
256ifself.is\_inplace:257half=attn.shape[0]//2258attn[half:]=attn[half:].softmax(dim=-1)259attn[:half]=attn[:half].softmax(dim=-1)260else:261attn=attn.softmax(dim=-1)
Compute attention output seqsoftmax(dkeyQK⊤)V
265out=torch.einsum('bhij,bjhd-\>bihd',attn,v)
Reshape to [batch_size, height * width, n_heads * d_head]
267out=out.reshape(\*out.shape[:2],-1)
Map to [batch_size, height * width, d_model] with a linear layer
269returnself.to\_out(out)
272classFeedForward(nn.Module):
d_model is the input embedding sized_mult is multiplicative factor for the hidden layer size277def\_\_init\_\_(self,d\_model:int,d\_mult:int=4):
282super().\_\_init\_\_()283self.net=nn.Sequential(284GeGLU(d\_model,d\_model\*d\_mult),285nn.Dropout(0.),286nn.Linear(d\_model\*d\_mult,d\_model)287)
289defforward(self,x:torch.Tensor):290returnself.net(x)
GeGLU(x)=(xW+b)∗GELU(xV+c)
293classGeGLU(nn.Module):
300def\_\_init\_\_(self,d\_in:int,d\_out:int):301super().\_\_init\_\_()
Combined linear projections xW+b and xV+c
303self.proj=nn.Linear(d\_in,d\_out\*2)
305defforward(self,x:torch.Tensor):
Get xW+b and xV+c
307x,gate=self.proj(x).chunk(2,dim=-1)
GeGLU(x)=(xW+b)∗GELU(xV+c)
309returnx\*F.gelu(gate)