Back to Cutlass

Static vs Dynamic layouts

media/docs/pythonDSL/cute_dsl_general/dsl_dynamic_layout.rst

4.4.27.9 KB
Original Source

.. _dsl_dynamic_layout: .. |DSL| replace:: CuTe DSL .. |SLAY| replace:: static layout .. |DLAY| replace:: dynamic layout

Static vs Dynamic layouts

Static Layout

When integrating with popular deep learning frameworks, one question is how to deal with the layout of the converted cute.Tensor. For example, when converting a torch.Tensor to a cute.Tensor, the shape of the torch.Tensor is honored for the layout of cute.Tensor.

.. code-block:: python

import torch
import cutlass
from cutlass.cute.runtime import from_dlpack

@cute.jit
def foo(tensor):
    print(f"tensor.layout: {tensor.layout}")  # Prints tensor layout at compile time
    cute.printf("tensor: {}", tensor)         # Prints tensor values at runtime

In this example, we define a JIT function foo that takes a cute.Tensor as input and prints its layout. Note that Python print is used to print the layout at compile time. This works fine for |SLAY| whose value is known at compile time.

Now let's try to run the JIT function foo with different shapes of the input torch.Tensor.

.. code-block:: python

a = torch.tensor([1, 2, 3], dtype=torch.uint16)
a_pack = from_dlpack(a)
compiled_func = cute.compile(foo, a_pack)
compiled_func(a_pack)

Here we first convert a 1D torch.Tensor with 3 elements to a cute.Tensor using from_dlpack. Then we compile the JIT function foo with the converted cute.Tensor and call the compiled function.

::

tensor.layout: (3):(1)
tensor: raw_ptr(0x00000000079e5100: i16, generic, align<2>) o (3):(1) = 

( 1, 2, 3 )

It prints (3):(1) for the layout because the converted cute.Tensor has a |SLAY| with shape (3) which is the shape of the a.

Now if we call the compiled function with a different shape of the input torch.Tensor, it would result in an unexpected result at runtime due to the mismatch of the type since compiled_func expects a cute.Tensor with layout (3):(1) while b has shape (5).

.. code-block:: python

b = torch.tensor([11, 12, 13, 14, 15], dtype=torch.uint16)
b_pack = from_dlpack(b)
compiled_func(b_pack)  # ❌ This results in an unexpected result at runtime due to type mismatch

Following is the output which is unexpected due to the type mismatch.

::

tensor: raw_ptr(0x00000000344804c0: i16, generic, align<2>) o (3):(1) = 

( 11, 12, 13 )

To fix that, we would have to trigger another code generation and compilation for the new shape for b.

.. code-block:: python

compiled_func_2 = cute.compile(foo, b_pack)  # This would trigger another compilation
compiled_func_2(b_pack)                      # ✅ Now this works fine

As shown in the example above, with the newly compiled compiled_func_2, we can pass in b_pack to the compiled JIT function compiled_func_2.

::

tensor.layout: (5):(1)
tensor: raw_ptr(0x0000000034bb2840:: i16, generic, align<2>) o (5):(1) = 

( 11, 12, 13, 14, 15 )

Now it recompiles and prints the values of b correctly.

It's obvoius that we need distinct codes generated and compiled for different static layout. In this case, one for layout (3):(1) and the other for layout (5):(1).

Dynamic Layout

In order to avoid generating and compiling multiple times for different shapes of the input torch.Tensor, |DSL| provides a way to generate and compile JIT function with |DLAY|.

To get dyanmic layout of the cute.Tensor, a torch.Tensor object can be passed into the JIT function directly which instructs |DSL| to call cute.mark_layout_dynamic automatically on the converted cute.Tensor per the leading dimension of the layout.

.. code-block:: python

import torch
import cutlass
from cutlass.cute.runtime import from_dlpack

@cute.jit
def foo(tensor):
    print(tensor.layout)  # Prints (?,?):(?,1) for dynamic layout

a = torch.tensor([[1, 2], [3, 4]], dtype=torch.uint16)
compiled_func = cute.compile(foo, a)
compiled_func(a)

b = torch.tensor([[11, 12], [13, 14], [15, 16]], dtype=torch.uint16)
compiled_func(b)  # Reuse the same compiled function for different shape

In the example above, a single compilation of the JIT function foo is reused for different shapes of the input torch.Tensor. This is possible because the converted cute.Tensor has a |DLAY| (?,?):(?,1) which is compatible with the shape of the input torch.Tensor of both calls.

Alternatively, for compact layout, cute.mark_compact_shape_dynamic can be called for a finer-grained control to specify the mode of the layout for dynamic and the divisibility constraint for the dynamic dimension.

Refer to :doc:framework_integration for more details on from_dlpack, mark_layout_dynamic, and mark_compact_shape_dynamic.

Static Layout vs. Dynamic Layout

Per the previous sections, we have seen that |SLAY| leads to distinct JIT code generations while |DLAY| leads to a single compilation for different shapes.

That said, creating JIT function with |SLAY| is useful when the use cases targeting input data with fixed shapes. Since more information is available at compile time, the compiler would be able to kick in optimizations that otherwise would not be possible for the code generated for |DLAY|.

On the other hand, |DLAY| would be more flexible for the cases where the input data has varying shapes. This provides more scalability of the generated code to deal with varying input data of different shapes.

Programming with Static and Dynamic Layout

|DSL| provides intuitive way to program with static and |DLAY| in the codes.

.. code-block:: python

import torch
import cutlass
from cutlass.cute.runtime import from_dlpack

@cute.jit
def foo(tensor, x: cutlass.Constexpr[int]):
    print(cute.size(tensor))  # Prints 3 for the 1st call
                              # Prints ? for the 2nd call
    if cute.size(tensor) > x:
        cute.printf("tensor[2]: {}", tensor[2])
    else:
        cute.printf("tensor size <= {}", x)

a = torch.tensor([1, 2, 3], dtype=torch.uint16)
foo(from_dlpack(a), 3)   # First call with static layout

b = torch.tensor([1, 2, 3, 4, 5], dtype=torch.uint16)
foo(b, 3)                # Second call with dynamic layout

In this example, the JIT function foo is compiled with a |SLAY| (3):(1) for the first call, which means the size of the tensor is known at compile time. |DSL| makes good use of this and automatically handles the if condition at the compile time. Hence the generated codes are efficient without the if condition at all.

For the second call, the JIT function foo is compiled with a |DLAY| (?):(1) hence the tensor size is only evaluated at runtime. |DSL| automatically generates the code to handle the |DLAY| and the if condition at runtime.

The same applies to loop as well:

.. code-block:: python

@cute.jit
def foo(tensor, x: cutlass.Constexpr[int]):
    for i in range(cute.size(tensor)):
        cute.printf("tensor[{}]: {}", i, tensor[i])

a = torch.tensor([1, 2, 3], dtype=torch.uint16)
foo(from_dlpack(a), 3)   # First call with static layout

b = torch.tensor([1, 2, 3, 4, 5], dtype=torch.uint16)
foo(b, 3)                # Second call with dynamic layout

With the static layout in the first call, |DSL| is able to fully unroll the loop at compile time. While in the second call, the generated codes will have the loop executed at runtime based on the |DLAY|.

With the single JIT function implementation, |DSL| is able to handle control-flow constructs and automatically generate the optimized codes for different cases. This is all possible because |DSL| is able to walk the Python AST and convert each control-flow construct it finds accordingly.

Please refer to :doc:dsl_control_flow for more details.