examples/python/CuTeDSL/notebooks/tensorssa.ipynb
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
import numpy as np
This tutorial introduces what is the TensorSSA and why we need it. We also give some examples to show how to use TensorSSA.
TensorSSA is a Python class that represents a tensor value in Static Single Assignment (SSA) form within the CuTe DSL. You can think of it as a tensor residing in a (simulated) register.
TensorSSA encapsulates the underlying MLIR tensor value into an object that's easier to manipulate in Python. By overloading numerous Python operators (like +, -, *, /, [], etc.), it allows users to express tensor computations (primarily element-wise operations and reductions) in a more Pythonic way. These element-wise operations are then translated into optimized vectorization instructions.
It's part of the CuTe DSL, serving as a bridge between the user-described computational logic and the lower-level MLIR IR, particularly for representing and manipulating register-level data.
TensorSSA is primarily used in the following scenarios:
@cute.jit
def load_and_store(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):
"""
Load data from memory and store the result to memory.
:param res: The destination tensor to store the result.
:param a: The source tensor to be loaded.
:param b: The source tensor to be loaded.
"""
a_vec = a.load()
print(f"a_vec: {a_vec}") # prints `a_vec: vector<12xf32> o (3, 4)`
b_vec = b.load()
print(f"b_vec: {b_vec}") # prints `b_vec: vector<12xf32> o (3, 4)`
res.store(a_vec + b_vec)
cute.print_tensor(res)
a = np.ones(12).reshape((3, 4)).astype(np.float32)
b = np.ones(12).reshape((3, 4)).astype(np.float32)
c = np.zeros(12).reshape((3, 4)).astype(np.float32)
load_and_store(from_dlpack(c), from_dlpack(a), from_dlpack(b))
When writing kernel logic, various computations, transformations, slicing, etc., are performed on data loaded into registers.
@cute.jit
def apply_slice(src: cute.Tensor, dst: cute.Tensor, indices: cutlass.Constexpr):
"""
Apply slice operation on the src tensor and store the result to the dst tensor.
:param src: The source tensor to be sliced.
:param dst: The destination tensor to store the result.
:param indices: The indices to slice the source tensor.
"""
src_vec = src.load()
dst_vec = src_vec[indices]
print(f"{src_vec} -> {dst_vec}")
if cutlass.const_expr(isinstance(dst_vec, cute.TensorSSA)):
dst.store(dst_vec)
cute.print_tensor(dst)
else:
dst[0] = dst_vec
cute.print_tensor(dst)
def slice_1():
src_shape = (4, 2, 3)
dst_shape = (4, 3)
indices = (None, 1, None)
"""
a:
[[[ 0. 1. 2.]
[ 3. 4. 5.]]
[[ 6. 7. 8.]
[ 9. 10. 11.]]
[[12. 13. 14.]
[15. 16. 17.]]
[[18. 19. 20.]
[21. 22. 23.]]]
"""
a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)
dst = np.random.randn(*dst_shape).astype(np.float32)
apply_slice(from_dlpack(a), from_dlpack(dst), indices)
slice_1()
def slice_2():
src_shape = (4, 2, 3)
dst_shape = (1,)
indices = 10
a = np.arange(np.prod(src_shape)).reshape(*src_shape).astype(np.float32)
dst = np.random.randn(*dst_shape).astype(np.float32)
apply_slice(from_dlpack(a), from_dlpack(dst), indices)
slice_2()
As we mentioned earlier, there're many tensor operations whose operands are TensorSSA. And they are all element-wise operations. We give some examples below.
For binary operations, the LHS operand is TensorSSA and the RHS operand can be either TensorSSA or Numeric. When the RHS is Numeric, it will be broadcast to a TensorSSA.
@cute.jit
def binary_op_1(a: cute.Tensor, b: cute.Tensor):
a_vec = a.load()
b_vec = b.load()
add_res = a_vec + b_vec
cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]
sub_res = a_vec - b_vec
cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]
mul_res = a_vec * b_vec
cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]
div_res = a_vec / b_vec
cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]
floor_div_res = a_vec // b_vec
cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]
mod_res = a_vec % b_vec
cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]
a = np.empty((3,), dtype=np.float32)
a.fill(1.0)
b = np.empty((3,), dtype=np.float32)
b.fill(2.0)
binary_op_1(from_dlpack(a), from_dlpack(b))
@cute.jit
def binary_op_2(a: cute.Tensor, c: cutlass.Constexpr):
a_vec = a.load()
add_res = a_vec + c
cute.print_tensor(add_res) # prints [3.000000, 3.000000, 3.000000]
sub_res = a_vec - c
cute.print_tensor(sub_res) # prints [-1.000000, -1.000000, -1.000000]
mul_res = a_vec * c
cute.print_tensor(mul_res) # prints [2.000000, 2.000000, 2.000000]
div_res = a_vec / c
cute.print_tensor(div_res) # prints [0.500000, 0.500000, 0.500000]
floor_div_res = a_vec // c
cute.print_tensor(floor_div_res) # prints [0.000000, 0.000000, 0.000000]
mod_res = a_vec % c
cute.print_tensor(mod_res) # prints [1.000000, 1.000000, 1.000000]
a = np.empty((3,), dtype=np.float32)
a.fill(1.0)
c = 2.0
binary_op_2(from_dlpack(a), c)
@cute.jit
def binary_op_3(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):
a_vec = a.load()
b_vec = b.load()
gt_res = a_vec > b_vec
res.store(gt_res)
"""
ge_res = a_ >= b_ # [False, True, False]
lt_res = a_ < b_ # [True, False, True]
le_res = a_ <= b_ # [True, False, True]
eq_res = a_ == b_ # [False, False, False]
"""
a = np.array([1, 2, 3], dtype=np.float32)
b = np.array([2, 1, 4], dtype=np.float32)
res = np.empty((3,), dtype=np.bool_)
binary_op_3(from_dlpack(res), from_dlpack(a), from_dlpack(b))
print(res) # prints [False, True, False]
@cute.jit
def binary_op_4(res: cute.Tensor, a: cute.Tensor, b: cute.Tensor):
a_vec = a.load()
b_vec = b.load()
xor_res = a_vec ^ b_vec
res.store(xor_res)
# or_res = a_vec | b_vec
# res.store(or_res) # prints [3, 2, 7]
# and_res = a_vec & b_vec
# res.store(and_res) # prints [0, 2, 0]
a = np.array([1, 2, 3], dtype=np.int32)
b = np.array([2, 2, 4], dtype=np.int32)
res = np.empty((3,), dtype=np.int32)
binary_op_4(from_dlpack(res), from_dlpack(a), from_dlpack(b))
print(res) # prints [3, 0, 7]
@cute.jit
def unary_op_1(res: cute.Tensor, a: cute.Tensor):
a_vec = a.load()
sqrt_res = cute.math.sqrt(a_vec)
cute.print_tensor(sqrt_res) # prints [2.000000, 2.000000, 2.000000]
sin_res = cute.math.sin(a_vec)
res.store(sin_res)
cute.print_tensor(sin_res) # prints [-0.756802, -0.756802, -0.756802]
exp2_res = cute.math.exp2(a_vec)
cute.print_tensor(exp2_res) # prints [16.000000, 16.000000, 16.000000]
a = np.array([4.0, 4.0, 4.0], dtype=np.float32)
res = np.empty((3,), dtype=np.float32)
unary_op_1(from_dlpack(res), from_dlpack(a))
The TensorSSA's reduce method applies a specified reduction operation (ReductionOp.ADD,
ReductionOp.MUL, ReductionOp.MAX, ReductionOp.MIN) starting with an initial value, and
performs this reduction along the dimensions specified by the reduction_profile. The result
is typically a new TensorSSA with reduced dimensions or a scalar value if it reduces across
all axes.
@cute.jit
def reduction_op(a: cute.Tensor):
"""
Apply reduction operation on the src tensor.
:param src: The source tensor to be reduced.
"""
a_vec = a.load()
red_res = a_vec.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=0)
cute.printf(red_res) # prints 21.000000
red_res = a_vec.reduce(cute.ReductionOp.ADD, 0.0, reduction_profile=(None, 1))
cute.print_tensor(red_res) # prints [6.000000, 15.000000]
red_res = a_vec.reduce(cute.ReductionOp.ADD, 1.0, reduction_profile=(1, None))
cute.print_tensor(red_res) # prints [6.000000, 8.000000, 10.000000]
a = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
reduction_op(from_dlpack(a))
TensorSSA supports broadcasting operations following NumPy's broadcasting rules. Broadcasting
allows you to perform operations on arrays of different shapes when certain conditions are met.
The key rules are:
Let's look at some examples of broadcasting in action:
import cutlass
import cutlass.cute as cute
@cute.jit
def broadcast_examples():
a = cute.make_rmem_tensor((1, 3), dtype=cutlass.Float32)
a[0] = 0.0
a[1] = 1.0
a[2] = 2.0
a_val = a.load()
cute.print_tensor(a_val.broadcast_to((4, 3)))
# tensor(raw_ptr(0x00007ffe26625740: f32, rmem, align<32>) o (4,3):(1,4), data=
# [[ 0.000000, 1.000000, 2.000000, ],
# [ 0.000000, 1.000000, 2.000000, ],
# [ 0.000000, 1.000000, 2.000000, ],
# [ 0.000000, 1.000000, 2.000000, ]])
c = cute.make_rmem_tensor((4, 1), dtype=cutlass.Float32)
c[0] = 0.0
c[1] = 1.0
c[2] = 2.0
c[3] = 3.0
cute.print_tensor(a.load() + c.load())
# tensor(raw_ptr(0x00007ffe26625780: f32, rmem, align<32>) o (4,3):(1,4), data=
# [[ 0.000000, 1.000000, 2.000000, ],
# [ 1.000000, 2.000000, 3.000000, ],
# [ 2.000000, 3.000000, 4.000000, ],
# [ 3.000000, 4.000000, 5.000000, ]])
broadcast_examples()
The examples above demonstrate two key broadcasting scenarios:
Row Vector Broadcasting: In the first example, we create a row vector a with shape
(1, 3) containing values [0.0, 1.0, 2.0]. When we broadcast it to shape (4, 3), the values
are repeated across the first dimension, resulting in:
[[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0]]
This demonstrates how a row vector can be broadcast to create multiple identical rows.
Column Vector and Row Vector Addition: In the second example, we have:
a with shape (1, 3) containing [0.0, 1.0, 2.0]c with shape (4, 1) containing [0.0, 1.0, 2.0, 3.0]When we add these together, both vectors are broadcast to shape (4, 3):
The result is:
[[0.0 + 0.0, 1.0 + 0.0, 2.0 + 0.0],
[0.0 + 1.0, 1.0 + 1.0, 2.0 + 1.0],
[0.0 + 2.0, 1.0 + 2.0, 2.0 + 2.0],
[0.0 + 3.0, 1.0 + 3.0, 2.0 + 3.0]]
=
[[0.0, 1.0, 2.0],
[1.0, 2.0, 3.0],
[2.0, 3.0, 4.0],
[3.0, 4.0, 5.0]]
This demonstrates how TensorSSA can automatically handle broadcasting of both row and column
vectors in arithmetic operations, following the broadcasting rules where each dimension must
either be 1 or match the target size. The broadcasting is handled implicitly during operations,
making it easy to work with tensors of different shapes.