docs/default__gemm_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
default_gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25
36 #pragma once
37
38 #include "cutlass/cutlass.h"
39
40 #include "cutlass/layout/matrix.h"
41 #include "cutlass/numeric_types.h"
42 #include "cutlass/arch/wmma.h"
43
44 #include "cutlass/epilogue/threadblock/epilogue.h"
45 #include "cutlass/epilogue/thread/linear_combination.h"
46
47 #include "cutlass/gemm/gemm.h"
48 #include "cutlass/gemm/kernel/gemm.h"
49 #include "cutlass/gemm/kernel/gemm_pipelined.h"
50 #include "[cutlass/gemm/threadblock/default_mma_core_sm75.h](default mma core__sm75_8h.html)"
51 #include "[cutlass/gemm/threadblock/default_mma_core_sm70.h](default mma core__sm70_8h.html)"
52 #include "cutlass/gemm/threadblock/default_mma.h"
53 #include "[cutlass/gemm/threadblock/default_mma_core_simt.h](default mma core__simt_8h.html)"
54 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
55
56 #include "[cutlass/epilogue/threadblock/default_epilogue_tensor_op.h](default epilogue tensor__op_8h.html)"
57 #include "[cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h](default epilogue volta tensor op_8h.html)"
58 #include "[cutlass/epilogue/threadblock/default_epilogue_simt.h](default epilogue simt_8h.html)"
59 #include "[cutlass/transform/threadblock/predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)"
60
61 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
62 #include "[cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h](default epilogue wmma tensor op_8h.html)"
63 #endif //CUTLASS_ARCH_WMMA_ENABLED
64
65
67
68 namespace cutlass {
69 namespace gemm {
70 namespace kernel {
71
73
74 template <
76typename ElementA_,
78typename LayoutA_,
80int kAlignmentA,
82typename ElementB_,
84typename LayoutB_,
86int kAlignmentB,
88typename ElementC_,
90typename LayoutC_,
92typename ElementAccumulator,
94typename OperatorClass,
96typename ArchTag,
98typename ThreadblockShape,
100typename WarpShape,
102typename InstructionShape,
104typename EpilogueOutputOp,
106typename ThreadblockSwizzle,
108int Stages,
111bool SplitKSerial,
113typename Operator,
115bool IsBetaZero = false>
116 struct DefaultGemm;
117
120 template <
122typename ElementA,
124typename LayoutA,
126int kAlignmentA,
128typename ElementB,
130typename LayoutB,
132int kAlignmentB,
134typename ElementC,
136typename ElementAccumulator,
138typename ThreadblockShape,
140typename WarpShape,
142typename InstructionShape,
144typename EpilogueOutputOp,
146typename ThreadblockSwizzle,
148bool SplitKSerial,
150typename Operator
151 >
152 struct DefaultGemm<
153 ElementA, LayoutA, kAlignmentA,
154 ElementB, LayoutB, kAlignmentB,
155 ElementC, layout::RowMajor,
156 ElementAccumulator,
157 arch::OpClassTensorOp,
158arch::Sm75,
159 ThreadblockShape,
160 WarpShape,
161 InstructionShape,
162 EpilogueOutputOp,
163 ThreadblockSwizzle,
164 2,
165 SplitKSerial,
166 Operator
167 > {
168
170using Mma = typename cutlass::gemm::threadblock::DefaultMma<
171 ElementA,
172 LayoutA,
173 kAlignmentA,
174 ElementB,
175 LayoutB,
176 kAlignmentB,
177 ElementAccumulator,
178layout::RowMajor,
179 arch::OpClassTensorOp,
180arch::Sm75,
181 ThreadblockShape,
182 WarpShape,
183 InstructionShape,
184 2,
185 Operator
186 >::ThreadblockMma;
187
188static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
189
191using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<
192 ThreadblockShape,
193typename Mma::Operator,
194 kPartitionsK,
195 EpilogueOutputOp,
196 EpilogueOutputOp::kCount
198
200using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
201 };
202
205 template <
207typename ElementA,
209int kAlignmentA,
211typename ElementB,
213int kAlignmentB,
215typename ElementC,
217typename ThreadblockShape,
219typename WarpShape,
221typename InstructionShape,
223typename EpilogueOutputOp,
225typename ThreadblockSwizzle,
227int InterleavedK,
230bool SplitKSerial,
232typename Operator,
234bool IsBetaZero>
235 struct DefaultGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
236 kAlignmentA, ElementB,
237layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
238 ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
239 int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape,
240 WarpShape, InstructionShape, EpilogueOutputOp,
241 ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
242using LayoutA = layout::ColumnMajorInterleaved<InterleavedK>;
243using LayoutB = layout::RowMajorInterleaved<InterleavedK>;
244using LayoutC = layout::ColumnMajorInterleaved<InterleavedK>;
245
246using ElementAccumulator = int32_t;
247
249using Mma = typename cutlass::gemm::threadblock::DefaultMma<
250 ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
251 arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape,
252 InstructionShape, 2, Operator, true>::ThreadblockMma;
253
254static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
255
257using Epilogue = typename cutlass::epilogue::threadblock::
258 DefaultInterleavedEpilogueTensorOp<
259 ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
260 64 / sizeof_bits<ElementC>::value, InterleavedK,
261 IsBetaZero>::Epilogue;
262
264using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
265 };
266
268
269
271 template <
273typename ElementA,
275typename LayoutA,
277int kAlignmentA,
279typename ElementB,
281typename LayoutB,
283int kAlignmentB,
285typename ElementC,
287typename ElementAccumulator,
289typename ThreadblockShape,
291typename WarpShape,
293typename EpilogueOutputOp,
295typename ThreadblockSwizzle,
297bool SplitKSerial,
299typename Operator
300 >
301 struct DefaultGemm<
302 ElementA, LayoutA, kAlignmentA,
303 ElementB, LayoutB, kAlignmentB,
304 ElementC, layout::RowMajor,
305 ElementAccumulator,
306 arch::OpClassTensorOp,
307arch::Sm70,
308 ThreadblockShape,
309 WarpShape,
310GemmShape<8, 8, 4>,
311 EpilogueOutputOp,
312 ThreadblockSwizzle,
313 2,
314 SplitKSerial,
315 Operator
316 > {
317
319using Mma = typename cutlass::gemm::threadblock::DefaultMma<
320 ElementA,
321 LayoutA,
322 kAlignmentA,
323 ElementB,
324 LayoutB,
325 kAlignmentB,
326 ElementAccumulator,
327layout::RowMajor,
328 arch::OpClassTensorOp,
329arch::Sm70,
330 ThreadblockShape,
331 WarpShape,
333 2,
334 Operator
335 >::ThreadblockMma;
336
337static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
338
340using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp<
341 ThreadblockShape,
342typename Mma::Operator,
343 kPartitionsK,
344 EpilogueOutputOp,
345 EpilogueOutputOp::kCount
347
349using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
350 };
351
353
355 template <
357typename ElementA,
359typename LayoutA,
361int kAlignmentA,
363typename ElementB,
365typename LayoutB,
367int kAlignmentB,
369typename ElementC,
371typename ElementAccumulator,
373typename ArchTag,
375typename ThreadblockShape,
377typename WarpShape,
379typename EpilogueOutputOp,
381typename ThreadblockSwizzle,
383bool SplitKSerial,
385typename Operator
386 >
387 struct DefaultGemm<
388 ElementA,
389 LayoutA,
390 kAlignmentA,
391 ElementB,
392 LayoutB,
393 kAlignmentB,
394 ElementC,
395 layout::RowMajor,
396 ElementAccumulator,
397 arch::OpClassSimt,
398 ArchTag,
399 ThreadblockShape,
400 WarpShape,
401GemmShape<1, 1, 1>,
402 EpilogueOutputOp,
403 ThreadblockSwizzle,
404 2,
405 SplitKSerial,
406 Operator> {
408using Mma = typename cutlass::gemm::threadblock::DefaultMma<
409 ElementA,
410 LayoutA,
411 kAlignmentA,
412 ElementB,
413 LayoutB,
414 kAlignmentB,
415 ElementAccumulator,
416layout::RowMajor,
417 arch::OpClassSimt,
418arch::Sm50,
419 ThreadblockShape,
420 WarpShape,
422 2,
423 Operator>::ThreadblockMma;
424
425static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
426static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
427
429using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
430 ThreadblockShape,
431typename Mma::Operator,
432 EpilogueOutputOp,
433 kEpilogueElementsPerAccess
435
437using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
438 };
439
441
444
445 template <
447typename LayoutA,
449int kAlignmentA,
451typename LayoutB,
453int kAlignmentB,
455typename LayoutC,
457typename ElementC,
459typename ArchTag,
461typename ElementAccumulator,
463typename ThreadblockShape,
465typename WarpShape,
467typename EpilogueOutputOp,
469typename ThreadblockSwizzle,
472bool SplitKSerial,
474typename Operator>
475 struct DefaultGemm<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
476 ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt,
477 ArchTag, ThreadblockShape, WarpShape, GemmShape<1, 1, 4>,
478 EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial,
479 Operator, false> {
480using InstructionShape = GemmShape<1, 1, 4>;
482using ElementB = int8_t;
483
484using OperatorClass = arch::OpClassSimt;
486using Mma = typename cutlass::gemm::threadblock::DefaultMma<ElementA,
487 LayoutA,
488 kAlignmentA,
489 ElementB,
490 LayoutB,
491 kAlignmentB,
492 ElementAccumulator,
493 LayoutC,
494 arch::OpClassSimt,
495arch::Sm50,
496 ThreadblockShape,
497 WarpShape,
498 InstructionShape,
499 2,
500 Operator,
501false
502 >::ThreadblockMma;
503
504static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
505static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
506
508using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimt<
509 ThreadblockShape,
510typename Mma::Operator,
511 EpilogueOutputOp,
512 kEpilogueElementsPerAccess
514
516using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
517 };
518
519
520 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
521 template <
525typename ElementA,
527typename LayoutA,
529int kAlignmentA,
531typename ElementB,
533typename LayoutB,
535int kAlignmentB,
537typename ElementC,
539typename LayoutC,
541typename ElementAccumulator,
543typename ArchTag,
545typename ThreadblockShape,
547typename WarpShape,
549typename InstructionShape,
551typename EpilogueOutputOp,
553typename ThreadblockSwizzle,
555int Stages,
558bool SplitKSerial,
560typename Operator>
561 struct DefaultGemm<
562ElementA, LayoutA, kAlignmentA,
563 ElementB, LayoutB, kAlignmentB,
564 ElementC, LayoutC,
565 ElementAccumulator,
566 arch::OpClassWmmaTensorOp,
567 ArchTag,
568 ThreadblockShape, WarpShape, InstructionShape,
569 EpilogueOutputOp,
570 ThreadblockSwizzle,
571 Stages,
572 SplitKSerial,
573 Operator> {
575using Mma = typename cutlass::gemm::threadblock::DefaultMma<
576ElementA, LayoutA, kAlignmentA,
577 ElementB, LayoutB, kAlignmentB,
578 ElementAccumulator, LayoutC,
579 arch::OpClassWmmaTensorOp,
580 ArchTag,
581 ThreadblockShape,
582 WarpShape,
583 InstructionShape,
584 Stages,
585 Operator>::ThreadblockMma;
586
587static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
588
590using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp<
591 ThreadblockShape,
592typename Mma::Operator,
593 kPartitionsK,
594 EpilogueOutputOp,
595 EpilogueOutputOp::kCount
596 >::Epilogue;
597
599using GemmKernel = kernel::Gemm<Mma, Epilogue, ThreadblockSwizzle, SplitKSerial>;
600 };
602 #endif //CUTLASS_ARCH_WMMA_ENABLED
603
605
606 } // namespace kernel
607 } // namespace gemm
608 } // namespace cutlass
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm70, ThreadblockShape, WarpShape, GemmShape< 8, 8, 4 >, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:335
cutlass::gemm::kernel::DefaultGemm
Definition: default_gemm.h:116
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:186
Definition: aligned_buffer.h:35
cutlass::epilogue::threadblock::DefaultEpilogueSimt
Defines sensible defaults for epilogues for SimtOps.
Definition: default_epilogue_simt.h:70
Definition: arch.h:37
[default_epilogue_wmma_tensor_op.h](default epilogue wmma tensor op_8h.html)
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Definition: arch.h:46
Defines common types used for all GEMM-like operators.
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, true >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:252
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassSimt, arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:502
[default_mma_core_sm70.h](default mma core__sm70_8h.html)
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
cutlass::gemm::threadblock::DefaultMma
Definition: default_mma.h:87
Definition: arch.h:52
Functor performing linear combination operations used by epilogues.
Defines the size of an element in bits.
Definition: numeric_types.h:42
typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:346
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
[default_epilogue_tensor_op.h](default epilogue tensor__op_8h.html)
Epilogue for threadblock scoped GEMMs using Tensor Ops.
typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:434
[default_mma_core_sm75.h](default mma core__sm75_8h.html)
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp
Defines sensible defaults for epilogues for TensorOps.
Definition: default_epilogue_volta_tensor_op.h:71
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:197
typename cutlass::epilogue::threadblock::DefaultInterleavedEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, 64/sizeof_bits< ElementC >::value, InterleavedK, IsBetaZero >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:261
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Definition: include/cutlass/gemm/kernel/gemm.h:52
Defines layout functions used by TensorRef and derived classes.
cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp
Defines sensible defaults for epilogues for WMMA TensorOps.
Definition: default_epilogue_wmma_tensor_op.h:71
cutlass::epilogue::threadblock::DefaultInterleavedEpilogueTensorOp
Definition: default_epilogue_tensor_op.h:147
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
cutlass::layout::ColumnMajorInterleaved
Definition: layout/matrix.h:343
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm50, ThreadblockShape, WarpShape, GemmShape< 1, 1, 1 >, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:423
[default_mma_core_simt.h](default mma core__simt_8h.html)
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
[predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
int8_t ElementA
Definition: default_gemm.h:481
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
cutlass::epilogue::threadblock::DefaultEpilogueTensorOp
Defines sensible defaults for epilogues for TensorOps.
Definition: default_epilogue_tensor_op.h:72
typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:513
Basic include for CUTLASS.
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
[default_epilogue_simt.h](default epilogue simt_8h.html)
Epilogue for threadblock scoped GEMMs using SIMT.
[default_epilogue_volta_tensor_op.h](default epilogue volta tensor op_8h.html)
Epilogue for threadblock scoped GEMMs using Tensor Ops on Volta.
cutlass::layout::RowMajorInterleaved
Definition: layout/matrix.h:237
<!-- fragment --> <!-- contents --><!-- start footer part -->