Back to Annotated Deep Learning Paper Implementations

test.py

docs/transformers/flash/test.html

latest6.6 KB
Original Source

hometransformersflash

View code on Github

#

Test Flash Attention Implementation

This is the code to test and measure performance of our flash attention implementation

7importtorch8importtriton910fromlabmlimportlogger,monit11fromlabml\_nn.transformers.flashimportattention1213HI\_PRES\_TORCH=torch.float32

#

Calculate absolute and relative error for reporting

[email protected]\_grad()17def\_calc\_abs\_rel\_error(a:torch.Tensor,b:torch.Tensor,atol=1e-2):

#

21d=(a-b).abs()22max\_abs=d.max()23d=(d-atol).clamp(min=0)24d=d/b.abs()25max\_rel=d.max()2627returnmax\_abs.cpu().item(),max\_rel.cpu().item()

#

Compare our implementation with naive PyTorch attention

30deftest\_fwd\_bwd(batch\_size,n\_heads,k\_heads,q\_seq\_len,kv\_seq\_len,d\_head,causal,dtype,device):

#

35withmonit.section(f'Init {q\_seq\_len} {kv\_seq\_len} {d\_head}'):36torch.manual\_seed(20)37q=(torch.empty((batch\_size,n\_heads,q\_seq\_len,d\_head),38dtype=dtype,device=device).normal\_(mean=0.0,std=0.5).requires\_grad\_())39k=(torch.empty((batch\_size,k\_heads,kv\_seq\_len,d\_head),40dtype=dtype,device=device).normal\_(mean=0.0,std=0.5).requires\_grad\_())41v=(torch.empty((batch\_size,k\_heads,kv\_seq\_len,d\_head),42dtype=dtype,device=device).normal\_(mean=0.0,std=0.5).requires\_grad\_())43sm\_scale=d\_head\*\*-0.544d\_out=torch.randn\_like(q)

#

reference implementation

46mask=torch.tril(torch.ones((q\_seq\_len,kv\_seq\_len),device=device,dtype=torch.bool))47torch.cuda.synchronize()4849withmonit.section('Pytorch'):50p=torch.matmul(q.view(batch\_size,k\_heads,-1,q\_seq\_len,d\_head),51k.transpose(2,3)[:,:,None,:,:])\*sm\_scale52ifcausal:53p[:,:,:,~mask]=float("-inf")54p=torch.softmax(p.to(HI\_PRES\_TORCH),dim=-1).to(dtype)55ref\_out=torch.matmul(p,v[:,:,None,:,:])56ref\_out=ref\_out.view(q.shape)57ref\_out.backward(d\_out)58ref\_dv,v.grad=v.grad.clone(),None59ref\_dk,k.grad=k.grad.clone(),None60ref\_dq,q.grad=q.grad.clone(),None61torch.cuda.synchronize()6263withmonit.section('Triton'):64assertq.dtype==dtype65tri\_out=attention(q,k,v,causal,sm\_scale).to(dtype)66monit.progress(0.5)6768tri\_out.backward(d\_out)69monit.progress(0.9)70tri\_dv,v.grad=v.grad.clone(),None# type: ignore71tri\_dk,k.grad=k.grad.clone(),None# type: ignore72tri\_dq,q.grad=q.grad.clone(),None# type: ignore73torch.cuda.synchronize()7475withmonit.section('Test')ass:

#

compare

77passed=True78ifnottorch.allclose(tri\_out,ref\_out,atol=1e-2,rtol=0.):79abs\_err,rel\_err=\_calc\_abs\_rel\_error(ref\_out,tri\_out)80logger.log(('[FAILED]',logger.Text.danger),f' Out mismatch {abs\_err} {rel\_err}')81passed=False82rtol=1e-183ifnottorch.allclose(tri\_dq,ref\_dq,atol=1e-2,rtol=rtol):84abs\_err,rel\_err=\_calc\_abs\_rel\_error(ref\_dq,tri\_dq)85logger.log(('[FAILED]',logger.Text.danger),f' dQ mismatch {abs\_err} {rel\_err}')86passed=False87ifnottorch.allclose(tri\_dv,ref\_dv,atol=1e-2,rtol=rtol):88abs\_err,rel\_err=\_calc\_abs\_rel\_error(ref\_dv,tri\_dv)89logger.log(('[FAILED]',logger.Text.danger),f' dV mismatch {abs\_err} {rel\_err}')90passed=False91ifnottorch.allclose(tri\_dk,ref\_dk,atol=1e-2,rtol=rtol):92abs\_err,rel\_err=\_calc\_abs\_rel\_error(ref\_dk,tri\_dk)93logger.log(('[FAILED]',logger.Text.danger),f' dK mismatch {abs\_err} {rel\_err}')94passed=False9596ifpassed:97logger.log('[PASSED]',logger.Text.success)98s.success=True99else:100s.success=False101torch.cuda.synchronize()

