Back to Annotated Deep Learning Paper Implementations

Flash Attention

docs/transformers/flash/index.html

latest29.9 KB
Original Source

hometransformersflash

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/flash/ init.py)

#

Flash Attention

Flash attention speeds up transformer attention mechanism by reducing the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.

It's introduced in paper FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness and further optimized in paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Official CUDA implementation can be found at Dao-AILab/flash-attention.

Our implementation is based on the Triton's example implementation.

Note: You can click on the mathematical symbols or identifiers to highlight them.

You can run test.py to see correctness and measure performance of this implementation.

Forward pass

Here's the attention forward pass. The formulas represent a single attention head. Qi​ is query vector (row vector) at position i and Kj​ and Vj​ are the key and value row vectors at position j. Oi​ is the output vector at position i.

Sij​Li​Pij​Oi​​=σQi​Kj​T=j∑​eSij​=Li​eSij​​=j∑​Pij​Vj​=Li​1​j∑​eSij​Vj​​

Sij​ is the attention score matrix before softmax, Li​ is the softmax denominator, and Pij​ is the attention matrix after softmax.

Flash Attention Optimization

You can compute Oi​, instead of doing the full softmax, by computing the sum of exponents li​ and the unnormalized output O~i​ while iterating over keys:

Sij​li​Oi​​=σQi​Kj​T←li​+eSij​←Oi​+eSij​oj​​

Finally you can compute,

Oi​=li​O~i​​

To make it numerically stable flash attention subtracts the current max of Sij​ before exponentiating.

So it maintains the following while iterating over keys:

  • mi​, the max Sij​
  • li​, the sum of exponents ∑j​eSij​−mi​, and
  • O~i​, the unnormalized output

For each block of keys j1​…j2​ it updates them:

mi​newPij​li​Oi​mi​​=max(mi​,j=j1maxj2​Sij​)=exp(Sij​−mi​new)←emi​−minew​li​+j=j1∑j2​Pij​←emi​−minew​Oi​+P~ij​∗Vj​←minew​​

Then finally,

Oi​=li​O~i​​

This reduces the memory usage since we don't have to compute full Sij​ matrix or Pij​ matrix. It also speeds up since we don't have to load these large matrices. Instead it only loads blocks of K and V as it iterates over them.

Backward pass

Here's the standard backward pass. dOi​ is the gradient vector on the output Oi​

dVj​dPij​dSij​dQi​dKj​​=i∑​Pij​dOi​=dOi​Vj​T=dsoftmax(dPij​)=k∑​Pik​(δjk​−Pij​)dPik​=Pij​dPij​−Pij​∑Pik​dPik​=σj∑​dSij​Kj​=σi∑​dSij​Qi​​

where δjk​ is 1 when j=k and 0 otherwise.

Flash attention paper introduces Di​ to simplify dS computation.

Di​​=k∑​Pik​dPik​=k∑​Pik​dOi​VkT​=dOi​k∑​Pik​VkT​=dOi​Oi​T​

Then,

dSij​=Pij​dPij​−Di​Pij​​

Flash attention saves Li​ from the forward pass since it doesn't take much memory. So during the backward pass it doesn't have to keep computing li​ or mi​.

It first computes Di​. Then it iterates over the queries and compute (accumulate) dKj​ and dVj​. Finally it iterates over the keys and compute (accumulate) dQi​.

In both forward and backward pass we calculate logarithms and exponentials of 2 instead of e for performance.

148fromtypingimportAny,Tuple149150importtorch151importtriton152importtriton.languageastl153154HI\_PRES\_TL:tl.constexpr=tl.float32155HI\_PRES\_TORCH:torch.dtype=torch.float32

#

158classAttentionFunc(torch.autograd.Function):

#

Forward pass

Group query attention forward pass. Returns the output in shape [batch_size, n_heads, q_seq_len, d_head] .

  • ctx is the context for torch gradient descent
  • q has shape [batch_size, n_heads, q_seq_len, d_head]
  • q has shape [batch_size, n_heads, q_seq_len, d_head]
  • k has shape [batch_size, k_heads, kv_seq_len, d_head]
  • v has shape [batch_size, k_heads, kv_seq_len, d_head]
  • causal whether to apply causal attention mask
  • sm_scale softmax scale factor σ
159@staticmethod160defforward(ctx:Any,161q:torch.Tensor,k:torch.Tensor,v:torch.Tensor,162causal:bool,sm\_scale:float)-\>torch.Tensor:

#

176batch\_size,n\_heads,q\_seq\_len,d\_head=q.shape177\_,k\_heads,kv\_seq\_len,\_=k.shape178assertn\_heads%k\_heads==0179n\_groups=n\_heads//k\_heads

#

Shape constraints

182assertd\_head==k.shape[-1]==v.shape[-1]183assertd\_headin{16,32,64,128,256}

#

Change the tensors combining the heads with the batch dimension

186q=q.view(batch\_size\*k\_heads,n\_groups,q\_seq\_len,d\_head)187k=k.view(batch\_size\*k\_heads,kv\_seq\_len,d\_head)188v=v.view(batch\_size\*k\_heads,kv\_seq\_len,d\_head)

#

Make sure the tensors are contiguous and the strides are same

191assertq.is\_contiguous()192assertk.is\_contiguous()193assertv.is\_contiguous()194assertk.stride()==v.stride()

#

Tensor for the output

197o=torch.empty\_like(q)

#

Tensor for log of sum of exponentials log2​Li​=log2​∑j​eSij​

199lse=torch.empty((batch\_size\*k\_heads,n\_groups,q\_seq\_len),device=q.device,dtype=HI\_PRES\_TORCH)

#

The forward computation will be parallelized along the batch dimension and the queries in blocks of size BLOCK_Q

202grid=lambdameta:(triton.cdiv(q\_seq\_len,meta["BLOCK\_Q"]),batch\_size\*k\_heads\*n\_groups,1)203\_attn\_fwd[grid](204q,k,v,sm\_scale\*1.4426950408889634,lse,o,205n\_groups=n\_groups,206q\_seq\_len=q\_seq\_len,207kv\_seq\_len=kv\_seq\_len,208d\_head=d\_head,209is\_causal=causal,210)

#

Save the reshaped inputs and outputs for the backward pass

213ctx.save\_for\_backward(q,k,v,o,lse)214ctx.sm\_scale=sm\_scale215ctx.n\_groups=n\_groups216ctx.causal=causal

#

Return the output in shape [batch_size, n_heads, q_seq_len, d_head]

219returno.view(batch\_size,n\_heads,q\_seq\_len,d\_head)

#

Backward pass

The backward pass computes the gradients of the input tensors.

  • ctx is the context for torch gradient descent
  • do is the gradient tensor of the attention output with shape [batch_size, n_heads, q_seq_len, d_head]
221@staticmethod222defbackward(ctx:Any,do:torch.Tensor)-\>Tuple[torch.Tensor,torch.Tensor,torch.Tensor,None,None]:

#

Get saved tensors and attributes

233n\_groups=ctx.n\_groups234sm\_scale=ctx.sm\_scale235causal=ctx.causal236q,k,v,o,lse=ctx.saved\_tensors

#

Get shapes

239batch\_size,n\_heads,q\_seq\_len,d\_head=do.shape240\_,kv\_seq\_len,\_=k.shape241k\_heads=n\_heads//n\_groups

#

Combine the heads with the batch dimension of the output gradients tensor

244do=do.view(batch\_size\*k\_heads,n\_groups,q\_seq\_len,d\_head)

#

Make sure it's contiguous and the strides are the same

247assertdo.is\_contiguous()248assertk.stride()==v.stride()249assertq.stride()==o.stride()==do.stride()

#

Create tensors for input gradients

252dq=torch.empty\_like(q)253dk=torch.empty\_like(k)254dv=torch.empty\_like(v)

#

Precompute σ(log2​e)Kj​

257k\_scaled=k\*(sm\_scale\*1.4426950408889634)

#

Di​=Pi:T​dPi:​=doiT​oi​

259pdp=torch.empty\_like(lse)

#

We use fixed BLOCK_Q for backward pass on D

#

Compute Di​

This is parallelized along the batch and query in blocks of size BLOCK_Q

265BLOCK\_Q=16266pre\_grid=(triton.cdiv(q\_seq\_len,BLOCK\_Q),batch\_size\*k\_heads)267\_attn\_bwd\_d[pre\_grid](268o,do,269pdp,270BLOCK\_Q=16,271d\_head=d\_head,272q\_seq\_len=q\_seq\_len,273n\_groups=n\_groups,274num\_stages=1,275)

#

Compute dK and dV

This is parallelized along the batch and keys in blocks of size BLOCK_K

280grid=lambdameta:(triton.cdiv(kv\_seq\_len,meta['BLOCK\_K']),batch\_size\*k\_heads)281\_attn\_bwd\_dkdv[grid](282q,k\_scaled,v,sm\_scale,do,dk,dv,283lse,pdp,284q\_seq\_len,kv\_seq\_len,n\_groups,d\_head,285is\_causal=causal,286287)

#

Compute dQ

This is parallelized along the batch and queries in blocks of size BLOCK_Q

292grid=lambdameta:(triton.cdiv(q\_seq\_len,meta['BLOCK\_Q']),batch\_size\*k\_heads\*n\_groups)293\_attn\_bwd\_dq[grid](294q,k\_scaled,v,do,295dq,296lse,pdp,297q\_seq\_len,kv\_seq\_len,n\_groups,d\_head,298is\_causal=causal,299)

#

Split the combined batch and heads

