Back to Cutlass

CUTLASS: default_mma.h Source File

docs/default__mma_8h_source.html

4.4.233.7 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

default_mma.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/numeric_types.h"

33 #include "cutlass/arch/arch.h"

34 #include "cutlass/arch/wmma.h"

35

36 #include "[cutlass/transform/threadblock/predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)"

37 #include "[cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h](predicated tile iterator__2dthreadtile_8h.html)"

38 #include "[cutlass/gemm/threadblock/default_mma_core_sm70.h](default mma core__sm70_8h.html)"

39 #include "[cutlass/gemm/threadblock/default_mma_core_sm75.h](default mma core__sm75_8h.html)"

40 #if defined(CUTLASS_ARCH_WMMA_ENABLED)

41 #include "[cutlass/gemm/threadblock/default_mma_core_wmma.h](default mma core__wmma_8h.html)"

42 #endif //CUTLASS_ARCH_WMMA_ENABLED

43

45

46 namespace cutlass {

47 namespace gemm {

48 namespace threadblock {

49

51

52 template <

54typename ElementA_,

56typename LayoutA_,

58int kAlignmentA,

60typename ElementB_,

62typename LayoutB_,

64int kAlignmentB,

66typename ElementAccumulator_,

68typename LayoutC_,

70typename OperatorClass_,

72typename ArchTag_,

74typename ThreadblockShape_,

76typename WarpShape_,

78typename InstructionShape_,

80int Stages,

82typename Operator,

85bool AccumulatorsInRowMajor = false

86 >

87 struct DefaultMma;

88

90

92 template <

94typename ElementA,

96typename LayoutA,

98int kAlignmentA,

100typename ElementB,

102typename LayoutB,

104int kAlignmentB,

106typename ElementAccumulator,

108typename ArchTag,

110typename ThreadblockShape,

112typename WarpShape,

114typename InstructionShape,

116typename Operator>

117 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

118 kAlignmentB, ElementAccumulator, layout::RowMajor,

119 arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape,

120 InstructionShape, 2, Operator, false> {

121// Define the MmaCore components

122using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<

123 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,

124 ElementB, LayoutB, ElementAccumulator, layout::RowMajor,

125 arch::OpClassSimt, 2, Operator>;

126

127// Define iterators over tiles from the A operand

128using IteratorA =

129cutlass::transform::threadblock::PredicatedTileIterator<

130cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,

131 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;

132

133// Define iterators over tiles from the B operand

134using IteratorB =

135cutlass::transform::threadblock::PredicatedTileIterator<

136cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,

137 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;

138

139// Define the threadblock-scoped pipelined matrix multiply

140using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<

141typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,

142IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,

143 layout::RowMajor, typename MmaCore::MmaPolicy>;

144 };

145

146

148 template <

150typename ElementA,

152typename LayoutA,

154int kAlignmentA,

156typename ElementB,

158typename LayoutB,

160int kAlignmentB,

162typename ElementAccumulator,

164typename ArchTag,

166typename ThreadblockShape,

168typename WarpShape,

170typename InstructionShape,

172typename Operator

173 >

174 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

175 kAlignmentB, ElementAccumulator, layout::RowMajor,

176 arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape,

177 InstructionShape, 2, Operator, false> {

178// Define the MmaCore components

179using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<

180 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,

181 ElementB, LayoutB, ElementAccumulator, layout::RowMajor,

182 arch::OpClassTensorOp, 2, Operator>;

183

184// Define iterators over tiles from the A operand

185using IteratorA =

186cutlass::transform::threadblock::PredicatedTileIterator<

187cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,

188 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;

189

190// Define iterators over tiles from the B operand

191using IteratorB =

192cutlass::transform::threadblock::PredicatedTileIterator<

193cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,

194 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;

195

196// Define the threadblock-scoped pipelined matrix multiply

197using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<

198typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,

199IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,

200 layout::RowMajor, typename MmaCore::MmaPolicy>;

201 };

203

205 template <

207typename ElementA,

209typename LayoutA,

211int kAlignmentA,

213typename ElementB,

215typename LayoutB,

217int kAlignmentB,

219typename ElementAccumulator,

221typename OperatorClass,

223typename ArchTag,

225typename ThreadblockShape,

227typename WarpShape,

229typename InstructionShape,

231typename Operator,

233int InterleavedK>

234 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

235 kAlignmentB, ElementAccumulator,

236 layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass,

237 ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,

238 Operator, true> {

239// Define the MmaCore components

240using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<

241 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,

242 ElementB, LayoutB, ElementAccumulator,

243layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,

244true>;

245

246static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,

247"Alignment must match thread data map's vector length");

248

249static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,

250"Alignment must match thread data map's vector length");

251

252// Define iterators over tiles from the A operand

253using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<

254cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA,

255 LayoutA, 1, typename MmaCore::IteratorThreadMapA>;

256

257// Define iterators over tiles from the B operand

258using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<

259cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB,

260 LayoutB, 0, typename MmaCore::IteratorThreadMapB>;

261

262// Define the threadblock-scoped pipelined matrix multiply

263using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<

264typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,

265IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,

266 layout::ColumnMajorInterleaved<InterleavedK>,

267typename MmaCore::MmaPolicy>;

268 };

269

273 template <

275typename LayoutA,

277int kAlignmentA,

279typename LayoutB,

281int kAlignmentB,

283typename ElementAccumulator,

285typename ArchTag,

287typename ThreadblockShape,

289typename Operator,

291typename WarpShape>

292 struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,

293 ElementAccumulator, layout::RowMajor, arch::OpClassSimt,

294 ArchTag, ThreadblockShape, WarpShape, GemmShape<1, 1, 4>, 2,

295 Operator, false> {

296using InstructionShape = GemmShape<1, 1, 4>;

297using ElementA = int8_t;

298using ElementB = int8_t;

299using OperatorClass = arch::OpClassSimt;

300

301static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value;

302static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value;

303

304// Define the MmaCore components

305using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<

306 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,

307 ElementB, LayoutB, ElementAccumulator, layout::RowMajor,

308 OperatorClass, 2, Operator>;

309

310// Define iterators over tiles from the A operand

311using IteratorA =

312cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile<

313cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,

314 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>;

315

316// Define iterators over tiles from the B operand

317using IteratorB =

318cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile<

319cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,

320 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>;

321

322// Define the threadblock-scoped pipelined matrix multiply

323using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<

324typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,

325IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,

326 layout::RowMajor, typename MmaCore::MmaPolicy>;

327 };

328

329 #if defined(CUTLASS_ARCH_WMMA_ENABLED)

330 template <

333typename ElementA,

335typename LayoutA,

337int kAlignmentA,

339typename ElementB,

341typename LayoutB,

343int kAlignmentB,

345typename ElementAccumulator,

347typename LayoutC,

349typename ArchTag,

351typename ThreadblockShape,

353typename WarpShape,

355typename InstructionShape,

357typename Operator>

358 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

359 kAlignmentB, ElementAccumulator, LayoutC,

360 arch::OpClassWmmaTensorOp, ArchTag, ThreadblockShape, WarpShape,

361 InstructionShape, 2, Operator> {

362// Define the MmaCore components

363using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<

364 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,

365 ElementB, LayoutB, ElementAccumulator, LayoutC,

366 arch::OpClassWmmaTensorOp, 2, Operator>;

367

368// Define iterators over tiles from the A operand

369using IteratorA =

370cutlass::transform::threadblock::PredicatedTileIterator<

371cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,

372 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;

373

374// Define iterators over tiles from the B operand

375using IteratorB =

376cutlass::transform::threadblock::PredicatedTileIterator<

377cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,

378 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;

379

380// Define the threadblock-scoped pipelined matrix multiply

381using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<

382typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,

383IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,

384 LayoutC, typename MmaCore::MmaPolicy>;

385 };

386

388 template <

390typename ElementA,

392typename LayoutA,

394int kAlignmentA,

396typename ElementB,

398typename LayoutB,

400int kAlignmentB,

402typename ElementAccumulator,

404typename LayoutC,

406typename ArchTag,

408typename ThreadblockShape,

410typename WarpShape,

412typename InstructionShape,

414typename Operator>

415 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,

416 kAlignmentB, ElementAccumulator, LayoutC,

417 arch::OpClassWmmaTensorOp, ArchTag, ThreadblockShape, WarpShape,

418 InstructionShape, 1, Operator> {

419// Define the MmaCore components

420using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<

421 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,

422 ElementB, LayoutB, ElementAccumulator, LayoutC,

423 arch::OpClassWmmaTensorOp, 1, Operator>;

424

425// Define iterators over tiles from the A operand

426using IteratorA =

427cutlass::transform::threadblock::PredicatedTileIterator<

428cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,

429 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;

430

431// Define iterators over tiles from the B operand

432using IteratorB =

433cutlass::transform::threadblock::PredicatedTileIterator<

434cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,

435 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;

436

437// Define the threadblock-scoped singlestage matrix multiply

438using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage<

439typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,

440IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,

441 LayoutC, typename MmaCore::MmaPolicy>;

442 };

444 #endif //CUTLASS_ARCH_WMMA_ENABLED

445

446 } // namespace threadblock

