examples/python/CuTeDSL/notebooks/async_pipeline.ipynb
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
This tutorial explores advanced CUDA programming techniques for implementing efficient producer-consumer patterns using asynchronous communication primitives in the CuTe Domain Specific Language (DSL).
A warp is the fundamental execution unit in CUDA, consisting of 32 threads that execute instructions in Single Instruction, Multiple Thread (SIMT) fashion on a Streaming Multiprocessor (SM). Understanding warp-level programming is crucial for achieving optimal GPU performance.
Key Concepts:
Shared memory serves as a programmer-managed cache with several important characteristics:
The conventional approach for inter-warp communication relies on explicit synchronization barriers. The following sequence diagram illustrates the typical producer-consumer pattern:
sequenceDiagram
participant W0 as Producer Warp
participant SMEM as Shared Memory
participant W1 as Consumer Warp
W0->>SMEM: Write data
critical Synchronization Barrier
W0-->W1: __syncthreads()
SMEM->>W1: Read data
W0-->W1: __syncthreads()
end
Limitations of Synchronous Communication:
@cute.kernel
def synced_producer_consumer(SharedStorage: cutlass.Constexpr, res: cute.Tensor):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(SharedStorage, 64)
staging_smem = storage.staging_buffer.get_tensor(cute.make_layout(1))
staging_smem.fill(0)
cute.arch.sync_threads()
for i in cutlass.range(cute.size(res)):
if warp_idx == 0:
staging_smem[0] = i * 1.0
# mark enter of critical region
cute.arch.sync_threads()
if warp_idx == 1:
res[i] = staging_smem[0]
# mark exit of critical region
cute.arch.sync_threads()
@cute.jit
def run_synced_producer_consumer(res: cute.Tensor):
@cute.struct
class SharedStorage:
staging_buffer: cute.struct.Align[
cute.struct.MemRange[cutlass.Float32, 1], 1024
]
synced_producer_consumer(SharedStorage, res).launch(
grid=(1, 1, 1), block=(64, 1, 1), smem=SharedStorage.size_in_bytes()
)
res = torch.zeros((8,), device="cuda")
run_synced_producer_consumer(from_dlpack(res))
res
The previous example demonstrates traditional synchronized communication between warps. While functional, this approach has significant performance limitations:
Critical Section Analysis:
__syncthreads(): Ensures data is written and ready for consumption__syncthreads(): Guarantees data has been consumed and memory can be safely overwrittenPerformance Impact:
Starting with the Hopper architecture, CUDA introduced sophisticated asynchronous communication primitives that enable warp specialization—allowing different warps to perform distinct, specialized roles while maintaining loose coupling.
Key Benefits:
The async pipeline abstraction provides a elegant solution for producer-consumer communication without rigid synchronization constraints:
sequenceDiagram
participant W0 as Producer Warp
participant Pipeline as Async Pipeline
participant SMEM as Shared Memory
participant W1 as Consumer Warp
W0->>Pipeline: Acquire (request write slot)
activate W1
Pipeline-->>W0: Grant access
deactivate W1
W1->>Pipeline: Wait (for data availability)
activate Pipeline
W0->>SMEM: Write data
W0->>Pipeline: Commit (signal data ready)
Pipeline-->>W1: Data available
deactivate Pipeline
activate W0
SMEM->>W1: Read data
deactivate W0
W1->>Pipeline: Release (mark slot available)
Async Pipeline Advantages:
The PipelineAsync abstraction in CuTe DSL provides a comprehensive set of primitives for implementing efficient producer-consumer patterns:
PipelineProducer.acquire(): Blocks until a write slot becomes available (released by consumer)
PipelineProducer.acquire_and_advance() additionally moves the producer's write index to the next buffer slotPipelineProducer.commit(PipelineProducer.ImmutableProducerHandle) / PipelineProducer.ImmutableProducerHandle.commit(): Signals that data has been written to the handle-pointed slot and is ready for consumption
PipelineConsumerHandle.release() tracks its internal maintained handle (pointed to the last one it acquires)PipelineConsumer.wait(): Blocks until data becomes available for reading
PipelineConsumer.wait_and_advance() additionally moves the consumer's read index to the next buffer slotPipelineConsumerHandle.release(PipelineConsumer.ImmutableConsumerHandle) / PipelineConsumer.ImmutableConsumerHandle.release(): Marks data as consumed and the handle-pointed slot as consumed and available for reuse
PipelineConsumerHandle.release() tracks its internal maintained handle (pointed to the last one it waits for)The pipeline APIs provided abstractions for developers to manage synchornization between warps, thread-blocks, etc. It doesn't provide deadlock-free guarantee. It's still developer's responsibility to write correct code to avoid deadlock.
Computational Overlap: This asynchronous communication pattern enables limited but significant computational overlap:
Memory Efficiency: Explicit slot management ensures optimal memory utilization without unnecessary copying or buffering.
@cute.kernel
def async_pipeline_kernel(res: cute.Tensor):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
@cute.struct
class SharedStorage:
tma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
staging_buffer: cute.struct.Align[
cute.struct.MemRange[cutlass.Float32, 1], 1024
]
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(SharedStorage, 64)
# Warp 0
producer_group = cutlass.pipeline.CooperativeGroup(
cutlass.pipeline.Agent.Thread, 32
)
# Warp 1
consumer_group = cutlass.pipeline.CooperativeGroup(
cutlass.pipeline.Agent.Thread, 32
)
pipeline = cutlass.pipeline.PipelineAsync.create(
num_stages=1,
producer_group=producer_group,
consumer_group=consumer_group,
barrier_storage=storage.tma_mbar_ptr.data_ptr(),
)
staging_smem = storage.staging_buffer.get_tensor(cute.make_layout(1))
staging_smem.fill(0)
cute.arch.sync_threads()
producer, consumer = pipeline.make_participants()
# Producer warp
if warp_idx == 0:
for i in cutlass.range(cute.size(res)):
# Producer: Wait for data buffer is available
handle = producer.acquire_and_advance()
# Producer: Write data to shared memory
staging_smem[handle.index] = 1.0 * i
# Producer: Signal data is ready for consumption
handle.commit()
producer.tail()
# Consumer warp
if warp_idx == 1:
for i in cutlass.range(cute.size(res)):
# Consumer: Wait for producer to signal when data is available for use
handle = consumer.wait_and_advance()
# Conumer: consumes data
res[i] = staging_smem[handle.index]
# Conumer: Signal data buffer is ready for write
handle.release()
@cute.jit
def async_pipeline(res: cute.Tensor):
# Launch kernel with two warps: producer and consumer
async_pipeline_kernel(res).launch(grid=(1, 1, 1), block=(64, 1, 1))
res = torch.zeros((8,), device="cuda")
async_pipeline(from_dlpack(res))
res
While async communication provides significant improvements over synchronous patterns, single-stage pipelines still exhibit serialization bottlenecks:
Dependency Chain Analysis:
sequenceDiagram
participant W0 as Producer
participant Pipeline as Pipeline
participant W1 as Consumer
W0->>Pipeline: Acquire
Note over W0,W1: Producer waits here
W1->>Pipeline: Release
Pipeline-->>W0: Granted
Performance Bottleneck: The producer must wait for the consumer to complete processing and release the buffer before acquiring the next write slot. This creates a serialization point that limits overall throughput.
The staged async pipeline implements a circular buffer managed by an array of synchronization barriers, enabling much higher degrees of parallelism:
Circular Buffer Management:
PipelineProducer.advance(): Moves the producer's write index to the next buffer slot
PipelineProducer.require_and_advance()PipelineConsumer.advance(): Moves the consumer's read index to the next buffer slot
PipelineConsumer.wait_and_advance()PipelineProducer.ImmutableResourceHandle.index / PipelineConsumer.ImmutableResourceHandle.index: Returns pointed buffer slot index
Legend:
W: Currently being written (producer active)
D: Data ready for consumption
R: Currently being read (consumer active)
X: Empty slot available for writing
Advance Direction
<-------------------
Producer Consumer
| ^
V |
+-----------------+
--|X|X|W|D|D|D|D|R|X|<-.
/ +-----------------+ \
| |
`------------------------'
Key Advantages:
The following implementation demonstrates efficient multi-stage pipeline communication with proper circular buffer management:
@cute.kernel
def async_pipeline_staged_kernel(
SharedStorage: cutlass.Constexpr, res: cute.Tensor, staging: cute.Tensor
):
stages = cute.size(staging)
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(SharedStorage, 64)
# Warp 0
producer_group = cutlass.pipeline.CooperativeGroup(
cutlass.pipeline.Agent.Thread, 32
)
# Warp 1
consumer_group = cutlass.pipeline.CooperativeGroup(
cutlass.pipeline.Agent.Thread, 32
)
pipeline = cutlass.pipeline.PipelineAsync.create(
num_stages=stages,
producer_group=producer_group,
consumer_group=consumer_group,
barrier_storage=storage.tma_mbar_ptr.data_ptr(),
)
staging_smem = storage.staging_buffer.get_tensor(staging.layout)
staging_smem.fill(0)
cute.arch.sync_threads()
producer, consumer = pipeline.make_participants()
# Producer warp
if warp_idx == 0:
for i in cutlass.range(cute.size(res)):
handle = producer.acquire_and_advance()
staging_smem[handle.index] = 1.0 * i
handle.commit() # or producer.commit(handle)
# prevents CTA0 from retiring until it receives all expected arrives.
producer.tail()
# Consumer warp
if warp_idx == 1:
for i in cutlass.range(cute.size(res)):
handle = consumer.wait_and_advance()
res[i] = staging_smem[handle.index]
handle.release() # or consumer.release(handle)
tidx, _, _ = cute.arch.thread_idx()
if tidx == 0:
staging.store(staging_smem.load())
@cute.jit
def async_pipeline_staged(res: cute.Tensor, staging: cute.Tensor):
stages = cute.size(staging)
@cute.struct
class SharedStorage:
tma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, stages * 2]
staging_buffer: cute.struct.Align[
cute.struct.MemRange[cutlass.Float32, stages], 1024
]
async_pipeline_staged_kernel(SharedStorage, res, staging).launch(
grid=(1, 1, 1), block=(64, 1, 1), smem=SharedStorage.size_in_bytes()
)
res = torch.zeros((8,), device="cuda")
staging = torch.zeros((5,), device="cuda")
async_pipeline_staged(from_dlpack(res), from_dlpack(staging))
torch.cuda.synchronize()
res, staging
In some circumstances, developers may want to just check status of pipeline state without blocking. This could benefit some cases that we have independent instructions to hide latency of checking pipeline state. We provided try_aquire or try_wait which are non-blocking APIs.