302dq=dq.view(batch\_size,n\_heads,q\_seq\_len,d\_head)303dk=dk.view(batch\_size,k\_heads,kv\_seq\_len,d\_head)304dv=dv.view(batch\_size,k\_heads,kv\_seq\_len,d\_head)

#

307returndq,dk,dv,None,None308309310attention=AttentionFunc.apply

#

Configs for auto-tuning

313def\_get\_autotune\_configs(inner\_loop:str)-\>list:

#

318configs=[]

#

Possible options for BLOCK_Q

321forbqin[64,128,256]:

#

Possible options for BLOCK_K

323forbkin[64,128,256]:

#

If the inner loop is along keys the BLOCK_Q must be a multiple of BLOCK_K for causal masking

325ifinner\_loop=='key'andbq%bk!=0:326continue

#

Similarly when the inner loop is along queries

328ifinner\_loop=='query'andbk%bq!=0:329continue

#

Number of stages and warps

332forsin[2,3,4]:333forwin[4,8]:334ifbq\*bk\<128\*128andw==8:335continue336337configs.append(triton.Config({'BLOCK\_Q':bq,'BLOCK\_K':bk},num\_stages=s,num\_warps=w))

#

Use return configs to autotune. Trying all combinations is slow for testing.

340returnconfigs[:1]

#

Triton kernel for Flash attention forward pass

  • t_q queries Qi​
  • t_k keys Kj​
  • t_v values Vj​
  • sm_scale_log2eσlog2​e softmax scale multiplied by log2​e
  • t_lselog2​∑j​eSij​ (out)
  • t_oOi​ output
  • n_groups number of groups in GQA
  • q_seq_len query sequence length
  • kv_seq_len key/value sequence length
  • d_head number of dimensions in a head
  • BLOCK_Q block size for query sequence length
  • BLOCK_K block size for key sequence length
  • is_causal whether causal attention

Strides z , h , m and d denote the stride of the corresponding dimensions (batch_size , n_heads , q_seq_len , d_head ) in the query. Stride n denote the stride on kv_seq_len of key.

[email protected](\_get\_autotune\_configs(inner\_loop='key'),344key=["q\_seq\_len","kv\_seq\_len","d\_head","n\_groups","is\_causal"])[email protected]\_attn\_fwd(t\_q,t\_k,t\_v,sm\_scale\_log2e,t\_lse,t\_o,347n\_groups:tl.constexpr,348q\_seq\_len:tl.constexpr,349kv\_seq\_len:tl.constexpr,350d\_head:tl.constexpr,351is\_causal:tl.constexpr,352BLOCK\_Q:tl.constexpr,353BLOCK\_K:tl.constexpr,354):

#