447 } // namespace gemm

448 } // namespace cutlass

449

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::threadblock::DefaultMma< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, 2, Operator, false >::MmaCore

typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator > MmaCore

Definition: default_mma.h:308

cutlass::platform::is_same

std::is_same (false specialization)

Definition: platform.h:394

cutlass::gemm::threadblock::DefaultMmaCore

Definition: default_mma_core.h:90

cutlass::gemm::threadblock::DefaultMma< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, 2, Operator, false >::ElementA

int8_t ElementA

Definition: default_mma.h:297

[default_mma_core_wmma.h](default mma core__wmma_8h.html)

Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...

cutlass::gemm::threadblock::MmaPipelined

Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.

Definition: mma_pipelined.h:86

cutlass::gemm::threadblock::MmaSingleStage

Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.

Definition: mma_singlestage.h:76

cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false >::MmaCore

typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, 2, Operator > MmaCore

Definition: default_mma.h:125

[default_mma_core_sm70.h](default mma core__sm70_8h.html)

Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...

cutlass::gemm::threadblock::DefaultMma

Definition: default_mma.h:87

cutlass::sizeof_bits

Defines the size of an element in bits.

Definition: numeric_types.h:42

cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false >::MmaCore

typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator > MmaCore

Definition: default_mma.h:182

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

[default_mma_core_sm75.h](default mma core__sm75_8h.html)

Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...

cutlass::transform::threadblock::PredicatedTileIterator

Definition: transform/threadblock/predicated_tile_iterator.h:133

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

arch.h

Defines tags for architecture-specific configurations.

cutlass::layout::ColumnMajorInterleaved

Definition: layout/matrix.h:343

[predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)

Templates implementing loading of tiles from pitch-linear rank=2 tensors.

cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::ColumnMajorInterleaved< InterleavedK >, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, true >::MmaCore

typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::ColumnMajorInterleaved< InterleavedK >, OperatorClass, 2, Operator, true > MmaCore

Definition: default_mma.h:244

cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile

Definition: predicated_tile_iterator_2dthreadtile.h:133

wmma.h

Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.

cutlass.h

Basic include for CUTLASS.

[predicated_tile_iterator_2dthreadtile.h](predicated tile iterator__2dthreadtile_8h.html)

Templates implementing loading of tiles from pitch-linear rank=2 tensors.


Generated by 1.8.11