Back to Cutlass

CUTLASS: default_gemm.h Source File

docs/default__gemm_8h_source.html

4.4.241.3 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

default_gemm.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

25

36 #pragma once

37

38 #include "cutlass/cutlass.h"

39

40 #include "cutlass/layout/matrix.h"

41 #include "cutlass/numeric_types.h"

42 #include "cutlass/arch/wmma.h"

43

44 #include "cutlass/epilogue/threadblock/epilogue.h"

45 #include "cutlass/epilogue/thread/linear_combination.h"

46

47 #include "cutlass/gemm/gemm.h"

48 #include "cutlass/gemm/kernel/gemm.h"

49 #include "cutlass/gemm/kernel/gemm_pipelined.h"

50 #include "[cutlass/gemm/threadblock/default_mma_core_sm75.h](default mma core__sm75_8h.html)"

51 #include "[cutlass/gemm/threadblock/default_mma_core_sm70.h](default mma core__sm70_8h.html)"

52 #include "cutlass/gemm/threadblock/default_mma.h"

53 #include "[cutlass/gemm/threadblock/default_mma_core_simt.h](default mma core__simt_8h.html)"

54 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"

55

56 #include "[cutlass/epilogue/threadblock/default_epilogue_tensor_op.h](default epilogue tensor__op_8h.html)"

57 #include "[cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h](default epilogue volta tensor op_8h.html)"

58 #include "[cutlass/epilogue/threadblock/default_epilogue_simt.h](default epilogue simt_8h.html)"

59 #include "[cutlass/transform/threadblock/predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)"

60

61 #if defined(CUTLASS_ARCH_WMMA_ENABLED)

62 #include "[cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h](default epilogue wmma tensor op_8h.html)"

63 #endif //CUTLASS_ARCH_WMMA_ENABLED

64

65

67

68 namespace cutlass {

69 namespace gemm {

70 namespace kernel {

71

73

74 template <

76typename ElementA_,

78typename LayoutA_,

80int kAlignmentA,

82typename ElementB_,

84typename LayoutB_,

86int kAlignmentB,

88typename ElementC_,

90typename LayoutC_,

92typename ElementAccumulator,

94typename OperatorClass,

96typename ArchTag,

98typename ThreadblockShape,

100typename WarpShape,

102typename InstructionShape,

104typename EpilogueOutputOp,

106typename ThreadblockSwizzle,

108int Stages,

111bool SplitKSerial,

113typename Operator,

115bool IsBetaZero = false>

116 struct DefaultGemm;

117

120 template <

122typename ElementA,

124typename LayoutA,

126int kAlignmentA,

128typename ElementB,

130typename LayoutB,

132int kAlignmentB,

134typename ElementC,

136typename ElementAccumulator,

138typename ThreadblockShape,

140typename WarpShape,

142typename InstructionShape,

144typename EpilogueOutputOp,

146typename ThreadblockSwizzle,

148bool SplitKSerial,

150typename Operator

151 >

152 struct DefaultGemm<

153 ElementA, LayoutA, kAlignmentA,

154 ElementB, LayoutB, kAlignmentB,

155 ElementC, layout::RowMajor,

156 ElementAccumulator,

157 arch::OpClassTensorOp,

158arch::Sm75,

159 ThreadblockShape,

160 WarpShape,

161 InstructionShape,

162 EpilogueOutputOp,

163 ThreadblockSwizzle,

164 2,

165 SplitKSerial,

166 Operator

167 > {

168

170using Mma = typename cutlass::gemm::threadblock::DefaultMma<

171 ElementA,

172 LayoutA,

173 kAlignmentA,

174 ElementB,

175 LayoutB,

176 kAlignmentB,

177 ElementAccumulator,

178layout::RowMajor,

179 arch::OpClassTensorOp,

180arch::Sm75,

181 ThreadblockShape,

182 WarpShape,

183 InstructionShape,

184 2,

185 Operator

186 >::ThreadblockMma;

187

188static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;

189

191using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<

192 ThreadblockShape,

193typename Mma::Operator,

194 kPartitionsK,

195 EpilogueOutputOp,

196 EpilogueOutputOp::kCount

197 >::Epilogue;

198

200using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;

201 };

202

205 template <

207typename ElementA,

209int kAlignmentA,

211typename ElementB,

213int kAlignmentB,

215typename ElementC,

217typename ThreadblockShape,

219typename WarpShape,

221typename InstructionShape,

223typename EpilogueOutputOp,

225typename ThreadblockSwizzle,

227int InterleavedK,

230bool SplitKSerial,

232typename Operator,

234bool IsBetaZero>

235 struct DefaultGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,

236 kAlignmentA, ElementB,

237layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,

238 ElementC, layout::ColumnMajorInterleaved<InterleavedK>,

239 int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape,

240 WarpShape, InstructionShape, EpilogueOutputOp,

241 ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {

242using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;

243using LayoutB = layout::RowMajorInterleaved<InterleavedK>;

244using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;

245

246using ElementAccumulator = int32_t;

247

249using Mma = typename cutlass::gemm::threadblock::DefaultMma<

250 ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,

251 arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape,

252 InstructionShape, 2, Operator, true>::ThreadblockMma;

253

254static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;

255

257using Epilogue = typename cutlass::epilogue::threadblock::

258 DefaultInterleavedEpilogueTensorOp<

259 ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,

260 64 / sizeof_bits<ElementC>::value, InterleavedK,

261 IsBetaZero>::Epilogue;

262

264using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;

265 };

266

268

269

271 template <

273typename ElementA,

275typename LayoutA,

277int kAlignmentA,

279typename ElementB,

281typename LayoutB,

283int kAlignmentB,

285typename ElementC,

287typename ElementAccumulator,

289typename ThreadblockShape,

291typename WarpShape,

293typename EpilogueOutputOp,

295typename ThreadblockSwizzle,

297bool SplitKSerial,

299typename Operator

300 >

301 struct DefaultGemm<

302 ElementA, LayoutA, kAlignmentA,

303 ElementB, LayoutB, kAlignmentB,

304 ElementC, layout::RowMajor,

305 ElementAccumulator,

306 arch::OpClassTensorOp,

307arch::Sm70,

308 ThreadblockShape,

309 WarpShape,

310GemmShape<8, 8, 4>,

311 EpilogueOutputOp,

312 ThreadblockSwizzle,

313 2,

314 SplitKSerial,

315 Operator

316 > {

317

319using Mma = typename cutlass::gemm::threadblock::DefaultMma<

320 ElementA,

321 LayoutA,

322 kAlignmentA,

323 ElementB,

324 LayoutB,

325 kAlignmentB,

326 ElementAccumulator,

327layout::RowMajor,

328 arch::OpClassTensorOp,

329arch::Sm70,

330 ThreadblockShape,

331 WarpShape,

332GemmShape<8, 8, 4>,

333 2,

334 Operator

335 >::ThreadblockMma;

336

337static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;

338

340using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp<

341 ThreadblockShape,

342typename Mma::Operator,

343 kPartitionsK,

344 EpilogueOutputOp,

345 EpilogueOutputOp::kCount

346 >::Epilogue;

347

349using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;

350 };

351

353

355 template <

357typename ElementA,

359typename LayoutA,

361int kAlignmentA,

363typename ElementB,

365typename LayoutB,

367int kAlignmentB,

369typename ElementC,

371typename ElementAccumulator,

373typename ArchTag,

375typename ThreadblockShape,

377typename WarpShape,

379typename EpilogueOutputOp,

381typename ThreadblockSwizzle,

383bool SplitKSerial,

385typename Operator

386 >

387 struct DefaultGemm<

388 ElementA,

389 LayoutA,

390 kAlignmentA,

391 ElementB,

392 LayoutB,

393 kAlignmentB,

394 ElementC,

395 layout::RowMajor,

396 ElementAccumulator,

397 arch::OpClassSimt,

398 ArchTag,

399 ThreadblockShape,

400 WarpShape,

401GemmShape<1, 1, 1>,

402 EpilogueOutputOp,

403 ThreadblockSwizzle,

404 2,

405 SplitKSerial,

406 Operator> {

408using Mma = typename cutlass::gemm::threadblock::DefaultMma<

409 ElementA,

410 LayoutA,

411 kAlignmentA,

412 ElementB,

413 LayoutB,

414 kAlignmentB,

415 ElementAccumulator,

416layout::RowMajor,

417 arch::OpClassSimt,

418arch::Sm50,

419 ThreadblockShape,

420 WarpShape,

421GemmShape<1, 1, 1>,

422 2,

423 Operator>::ThreadblockMma;

424

425static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;

426static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");

427

429using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<

430 ThreadblockShape,

431typename Mma::Operator,

432 EpilogueOutputOp,

433 kEpilogueElementsPerAccess

434 >::Epilogue;

435

437using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;

438 };

439

441

444

445 template <

447typename LayoutA,

449int kAlignmentA,

451typename LayoutB,

453int kAlignmentB,

455typename LayoutC,

457typename ElementC,

459typename ArchTag,

461typename ElementAccumulator,

463typename ThreadblockShape,

465typename WarpShape,

467typename EpilogueOutputOp,

469typename ThreadblockSwizzle,

472bool SplitKSerial,

474typename Operator>

475 struct DefaultGemm<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,

476 ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt,

477 ArchTag, ThreadblockShape, WarpShape, GemmShape<1, 1, 4>,

478 EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial,

479 Operator, false> {

480using InstructionShape = GemmShape<1, 1, 4>;

481using ElementA = int8_t;

482using ElementB = int8_t;

483

484using OperatorClass = arch::OpClassSimt;

486using Mma = typename cutlass::gemm::threadblock::DefaultMma<ElementA,

487 LayoutA,

488 kAlignmentA,

489 ElementB,

490 LayoutB,

491 kAlignmentB,

492 ElementAccumulator,

493 LayoutC,

494 arch::OpClassSimt,

495arch::Sm50,

496 ThreadblockShape,

497 WarpShape,

498 InstructionShape,

499 2,

500 Operator,

501false

502 >::ThreadblockMma;

503

504static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;

505static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");

506

508using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<

509 ThreadblockShape,

510typename Mma::Operator,

511 EpilogueOutputOp,

512 kEpilogueElementsPerAccess

513 >::Epilogue;

514

516using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;

517 };

518

519

520 #if defined(CUTLASS_ARCH_WMMA_ENABLED)

521 template <

525typename ElementA,

527typename LayoutA,

529int kAlignmentA,

531typename ElementB,

533typename LayoutB,

535int kAlignmentB,

537typename ElementC,

539typename LayoutC,

541typename ElementAccumulator,

543typename ArchTag,

545typename ThreadblockShape,

547typename WarpShape,

549typename InstructionShape,

551typename EpilogueOutputOp,

553typename ThreadblockSwizzle,

555int Stages,

558bool SplitKSerial,

560typename Operator>

561 struct DefaultGemm<

562ElementA, LayoutA, kAlignmentA,

563 ElementB, LayoutB, kAlignmentB,

564 ElementC, LayoutC,

565 ElementAccumulator,

566 arch::OpClassWmmaTensorOp,

567 ArchTag,

568 ThreadblockShape, WarpShape, InstructionShape,

569 EpilogueOutputOp,

570 ThreadblockSwizzle,

571 Stages,

572 SplitKSerial,

573 Operator> {

575using Mma = typename cutlass::gemm::threadblock::DefaultMma<

576ElementA, LayoutA, kAlignmentA,

577 ElementB, LayoutB, kAlignmentB,

578 ElementAccumulator, LayoutC,

579 arch::OpClassWmmaTensorOp,

580 ArchTag,

581 ThreadblockShape,

582 WarpShape,

583 InstructionShape,

584 Stages,

585 Operator>::ThreadblockMma;

586

587static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;

588

590using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp<

591 ThreadblockShape,

592typename Mma::Operator,

593 kPartitionsK,

594 EpilogueOutputOp,

595 EpilogueOutputOp::kCount

596 >::Epilogue;

597

599using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;

600 };

602 #endif //CUTLASS_ARCH_WMMA_ENABLED

603

605

606 } // namespace kernel

607 } // namespace gemm

608 } // namespace cutlass

cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, arch::Sm70, ThreadblockShape, WarpShape, GemmShape< 8, 8, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Mma

typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm70, ThreadblockShape, WarpShape, GemmShape< 8, 8, 4 >, 2, Operator >::ThreadblockMma Mma

Define the threadblock-scoped matrix multiply-accumulate.

Definition: default_gemm.h:335

cutlass::gemm::kernel::DefaultGemm

Definition: default_gemm.h:116

cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Mma

typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, 2, Operator >::ThreadblockMma Mma

Define the threadblock-scoped matrix multiply-accumulate.

Definition: default_gemm.h:186

cutlass

Definition: aligned_buffer.h:35

cutlass::epilogue::threadblock::DefaultEpilogueSimt

Defines sensible defaults for epilogues for SimtOps.

Definition: default_epilogue_simt.h:70

cutlass::arch::Sm50

Definition: arch.h:37

[default_epilogue_wmma_tensor_op.h](default epilogue wmma tensor op_8h.html)

Epilogue for threadblock scoped GEMMs using Tensor Ops.

cutlass::arch::Sm70

Definition: arch.h:46

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::kernel::DefaultGemm< ElementA, layout::ColumnMajorInterleaved< InterleavedK >, kAlignmentA, ElementB, layout::RowMajorInterleaved< InterleavedK >, kAlignmentB, ElementC, layout::ColumnMajorInterleaved< InterleavedK >, int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero >::Mma

typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, true >::ThreadblockMma Mma

Define the threadblock-scoped matrix multiply-accumulate.

Definition: default_gemm.h:252

default_mma.h

Template for a pipelined GEMM kernel. Does not compute batching or support split-K.

cutlass::gemm::kernel::DefaultGemm< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, false >::Mma

typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassSimt, arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false >::ThreadblockMma Mma

Define the threadblock-scoped matrix multiply-accumulate.

Definition: default_gemm.h:502

[default_mma_core_sm70.h](default mma core__sm70_8h.html)

Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...

cutlass::gemm::threadblock::DefaultMma

Definition: default_mma.h:87

cutlass::arch::Sm75

Definition: arch.h:52

linear_combination.h

Functor performing linear combination operations used by epilogues.

cutlass::sizeof_bits

Defines the size of an element in bits.

Definition: numeric_types.h:42

cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, arch::Sm70, ThreadblockShape, WarpShape, GemmShape< 8, 8, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Epilogue

typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue Epilogue

Define the epilogue.

Definition: default_gemm.h:346

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

[default_epilogue_tensor_op.h](default epilogue tensor__op_8h.html)

Epilogue for threadblock scoped GEMMs using Tensor Ops.

cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 1 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Epilogue

typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue Epilogue

Define the epilogue.

Definition: default_gemm.h:434

[default_mma_core_sm75.h](default mma core__sm75_8h.html)

Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...

cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp

Defines sensible defaults for epilogues for TensorOps.

Definition: default_epilogue_volta_tensor_op.h:71

cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Epilogue

typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue Epilogue

Define the epilogue.

Definition: default_gemm.h:197

cutlass::gemm::kernel::DefaultGemm< ElementA, layout::ColumnMajorInterleaved< InterleavedK >, kAlignmentA, ElementB, layout::RowMajorInterleaved< InterleavedK >, kAlignmentB, ElementC, layout::ColumnMajorInterleaved< InterleavedK >, int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero >::Epilogue

typename cutlass::epilogue::threadblock::DefaultInterleavedEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, 64/sizeof_bits< ElementC >::value, InterleavedK, IsBetaZero >::Epilogue Epilogue

Define the epilogue.

Definition: default_gemm.h:261

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

epilogue.h

Epilogue for threadblock scoped GEMMs using Tensor Ops.

cutlass::gemm::kernel::Gemm

Definition: include/cutlass/gemm/kernel/gemm.h:52

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp

Defines sensible defaults for epilogues for WMMA TensorOps.

Definition: default_epilogue_wmma_tensor_op.h:71

cutlass::epilogue::threadblock::DefaultInterleavedEpilogueTensorOp

Definition: default_epilogue_tensor_op.h:147

threadblock_swizzle.h

Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...

cutlass::layout::ColumnMajorInterleaved

Definition: layout/matrix.h:343

cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 1 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Mma

typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm50, ThreadblockShape, WarpShape, GemmShape< 1, 1, 1 >, 2, Operator >::ThreadblockMma Mma

Define the threadblock-scoped matrix multiply-accumulate.

Definition: default_gemm.h:423

[default_mma_core_simt.h](default mma core__simt_8h.html)

Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...

[predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)

Templates implementing loading of tiles from pitch-linear rank=2 tensors.

cutlass::gemm::kernel::DefaultGemm< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, false >::ElementA

int8_t ElementA

Definition: default_gemm.h:481

gemm_pipelined.h

Template for a pipelined GEMM kernel. Does not compute batching or support split-K.

wmma.h

Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.

cutlass::epilogue::threadblock::DefaultEpilogueTensorOp

Defines sensible defaults for epilogues for TensorOps.

Definition: default_epilogue_tensor_op.h:72

cutlass::gemm::kernel::DefaultGemm< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, false >::Epilogue

typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue Epilogue

Define the epilogue.

Definition: default_gemm.h:513

cutlass.h

Basic include for CUTLASS.

gemm.h

Template for a pipelined GEMM kernel. Does not compute batching or support split-K.

[default_epilogue_simt.h](default epilogue simt_8h.html)

Epilogue for threadblock scoped GEMMs using SIMT.

[default_epilogue_volta_tensor_op.h](default epilogue volta tensor op_8h.html)

Epilogue for threadblock scoped GEMMs using Tensor Ops on Volta.

cutlass::layout::RowMajorInterleaved

Definition: layout/matrix.h:237

<!-- fragment --> <!-- contents --><!-- start footer part -->
<address class="footer"><small> Generated by 1.8.11 </small></address>