Back to Cutlass

Source code for cutlass.library_defaults

python/docs/_modules/cutlass/library_defaults.html

4.5.216.2 KB
Original Source

Source code for cutlass.library_defaults

################################################################################################### Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: BSD-3-Clause## Redistribution and use in source and binary forms, with or without# modification, are permitted provided that the following conditions are met:## 1. Redistributions of source code must retain the above copyright notice, this# list of conditions and the following disclaimer.## 2. Redistributions in binary form must reproduce the above copyright notice,# this list of conditions and the following disclaimer in the documentation# and/or other materials provided with the distribution.## 3. Neither the name of the copyright holder nor the names of its# contributors may be used to endorse or promote products derived from# this software without specific prior written permission.## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.##################################################################################################"""Classes containing valid operations for a given compute capability and data types."""importloggingfromcudaimport\_\_version\_\_# Strip any additional information from the CUDA version\_cuda\_version=\_\_version\_\_.split("rc")[0]# Imports from CUTLASS profiler generator and manifest scriptsimportgeneratorasprof\_generatorimportmanifestasprof\_manifestimportcutlassfromcutlass.utils.checkimportvalid\_stage\_countfromcutlass.utils.datatypesimporttd\_from\_profiler\_td,td\_from\_profiler\_op,has\_binding\_type\_generator\_ccs=[50,60,61,70,75,80,90]
[[docs]](../../cutlass.html#cutlass.library_defaults.KernelsForDataType)class KernelsForDataType:""" Container class for keeping track of kernels that correspond to a particular combination of data types for operands A, B, and accumulator """def \_\_init\_\_(self, datatype\_comb: tuple, layout\_comb: tuple):self.datatype\_comb = datatype\_combself.layout\_comb = layout\_comb# Dictionary mapping from alignment (int) to a list of kernels that fit the alignment# constraint for the data type combinationself.kernels\_by\_alignment = {}
[[docs]](../../cutlass.html#cutlass.library_defaults.KernelsForDataType.add)def add(self, operation):""" Add an operation to the list of supported kernels """alignment = operation.A.alignmentif alignment not in self.kernels\_by\_alignment:self.kernels\_by\_alignment[alignment] = []self.kernels\_by\_alignment[alignment].append(operation)
@propertydef alignments(self):""" Returns an unsorted list of alignments supported by this data type combination :return: unsorted list of alignments supported by this data type combination :rtype: list """return list(self.kernels\_by\_alignment.keys())@propertydef all\_operations(self):""" Returns a list of all operations supported by this data type combination :return: list of all operations supported by this data type combination :rtype: list """ops = []for \_, alignment\_ops in self.kernels\_by\_alignment.items():ops.extend(alignment\_ops)return ops
[[docs]](../../cutlass.html#cutlass.library_defaults.KernelsForDataType.operations)def operations(self, alignment: int):""" Returns operations satisfying the alignment constraint indicated by `alignment` :param alignment: alignment constraint of operations to return :type alignment: int :return: list of operations :rtype: list """if alignment not in self.kernels\_by\_alignment:raise Exception(f"No operations of alignment {alignment} found for data type and layout "f"combination {self.datatype\_comb} {self.layout\_comb}")return self.kernels\_by\_alignment[alignment]

[[docs]](../../cutlass.html#cutlass.library_defaults.KernelsForDataType.find_alignment)def find\_alignment(self, shape: tuple, layout: cutlass.LayoutType) -\> int:""" Returns the most preferable alignment for a given shape and layout :param shape: extent of each dimension of the tensor :type shape: tuple :param layout: layout of the tensor :type layout: cutlass.LayoutType :return: maximum alignment supported by the data type combination and tensor size :rtype: int """# Determine the leading dimension of the shapeif layout == cutlass.LayoutType.RowMajor:ld = shape[0]elif layout == cutlass.LayoutType.RowMajor:ld = shape[1]else:raise Exception(f"Unexpected or unsupported layout {layout}")for alignment in sorted(list(self.kernels\_by\_alignment.keys()), reverse=True):if ld % alignment == 0:return alignment# Default to alignment of 1 if no others matchreturn 1

[[docs]](../../cutlass.html#cutlass.library_defaults.KernelsForDataType.sort)def sort(self):""" Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape """key = lambda op: (op.tile\_description.threadblock\_shape[0]\* op.tile\_description.threadblock\_shape[1]\* op.tile\_description.threadblock\_shape[2])for alignment in self.kernels\_by\_alignment.keys():self.kernels\_by\_alignment[alignment].sort(key=key, reverse=True)

[[docs]](../../cutlass.html#cutlass.library_defaults.ArchOptions)class ArchOptions:""" Structure for keeping track of kernels available on a given compute capability :param target\_cc: compute capability of the device on which kernels will be run :type target\_cc: int :param kernel\_cc: compute capability of the kernels to generate :type kernel\_cc: int :param operation\_kind: type of operation to register :type operation\_kind: cutlass.OperationKind :param gemm\_kinds: types of GEMM operations that can be included :type gemm\_kinds: list :param allowed\_math\_operations: types of primitive math operations allowed :type allowed\_math\_operations: list """def \_\_init\_\_(self,target\_cc: int,kernel\_cc: int,operation\_kind: cutlass.OperationKind,gemm\_kinds: list,allowed\_math\_operations: list = [cutlass.MathOperation.multiply\_add,cutlass.MathOperation.multiply\_add\_saturate,]):self.cc = kernel\_cc# Dictionary with following structure:# Key: OpcodeClass# Value: Dictionary with the following structure:# Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),# representing ((element\_a, element\_b, element\_accumulator), (layout\_a, layout\_b))# Value: KernelsForDataTypeself.operations\_by\_opclass = {}self.op\_class = Noneself.allowed\_math\_operations = allowed\_math\_operations# Identify the method within CUTLASS generator script that generates kernel# descriptions for the target CCgenerate\_function\_name = "GenerateSM" + str(kernel\_cc)if not hasattr(prof\_generator, generate\_function\_name):cutlass.logger.warning(f"No generator found for architecture {kernel\_cc}")returngenerate\_function = getattr(prof\_generator, generate\_function\_name)# Initialize a default manifest and populate it with valid kernel descriptions# for the target CCargs = ["--kernels=all",f"--log-level={logging.getLevelName(cutlass.logger.level)}"]manifest\_args = prof\_generator.define\_parser().parse\_args(args)manifest = prof\_manifest.Manifest(manifest\_args)generate\_function(manifest, \_cuda\_version)if operation\_kind not in manifest.operations:# No kernels generated for this architecture, this could be because the CUDA# toolkit is insufficient to support operations in this CCcutlass.logger.warning(f"No operations of type {operation\_kind} found for CC {kernel\_cc}")return# Iterate through the available operations for this operation kind and# find available opclasses and data typesfor name, op\_list in manifest.operations[operation\_kind].items():for op in op\_list:if op.gemm\_kind not in gemm\_kinds:continuemi = op.tile\_description.math\_instructionif mi.math\_operation not in self.allowed\_math\_operations:continuedatatype\_comb = (mi.element\_a, mi.element\_b, mi.element\_accumulator)# Skip any data types that do not currently have conversions via cutlass\_bindingsif False in [has\_binding\_type(elt) for elt in datatype\_comb]:continue# Prune operations that don't fit in shared memorytd = td\_from\_profiler\_op(op)if not valid\_stage\_count(target\_cc, td)[0]:continueif mi.opcode\_class not in self.operations\_by\_opclass:self.operations\_by\_opclass[mi.opcode\_class] = {}datatype\_comb = (mi.element\_a, mi.element\_b, mi.element\_accumulator)layout\_comb = (op.A.layout, op.B.layout)# Register TF32 kernels as F32 to enable F32 -\> TF32 conversion + TF32 Tensor Core operationsif datatype\_comb == (cutlass.DataType.tf32, cutlass.DataType.tf32, cutlass.DataType.f32):# TF32 kernels only supported on SM80 and beyondif self.cc \< 80:continueelif self.cc == 90:if (op.A.element != cutlass.DataType.f32or op.B.element != cutlass.DataType.f32or op.C.element != cutlass.DataType.f32):continuedatatype\_comb = (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32)opclass\_dict = self.operations\_by\_opclass[mi.opcode\_class]key = (datatype\_comb, layout\_comb)if key not in opclass\_dict:opclass\_dict[key] = KernelsForDataType(datatype\_comb, layout\_comb)opclass\_dict[key].add(op)# Set the default opclass to TensorOp, if available. Otherwise default to SIMTif cutlass.OpcodeClass.TensorOp in self.operations\_by\_opclass:self.op\_class = cutlass.OpcodeClass.TensorOpelse:self.op\_class = cutlass.OpcodeClass.Simt# The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.# Here, we generate additional versions via a generic TileDescription.if cutlass.OpcodeClass.Simt not in self.operations\_by\_opclass:self.operations\_by\_opclass[cutlass.OpcodeClass.Simt] = {}types = [(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8),(cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32),(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16),(cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32),(cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32),(cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64),]layouts = [(cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor),(cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor),(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor),(cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor),]alignment = 1epilogue\_functor = cutlass.EpilogueFunctor.LinearCombinationswizzling\_functor = cutlass.SwizzlingFunctor.Identity8for type\_comb in types:for layout\_comb in layouts:comb = (type\_comb, layout\_comb)if comb in self.operations\_by\_opclass[cutlass.OpcodeClass.Simt]:continueA = cutlass.TensorDescription(type\_comb[0], layout\_comb[0], alignment)B = cutlass.TensorDescription(type\_comb[1], layout\_comb[1], alignment)C = cutlass.TensorDescription(type\_comb[2], cutlass.LayoutType.ColumnMajor, alignment)math\_inst = cutlass.MathInstruction([1, 1, 1],type\_comb[0],type\_comb[1],type\_comb[2],cutlass.OpcodeClass.Simt,cutlass.MathOperation.multiply\_add)td = cutlass.TileDescription([128, 128, 8], 2, [4, 2, 1], math\_inst, 50, 1024)# Prune operations that don't fit in shared memoryif not valid\_stage\_count(target\_cc, td\_from\_profiler\_td(td))[0]:continuenew\_operation = prof\_manifest.GemmOperation(cutlass.GemmKind.Universal, td.minimum\_compute\_capability,td, A, B, C, type\_comb[2], epilogue\_functor, swizzling\_functor)new\_kernels = KernelsForDataType(type\_comb, layout\_comb)new\_kernels.add(new\_operation)self.operations\_by\_opclass[cutlass.OpcodeClass.Simt][comb] = new\_kernels# Sort all operationsfor oc in self.operations\_by\_opclass.keys():for comb in self.operations\_by\_opclass[oc].keys():self.operations\_by\_opclass[oc][comb].sort()
[[docs]](../../cutlass.html#cutlass.library_defaults.ArchOptions.opclass_supports_combination)def opclass\_supports\_combination(self, op\_class: cutlass.OpcodeClass, datatype\_comb: tuple, layout\_comb: tuple) -\> bool:""" Returns whether the provided operation class supports the provided data type and layout combination :param op\_class: operation class to consider :type op\_class: cutlass.OpcodeClass :param datatype\_comb: tuple of data types for (element\_A, element\_B, element\_accumulator) :type datatype\_comb: tuple[cutlass.DataType] :param layout\_comb: tuple of data types for (layout\_A, layout\_B) :type layout\_comb: tuple[cutlass.LayoutType] :return: set of operation classes that support the provided data type and layout combination :rtype: set """if op\_class not in self.operations\_by\_opclass:raise Exception(f"Unexpected or unsupported operation class {op\_class}")return (datatype\_comb, layout\_comb) in self.operations\_by\_opclass[op\_class]

[[docs]](../../cutlass.html#cutlass.library_defaults.ArchOptions.supporting_opclasses)def supporting\_opclasses(self,element\_a: cutlass.DataType,element\_b: cutlass.DataType,element\_accumulator: cutlass.DataType,layout\_a: cutlass.LayoutType,layout\_b: cutlass.LayoutType,) -\> set:""" Returns a set of operation classes that support the provided data type combination :param element\_a: data type of operand A :type element\_a: cutlass.DataType :param element\_b: data type of operand B :type element\_b: cutlass.DataType :param element\_accumulator: data type of accumulator :type element\_accumulator: cutlass.DataType :param layout\_a: layout of operand A :type layout\_a: cutlass.LayoutType :param layout\_b: layout of operand B :type layout\_b: cutlass.LayoutType :return: set of operation classes that support the provided data type combination :rtype: set """supporting\_op\_classes = set()datatype\_comb = (element\_a, element\_b, element\_accumulator)layout\_comb = (layout\_a, layout\_b)for op\_class in self.operations\_by\_opclass.keys():if self.opclass\_supports\_combination(op\_class, datatype\_comb, layout\_comb):supporting\_op\_classes.add(op\_class)return supporting\_op\_classes

[[docs]](../../cutlass.html#cutlass.library_defaults.ArchOptions.operations)def operations(self,op\_class: cutlass.OpcodeClass,element\_a: cutlass.DataType,element\_b: cutlass.DataType,element\_accumulator: cutlass.DataType,layout\_a: cutlass.LayoutType,layout\_b: cutlass.LayoutType,) -\> KernelsForDataType:""" Returns whether the provided operation class supports the provided data type combination :param op\_class: operation class to consider :type op\_class: cutlass.OpcodeClass :param element\_a: data type of operand A :type element\_a: cutlass.DataType :param element\_b: data type of operand B :type element\_b: cutlass.DataType :param element\_accumulator: data type of accumulator :type element\_accumulator: cutlass.DataType :param layout\_a: layout of operand A :type layout\_a: cutlass.LayoutType :param layout\_b: layout of operand B :type layout\_b: cutlass.LayoutType :return: container of kernels by alignment supported by the provided combination of parameters :rtype: KernelsForDataType """datatype\_comb = (element\_a, element\_b, element\_accumulator)layout\_comb = (layout\_a, layout\_b)if not self.opclass\_supports\_combination(op\_class, datatype\_comb, layout\_comb):raise Exception(f"Data type layout combination {datatype\_comb}, {layout\_comb} "f"is not supported by opcode class {op\_class} on CC {self.cc}.")return self.operations\_by\_opclass[op\_class][(datatype\_comb, layout\_comb)]

[[docs]](../../cutlass.html#cutlass.library_defaults.OptionRegistry)class OptionRegistry:""" Container of all architecture-specific options :param target\_cc: compute capability of the device on which operations will be run :type target\_cc: int """def \_\_init\_\_(self, target\_cc: int):self.registry = {}gemm\_kinds = [cutlass.GemmKind.Universal, cutlass.GemmKind.Universal3x]# Construct options for each CCfor kernel\_cc in \_generator\_ccs:self.registry[kernel\_cc] = ArchOptions(target\_cc, kernel\_cc, cutlass.OperationKind.Gemm, gemm\_kinds)
[[docs]](../../cutlass.html#cutlass.library_defaults.OptionRegistry.options_for_cc)def options\_for\_cc(self, cc: int) -\> ArchOptions:return self.registry.get(cc, None)