Back to Cutlass

CUTLASS: mma_tensor_op.h Source File

docs/mma__tensor__op_8h_source.html

4.4.219.9 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

mma_tensor_op.h

[Go to the documentation of this file.](mma tensor op_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

30 #pragma once

31

32 #include "cutlass/cutlass.h"

33 #include "cutlass/array.h"

34

35 #include "cutlass/numeric_types.h"

36 #include "cutlass/matrix_shape.h"

37

38 #include "cutlass/arch/memory_sm75.h"

39 #include "cutlass/arch/mma_sm75.h"

40 #include "cutlass/gemm/gemm.h"

41 #include "cutlass/gemm/warp/mma.h"

42

43 #include "[cutlass/gemm/warp/mma_tensor_op_policy.h](mma tensor op__policy_8h.html)"

44

45 #include "[cutlass/gemm/warp/mma_tensor_op_tile_iterator.h](mma tensor op tile iterator_8h.html)"

47

48 namespace cutlass {

49 namespace gemm {

50 namespace warp {

51

53

55 template <

57typename Shape_,

59typename ElementA_,

61typename LayoutA_,

63typename ElementB_,

65typename LayoutB_,

67typename ElementC_,

69typename LayoutC_,

71typename Policy_,

73int PartitionsK_ = 1,

76bool AccumulatorsInRowMajor = false,

78int PartitionsN_ = 1,

80typename Enable = bool

81 >

82 class MmaTensorOp {

83 public:

85using Shape = Shape_;

86

88using ElementA = ElementA_;

89

91using LayoutA = LayoutA_;

92

94using ElementB = ElementB_;

95

97using LayoutB = LayoutB_;

98

100using ElementC = ElementC_;

101

103using LayoutC = LayoutC_;

104

106using Policy = Policy_;

107

109using OperatorClass = arch::OpClassTensorOp;

110

112static int const kThreadCount = 32;

113

115static int const kPartitionsK = PartitionsK_;

116

118static int const kPartitionsN = PartitionsN_;

119

120 public:

121

123using IteratorA = MmaTensorOpMultiplicandTileIterator<

124MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,

125MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,

126 Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;

127

129using FragmentA = typename IteratorA::Fragment;

130

132using IteratorB = MmaTensorOpMultiplicandTileIterator<

133MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,

134MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,

135 Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;

136

138using FragmentB = typename IteratorB::Fragment;

139

141using IteratorC = MmaTensorOpAccumulatorTileIterator<

142MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,

143typename Policy::Operator::Shape, typename Policy::OpDelta>;

144

146using FragmentC = typename IteratorC::Fragment;

147

148 private:

149

150static_assert(

151 !(Shape::kM % Policy::Operator::Shape::kM) &&

152 !(Shape::kN % Policy::Operator::Shape::kN),

153"Shape of warp-level Mma must be divisible by operator shape.");

154

156using MmaIterations = MatrixShape<

157 Shape::kM / Policy::Operator::Shape::kM,

158 (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?

159 Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :

160 1

161 >;

162

163 public:

164

166typename Policy::Operator mma;

167

168 public:

169

170//

171// Methods

172//

173

175 CUTLASS_DEVICE

176MmaTensorOp() {}

177

179 CUTLASS_DEVICE

180void operator()(

181FragmentC &D,

182FragmentA const &A,

183FragmentB const &B,

184FragmentC const &C,

185int const &partitionN_idx = 0) const {

186

187using MmaOperandA = typename Policy::Operator::FragmentA;

188using MmaOperandB = typename Policy::Operator::FragmentB;

189using MmaOperandC = typename Policy::Operator::FragmentC;

190

191 D = C;

192

193 MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);

194 MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);

195 MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);

196

197// The offset of multilicand B for current partition

198const int n_off = partitionN_idx * FragmentB::kElements / MmaOperandB::kElements / kPartitionsN;

199// Serpentine visitation order maximizing reuse of Rb

200CUTLASS_PRAGMA_UNROLL

201for (int n = 0; n < MmaIterations::kColumn; ++n) {

202

203CUTLASS_PRAGMA_UNROLL

204for (int m = 0; m < MmaIterations::kRow; ++m) {

205

206int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);

207

208if (AccumulatorsInRowMajor) { // matrix B is reordered

209mma(

210 ptr_D[n + m_serpentine * MmaIterations::kColumn],

211 ptr_A[m_serpentine],

212 ptr_B[n],

213 ptr_D[n + m_serpentine * MmaIterations::kColumn]);

214 } else {

215mma(

216 ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow],

217 ptr_A[m_serpentine],

218 ptr_B[n + n_off],

219 ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow]);

220 }

221 }

222 }

223 }

