python/docs/externals/01_epilogue.html
This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.
We first import various packages needed for the example and construct the input and output tensors that will be used in our example.
[1]:
importnumpyasnpimportcutlass# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to# omit this information.print\_module=Truem=256n=mk=mtype\_A=np.float16type\_B=np.float16type\_C=np.float16type\_D=np.float16np.random.seed(1234)scope\_min=-4scope\_max=4tensor\_A=np.ceil(np.random.uniform(low=scope\_min,high=scope\_max,size=(m,k)).astype(type\_A))tensor\_B=np.ceil(np.random.uniform(low=scope\_min,high=scope\_max,size=(k,n)).astype(type\_B))tensor\_C=np.ceil(np.random.uniform(low=scope\_min,high=scope\_max,size=(m,n)).astype(type\_C))alpha=np.float16(1.)beta=np.float16(0.)tensor\_D=np.zeros(tensor\_C.shape).astype(type\_D)
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation D = alpha * (A @ B) + beta * C. This is the default activation function used, and does not need to be specified.
[2]:
plan=cutlass.op.Gemm(element=np.float16,layout=cutlass.LayoutType.RowMajor)plan.run(tensor\_A,tensor\_B,tensor\_C,tensor\_D,print\_module=print\_module)
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
[2]:
<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed907287c0>
CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function act, the resulting formulation is:
D = alpha * (A @ B) + beta * C
D = act(D)
Here, we will add a ReLU activation function. Given an input x, ReLU returns max(x, 0).
This is easy to do in CUTLASS. One only needs to set the plan’s activation field.
[3]:
tensor\_D\_relu=np.zeros(tensor\_C.shape).astype(type\_D)plan.activation=cutlass.epilogue.reluplan.run(tensor\_A,tensor\_B,tensor\_C,tensor\_D\_relu,print\_module=print\_module)
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
[3]:
<cutlass.backend.gemm_operation.GemmArguments2x at 0x7fed906f2460>
We can now verify that the result of the GEMM that used a ReLU activation function:
[4]:
relu\_ref=(tensor\_D\>=0).astype(type\_D)\*tensor\_Dnp.testing.assert\_array\_equal(relu\_ref,tensor\_D\_relu)
CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the get_activations() method.
[5]:
activations=plan.activations()foractivationinactivations:print(activation)
<class 'cutlass.backend.epilogue.gelu'>
<class 'cutlass.backend.epilogue.hardswish'>
<class 'cutlass.backend.epilogue.identity'>
<class 'cutlass.backend.epilogue.leaky_relu'>
<class 'cutlass.backend.epilogue.relu'>
<class 'cutlass.backend.epilogue.sigmoid'>
<class 'cutlass.backend.epilogue.silu'>
<class 'cutlass.backend.epilogue.tanh'>
We can then run each of them:
[6]:
foractivationinactivations:print('=============================================================================================')print(f'Compiling and running activation {activation}')print('=============================================================================================')plan.activation=activationplan.run(tensor\_A,tensor\_B,tensor\_C,tensor\_D,print\_module=print\_module)
=============================================================================================
Compiling and running activation <class 'cutlass.backend.epilogue.gelu'>
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation <class 'cutlass.backend.epilogue.hardswish'>
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::HardSwish, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation <class 'cutlass.backend.epilogue.identity'>
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation <class 'cutlass.backend.epilogue.leaky_relu'>
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::LeakyReLU, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation <class 'cutlass.backend.epilogue.relu'>
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::ReLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation <class 'cutlass.backend.epilogue.sigmoid'>
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Sigmoid, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation <class 'cutlass.backend.epilogue.silu'>
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::SiLu, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
=============================================================================================
Compiling and running activation <class 'cutlass.backend.epilogue.tanh'>
=============================================================================================
// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::Tanh, cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
3,
cutlass::arch::OpMultiplyAdd
>::GemmKernel;
// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_type :
public cutlass_sm80_tensorop_h16x8x16gemm_1x1x1_256x128_64x3_tt_align8_base { };
[]: