Back to Cutlass

CUTLASS: mma_tensor_op_sm70.h Source File

docs/mma__tensor__op__sm70_8h_source.html

4.4.222.1 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

mma_tensor_op_sm70.h

[Go to the documentation of this file.](mma tensor op__sm70_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

32 #pragma once

33

34 #include "cutlass/cutlass.h"

35 #include "cutlass/array.h"

36

37 #include "cutlass/numeric_types.h"

38 #include "cutlass/matrix_shape.h"

39

40 #include "cutlass/arch/mma.h"

41

42 #include "cutlass/gemm/gemm.h"

43 #include "cutlass/gemm/warp/mma.h"

44

45 #include "[cutlass/gemm/warp/mma_tensor_op_policy.h](mma tensor op__policy_8h.html)"

46 #include "[cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h](mma tensor op tile iterator__sm70_8h.html)"

47

49

50 namespace cutlass {

51 namespace gemm {

52 namespace warp {

53

55

57 template <

59typename Shape_,

61typename ElementA_,

63typename LayoutA_,

65typename ElementB_,

67typename LayoutB_,

69typename ElementC_,

71typename LayoutC_,

73typename Policy_,

75typename Enable = bool

76 >

77 class MmaVoltaTensorOp {

78 public:

80using Shape = Shape_;

81

83using ElementA = ElementA_;

84

86using LayoutA = LayoutA_;

87

89using ElementB = ElementB_;

90

92using LayoutB = LayoutB_;

93

95using ElementC = ElementC_;

96

98using LayoutC = LayoutC_;

99

101using Policy = Policy_;

102

104using OperatorClass = arch::OpClassTensorOp;

105

107static int const kThreadCount = 32;

108

110using InterleavedTileShape = GemmShape<32, 32, 4>;

111

112static_assert(!(Shape::kM % InterleavedTileShape::kM) &&

113 !(Shape::kN % InterleavedTileShape::kN),

114"Shape must be a multiple of InterleavedTileShape.");

115 public:

116

118using IteratorA = MmaVoltaTensorOpMultiplicandTileIterator<

119MatrixShape<Shape::kM, Shape::kK>,

120Operand::kA,

121ElementA,

122LayoutA,

123MatrixShape<

124 Policy::Operator::Shape::kM,

125 Policy::Operator::Shape::kK

126 >,

127 Policy::OpDelta::kRow,

128 kThreadCount

129 >;

130

132using FragmentA = typename IteratorA::Fragment;

133

135using IteratorB = MmaVoltaTensorOpMultiplicandTileIterator<

136MatrixShape<Shape::kK, Shape::kN>,

137Operand::kB,

138ElementB,

139LayoutB,

140MatrixShape<

141 Policy::Operator::Shape::kK,

142 Policy::Operator::Shape::kN

143 >,

144 Policy::OpDelta::kRow,

145 kThreadCount

146 >;

147

149using FragmentB = typename IteratorB::Fragment;

150

152using IteratorC = MmaVoltaTensorOpAccumulatorTileIterator<

153MatrixShape<Shape::kM, Shape::kN>,

154ElementC,

155LayoutC,

156typename Policy::Operator::Shape,

157typename Policy::OpDelta

158 >;

159

161using FragmentC = typename IteratorC::Fragment;

162

163 private:

164

165static_assert(

166 !(Shape::kM % Policy::Operator::Shape::kM) &&

167 !(Shape::kN % Policy::Operator::Shape::kN),

168"Shape of warp-level Mma must be divisible by operator shape.");

169

171using MmaIterations = MatrixShape<

172 InterleavedTileShape::kM / Policy::Operator::Shape::kM,

173InterleavedTileShape::kN / Policy::Operator::Shape::kN

174 >;

175using TileIterations = MatrixShape<

176 Shape::kM / InterleavedTileShape::kM,

177 Shape::kN / InterleavedTileShape::kN

178 >;

179

180// Whether matrix B is reordered

181bool reorder_B_;

182

183 public:

184

186typename Policy::Operator mma;

187

188 public:

189

190//

191// Methods

192//

193

195 CUTLASS_DEVICE

196MmaVoltaTensorOp() {}

197

199 CUTLASS_DEVICE

200void operator()(

201FragmentC &D,

202FragmentA const &A,

203FragmentB const &B,

204FragmentC const &C,

205int const &partitionN_idx = 0) {

206

207using MmaOperandA = typename Policy::Operator::FragmentA;

208using MmaOperandB = typename Policy::Operator::FragmentB;

209using MmaOperandC = typename Policy::Operator::FragmentC;

210

211 D = C;

212

213 MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);

214 MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);

215 MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);

216

217CUTLASS_PRAGMA_UNROLL

218for (int outer_col = 0; outer_col < TileIterations::kColumn; ++outer_col) {

219CUTLASS_PRAGMA_UNROLL

220for (int inner_col = 0; inner_col < MmaIterations::kColumn; ++inner_col) {

221CUTLASS_PRAGMA_UNROLL

222for (int outer_row = 0; outer_row < TileIterations::kRow; ++outer_row) {

223CUTLASS_PRAGMA_UNROLL

224

225for (int inner_row = 0; inner_row < MmaIterations::kRow; ++inner_row) {

226

227int op_col = inner_col + MmaIterations::kColumn * outer_col;

228

229// Column-major serpentine sequence to maximize reuse of A operand.

230int inner_row_serp = inner_row;

231int outer_row_serp = outer_row;

232if (op_col & 1) {

233 inner_row_serp = MmaIterations::kRow - inner_row - 1;

234 outer_row_serp = TileIterations::kRow - outer_row - 1;

235 }

236int op_row = inner_row_serp + MmaIterations::kRow * outer_row_serp;

237int op_idx = inner_row_serp + MmaIterations::kRow *

238 (inner_col + MmaIterations::kColumn *

239 (outer_row_serp + TileIterations::kRow * outer_col));

240mma(

241 ptr_D[op_idx],

242 ptr_A[op_row],

243 ptr_B[op_col],

244 ptr_D[op_idx]);

245

246 }

247 }

248 }

249 }

250 }

251 };

252

254

255 } // namespace warp

256 } // namespace gemm

257 } // namespace cutlass

cutlass::gemm::warp::MmaVoltaTensorOp::Policy

Policy_ Policy

Shape of the warp in units of thread (concept: MmaLanePolicySimt)

Definition: mma_tensor_op_sm70.h:101

cutlass::gemm::warp::MmaVoltaTensorOp::FragmentB

typename IteratorB::Fragment FragmentB

Storage for B tile.

Definition: mma_tensor_op_sm70.h:149

cutlass::gemm::GemmShape::kM

static int const kM

Definition: include/cutlass/gemm/gemm.h:58

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass::gemm::warp::MmaVoltaTensorOp::LayoutB

LayoutB_ LayoutB

Layout of multiplicand B.

Definition: mma_tensor_op_sm70.h:92

cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator

Definition: mma_tensor_op_tile_iterator_sm70.h:70

cutlass

Definition: aligned_buffer.h:35

cutlass::MatrixShape::kColumn

static int const kColumn

columns of a matrix

Definition: matrix_shape.h:44

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::warp::MmaVoltaTensorOp::Shape

Shape_ Shape

Shape of warp-level matrix operation (concept: GemmShape)

Definition: mma_tensor_op_sm70.h:80

cutlass::gemm::warp::MmaVoltaTensorOp::OperatorClass

arch::OpClassTensorOp OperatorClass

Indicates class of matrix operator.

Definition: mma_tensor_op_sm70.h:104

cutlass::gemm::warp::MmaVoltaTensorOp::LayoutA

LayoutA_ LayoutA

Layout of multiplicand A.

Definition: mma_tensor_op_sm70.h:86

cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator::Fragment

Array< Element, Shape::kCount/kThreads > Fragment

Fragment object holding a thread's part of a tile.

Definition: mma_tensor_op_tile_iterator_sm70.h:1213

cutlass::gemm::warp::MmaVoltaTensorOp::ElementA

ElementA_ ElementA

Data type of multiplicand A.

Definition: mma_tensor_op_sm70.h:83

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

mma.h

Templates exposing architecture support for warp-level multiply-add operations.

mma.h

Templates exposing architecture support for multiply-add operations.

cutlass::gemm::Operand::kA

[mma_tensor_op_tile_iterator_sm70.h](mma tensor op tile iterator__sm70_8h.html)

Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::warp::MmaVoltaTensorOp::ElementB

ElementB_ ElementB

Data type of multiplicand B.

Definition: mma_tensor_op_sm70.h:89

cutlass::gemm::warp::MmaVoltaTensorOp::operator()

CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0)

Performs a warp-level matrix multiply-accumulate operation.

Definition: mma_tensor_op_sm70.h:200

cutlass::gemm::warp::MmaVoltaTensorOp::mma

Policy::Operator mma

Underlying matrix multiply operator (concept: arch::Mma)

Definition: mma_tensor_op_sm70.h:186

cutlass::gemm::warp::MmaVoltaTensorOp

Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.

Definition: mma_tensor_op_sm70.h:77

cutlass::MatrixShape::kRow

static int const kRow

rows of a matrix

Definition: matrix_shape.h:43

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::gemm::warp::MmaVoltaTensorOp::FragmentC

typename IteratorC::Fragment FragmentC

Storage for C tile.

Definition: mma_tensor_op_sm70.h:161

cutlass::gemm::warp::MmaVoltaTensorOp::LayoutC

LayoutC_ LayoutC

Layout of accumulator matrix C.

Definition: mma_tensor_op_sm70.h:98

cutlass::gemm::warp::MmaVoltaTensorOp::kThreadCount

static int const kThreadCount

Number of threads participating in warp-level matrix product.

Definition: mma_tensor_op_sm70.h:107

cutlass::gemm::warp::MmaVoltaTensorOp::MmaVoltaTensorOp

CUTLASS_DEVICE MmaVoltaTensorOp()

Ctor.

Definition: mma_tensor_op_sm70.h:196

cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator

Definition: mma_tensor_op_tile_iterator_sm70.h:1135

cutlass::gemm::warp::MmaVoltaTensorOp::ElementC

ElementC_ ElementC

Data type of accumulator matrix C.

Definition: mma_tensor_op_sm70.h:95

cutlass::gemm::Operand::kB

A multiplicand.

cutlass::gemm::warp::MmaVoltaTensorOp::FragmentA

typename IteratorA::Fragment FragmentA

Storage for A tile.

Definition: mma_tensor_op_sm70.h:132

cutlass.h

Basic include for CUTLASS.

[mma_tensor_op_policy.h](mma tensor op__policy_8h.html)

Policy describing implementation details of warp-level GEMM targeting Tensor Cores.

cutlass::gemm::GemmShape::kN

static int const kN

Definition: include/cutlass/gemm/gemm.h:59


Generated by 1.8.11