docs/default__mma_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
default_mma.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/arch/wmma.h"
35
36 #include "[cutlass/transform/threadblock/predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)"
37 #include "[cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h](predicated tile iterator__2dthreadtile_8h.html)"
38 #include "[cutlass/gemm/threadblock/default_mma_core_sm70.h](default mma core__sm70_8h.html)"
39 #include "[cutlass/gemm/threadblock/default_mma_core_sm75.h](default mma core__sm75_8h.html)"
40 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
41 #include "[cutlass/gemm/threadblock/default_mma_core_wmma.h](default mma core__wmma_8h.html)"
42 #endif //CUTLASS_ARCH_WMMA_ENABLED
43
45
46 namespace cutlass {
47 namespace gemm {
48 namespace threadblock {
49
51
52 template <
54typename ElementA_,
56typename LayoutA_,
58int kAlignmentA,
60typename ElementB_,
62typename LayoutB_,
64int kAlignmentB,
66typename ElementAccumulator_,
68typename LayoutC_,
70typename OperatorClass_,
72typename ArchTag_,
74typename ThreadblockShape_,
76typename WarpShape_,
78typename InstructionShape_,
80int Stages,
82typename Operator,
85bool AccumulatorsInRowMajor = false
86 >
87 struct DefaultMma;
88
90
92 template <
94typename ElementA,
96typename LayoutA,
98int kAlignmentA,
100typename ElementB,
102typename LayoutB,
104int kAlignmentB,
106typename ElementAccumulator,
108typename ArchTag,
110typename ThreadblockShape,
112typename WarpShape,
114typename InstructionShape,
116typename Operator>
117 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
118 kAlignmentB, ElementAccumulator, layout::RowMajor,
119 arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape,
120 InstructionShape, 2, Operator, false> {
121// Define the MmaCore components
122using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
123 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
124 ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
125 arch::OpClassSimt, 2, Operator>;
126
127// Define iterators over tiles from the A operand
128using IteratorA =
129cutlass::transform::threadblock::PredicatedTileIterator<
130cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
131 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;
132
133// Define iterators over tiles from the B operand
134using IteratorB =
135cutlass::transform::threadblock::PredicatedTileIterator<
136cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
137 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;
138
139// Define the threadblock-scoped pipelined matrix multiply
140using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<
141typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
142IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
143 layout::RowMajor, typename MmaCore::MmaPolicy>;
144 };
145
146
148 template <
150typename ElementA,
152typename LayoutA,
154int kAlignmentA,
156typename ElementB,
158typename LayoutB,
160int kAlignmentB,
162typename ElementAccumulator,
164typename ArchTag,
166typename ThreadblockShape,
168typename WarpShape,
170typename InstructionShape,
172typename Operator
173 >
174 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
175 kAlignmentB, ElementAccumulator, layout::RowMajor,
176 arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape,
177 InstructionShape, 2, Operator, false> {
178// Define the MmaCore components
179using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
180 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
181 ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
182 arch::OpClassTensorOp, 2, Operator>;
183
184// Define iterators over tiles from the A operand
185using IteratorA =
186cutlass::transform::threadblock::PredicatedTileIterator<
187cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
188 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;
189
190// Define iterators over tiles from the B operand
191using IteratorB =
192cutlass::transform::threadblock::PredicatedTileIterator<
193cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
194 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;
195
196// Define the threadblock-scoped pipelined matrix multiply
197using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<
198typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
199IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
200 layout::RowMajor, typename MmaCore::MmaPolicy>;
201 };
203
205 template <
207typename ElementA,
209typename LayoutA,
211int kAlignmentA,
213typename ElementB,
215typename LayoutB,
217int kAlignmentB,
219typename ElementAccumulator,
221typename OperatorClass,
223typename ArchTag,
225typename ThreadblockShape,
227typename WarpShape,
229typename InstructionShape,
231typename Operator,
233int InterleavedK>
234 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
235 kAlignmentB, ElementAccumulator,
236 layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass,
237 ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
238 Operator, true> {
239// Define the MmaCore components
240using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
241 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
242 ElementB, LayoutB, ElementAccumulator,
243layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
244true>;
245
246static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
247"Alignment must match thread data map's vector length");
248
249static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
250"Alignment must match thread data map's vector length");
251
252// Define iterators over tiles from the A operand
253using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
254cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA,
255 LayoutA, 1, typename MmaCore::IteratorThreadMapA>;
256
257// Define iterators over tiles from the B operand
258using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
259cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB,
260 LayoutB, 0, typename MmaCore::IteratorThreadMapB>;
261
262// Define the threadblock-scoped pipelined matrix multiply
263using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<
264typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
265IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
266 layout::ColumnMajorInterleaved<InterleavedK>,
267typename MmaCore::MmaPolicy>;
268 };
269
273 template <
275typename LayoutA,
277int kAlignmentA,
279typename LayoutB,
281int kAlignmentB,
283typename ElementAccumulator,
285typename ArchTag,
287typename ThreadblockShape,
289typename Operator,
291typename WarpShape>
292 struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
293 ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
294 ArchTag, ThreadblockShape, WarpShape, GemmShape<1, 1, 4>, 2,
295 Operator, false> {
296using InstructionShape = GemmShape<1, 1, 4>;
298using ElementB = int8_t;
299using OperatorClass = arch::OpClassSimt;
300
301static const bool transposeA = cutlass::platform::is_same< LayoutA, layout::ColumnMajor >::value;
302static const bool transposeB = cutlass::platform::is_same< LayoutB, layout::RowMajor >::value;
303
304// Define the MmaCore components
305using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
306 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
307 ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
308 OperatorClass, 2, Operator>;
309
310// Define iterators over tiles from the A operand
311using IteratorA =
312cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile<
313cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
314 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>;
315
316// Define iterators over tiles from the B operand
317using IteratorB =
318cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile<
319cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
320 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>;
321
322// Define the threadblock-scoped pipelined matrix multiply
323using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<
324typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
325IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
326 layout::RowMajor, typename MmaCore::MmaPolicy>;
327 };
328
329 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
330 template <
333typename ElementA,
335typename LayoutA,
337int kAlignmentA,
339typename ElementB,
341typename LayoutB,
343int kAlignmentB,
345typename ElementAccumulator,
347typename LayoutC,
349typename ArchTag,
351typename ThreadblockShape,
353typename WarpShape,
355typename InstructionShape,
357typename Operator>
358 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
359 kAlignmentB, ElementAccumulator, LayoutC,
360 arch::OpClassWmmaTensorOp, ArchTag, ThreadblockShape, WarpShape,
361 InstructionShape, 2, Operator> {
362// Define the MmaCore components
363using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
364 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
365 ElementB, LayoutB, ElementAccumulator, LayoutC,
366 arch::OpClassWmmaTensorOp, 2, Operator>;
367
368// Define iterators over tiles from the A operand
369using IteratorA =
370cutlass::transform::threadblock::PredicatedTileIterator<
371cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
372 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;
373
374// Define iterators over tiles from the B operand
375using IteratorB =
376cutlass::transform::threadblock::PredicatedTileIterator<
377cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
378 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;
379
380// Define the threadblock-scoped pipelined matrix multiply
381using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<
382typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
383IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
384 LayoutC, typename MmaCore::MmaPolicy>;
385 };
386
388 template <
390typename ElementA,
392typename LayoutA,
394int kAlignmentA,
396typename ElementB,
398typename LayoutB,
400int kAlignmentB,
402typename ElementAccumulator,
404typename LayoutC,
406typename ArchTag,
408typename ThreadblockShape,
410typename WarpShape,
412typename InstructionShape,
414typename Operator>
415 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
416 kAlignmentB, ElementAccumulator, LayoutC,
417 arch::OpClassWmmaTensorOp, ArchTag, ThreadblockShape, WarpShape,
418 InstructionShape, 1, Operator> {
419// Define the MmaCore components
420using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<
421 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
422 ElementB, LayoutB, ElementAccumulator, LayoutC,
423 arch::OpClassWmmaTensorOp, 1, Operator>;
424
425// Define iterators over tiles from the A operand
426using IteratorA =
427cutlass::transform::threadblock::PredicatedTileIterator<
428cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
429 ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;
430
431// Define iterators over tiles from the B operand
432using IteratorB =
433cutlass::transform::threadblock::PredicatedTileIterator<
434cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
435 ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;
436
437// Define the threadblock-scoped singlestage matrix multiply
438using ThreadblockMma = cutlass::gemm::threadblock::MmaSingleStage<
439typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
440IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
441 LayoutC, typename MmaCore::MmaPolicy>;
442 };
444 #endif //CUTLASS_ARCH_WMMA_ENABLED
445
446 } // namespace threadblock
447 } // namespace gemm
448 } // namespace cutlass
449
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator > MmaCore
Definition: default_mma.h:308
std::is_same (false specialization)
Definition: platform.h:394
cutlass::gemm::threadblock::DefaultMmaCore
Definition: default_mma_core.h:90
int8_t ElementA
Definition: default_mma.h:297
[default_mma_core_wmma.h](default mma core__wmma_8h.html)
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
cutlass::gemm::threadblock::MmaPipelined
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_pipelined.h:86
cutlass::gemm::threadblock::MmaSingleStage
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_singlestage.h:76
typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, 2, Operator > MmaCore
Definition: default_mma.h:125
[default_mma_core_sm70.h](default mma core__sm70_8h.html)
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
cutlass::gemm::threadblock::DefaultMma
Definition: default_mma.h:87
Defines the size of an element in bits.
Definition: numeric_types.h:42
typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator > MmaCore
Definition: default_mma.h:182
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
[default_mma_core_sm75.h](default mma core__sm75_8h.html)
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
cutlass::transform::threadblock::PredicatedTileIterator
Definition: transform/threadblock/predicated_tile_iterator.h:133
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines tags for architecture-specific configurations.
cutlass::layout::ColumnMajorInterleaved
Definition: layout/matrix.h:343
[predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::ColumnMajorInterleaved< InterleavedK >, OperatorClass, 2, Operator, true > MmaCore
Definition: default_mma.h:244
cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile
Definition: predicated_tile_iterator_2dthreadtile.h:133
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
[predicated_tile_iterator_2dthreadtile.h](predicated tile iterator__2dthreadtile_8h.html)
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
Generated by 1.8.11