#

Get a partial function to test performance of our implementation

104def\_perf\_triton\_fn(\*,device,dtype,batch\_size,k\_heads,n\_groups,seq\_len,d\_head,causal):

#

108q=torch.randn((batch\_size,k\_heads\*n\_groups,seq\_len,d\_head),dtype=dtype,device=device,requires\_grad=True)109k=torch.randn((batch\_size,k\_heads,seq\_len,d\_head),dtype=dtype,device=device,requires\_grad=True)110v=torch.randn((batch\_size,k\_heads,seq\_len,d\_head),dtype=dtype,device=device,requires\_grad=True)111sm\_scale=d\_head\*\*-0.5112returnlambda:attention(q,k,v,causal,sm\_scale)

#

Get a partial function to test performance of original flash implementation

115def\_perf\_flash(\*,batch\_size,k\_heads,n\_groups,seq\_len,d\_head,causal,device,dtype):

#

119q=torch.randn((batch\_size,seq\_len,k\_heads\*n\_groups,d\_head),dtype=dtype,device=device,requires\_grad=True)120k=torch.randn((batch\_size,seq\_len,k\_heads,d\_head),dtype=dtype,device=device,requires\_grad=True)121v=torch.randn((batch\_size,seq\_len,k\_heads,d\_head),dtype=dtype,device=device,requires\_grad=True)122fromflash\_attnimportflash\_attn\_func123returnlambda:flash\_attn\_func(q,k,v,causal=causal)

#

Measure the speed

126defmeasure\_performance(name,fn,\*,batch\_size,k\_heads,n\_groups,seq\_len,d\_head,causal,is\_bwd:bool):

#

130ifis\_bwd:131o=fn()132do=torch.randn\_like(o)133fn=lambda:o.backward(do,retain\_graph=True)134ms=triton.testing.do\_bench(fn)135136flops\_per\_matmul=2.0\*batch\_size\*k\_heads\*n\_groups\*seq\_len\*seq\_len\*d\_head137total\_flops=2\*flops\_per\_matmul138ifcausal:139total\_flops\*=0.5140ifis\_bwd:141total\_flops\*=2.5# 2.0(bwd) + 0.5(recompute)142143tf\_ps=total\_flops\*1e-12/(ms\*1e-3)144logger.log((f'{name}',logger.Text.key),': ',f'{ms :,.1f}ms',' ',f'{tf\_ps :,.2f}TFps')

#

147defmain():148device=torch.device('cuda:0')149torch.cuda.set\_device(device)150151dtype=torch.float16

#

only works on post-Ampere GPUs right now

154test\_fwd\_bwd(1,4,1,2048,2048,128,True,dtype=dtype,device=device)155test\_fwd\_bwd(16,32,8,2001,4001,128,False,dtype=dtype,device=device)156test\_fwd\_bwd(4,32,8,2048,1024,128,False,dtype=dtype,device=device)157test\_fwd\_bwd(4,32,8,2001,4001,128,True,dtype=dtype,device=device)158159\_conf={160'batch\_size':16,161'k\_heads':8,162'n\_groups':4,163'seq\_len':2048,164'd\_head':128,165}166167for\_causalin[False,True]:168foris\_bwdin[False,True]:169logger.log(f'{"Causal" if \_causal else "Non-causal"} {" Backward" if is\_bwd else ""}',logger.Text.title)170measure\_performance(f'flash',\_perf\_flash(causal=\_causal,device=device,dtype=dtype,\*\*\_conf),171is\_bwd=is\_bwd,172causal=\_causal,\*\*\_conf)173measure\_performance(f'triton',\_perf\_triton\_fn(causal=\_causal,device=device,dtype=dtype,\*\*\_conf),174is\_bwd=is\_bwd,175causal=\_causal,\*\*\_conf)176177178if\_\_name\_\_=="\_\_main\_\_":179main()

labml.ai