docs/mma__tensor__op_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_tensor_op.h
[Go to the documentation of this file.](mma tensor op_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
30 #pragma once
31
32 #include "cutlass/cutlass.h"
33 #include "cutlass/array.h"
34
35 #include "cutlass/numeric_types.h"
36 #include "cutlass/matrix_shape.h"
37
38 #include "cutlass/arch/memory_sm75.h"
39 #include "cutlass/arch/mma_sm75.h"
40 #include "cutlass/gemm/gemm.h"
41 #include "cutlass/gemm/warp/mma.h"
42
43 #include "[cutlass/gemm/warp/mma_tensor_op_policy.h](mma tensor op__policy_8h.html)"
44
45 #include "[cutlass/gemm/warp/mma_tensor_op_tile_iterator.h](mma tensor op tile iterator_8h.html)"
47
48 namespace cutlass {
49 namespace gemm {
50 namespace warp {
51
53
55 template <
57typename Shape_,
59typename ElementA_,
61typename LayoutA_,
63typename ElementB_,
65typename LayoutB_,
67typename ElementC_,
69typename LayoutC_,
71typename Policy_,
73int PartitionsK_ = 1,
76bool AccumulatorsInRowMajor = false,
78int PartitionsN_ = 1,
80typename Enable = bool
81 >
82 class MmaTensorOp {
83 public:
86
89
92
95
98
100using ElementC = ElementC_;
101
104
107
109using OperatorClass = arch::OpClassTensorOp;
110
112static int const kThreadCount = 32;
113
115static int const kPartitionsK = PartitionsK_;
116
118static int const kPartitionsN = PartitionsN_;
119
120 public:
121
123using IteratorA = MmaTensorOpMultiplicandTileIterator<
124MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
125MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
126 Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
127
129using FragmentA = typename IteratorA::Fragment;
130
132using IteratorB = MmaTensorOpMultiplicandTileIterator<
133MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
134MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
135 Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
136
138using FragmentB = typename IteratorB::Fragment;
139
141using IteratorC = MmaTensorOpAccumulatorTileIterator<
142MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
143typename Policy::Operator::Shape, typename Policy::OpDelta>;
144
146using FragmentC = typename IteratorC::Fragment;
147
148 private:
149
150static_assert(
151 !(Shape::kM % Policy::Operator::Shape::kM) &&
152 !(Shape::kN % Policy::Operator::Shape::kN),
153"Shape of warp-level Mma must be divisible by operator shape.");
154
156using MmaIterations = MatrixShape<
157 Shape::kM / Policy::Operator::Shape::kM,
158 (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?
159 Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :
160 1
161 >;
162
163 public:
164
166typename Policy::Operator mma;
167
168 public:
169
170//
171// Methods
172//
173
175 CUTLASS_DEVICE
176MmaTensorOp() {}
177
179 CUTLASS_DEVICE
180void operator()(
181FragmentC &D,
182FragmentA const &A,
183FragmentB const &B,
184FragmentC const &C,
185int const &partitionN_idx = 0) const {
186
187using MmaOperandA = typename Policy::Operator::FragmentA;
188using MmaOperandB = typename Policy::Operator::FragmentB;
189using MmaOperandC = typename Policy::Operator::FragmentC;
190
191 D = C;
192
193 MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
194 MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
195 MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
196
197// The offset of multilicand B for current partition
198const int n_off = partitionN_idx * FragmentB::kElements / MmaOperandB::kElements / kPartitionsN;
199// Serpentine visitation order maximizing reuse of Rb
201for (int n = 0; n < MmaIterations::kColumn; ++n) {
202
204for (int m = 0; m < MmaIterations::kRow; ++m) {
205
206int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
207
208if (AccumulatorsInRowMajor) { // matrix B is reordered
209mma(
210 ptr_D[n + m_serpentine * MmaIterations::kColumn],
211 ptr_A[m_serpentine],
212 ptr_B[n],
213 ptr_D[n + m_serpentine * MmaIterations::kColumn]);
214 } else {
215mma(
216 ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow],
217 ptr_A[m_serpentine],
218 ptr_B[n + n_off],
219 ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow]);
220 }
221 }
222 }
223 }
224 };
225
227
228 } // namespace warp
229 } // namespace gemm
230 } // namespace cutlass
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
cutlass::gemm::warp::MmaTensorOp::FragmentA
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_tensor_op.h:129
Definition: aligned_buffer.h:35
cutlass::gemm::warp::MmaTensorOp::LayoutB
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_tensor_op.h:97
cutlass::gemm::warp::MmaTensorOp::MmaTensorOp
CUTLASS_DEVICE MmaTensorOp()
Ctor.
Definition: mma_tensor_op.h:176
Architecture-specific operators on memory added for SM75.
[mma_tensor_op_tile_iterator.h](mma tensor op tile iterator_8h.html)
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
Defines common types used for all GEMM-like operators.
cutlass::gemm::warp::MmaTensorOp::kThreadCount
static int const kThreadCount
Number of threads participating in warp-level matrix product.
Definition: mma_tensor_op.h:112
cutlass::gemm::warp::MmaTensorOp::kPartitionsN
static int const kPartitionsN
PartitionsN indicating how many PartitionsN for multiplicand B.
Definition: mma_tensor_op.h:118
cutlass::gemm::warp::MmaTensorOp
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_tensor_op.h:82
cutlass::gemm::warp::MmaTensorOp::LayoutA
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_tensor_op.h:91
cutlass::gemm::warp::MmaTensorOp::FragmentB
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_tensor_op.h:138
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Defines a Shape template for matrix tiles.
cutlass::gemm::warp::MmaTensorOp::FragmentC
typename IteratorC::Fragment FragmentC
Storage for C tile.
Definition: mma_tensor_op.h:146
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator
Definition: mma_tensor_op_tile_iterator.h:1794
cutlass::gemm::warp::MmaTensorOp::operator()
CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0) const
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_tensor_op.h:180
cutlass::gemm::warp::MmaTensorOp::mma
Policy::Operator mma
Underlying matrix multiply operator (concept: arch::Mma)
Definition: mma_tensor_op.h:166
cutlass::gemm::warp::MmaTensorOp::ElementC
ElementC_ ElementC
Data type of accumulator matrix C.
Definition: mma_tensor_op.h:100
Top-level include for all CUTLASS numeric types.
cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator
Definition: mma_tensor_op_tile_iterator.h:75
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::gemm::warp::MmaTensorOp::LayoutC
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_tensor_op.h:103
cutlass::gemm::warp::MmaTensorOp::Policy
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_tensor_op.h:106
cutlass::gemm::warp::MmaTensorOp::Shape
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_tensor_op.h:85
cutlass::gemm::warp::MmaTensorOp::ElementB
ElementB_ ElementB
Data type of multiplicand B.
Definition: mma_tensor_op.h:94
cutlass::gemm::warp::MmaTensorOp::kPartitionsK
static int const kPartitionsK
Number of partitions along K dimension.
Definition: mma_tensor_op.h:115
cutlass::gemm::warp::MmaTensorOp::OperatorClass
arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator.
Definition: mma_tensor_op.h:109
cutlass::gemm::warp::MmaTensorOp::ElementA
ElementA_ ElementA
Data type of multiplicand A.
Definition: mma_tensor_op.h:88
Matrix multiply for SM75.
A multiplicand.
Basic include for CUTLASS.
[mma_tensor_op_policy.h](mma tensor op__policy_8h.html)
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.
Generated by 1.8.11