224 };

225

227

228 } // namespace warp

229 } // namespace gemm

230 } // namespace cutlass

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass::gemm::warp::MmaTensorOp::FragmentA

typename IteratorA::Fragment FragmentA

Storage for A tile.

Definition: mma_tensor_op.h:129

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::warp::MmaTensorOp::LayoutB

LayoutB_ LayoutB

Layout of multiplicand B.

Definition: mma_tensor_op.h:97

cutlass::gemm::warp::MmaTensorOp::MmaTensorOp

CUTLASS_DEVICE MmaTensorOp()

Ctor.

Definition: mma_tensor_op.h:176

memory_sm75.h

Architecture-specific operators on memory added for SM75.

[mma_tensor_op_tile_iterator.h](mma tensor op tile iterator_8h.html)

Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::warp::MmaTensorOp::kThreadCount

static int const kThreadCount

Number of threads participating in warp-level matrix product.

Definition: mma_tensor_op.h:112

cutlass::gemm::warp::MmaTensorOp::kPartitionsN

static int const kPartitionsN

PartitionsN indicating how many PartitionsN for multiplicand B.

Definition: mma_tensor_op.h:118

cutlass::gemm::warp::MmaTensorOp

Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.

Definition: mma_tensor_op.h:82

cutlass::gemm::warp::MmaTensorOp::LayoutA

LayoutA_ LayoutA

Layout of multiplicand A.

Definition: mma_tensor_op.h:91

cutlass::gemm::warp::MmaTensorOp::FragmentB

typename IteratorB::Fragment FragmentB

Storage for B tile.

Definition: mma_tensor_op.h:138

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

mma.h

Templates exposing architecture support for warp-level multiply-add operations.

cutlass::gemm::Operand::kA

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::warp::MmaTensorOp::FragmentC

typename IteratorC::Fragment FragmentC

Storage for C tile.

Definition: mma_tensor_op.h:146

cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator

Definition: mma_tensor_op_tile_iterator.h:1794

cutlass::gemm::warp::MmaTensorOp::operator()

CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0) const

Performs a warp-level matrix multiply-accumulate operation.

Definition: mma_tensor_op.h:180

cutlass::gemm::warp::MmaTensorOp::mma

Policy::Operator mma

Underlying matrix multiply operator (concept: arch::Mma)

Definition: mma_tensor_op.h:166

cutlass::gemm::warp::MmaTensorOp::ElementC

ElementC_ ElementC

Data type of accumulator matrix C.

Definition: mma_tensor_op.h:100

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator

Definition: mma_tensor_op_tile_iterator.h:75

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::gemm::warp::MmaTensorOp::LayoutC

LayoutC_ LayoutC

Layout of accumulator matrix C.

Definition: mma_tensor_op.h:103

cutlass::gemm::warp::MmaTensorOp::Policy

Policy_ Policy

Shape of the warp in units of thread (concept: MmaLanePolicySimt)

Definition: mma_tensor_op.h:106

cutlass::gemm::warp::MmaTensorOp::Shape

Shape_ Shape

Shape of warp-level matrix operation (concept: GemmShape)

Definition: mma_tensor_op.h:85

cutlass::gemm::warp::MmaTensorOp::ElementB

ElementB_ ElementB

Data type of multiplicand B.

Definition: mma_tensor_op.h:94

cutlass::gemm::warp::MmaTensorOp::kPartitionsK

static int const kPartitionsK

Number of partitions along K dimension.

Definition: mma_tensor_op.h:115

cutlass::gemm::warp::MmaTensorOp::OperatorClass

arch::OpClassTensorOp OperatorClass

Indicates class of matrix operator.

Definition: mma_tensor_op.h:109

cutlass::gemm::warp::MmaTensorOp::ElementA

ElementA_ ElementA

Data type of multiplicand A.

Definition: mma_tensor_op.h:88

mma_sm75.h

Matrix multiply for SM75.

cutlass::gemm::Operand::kB

A multiplicand.

cutlass.h

Basic include for CUTLASS.

[mma_tensor_op_policy.h](mma tensor op__policy_8h.html)

Policy describing implementation details of warp-level GEMM targeting Tensor Cores.


Generated by 1.8.11