We are computing the attention for Oi​ for i ... `i + BLOCK_Q' in batch/head combination z.

378i=tl.program\_id(0)379z=tl.program\_id(1)//n\_groups380g=tl.program\_id(1)%n\_groups

#

Create block pointers

383p\_q=tl.make\_block\_ptr(t\_q+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,384(q\_seq\_len,d\_head),385(d\_head,1),386(i\*BLOCK\_Q,0),387(BLOCK\_Q,d\_head),388(1,0))389p\_v=tl.make\_block\_ptr(t\_v+z\*kv\_seq\_len\*d\_head,390(kv\_seq\_len,d\_head),391(d\_head,1),392(0,0),393(BLOCK\_K,d\_head),394(1,0))395p\_kT=tl.make\_block\_ptr(t\_k+z\*kv\_seq\_len\*d\_head,396(d\_head,kv\_seq\_len),397(1,d\_head),398(0,0),399(d\_head,BLOCK\_K),400(0,1))401p\_o=tl.make\_block\_ptr(t\_o+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,402(q\_seq\_len,d\_head),403(d\_head,1),404(i\*BLOCK\_Q,0),405(BLOCK\_Q,d\_head),406(1,0))407p\_lse=tl.make\_block\_ptr(t\_lse+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,408(q\_seq\_len,),409(1,),410(i\*BLOCK\_Q,),411(BLOCK\_Q,),412(0,))

#

Initialize offsets

415offs\_i=i\*BLOCK\_Q+tl.arange(0,BLOCK\_Q)416offs\_j=tl.arange(0,BLOCK\_K)

#

Mask for Q for the last block

419i\_mask=offs\_i\<q\_seq\_len

#

Initialize mi​ and li​. mi​ is initialized to −inf and li​ to 1. So in the first update, the effect of initial li​ is emi​−minew​li​=0.

b_m will be storing mi​log2​e

425b\_m=tl.where(i\_mask,-float("inf"),0.0)426b\_l=tl.where(i\_mask,1.0,0.0)

#

Oi​

429b\_o=tl.zeros([BLOCK\_Q,d\_head],dtype=HI\_PRES\_TL)

#

Load Qi​ outside the loop since it will be reused through out the loop over Kj​.

432b\_q=tl.load(p\_q,boundary\_check=(0,),padding\_option="zero")433434ifis\_causal:

#

Inner loop upto the diagonal block

436b\_o,b\_l,b\_m=\_attn\_fwd\_inner(b\_o,b\_l,b\_m,b\_q,437p\_kT,p\_v,438sm\_scale\_log2e,439BLOCK\_Q,d\_head,BLOCK\_K,440offs\_i,offs\_j,441j=tl.full([],0,tl.int32),# type: ignore442steps=(i\*BLOCK\_Q)//BLOCK\_K,443MASK=False,444q\_seq\_len=q\_seq\_len,445kv\_seq\_len=kv\_seq\_len446)

#

Diagonal block with masking within it

448b\_o,b\_l,b\_m=\_attn\_fwd\_inner(b\_o,b\_l,b\_m,b\_q,p\_kT,p\_v,449sm\_scale\_log2e,450BLOCK\_Q,d\_head,BLOCK\_K,451offs\_i,offs\_j,452j=i\*BLOCK\_Q,453steps=BLOCK\_Q//BLOCK\_K,454MASK=True,455q\_seq\_len=q\_seq\_len,456kv\_seq\_len=kv\_seq\_len457)458else:

#

Iterate through all Kj​

460b\_o,b\_l,b\_m=\_attn\_fwd\_inner(b\_o,b\_l,b\_m,b\_q,p\_kT,p\_v,461sm\_scale\_log2e,462BLOCK\_Q,d\_head,BLOCK\_K,463offs\_i,offs\_j,464j=tl.full([],0,tl.int32),# type: ignore465steps=tl.cdiv(kv\_seq\_len,BLOCK\_K),466MASK=False,467q\_seq\_len=q\_seq\_len,468kv\_seq\_len=kv\_seq\_len469)

#

Store LSE log2​Li​=log2​(li​∗emi​)=log2​li​+mi​log2

472tl.store(p\_lse,b\_m+tl.math.log2(b\_l),boundary\_check=(0,))

#

Store Oi​=li​O~i​​

474tl.store(p\_o,(b\_o/b\_l[:,None]).to(t\_o.type.element\_ty),boundary\_check=(0,))

#

Inner loop to calculate Oi​

This iterates through keys and values starting from j for steps number of steps. In each step it processes BLOCK_K entries of keys/values.

[email protected]\_attn\_fwd\_inner(b\_o,b\_l,b\_m,b\_q,479p\_kT,p\_v,480sm\_scale\_log2e,481BLOCK\_Q:tl.constexpr,482d\_head:tl.constexpr,483BLOCK\_K:tl.constexpr,484offs\_i,offs\_j,485j,486steps,487MASK:tl.constexpr,488q\_seq\_len:tl.constexpr,489kv\_seq\_len:tl.constexpr490):

#

497tl.static\_assert(BLOCK\_Q%BLOCK\_K==0)

#

Move Kj​ and Vj​ pointers

500p\_kT=tl.advance(p\_kT,(0,j))501p\_v=tl.advance(p\_v,(j,0))

#

Iterate over K, V and update O~i​ and li​

504for\_inrange(steps):

#

Load Kj​T

506b\_kT=tl.load(p\_kT,boundary\_check=(1,),padding\_option="zero")

#

Compute (log2)Si​j=(log2)σQi​Kj​T

508b\_s=tl.dot(b\_q,b\_kT,out\_dtype=HI\_PRES\_TL)509b\_s=b\_s\*sm\_scale\_log2e

#

Apply causal mask

512ifMASK:513causal\_mask=offs\_i[:,None]\>=(j+offs\_j[None,:])514b\_s=tl.where(causal\_mask,b\_s,-float("inf"))

#

Mask out if the block is beyond the end of Kj​

517j\_mask=(j+offs\_j)\<kv\_seq\_len518b\_s=tl.where(j\_mask[None,:],b\_s,-float("inf"))

#

(log2​e)minew​=max((log2​e)mi​,maxj=j1j2​(log2​e)Sij​)

521b\_m\_new=tl.maximum(b\_m,tl.max(b\_s,-1))

# P~ij​​=e(Sij​−mi​new=2(log2​e)Sij​−(log2​e)mi​new​

527b\_p=tl.math.exp2(b\_s-b\_m\_new[:,None])

#

∑j=j1j2​P~ij​

530b\_l\_new=tl.sum(b\_p,-1)

#

emi​−minew​

532b\_m\_m\_new=tl.math.exp2(b\_m-b\_m\_new)

#

li​←emi​−minew​li​+∑j=j1j2​P~ij​

534b\_l=b\_l\*b\_m\_m\_new+b\_l\_new

#

Oi​←emi​−minew​Oi​+P~ij​Vj​

537b\_o=b\_o\*b\_m\_m\_new[:,None]538b\_p=b\_p.to(b\_q.dtype)# TODO539b\_v=tl.load(p\_v,boundary\_check=(0,),padding\_option="zero")540b\_o+=tl.dot(b\_p,b\_v,out\_dtype=HI\_PRES\_TL)

#

(log2​e)mi​←(log2​e)minew​

543b\_m=b\_m\_new

#

Move pointers

546j+=BLOCK\_K547p\_v=tl.advance(p\_v,(BLOCK\_K,0))548p\_kT=tl.advance(p\_kT,(0,BLOCK\_K))549550tl.static\_assert(b\_o.dtype==HI\_PRES\_TL,"attn\_fwd\_inner requires accumulator to be in HI\_PRES\_TL precision")551552returnb\_o,b\_l,b\_m

#

Triton kernel to compute Di​

[email protected]\_attn\_bwd\_d(t\_o,t\_do,557t\_pdp,558BLOCK\_Q:tl.constexpr,d\_head:tl.constexpr,559q\_seq\_len:tl.constexpr,560n\_groups:tl.constexpr,561):

#

565i=tl.program\_id(0)\*BLOCK\_Q566z=tl.program\_id(1)

#

Create block pointers

569p\_o=tl.make\_block\_ptr(t\_o+z\*n\_groups\*q\_seq\_len\*d\_head,570(n\_groups,q\_seq\_len,d\_head),571(q\_seq\_len\*d\_head,d\_head,1),572(0,i,0),573(n\_groups,BLOCK\_Q,d\_head),574(2,1,0))575p\_do=tl.make\_block\_ptr(t\_do+z\*n\_groups\*q\_seq\_len\*d\_head,576(n\_groups,q\_seq\_len,d\_head),577(q\_seq\_len\*d\_head,d\_head,1),578(0,i,0),579(n\_groups,BLOCK\_Q,d\_head),580(2,1,0))581p\_pdp=tl.make\_block\_ptr(t\_pdp+z\*n\_groups\*q\_seq\_len,582(n\_groups,q\_seq\_len),583(q\_seq\_len,1),584(0,i),585(n\_groups,BLOCK\_Q),586(1,0))

#

Load Oi​

589o=tl.load(p\_o,boundary\_check=(1,),padding\_option="zero")

#

Load dOi​

591do=tl.load(p\_do,boundary\_check=(1,),padding\_option="zero").to(HI\_PRES\_TL)

#

Calculate Di​=dOi​Oi​T

593d=tl.sum(o\*do,axis=-1)

#

Save Di​

595tl.store(p\_pdp,d,boundary\_check=(1,))

#

Triton kernel to compute dKj​ and dVj​

[email protected](\_get\_autotune\_configs(inner\_loop='query'),599key=["q\_seq\_len","kv\_seq\_len","d\_head","n\_groups","is\_causal"])[email protected]\_attn\_bwd\_dkdv(t\_q,t\_k,t\_v,sm\_scale,602t\_do,603t\_dk,t\_dv,604t\_lse,t\_pdp,605q\_seq\_len:tl.constexpr,kv\_seq\_len:tl.constexpr,606n\_groups:tl.constexpr,d\_head:tl.constexpr,607is\_causal:tl.constexpr,608BLOCK\_Q:tl.constexpr,609BLOCK\_K:tl.constexpr,610):

#

Compute dKj​ and dVj​ for j ... j + BLOCK_K by iterating over Qi​

616j=tl.program\_id(0)\*BLOCK\_K617z=tl.program\_id(1)

#

Create block pointers

620p\_k=tl.make\_block\_ptr(t\_k+z\*kv\_seq\_len\*d\_head,621(kv\_seq\_len,d\_head),622(d\_head,1),623(j,0),624(BLOCK\_K,d\_head),625(1,0))626p\_v=tl.make\_block\_ptr(t\_v+z\*kv\_seq\_len\*d\_head,627(kv\_seq\_len,d\_head),628(d\_head,1),629(j,0),630(BLOCK\_K,d\_head),631(1,0))632p\_dk=tl.make\_block\_ptr(t\_dk+z\*kv\_seq\_len\*d\_head,633(kv\_seq\_len,d\_head),634(d\_head,1),635(j,0),636(BLOCK\_K,d\_head),637(1,0))638p\_dv=tl.make\_block\_ptr(t\_dv+z\*kv\_seq\_len\*d\_head,639(kv\_seq\_len,d\_head),640(d\_head,1),641(j,0),642(BLOCK\_K,d\_head),643(1,0))

#

Initialize σ1​dK and dV

646b\_dk=tl.zeros([BLOCK\_K,d\_head],dtype=HI\_PRES\_TL)647b\_dv=tl.zeros([BLOCK\_K,d\_head],dtype=HI\_PRES\_TL)

#

Load log2σ​K and V outside the loop.

650b\_k=tl.load(p\_k,boundary\_check=(0,),padding\_option="zero")651b\_v=tl.load(p\_v,boundary\_check=(0,),padding\_option="zero")

#

Iterate through queries in GQA

654forginrange(n\_groups):

#

Create block pointers

656p\_qT=tl.make\_block\_ptr(t\_q+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,657(d\_head,q\_seq\_len),658(1,d\_head),659(0,0),660(d\_head,BLOCK\_Q),661(0,1))662663p\_do=tl.make\_block\_ptr(t\_do+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,664(q\_seq\_len,d\_head),665(d\_head,1),666(0,0),667(BLOCK\_Q,d\_head),668(1,0))669p\_lse=tl.make\_block\_ptr(t\_lse+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,670(q\_seq\_len,),671(1,),672(0,),673(BLOCK\_Q,),674(0,))675p\_pdp=tl.make\_block\_ptr(t\_pdp+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,676(q\_seq\_len,),677(1,),678(0,),679(BLOCK\_Q,),680(0,))681682ifis\_causal:

#

Inner loop at the diagonal block

684b\_dk,b\_dv=\_attn\_bwd\_dkdv\_inner(685b\_dk,b\_dv,686p\_qT,b\_k,b\_v,p\_do,687p\_lse,p\_pdp,688BLOCK\_Q,BLOCK\_K,689d\_head,690j=j,i=j,691steps=BLOCK\_K//BLOCK\_Q,692MASK=True,693q\_seq\_len=q\_seq\_len,694kv\_seq\_len=kv\_seq\_len,695)

#

Inner loop on queries after the diagonal

698b\_dk,b\_dv=\_attn\_bwd\_dkdv\_inner(699b\_dk,b\_dv,700p\_qT,b\_k,b\_v,p\_do,701p\_lse,p\_pdp,702BLOCK\_Q,BLOCK\_K,703d\_head,704j=j,i=j+BLOCK\_K,705steps=tl.cdiv((q\_seq\_len-(j+BLOCK\_K)),BLOCK\_Q),706MASK=False,707q\_seq\_len=q\_seq\_len,708kv\_seq\_len=kv\_seq\_len709)710else:

#

Iterate through all queries

712b\_dk,b\_dv=\_attn\_bwd\_dkdv\_inner(713b\_dk,b\_dv,714p\_qT,b\_k,b\_v,p\_do,715p\_lse,p\_pdp,716BLOCK\_Q,BLOCK\_K,717d\_head,718j=j,i=tl.full([],0,tl.int32),719steps=tl.cdiv(q\_seq\_len,BLOCK\_Q),720MASK=False,721q\_seq\_len=q\_seq\_len,722kv\_seq\_len=kv\_seq\_len723)

#

Save dV

726tl.store(p\_dv,b\_dv.to(t\_dv.type.element\_ty),boundary\_check=(0,))

#

b_dk had σ1​dK

729b\_dk\*=sm\_scale

#

Save dK

732tl.store(p\_dk,b\_dk.to(t\_dk.type.element\_ty),boundary\_check=(0,))

#

Inner loop to calculate dKj​, dVj​

[email protected]\_attn\_bwd\_dkdv\_inner(b\_dk,b\_dv,737p\_qT,b\_k,b\_v,p\_do,738p\_lse,p\_pdp,739BLOCK\_Q:tl.constexpr,BLOCK\_K:tl.constexpr,740d\_head:tl.constexpr,741j,i,steps,742MASK:tl.constexpr,743q\_seq\_len:tl.constexpr,744kv\_seq\_len:tl.constexpr):

#

To apply the mask

750tl.static\_assert(BLOCK\_K%BLOCK\_Q==0)

#

Offsets and mask

753offs\_i=i+tl.arange(0,BLOCK\_Q)754offs\_j=j+tl.arange(0,BLOCK\_K)

#

Move the pointers

757p\_qT=tl.advance(p\_qT,(0,i))758p\_do=tl.advance(p\_do,(i,0))759p\_lse=tl.advance(p\_lse,(i,))760p\_pdp=tl.advance(p\_pdp,(i,))

#

Iterate over Q

763for\_inrange(steps):

#

Load Qi​T

765b\_qT=tl.load(p\_qT,boundary\_check=(1,),padding\_option="zero")

#

log2​Li​

768b\_l=tl.load(p\_lse,boundary\_check=(0,),padding\_option="zero")

#

(log2​e)Sij​T=σ(log2​e)Kj​Qi​T

771b\_sT=tl.dot(b\_k,b\_qT,out\_dtype=HI\_PRES\_TL)

# Pij​​=Li​eSij​​=2log2​Li​2(log2​e)Sij​​=2(log2​e)Sij​−log2​Li​​

780b\_pT=tl.math.exp2(b\_sT-b\_l[None,:])

#

Autoregressive masking

783ifMASK:784mask=(offs\_i[None,:]\>=offs\_j[:,None])785b\_pT=tl.where(mask,b\_pT,0.0)

#

Mask out if the block is beyond the end of Qi​

Note: No need to mask out based on j because the effects on positions outside boundary will not get stored in dK or dV Masking by i may also not be necessary size the tensors have 0 on loading

792i\_mask=offs\_i\<q\_seq\_len793b\_pT=tl.where(i\_mask[None,:],b\_pT,0.0)

#

dVj​=∑i​Pij​dOi​

796b\_do=tl.load(p\_do,boundary\_check=(0,),padding\_option="zero")797b\_dv+=tl.dot(b\_pT.to(b\_do.dtype),b\_do,out\_dtype=HI\_PRES\_TL)

#

Di​

800b\_pdp=tl.load(p\_pdp,boundary\_check=(0,),padding\_option="zero")

#

dPij​=Vj​dOi​T

802b\_dpT=tl.dot(b\_v,tl.trans(b\_do),out\_dtype=HI\_PRES\_TL).to(HI\_PRES\_TL)

#

dSij​=Pij​(dPij​−Di​)

804b\_dsT=b\_pT\*(b\_dpT-b\_pdp[None,:])

#

σ1​dKj​=∑i​dSij​Qi​

806b\_dk+=tl.dot(b\_dsT.to(b\_qT.dtype),tl.trans(b\_qT),out\_dtype=HI\_PRES\_TL)

#

Increment pointers.

809offs\_i+=BLOCK\_Q810p\_lse=tl.advance(p\_lse,(BLOCK\_Q,))811p\_pdp=tl.advance(p\_pdp,(BLOCK\_Q,))812p\_qT=tl.advance(p\_qT,(0,BLOCK\_Q))813p\_do=tl.advance(p\_do,(BLOCK\_Q,0))

#

Return accumulated dK and dV

816returnb\_dk,b\_dv

#

Triton kernel to compute dQi​

[email protected](\_get\_autotune\_configs(inner\_loop='key'),820key=["q\_seq\_len","kv\_seq\_len","d\_head","n\_groups","is\_causal"])[email protected]\_attn\_bwd\_dq(t\_q,t\_k,t\_v,t\_do,823t\_dq,824t\_lse,t\_pdp,825q\_seq\_len:tl.constexpr,kv\_seq\_len:tl.constexpr,826n\_groups:tl.constexpr,d\_head:tl.constexpr,827is\_causal:tl.constexpr,828BLOCK\_Q:tl.constexpr,829BLOCK\_K:tl.constexpr,830):

#

835i=tl.program\_id(0)\*BLOCK\_Q836z=tl.program\_id(1)//n\_groups837g=tl.program\_id(1)%n\_groups# TODO

#

Create block pointers

840p\_q=tl.make\_block\_ptr(t\_q+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,841(q\_seq\_len,d\_head),842(d\_head,1),843(i,0),844(BLOCK\_Q,d\_head),845(1,0))846p\_dq=tl.make\_block\_ptr(t\_dq+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,847(q\_seq\_len,d\_head),848(d\_head,1),849(i,0),850(BLOCK\_Q,d\_head),851(1,0))852p\_do=tl.make\_block\_ptr(t\_do+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,853(q\_seq\_len,d\_head),854(d\_head,1),855(i,0),856(BLOCK\_Q,d\_head),857(1,0))858p\_kT=tl.make\_block\_ptr(t\_k+z\*kv\_seq\_len\*d\_head,859(d\_head,kv\_seq\_len),860(1,d\_head),861(0,0),862(d\_head,BLOCK\_K),863(0,1))864p\_vT=tl.make\_block\_ptr(t\_v+z\*kv\_seq\_len\*d\_head,865(d\_head,kv\_seq\_len),866(1,d\_head),867(0,0),868(d\_head,BLOCK\_K),869(0,1))870p\_lse=tl.make\_block\_ptr(t\_lse+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,871(q\_seq\_len,),872(1,),873(i,),874(BLOCK\_Q,),875(0,))876p\_pdp=tl.make\_block\_ptr(t\_pdp+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,877(q\_seq\_len,),878(1,),879(i,),880(BLOCK\_Q,),881(0,))

#

Load Qi​, dOi​, Di​, and log2​Li​ outside the loop

884b\_q=tl.load(p\_q,boundary\_check=(0,),padding\_option="zero")885b\_do=tl.load(p\_do,boundary\_check=(0,),padding\_option="zero")886b\_pdp=tl.load(p\_pdp,boundary\_check=(0,),padding\_option="zero")887b\_lse=tl.load(p\_lse,boundary\_check=(0,),padding\_option="zero")

#

Initialize (log2​e)dQ

890b\_dq=tl.zeros([BLOCK\_Q,d\_head],dtype=HI\_PRES\_TL)

#

dqi​=j∑​dSij​kj​=j∑​Pij​(dPij​−Di​)kj​

894ifis\_causal:

#

Compute dQ for masked (diagonal) blocks.

896b\_dq=\_attn\_bwd\_dq\_inner(b\_dq,b\_q,p\_kT,p\_vT,897b\_do,b\_lse,b\_pdp,898BLOCK\_Q,BLOCK\_K,899i=i,j=i,900steps=BLOCK\_Q//BLOCK\_K,901MASK=True,902q\_seq\_len=q\_seq\_len,903kv\_seq\_len=kv\_seq\_len904)

#

Compute for other blocks

907b\_dq=\_attn\_bwd\_dq\_inner(b\_dq,b\_q,p\_kT,p\_vT,908b\_do,b\_lse,b\_pdp,909BLOCK\_Q,BLOCK\_K,910i=i,j=tl.full([],0,tl.int32),# type: ignore911steps=i//BLOCK\_K,912MASK=False,913q\_seq\_len=q\_seq\_len,914kv\_seq\_len=kv\_seq\_len915)916else:

#

Iterate through all K

918b\_dq=\_attn\_bwd\_dq\_inner(b\_dq,b\_q,p\_kT,p\_vT,919b\_do,b\_lse,b\_pdp,920BLOCK\_Q,BLOCK\_K,921i=i,j=tl.full([],0,tl.int32),# type: ignore922steps=tl.cdiv(kv\_seq\_len,BLOCK\_K),923MASK=False,924q\_seq\_len=q\_seq\_len,925kv\_seq\_len=kv\_seq\_len926)

#

b_dq stores (log2​e)dQ so multiply by loge​2 to get dQ

929b\_dq\*=0.6931471824645996

#

Save dQ

932tl.store(p\_dq,b\_dq.to(t\_dq.type.element\_ty),boundary\_check=(0,))

#

Inner loop to calculate dQi​

[email protected]\_attn\_bwd\_dq\_inner(b\_dq,b\_q,p\_kT,p\_vT,937b\_do,b\_lse,b\_pdp,938BLOCK\_Q:tl.constexpr,BLOCK\_K:tl.constexpr,939i,j,steps,940MASK:tl.constexpr,941q\_seq\_len:tl.constexpr,942kv\_seq\_len:tl.constexpr):

#

Offsets

948offs\_i=i+tl.arange(0,BLOCK\_Q)949offs\_j=j+tl.arange(0,BLOCK\_K)

#

Move the pointers

952p\_kT=tl.advance(p\_kT,(0,j))953p\_vT=tl.advance(p\_vT,(0,j))954955tl.static\_assert(BLOCK\_Q%BLOCK\_K==0,'BLOCK\_Q must be divisible by BLOCK\_K')

#

Iterate over K

958for\_inrange(steps):

#

Load Kj​T

960b\_kT=tl.load(p\_kT,boundary\_check=(1,),padding\_option="zero")

#

Load Vj​T

962b\_vT=tl.load(p\_vT,boundary\_check=(1,),padding\_option="zero")

#

(log2​e)Sij​=σ(log2​e)Qi​Kj​T

965b\_s=tl.dot(b\_q,b\_kT,out\_dtype=HI\_PRES\_TL)

# Pij​​=Li​eSij​​=2log2​Li​2(log2​e)Sij​​=2(log2​e)Sij​−log2​Li​​

974b\_p=tl.math.exp2(b\_s-b\_lse[:,None])

#

Autoregressive masking

977ifMASK:978causal\_mask=(offs\_i[:,None]\>=offs\_j[None,:])979b\_p=tl.where(causal\_mask,b\_p,0.0)

#

Mask out if the block is beyond the end of Qi​

982j\_mask=offs\_j\<kv\_seq\_len983b\_p=tl.where(j\_mask[None,:],b\_p,0.0)

#

dqi​=j∑​dSij​kj​=j∑​Pij​(dPij​−Di​)kj​

#

dPij​=dOi​Vj​T

988b\_dp=tl.dot(b\_do,b\_vT,out\_dtype=HI\_PRES\_TL).to(HI\_PRES\_TL)

#

dSij​=Pij​(dPij​−Di​)

990b\_ds=b\_p\*(b\_dp-b\_pdp[:,None])

#

(log2​e)dQi​=∑j​dSij​σ(log2​e)Kj​

992b\_dq+=tl.dot(b\_ds.to(b\_kT.dtype),tl.trans(b\_kT),out\_dtype=HI\_PRES\_TL)

#

Increment pointers.

995offs\_j+=BLOCK\_K996p\_kT=tl.advance(p\_kT,(0,BLOCK\_K))997p\_vT=tl.advance(p\_vT,(0,BLOCK\_K))

#

Return accumulated dQ

1000returnb\_dq

labml.ai