Back to Cutlass

CUTLASS: default_mma_tensor_op.h Source File

docs/default__mma__tensor__op_8h_source.html

4.4.26.9 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

default_mma_tensor_op.h

[Go to the documentation of this file.](default mma tensor__op_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "[cutlass/gemm/warp/mma_tensor_op.h](mma tensor op_8h.html)"

33

34 namespace cutlass {

35 namespace gemm {

36 namespace warp {

37

39

40 template <

42typename WarpShape_,

44typename InstructionShape_,

46typename ElementA_,

48typename LayoutA_,

50typename ElementB_,

52typename LayoutB_,

54typename ElementC_,

56typename LayoutC_,

58typename Operator_ = arch::OpMultiplyAdd,

60int PartitionsK = 1,

63bool AccumulatorsInRowMajor = false,

65int PartitionsN = 1

66 >

67 struct DefaultMmaTensorOp;

68

70

72 template <

74typename WarpShape_,

76typename InstructionShape_,

78typename ElementA,

80typename LayoutA,

82typename ElementB,

84typename LayoutB,

86typename ElementC,

88typename LayoutC,

90typename Operator_,

92int PartitionsK,

95bool AccumulatorsInRowMajor,

97int PartitionsN>

98 struct DefaultMmaTensorOp {

99using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<

100cutlass::arch::Mma<InstructionShape_, 32, ElementA,

101cutlass::layout::RowMajor, ElementB,

102cutlass::layout::ColumnMajor, ElementC,

103 cutlass::layout::RowMajor, Operator_>,

104cutlass::MatrixShape<1, 1> >;

105

106// Define the warp-level tensor op

107using Type = cutlass::gemm::warp::MmaTensorOp<

108 WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,

109Policy, PartitionsK, AccumulatorsInRowMajor, PartitionsN>;

110 };

111

113

114 } // namespace warp

115 } // namespace gemm

116 } // namespace cutlass

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::warp::DefaultMmaTensorOp

Partial specialization for m-by-n-by-kgroup.

Definition: default_mma_tensor_op.h:67

cutlass::gemm::warp::MmaTensorOp

Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.

Definition: mma_tensor_op.h:82

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

cutlass::gemm::warp::MmaTensorOpPolicy

Policy.

Definition: mma_tensor_op_policy.h:48

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::gemm::warp::DefaultMmaTensorOp::Policy

cutlass::gemm::warp::MmaTensorOpPolicy< cutlass::arch::Mma< InstructionShape_, 32, ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::ColumnMajor, ElementC, cutlass::layout::RowMajor, Operator_ >, cutlass::MatrixShape< 1, 1 > > Policy

Definition: default_mma_tensor_op.h:104

cutlass::arch::Mma

Matrix multiply-add operation.

Definition: arch/mma.h:92

[mma_tensor_op.h](mma tensor op_8h.html)

Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores...

cutlass.h

Basic include for CUTLASS.


Generated by 1.8.11