docs/transformers/flash/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/flash/ init.py)
Flash attention speeds up transformer attention mechanism by reducing the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM.
It's introduced in paper FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness and further optimized in paper FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. Official CUDA implementation can be found at Dao-AILab/flash-attention.
Our implementation is based on the Triton's example implementation.
Note: You can click on the mathematical symbols or identifiers to highlight them.
You can run test.py to see correctness and measure performance of this implementation.
Here's the attention forward pass. The formulas represent a single attention head. Qi is query vector (row vector) at position i and Kj and Vj are the key and value row vectors at position j. Oi is the output vector at position i.
SijLiPijOi=σQiKjT=j∑eSij=LieSij=j∑PijVj=Li1j∑eSijVj
Sij is the attention score matrix before softmax, Li is the softmax denominator, and Pij is the attention matrix after softmax.
You can compute Oi, instead of doing the full softmax, by computing the sum of exponents li and the unnormalized output O~i while iterating over keys:
SijliOi=σQiKjT←li+eSij←Oi+eSijoj
Finally you can compute,
Oi=liO~i
To make it numerically stable flash attention subtracts the current max of Sij before exponentiating.
So it maintains the following while iterating over keys:
For each block of keys j1…j2 it updates them:
minewPijliOimi=max(mi,j=j1maxj2Sij)=exp(Sij−minew)←emi−minewli+j=j1∑j2Pij←emi−minewOi+P~ij∗Vj←minew
Then finally,
Oi=liO~i
This reduces the memory usage since we don't have to compute full Sij matrix or Pij matrix. It also speeds up since we don't have to load these large matrices. Instead it only loads blocks of K and V as it iterates over them.
Here's the standard backward pass. dOi is the gradient vector on the output Oi
dVjdPijdSijdQidKj=i∑PijdOi=dOiVjT=dsoftmax(dPij)=k∑Pik(δjk−Pij)dPik=PijdPij−Pij∑PikdPik=σj∑dSijKj=σi∑dSijQi
where δjk is 1 when j=k and 0 otherwise.
Flash attention paper introduces Di to simplify dS computation.
Di=k∑PikdPik=k∑PikdOiVkT=dOik∑PikVkT=dOiOiT
Then,
dSij=PijdPij−DiPij
Flash attention saves Li from the forward pass since it doesn't take much memory. So during the backward pass it doesn't have to keep computing li or mi.
It first computes Di. Then it iterates over the queries and compute (accumulate) dKj and dVj. Finally it iterates over the keys and compute (accumulate) dQi.
In both forward and backward pass we calculate logarithms and exponentials of 2 instead of e for performance.
148fromtypingimportAny,Tuple149150importtorch151importtriton152importtriton.languageastl153154HI\_PRES\_TL:tl.constexpr=tl.float32155HI\_PRES\_TORCH:torch.dtype=torch.float32
158classAttentionFunc(torch.autograd.Function):
Group query attention forward pass. Returns the output in shape [batch_size, n_heads, q_seq_len, d_head] .
ctx is the context for torch gradient descentq has shape [batch_size, n_heads, q_seq_len, d_head]q has shape [batch_size, n_heads, q_seq_len, d_head]k has shape [batch_size, k_heads, kv_seq_len, d_head]v has shape [batch_size, k_heads, kv_seq_len, d_head]causal whether to apply causal attention masksm_scale softmax scale factor σ159@staticmethod160defforward(ctx:Any,161q:torch.Tensor,k:torch.Tensor,v:torch.Tensor,162causal:bool,sm\_scale:float)-\>torch.Tensor:
176batch\_size,n\_heads,q\_seq\_len,d\_head=q.shape177\_,k\_heads,kv\_seq\_len,\_=k.shape178assertn\_heads%k\_heads==0179n\_groups=n\_heads//k\_heads
Shape constraints
182assertd\_head==k.shape[-1]==v.shape[-1]183assertd\_headin{16,32,64,128,256}
Change the tensors combining the heads with the batch dimension
186q=q.view(batch\_size\*k\_heads,n\_groups,q\_seq\_len,d\_head)187k=k.view(batch\_size\*k\_heads,kv\_seq\_len,d\_head)188v=v.view(batch\_size\*k\_heads,kv\_seq\_len,d\_head)
Make sure the tensors are contiguous and the strides are same
191assertq.is\_contiguous()192assertk.is\_contiguous()193assertv.is\_contiguous()194assertk.stride()==v.stride()
Tensor for the output
197o=torch.empty\_like(q)
Tensor for log of sum of exponentials log2Li=log2∑jeSij
199lse=torch.empty((batch\_size\*k\_heads,n\_groups,q\_seq\_len),device=q.device,dtype=HI\_PRES\_TORCH)
The forward computation will be parallelized along the batch dimension and the queries in blocks of size BLOCK_Q
202grid=lambdameta:(triton.cdiv(q\_seq\_len,meta["BLOCK\_Q"]),batch\_size\*k\_heads\*n\_groups,1)203\_attn\_fwd[grid](204q,k,v,sm\_scale\*1.4426950408889634,lse,o,205n\_groups=n\_groups,206q\_seq\_len=q\_seq\_len,207kv\_seq\_len=kv\_seq\_len,208d\_head=d\_head,209is\_causal=causal,210)
Save the reshaped inputs and outputs for the backward pass
213ctx.save\_for\_backward(q,k,v,o,lse)214ctx.sm\_scale=sm\_scale215ctx.n\_groups=n\_groups216ctx.causal=causal
Return the output in shape [batch_size, n_heads, q_seq_len, d_head]
219returno.view(batch\_size,n\_heads,q\_seq\_len,d\_head)
The backward pass computes the gradients of the input tensors.
ctx is the context for torch gradient descentdo is the gradient tensor of the attention output with shape [batch_size, n_heads, q_seq_len, d_head]221@staticmethod222defbackward(ctx:Any,do:torch.Tensor)-\>Tuple[torch.Tensor,torch.Tensor,torch.Tensor,None,None]:
Get saved tensors and attributes
233n\_groups=ctx.n\_groups234sm\_scale=ctx.sm\_scale235causal=ctx.causal236q,k,v,o,lse=ctx.saved\_tensors
Get shapes
239batch\_size,n\_heads,q\_seq\_len,d\_head=do.shape240\_,kv\_seq\_len,\_=k.shape241k\_heads=n\_heads//n\_groups
Combine the heads with the batch dimension of the output gradients tensor
244do=do.view(batch\_size\*k\_heads,n\_groups,q\_seq\_len,d\_head)
Make sure it's contiguous and the strides are the same
247assertdo.is\_contiguous()248assertk.stride()==v.stride()249assertq.stride()==o.stride()==do.stride()
Create tensors for input gradients
252dq=torch.empty\_like(q)253dk=torch.empty\_like(k)254dv=torch.empty\_like(v)
Precompute σ(log2e)Kj
257k\_scaled=k\*(sm\_scale\*1.4426950408889634)
Di=Pi:TdPi:=doiToi
259pdp=torch.empty\_like(lse)
We use fixed BLOCK_Q for backward pass on D
Compute Di
This is parallelized along the batch and query in blocks of size BLOCK_Q
265BLOCK\_Q=16266pre\_grid=(triton.cdiv(q\_seq\_len,BLOCK\_Q),batch\_size\*k\_heads)267\_attn\_bwd\_d[pre\_grid](268o,do,269pdp,270BLOCK\_Q=16,271d\_head=d\_head,272q\_seq\_len=q\_seq\_len,273n\_groups=n\_groups,274num\_stages=1,275)
Compute dK and dV
This is parallelized along the batch and keys in blocks of size BLOCK_K
280grid=lambdameta:(triton.cdiv(kv\_seq\_len,meta['BLOCK\_K']),batch\_size\*k\_heads)281\_attn\_bwd\_dkdv[grid](282q,k\_scaled,v,sm\_scale,do,dk,dv,283lse,pdp,284q\_seq\_len,kv\_seq\_len,n\_groups,d\_head,285is\_causal=causal,286287)
Compute dQ
This is parallelized along the batch and queries in blocks of size BLOCK_Q
292grid=lambdameta:(triton.cdiv(q\_seq\_len,meta['BLOCK\_Q']),batch\_size\*k\_heads\*n\_groups)293\_attn\_bwd\_dq[grid](294q,k\_scaled,v,do,295dq,296lse,pdp,297q\_seq\_len,kv\_seq\_len,n\_groups,d\_head,298is\_causal=causal,299)
Split the combined batch and heads
302dq=dq.view(batch\_size,n\_heads,q\_seq\_len,d\_head)303dk=dk.view(batch\_size,k\_heads,kv\_seq\_len,d\_head)304dv=dv.view(batch\_size,k\_heads,kv\_seq\_len,d\_head)
307returndq,dk,dv,None,None308309310attention=AttentionFunc.apply
313def\_get\_autotune\_configs(inner\_loop:str)-\>list:
318configs=[]
Possible options for BLOCK_Q
321forbqin[64,128,256]:
Possible options for BLOCK_K
323forbkin[64,128,256]:
If the inner loop is along keys the BLOCK_Q must be a multiple of BLOCK_K for causal masking
325ifinner\_loop=='key'andbq%bk!=0:326continue
Similarly when the inner loop is along queries
328ifinner\_loop=='query'andbk%bq!=0:329continue
Number of stages and warps
332forsin[2,3,4]:333forwin[4,8]:334ifbq\*bk\<128\*128andw==8:335continue336337configs.append(triton.Config({'BLOCK\_Q':bq,'BLOCK\_K':bk},num\_stages=s,num\_warps=w))
Use return configs to autotune. Trying all combinations is slow for testing.
340returnconfigs[:1]
t_q queries Qit_k keys Kjt_v values Vjsm_scale_log2eσlog2e softmax scale multiplied by log2et_lselog2∑jeSij (out)t_oOi outputn_groups number of groups in GQAq_seq_len query sequence lengthkv_seq_len key/value sequence lengthd_head number of dimensions in a headBLOCK_Q block size for query sequence lengthBLOCK_K block size for key sequence lengthis_causal whether causal attentionStrides z , h , m and d denote the stride of the corresponding dimensions (batch_size , n_heads , q_seq_len , d_head ) in the query. Stride n denote the stride on kv_seq_len of key.
[email protected](\_get\_autotune\_configs(inner\_loop='key'),344key=["q\_seq\_len","kv\_seq\_len","d\_head","n\_groups","is\_causal"])[email protected]\_attn\_fwd(t\_q,t\_k,t\_v,sm\_scale\_log2e,t\_lse,t\_o,347n\_groups:tl.constexpr,348q\_seq\_len:tl.constexpr,349kv\_seq\_len:tl.constexpr,350d\_head:tl.constexpr,351is\_causal:tl.constexpr,352BLOCK\_Q:tl.constexpr,353BLOCK\_K:tl.constexpr,354):
We are computing the attention for Oi for i ... `i + BLOCK_Q' in batch/head combination z.
378i=tl.program\_id(0)379z=tl.program\_id(1)//n\_groups380g=tl.program\_id(1)%n\_groups
383p\_q=tl.make\_block\_ptr(t\_q+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,384(q\_seq\_len,d\_head),385(d\_head,1),386(i\*BLOCK\_Q,0),387(BLOCK\_Q,d\_head),388(1,0))389p\_v=tl.make\_block\_ptr(t\_v+z\*kv\_seq\_len\*d\_head,390(kv\_seq\_len,d\_head),391(d\_head,1),392(0,0),393(BLOCK\_K,d\_head),394(1,0))395p\_kT=tl.make\_block\_ptr(t\_k+z\*kv\_seq\_len\*d\_head,396(d\_head,kv\_seq\_len),397(1,d\_head),398(0,0),399(d\_head,BLOCK\_K),400(0,1))401p\_o=tl.make\_block\_ptr(t\_o+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,402(q\_seq\_len,d\_head),403(d\_head,1),404(i\*BLOCK\_Q,0),405(BLOCK\_Q,d\_head),406(1,0))407p\_lse=tl.make\_block\_ptr(t\_lse+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,408(q\_seq\_len,),409(1,),410(i\*BLOCK\_Q,),411(BLOCK\_Q,),412(0,))
Initialize offsets
415offs\_i=i\*BLOCK\_Q+tl.arange(0,BLOCK\_Q)416offs\_j=tl.arange(0,BLOCK\_K)
Mask for Q for the last block
419i\_mask=offs\_i\<q\_seq\_len
Initialize mi and li. mi is initialized to −inf and li to 1. So in the first update, the effect of initial li is emi−minewli=0.
b_m will be storing milog2e
425b\_m=tl.where(i\_mask,-float("inf"),0.0)426b\_l=tl.where(i\_mask,1.0,0.0)
Oi
429b\_o=tl.zeros([BLOCK\_Q,d\_head],dtype=HI\_PRES\_TL)
Load Qi outside the loop since it will be reused through out the loop over Kj.
432b\_q=tl.load(p\_q,boundary\_check=(0,),padding\_option="zero")433434ifis\_causal:
Inner loop upto the diagonal block
436b\_o,b\_l,b\_m=\_attn\_fwd\_inner(b\_o,b\_l,b\_m,b\_q,437p\_kT,p\_v,438sm\_scale\_log2e,439BLOCK\_Q,d\_head,BLOCK\_K,440offs\_i,offs\_j,441j=tl.full([],0,tl.int32),# type: ignore442steps=(i\*BLOCK\_Q)//BLOCK\_K,443MASK=False,444q\_seq\_len=q\_seq\_len,445kv\_seq\_len=kv\_seq\_len446)
Diagonal block with masking within it
448b\_o,b\_l,b\_m=\_attn\_fwd\_inner(b\_o,b\_l,b\_m,b\_q,p\_kT,p\_v,449sm\_scale\_log2e,450BLOCK\_Q,d\_head,BLOCK\_K,451offs\_i,offs\_j,452j=i\*BLOCK\_Q,453steps=BLOCK\_Q//BLOCK\_K,454MASK=True,455q\_seq\_len=q\_seq\_len,456kv\_seq\_len=kv\_seq\_len457)458else:
Iterate through all Kj
460b\_o,b\_l,b\_m=\_attn\_fwd\_inner(b\_o,b\_l,b\_m,b\_q,p\_kT,p\_v,461sm\_scale\_log2e,462BLOCK\_Q,d\_head,BLOCK\_K,463offs\_i,offs\_j,464j=tl.full([],0,tl.int32),# type: ignore465steps=tl.cdiv(kv\_seq\_len,BLOCK\_K),466MASK=False,467q\_seq\_len=q\_seq\_len,468kv\_seq\_len=kv\_seq\_len469)
Store LSE log2Li=log2(li∗emi)=log2li+milog2
472tl.store(p\_lse,b\_m+tl.math.log2(b\_l),boundary\_check=(0,))
Store Oi=liO~i
474tl.store(p\_o,(b\_o/b\_l[:,None]).to(t\_o.type.element\_ty),boundary\_check=(0,))
This iterates through keys and values starting from j for steps number of steps. In each step it processes BLOCK_K entries of keys/values.
[email protected]\_attn\_fwd\_inner(b\_o,b\_l,b\_m,b\_q,479p\_kT,p\_v,480sm\_scale\_log2e,481BLOCK\_Q:tl.constexpr,482d\_head:tl.constexpr,483BLOCK\_K:tl.constexpr,484offs\_i,offs\_j,485j,486steps,487MASK:tl.constexpr,488q\_seq\_len:tl.constexpr,489kv\_seq\_len:tl.constexpr490):
497tl.static\_assert(BLOCK\_Q%BLOCK\_K==0)
Move Kj and Vj pointers
500p\_kT=tl.advance(p\_kT,(0,j))501p\_v=tl.advance(p\_v,(j,0))
Iterate over K, V and update O~i and li
504for\_inrange(steps):
Load KjT
506b\_kT=tl.load(p\_kT,boundary\_check=(1,),padding\_option="zero")
Compute (log2)Sij=(log2)σQiKjT
508b\_s=tl.dot(b\_q,b\_kT,out\_dtype=HI\_PRES\_TL)509b\_s=b\_s\*sm\_scale\_log2e
Apply causal mask
512ifMASK:513causal\_mask=offs\_i[:,None]\>=(j+offs\_j[None,:])514b\_s=tl.where(causal\_mask,b\_s,-float("inf"))
Mask out if the block is beyond the end of Kj
517j\_mask=(j+offs\_j)\<kv\_seq\_len518b\_s=tl.where(j\_mask[None,:],b\_s,-float("inf"))
(log2e)minew=max((log2e)mi,maxj=j1j2(log2e)Sij)
521b\_m\_new=tl.maximum(b\_m,tl.max(b\_s,-1))
# P~ij=e(Sij−minew=2(log2e)Sij−(log2e)minew
527b\_p=tl.math.exp2(b\_s-b\_m\_new[:,None])
∑j=j1j2P~ij
530b\_l\_new=tl.sum(b\_p,-1)
emi−minew
532b\_m\_m\_new=tl.math.exp2(b\_m-b\_m\_new)
li←emi−minewli+∑j=j1j2P~ij
534b\_l=b\_l\*b\_m\_m\_new+b\_l\_new
Oi←emi−minewOi+P~ijVj
537b\_o=b\_o\*b\_m\_m\_new[:,None]538b\_p=b\_p.to(b\_q.dtype)# TODO539b\_v=tl.load(p\_v,boundary\_check=(0,),padding\_option="zero")540b\_o+=tl.dot(b\_p,b\_v,out\_dtype=HI\_PRES\_TL)
(log2e)mi←(log2e)minew
543b\_m=b\_m\_new
Move pointers
546j+=BLOCK\_K547p\_v=tl.advance(p\_v,(BLOCK\_K,0))548p\_kT=tl.advance(p\_kT,(0,BLOCK\_K))549550tl.static\_assert(b\_o.dtype==HI\_PRES\_TL,"attn\_fwd\_inner requires accumulator to be in HI\_PRES\_TL precision")551552returnb\_o,b\_l,b\_m
[email protected]\_attn\_bwd\_d(t\_o,t\_do,557t\_pdp,558BLOCK\_Q:tl.constexpr,d\_head:tl.constexpr,559q\_seq\_len:tl.constexpr,560n\_groups:tl.constexpr,561):
565i=tl.program\_id(0)\*BLOCK\_Q566z=tl.program\_id(1)
Create block pointers
569p\_o=tl.make\_block\_ptr(t\_o+z\*n\_groups\*q\_seq\_len\*d\_head,570(n\_groups,q\_seq\_len,d\_head),571(q\_seq\_len\*d\_head,d\_head,1),572(0,i,0),573(n\_groups,BLOCK\_Q,d\_head),574(2,1,0))575p\_do=tl.make\_block\_ptr(t\_do+z\*n\_groups\*q\_seq\_len\*d\_head,576(n\_groups,q\_seq\_len,d\_head),577(q\_seq\_len\*d\_head,d\_head,1),578(0,i,0),579(n\_groups,BLOCK\_Q,d\_head),580(2,1,0))581p\_pdp=tl.make\_block\_ptr(t\_pdp+z\*n\_groups\*q\_seq\_len,582(n\_groups,q\_seq\_len),583(q\_seq\_len,1),584(0,i),585(n\_groups,BLOCK\_Q),586(1,0))
Load Oi
589o=tl.load(p\_o,boundary\_check=(1,),padding\_option="zero")
Load dOi
591do=tl.load(p\_do,boundary\_check=(1,),padding\_option="zero").to(HI\_PRES\_TL)
Calculate Di=dOiOiT
593d=tl.sum(o\*do,axis=-1)
Save Di
595tl.store(p\_pdp,d,boundary\_check=(1,))
[email protected](\_get\_autotune\_configs(inner\_loop='query'),599key=["q\_seq\_len","kv\_seq\_len","d\_head","n\_groups","is\_causal"])[email protected]\_attn\_bwd\_dkdv(t\_q,t\_k,t\_v,sm\_scale,602t\_do,603t\_dk,t\_dv,604t\_lse,t\_pdp,605q\_seq\_len:tl.constexpr,kv\_seq\_len:tl.constexpr,606n\_groups:tl.constexpr,d\_head:tl.constexpr,607is\_causal:tl.constexpr,608BLOCK\_Q:tl.constexpr,609BLOCK\_K:tl.constexpr,610):
Compute dKj and dVj for j ... j + BLOCK_K by iterating over Qi
616j=tl.program\_id(0)\*BLOCK\_K617z=tl.program\_id(1)
Create block pointers
620p\_k=tl.make\_block\_ptr(t\_k+z\*kv\_seq\_len\*d\_head,621(kv\_seq\_len,d\_head),622(d\_head,1),623(j,0),624(BLOCK\_K,d\_head),625(1,0))626p\_v=tl.make\_block\_ptr(t\_v+z\*kv\_seq\_len\*d\_head,627(kv\_seq\_len,d\_head),628(d\_head,1),629(j,0),630(BLOCK\_K,d\_head),631(1,0))632p\_dk=tl.make\_block\_ptr(t\_dk+z\*kv\_seq\_len\*d\_head,633(kv\_seq\_len,d\_head),634(d\_head,1),635(j,0),636(BLOCK\_K,d\_head),637(1,0))638p\_dv=tl.make\_block\_ptr(t\_dv+z\*kv\_seq\_len\*d\_head,639(kv\_seq\_len,d\_head),640(d\_head,1),641(j,0),642(BLOCK\_K,d\_head),643(1,0))
Initialize σ1dK and dV
646b\_dk=tl.zeros([BLOCK\_K,d\_head],dtype=HI\_PRES\_TL)647b\_dv=tl.zeros([BLOCK\_K,d\_head],dtype=HI\_PRES\_TL)
Load log2σK and V outside the loop.
650b\_k=tl.load(p\_k,boundary\_check=(0,),padding\_option="zero")651b\_v=tl.load(p\_v,boundary\_check=(0,),padding\_option="zero")
Iterate through queries in GQA
654forginrange(n\_groups):
Create block pointers
656p\_qT=tl.make\_block\_ptr(t\_q+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,657(d\_head,q\_seq\_len),658(1,d\_head),659(0,0),660(d\_head,BLOCK\_Q),661(0,1))662663p\_do=tl.make\_block\_ptr(t\_do+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,664(q\_seq\_len,d\_head),665(d\_head,1),666(0,0),667(BLOCK\_Q,d\_head),668(1,0))669p\_lse=tl.make\_block\_ptr(t\_lse+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,670(q\_seq\_len,),671(1,),672(0,),673(BLOCK\_Q,),674(0,))675p\_pdp=tl.make\_block\_ptr(t\_pdp+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,676(q\_seq\_len,),677(1,),678(0,),679(BLOCK\_Q,),680(0,))681682ifis\_causal:
Inner loop at the diagonal block
684b\_dk,b\_dv=\_attn\_bwd\_dkdv\_inner(685b\_dk,b\_dv,686p\_qT,b\_k,b\_v,p\_do,687p\_lse,p\_pdp,688BLOCK\_Q,BLOCK\_K,689d\_head,690j=j,i=j,691steps=BLOCK\_K//BLOCK\_Q,692MASK=True,693q\_seq\_len=q\_seq\_len,694kv\_seq\_len=kv\_seq\_len,695)
Inner loop on queries after the diagonal
698b\_dk,b\_dv=\_attn\_bwd\_dkdv\_inner(699b\_dk,b\_dv,700p\_qT,b\_k,b\_v,p\_do,701p\_lse,p\_pdp,702BLOCK\_Q,BLOCK\_K,703d\_head,704j=j,i=j+BLOCK\_K,705steps=tl.cdiv((q\_seq\_len-(j+BLOCK\_K)),BLOCK\_Q),706MASK=False,707q\_seq\_len=q\_seq\_len,708kv\_seq\_len=kv\_seq\_len709)710else:
Iterate through all queries
712b\_dk,b\_dv=\_attn\_bwd\_dkdv\_inner(713b\_dk,b\_dv,714p\_qT,b\_k,b\_v,p\_do,715p\_lse,p\_pdp,716BLOCK\_Q,BLOCK\_K,717d\_head,718j=j,i=tl.full([],0,tl.int32),719steps=tl.cdiv(q\_seq\_len,BLOCK\_Q),720MASK=False,721q\_seq\_len=q\_seq\_len,722kv\_seq\_len=kv\_seq\_len723)
Save dV
726tl.store(p\_dv,b\_dv.to(t\_dv.type.element\_ty),boundary\_check=(0,))
b_dk had σ1dK
729b\_dk\*=sm\_scale
Save dK
732tl.store(p\_dk,b\_dk.to(t\_dk.type.element\_ty),boundary\_check=(0,))
[email protected]\_attn\_bwd\_dkdv\_inner(b\_dk,b\_dv,737p\_qT,b\_k,b\_v,p\_do,738p\_lse,p\_pdp,739BLOCK\_Q:tl.constexpr,BLOCK\_K:tl.constexpr,740d\_head:tl.constexpr,741j,i,steps,742MASK:tl.constexpr,743q\_seq\_len:tl.constexpr,744kv\_seq\_len:tl.constexpr):
To apply the mask
750tl.static\_assert(BLOCK\_K%BLOCK\_Q==0)
Offsets and mask
753offs\_i=i+tl.arange(0,BLOCK\_Q)754offs\_j=j+tl.arange(0,BLOCK\_K)
Move the pointers
757p\_qT=tl.advance(p\_qT,(0,i))758p\_do=tl.advance(p\_do,(i,0))759p\_lse=tl.advance(p\_lse,(i,))760p\_pdp=tl.advance(p\_pdp,(i,))
Iterate over Q
763for\_inrange(steps):
Load QiT
765b\_qT=tl.load(p\_qT,boundary\_check=(1,),padding\_option="zero")
log2Li
768b\_l=tl.load(p\_lse,boundary\_check=(0,),padding\_option="zero")
(log2e)SijT=σ(log2e)KjQiT
771b\_sT=tl.dot(b\_k,b\_qT,out\_dtype=HI\_PRES\_TL)
# Pij=LieSij=2log2Li2(log2e)Sij=2(log2e)Sij−log2Li
780b\_pT=tl.math.exp2(b\_sT-b\_l[None,:])
Autoregressive masking
783ifMASK:784mask=(offs\_i[None,:]\>=offs\_j[:,None])785b\_pT=tl.where(mask,b\_pT,0.0)
Mask out if the block is beyond the end of Qi
Note: No need to mask out based on j because the effects on positions outside boundary will not get stored in dK or dV Masking by i may also not be necessary size the tensors have 0 on loading
792i\_mask=offs\_i\<q\_seq\_len793b\_pT=tl.where(i\_mask[None,:],b\_pT,0.0)
dVj=∑iPijdOi
796b\_do=tl.load(p\_do,boundary\_check=(0,),padding\_option="zero")797b\_dv+=tl.dot(b\_pT.to(b\_do.dtype),b\_do,out\_dtype=HI\_PRES\_TL)
Di
800b\_pdp=tl.load(p\_pdp,boundary\_check=(0,),padding\_option="zero")
dPij=VjdOiT
802b\_dpT=tl.dot(b\_v,tl.trans(b\_do),out\_dtype=HI\_PRES\_TL).to(HI\_PRES\_TL)
dSij=Pij(dPij−Di)
804b\_dsT=b\_pT\*(b\_dpT-b\_pdp[None,:])
σ1dKj=∑idSijQi
806b\_dk+=tl.dot(b\_dsT.to(b\_qT.dtype),tl.trans(b\_qT),out\_dtype=HI\_PRES\_TL)
Increment pointers.
809offs\_i+=BLOCK\_Q810p\_lse=tl.advance(p\_lse,(BLOCK\_Q,))811p\_pdp=tl.advance(p\_pdp,(BLOCK\_Q,))812p\_qT=tl.advance(p\_qT,(0,BLOCK\_Q))813p\_do=tl.advance(p\_do,(BLOCK\_Q,0))
Return accumulated dK and dV
816returnb\_dk,b\_dv
[email protected](\_get\_autotune\_configs(inner\_loop='key'),820key=["q\_seq\_len","kv\_seq\_len","d\_head","n\_groups","is\_causal"])[email protected]\_attn\_bwd\_dq(t\_q,t\_k,t\_v,t\_do,823t\_dq,824t\_lse,t\_pdp,825q\_seq\_len:tl.constexpr,kv\_seq\_len:tl.constexpr,826n\_groups:tl.constexpr,d\_head:tl.constexpr,827is\_causal:tl.constexpr,828BLOCK\_Q:tl.constexpr,829BLOCK\_K:tl.constexpr,830):
835i=tl.program\_id(0)\*BLOCK\_Q836z=tl.program\_id(1)//n\_groups837g=tl.program\_id(1)%n\_groups# TODO
Create block pointers
840p\_q=tl.make\_block\_ptr(t\_q+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,841(q\_seq\_len,d\_head),842(d\_head,1),843(i,0),844(BLOCK\_Q,d\_head),845(1,0))846p\_dq=tl.make\_block\_ptr(t\_dq+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,847(q\_seq\_len,d\_head),848(d\_head,1),849(i,0),850(BLOCK\_Q,d\_head),851(1,0))852p\_do=tl.make\_block\_ptr(t\_do+z\*n\_groups\*q\_seq\_len\*d\_head+g\*q\_seq\_len\*d\_head,853(q\_seq\_len,d\_head),854(d\_head,1),855(i,0),856(BLOCK\_Q,d\_head),857(1,0))858p\_kT=tl.make\_block\_ptr(t\_k+z\*kv\_seq\_len\*d\_head,859(d\_head,kv\_seq\_len),860(1,d\_head),861(0,0),862(d\_head,BLOCK\_K),863(0,1))864p\_vT=tl.make\_block\_ptr(t\_v+z\*kv\_seq\_len\*d\_head,865(d\_head,kv\_seq\_len),866(1,d\_head),867(0,0),868(d\_head,BLOCK\_K),869(0,1))870p\_lse=tl.make\_block\_ptr(t\_lse+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,871(q\_seq\_len,),872(1,),873(i,),874(BLOCK\_Q,),875(0,))876p\_pdp=tl.make\_block\_ptr(t\_pdp+z\*n\_groups\*q\_seq\_len+g\*q\_seq\_len,877(q\_seq\_len,),878(1,),879(i,),880(BLOCK\_Q,),881(0,))
Load Qi, dOi, Di, and log2Li outside the loop
884b\_q=tl.load(p\_q,boundary\_check=(0,),padding\_option="zero")885b\_do=tl.load(p\_do,boundary\_check=(0,),padding\_option="zero")886b\_pdp=tl.load(p\_pdp,boundary\_check=(0,),padding\_option="zero")887b\_lse=tl.load(p\_lse,boundary\_check=(0,),padding\_option="zero")
Initialize (log2e)dQ
890b\_dq=tl.zeros([BLOCK\_Q,d\_head],dtype=HI\_PRES\_TL)
dqi=j∑dSijkj=j∑Pij(dPij−Di)kj
894ifis\_causal:
Compute dQ for masked (diagonal) blocks.
896b\_dq=\_attn\_bwd\_dq\_inner(b\_dq,b\_q,p\_kT,p\_vT,897b\_do,b\_lse,b\_pdp,898BLOCK\_Q,BLOCK\_K,899i=i,j=i,900steps=BLOCK\_Q//BLOCK\_K,901MASK=True,902q\_seq\_len=q\_seq\_len,903kv\_seq\_len=kv\_seq\_len904)
Compute for other blocks
907b\_dq=\_attn\_bwd\_dq\_inner(b\_dq,b\_q,p\_kT,p\_vT,908b\_do,b\_lse,b\_pdp,909BLOCK\_Q,BLOCK\_K,910i=i,j=tl.full([],0,tl.int32),# type: ignore911steps=i//BLOCK\_K,912MASK=False,913q\_seq\_len=q\_seq\_len,914kv\_seq\_len=kv\_seq\_len915)916else:
Iterate through all K
918b\_dq=\_attn\_bwd\_dq\_inner(b\_dq,b\_q,p\_kT,p\_vT,919b\_do,b\_lse,b\_pdp,920BLOCK\_Q,BLOCK\_K,921i=i,j=tl.full([],0,tl.int32),# type: ignore922steps=tl.cdiv(kv\_seq\_len,BLOCK\_K),923MASK=False,924q\_seq\_len=q\_seq\_len,925kv\_seq\_len=kv\_seq\_len926)
b_dq stores (log2e)dQ so multiply by loge2 to get dQ
929b\_dq\*=0.6931471824645996
Save dQ
932tl.store(p\_dq,b\_dq.to(t\_dq.type.element\_ty),boundary\_check=(0,))
[email protected]\_attn\_bwd\_dq\_inner(b\_dq,b\_q,p\_kT,p\_vT,937b\_do,b\_lse,b\_pdp,938BLOCK\_Q:tl.constexpr,BLOCK\_K:tl.constexpr,939i,j,steps,940MASK:tl.constexpr,941q\_seq\_len:tl.constexpr,942kv\_seq\_len:tl.constexpr):
Offsets
948offs\_i=i+tl.arange(0,BLOCK\_Q)949offs\_j=j+tl.arange(0,BLOCK\_K)
Move the pointers
952p\_kT=tl.advance(p\_kT,(0,j))953p\_vT=tl.advance(p\_vT,(0,j))954955tl.static\_assert(BLOCK\_Q%BLOCK\_K==0,'BLOCK\_Q must be divisible by BLOCK\_K')
Iterate over K
958for\_inrange(steps):
Load KjT
960b\_kT=tl.load(p\_kT,boundary\_check=(1,),padding\_option="zero")
Load VjT
962b\_vT=tl.load(p\_vT,boundary\_check=(1,),padding\_option="zero")
(log2e)Sij=σ(log2e)QiKjT
965b\_s=tl.dot(b\_q,b\_kT,out\_dtype=HI\_PRES\_TL)
# Pij=LieSij=2log2Li2(log2e)Sij=2(log2e)Sij−log2Li
974b\_p=tl.math.exp2(b\_s-b\_lse[:,None])
Autoregressive masking
977ifMASK:978causal\_mask=(offs\_i[:,None]\>=offs\_j[None,:])979b\_p=tl.where(causal\_mask,b\_p,0.0)
Mask out if the block is beyond the end of Qi
982j\_mask=offs\_j\<kv\_seq\_len983b\_p=tl.where(j\_mask[None,:],b\_p,0.0)
dqi=j∑dSijkj=j∑Pij(dPij−Di)kj
dPij=dOiVjT
988b\_dp=tl.dot(b\_do,b\_vT,out\_dtype=HI\_PRES\_TL).to(HI\_PRES\_TL)
dSij=Pij(dPij−Di)
990b\_ds=b\_p\*(b\_dp-b\_pdp[:,None])
(log2e)dQi=∑jdSijσ(log2e)Kj
992b\_dq+=tl.dot(b\_ds.to(b\_kT.dtype),tl.trans(b\_kT),out\_dtype=HI\_PRES\_TL)
Increment pointers.
995offs\_j+=BLOCK\_K996p\_kT=tl.advance(p\_kT,(0,BLOCK\_K))997p\_vT=tl.advance(p\_vT,(0,BLOCK\_K))
Return accumulated dQ
1000returnb\_dq