docs/mma__tensor__op__sm70_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_tensor_op_sm70.h
[Go to the documentation of this file.](mma tensor op__sm70_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
32 #pragma once
33
34 #include "cutlass/cutlass.h"
35 #include "cutlass/array.h"
36
37 #include "cutlass/numeric_types.h"
38 #include "cutlass/matrix_shape.h"
39
40 #include "cutlass/arch/mma.h"
41
42 #include "cutlass/gemm/gemm.h"
43 #include "cutlass/gemm/warp/mma.h"
44
45 #include "[cutlass/gemm/warp/mma_tensor_op_policy.h](mma tensor op__policy_8h.html)"
46 #include "[cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h](mma tensor op tile iterator__sm70_8h.html)"
47
49
50 namespace cutlass {
51 namespace gemm {
52 namespace warp {
53
55
57 template <
59typename Shape_,
61typename ElementA_,
63typename LayoutA_,
65typename ElementB_,
67typename LayoutB_,
69typename ElementC_,
71typename LayoutC_,
73typename Policy_,
75typename Enable = bool
76 >
77 class MmaVoltaTensorOp {
78 public:
81
84
87
90
93
96
99
102
104using OperatorClass = arch::OpClassTensorOp;
105
107static int const kThreadCount = 32;
108
110using InterleavedTileShape = GemmShape<32, 32, 4>;
111
112static_assert(!(Shape::kM % InterleavedTileShape::kM) &&
113 !(Shape::kN % InterleavedTileShape::kN),
114"Shape must be a multiple of InterleavedTileShape.");
115 public:
116
118using IteratorA = MmaVoltaTensorOpMultiplicandTileIterator<
119MatrixShape<Shape::kM, Shape::kK>,
120Operand::kA,
121ElementA,
122LayoutA,
123MatrixShape<
124 Policy::Operator::Shape::kM,
125 Policy::Operator::Shape::kK
126 >,
127 Policy::OpDelta::kRow,
128 kThreadCount
129 >;
130
132using FragmentA = typename IteratorA::Fragment;
133
135using IteratorB = MmaVoltaTensorOpMultiplicandTileIterator<
136MatrixShape<Shape::kK, Shape::kN>,
137Operand::kB,
138ElementB,
139LayoutB,
140MatrixShape<
141 Policy::Operator::Shape::kK,
142 Policy::Operator::Shape::kN
143 >,
144 Policy::OpDelta::kRow,
145 kThreadCount
146 >;
147
149using FragmentB = typename IteratorB::Fragment;
150
152using IteratorC = MmaVoltaTensorOpAccumulatorTileIterator<
153MatrixShape<Shape::kM, Shape::kN>,
154ElementC,
155LayoutC,
156typename Policy::Operator::Shape,
157typename Policy::OpDelta
158 >;
159
161using FragmentC = typename IteratorC::Fragment;
162
163 private:
164
165static_assert(
166 !(Shape::kM % Policy::Operator::Shape::kM) &&
167 !(Shape::kN % Policy::Operator::Shape::kN),
168"Shape of warp-level Mma must be divisible by operator shape.");
169
171using MmaIterations = MatrixShape<
172 InterleavedTileShape::kM / Policy::Operator::Shape::kM,
173InterleavedTileShape::kN / Policy::Operator::Shape::kN
174 >;
175using TileIterations = MatrixShape<
176 Shape::kM / InterleavedTileShape::kM,
177 Shape::kN / InterleavedTileShape::kN
178 >;
179
180// Whether matrix B is reordered
181bool reorder_B_;
182
183 public:
184
186typename Policy::Operator mma;
187
188 public:
189
190//
191// Methods
192//
193
195 CUTLASS_DEVICE
196MmaVoltaTensorOp() {}
197
199 CUTLASS_DEVICE
200void operator()(
201FragmentC &D,
202FragmentA const &A,
203FragmentB const &B,
204FragmentC const &C,
205int const &partitionN_idx = 0) {
206
207using MmaOperandA = typename Policy::Operator::FragmentA;
208using MmaOperandB = typename Policy::Operator::FragmentB;
209using MmaOperandC = typename Policy::Operator::FragmentC;
210
211 D = C;
212
213 MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
214 MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
215 MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
216
218for (int outer_col = 0; outer_col < TileIterations::kColumn; ++outer_col) {
220for (int inner_col = 0; inner_col < MmaIterations::kColumn; ++inner_col) {
222for (int outer_row = 0; outer_row < TileIterations::kRow; ++outer_row) {
224
225for (int inner_row = 0; inner_row < MmaIterations::kRow; ++inner_row) {
226
227int op_col = inner_col + MmaIterations::kColumn * outer_col;
228
229// Column-major serpentine sequence to maximize reuse of A operand.
230int inner_row_serp = inner_row;
231int outer_row_serp = outer_row;
232if (op_col & 1) {
233 inner_row_serp = MmaIterations::kRow - inner_row - 1;
234 outer_row_serp = TileIterations::kRow - outer_row - 1;
235 }
236int op_row = inner_row_serp + MmaIterations::kRow * outer_row_serp;
237int op_idx = inner_row_serp + MmaIterations::kRow *
238 (inner_col + MmaIterations::kColumn *
239 (outer_row_serp + TileIterations::kRow * outer_col));
240mma(
241 ptr_D[op_idx],
242 ptr_A[op_row],
243 ptr_B[op_col],
244 ptr_D[op_idx]);
245
246 }
247 }
248 }
249 }
250 }
251 };
252
254
255 } // namespace warp
256 } // namespace gemm
257 } // namespace cutlass
cutlass::gemm::warp::MmaVoltaTensorOp::Policy
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_tensor_op_sm70.h:101
cutlass::gemm::warp::MmaVoltaTensorOp::FragmentB
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_tensor_op_sm70.h:149
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
cutlass::gemm::warp::MmaVoltaTensorOp::LayoutB
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_tensor_op_sm70.h:92
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator
Definition: mma_tensor_op_tile_iterator_sm70.h:70
Definition: aligned_buffer.h:35
static int const kColumn
columns of a matrix
Definition: matrix_shape.h:44
Defines common types used for all GEMM-like operators.
cutlass::gemm::warp::MmaVoltaTensorOp::Shape
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_tensor_op_sm70.h:80
cutlass::gemm::warp::MmaVoltaTensorOp::OperatorClass
arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator.
Definition: mma_tensor_op_sm70.h:104
cutlass::gemm::warp::MmaVoltaTensorOp::LayoutA
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_tensor_op_sm70.h:86
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator::Fragment
Array< Element, Shape::kCount/kThreads > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_tensor_op_tile_iterator_sm70.h:1213
cutlass::gemm::warp::MmaVoltaTensorOp::ElementA
ElementA_ ElementA
Data type of multiplicand A.
Definition: mma_tensor_op_sm70.h:83
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Templates exposing architecture support for multiply-add operations.
[mma_tensor_op_tile_iterator_sm70.h](mma tensor op tile iterator__sm70_8h.html)
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
Defines a Shape template for matrix tiles.
cutlass::gemm::warp::MmaVoltaTensorOp::ElementB
ElementB_ ElementB
Data type of multiplicand B.
Definition: mma_tensor_op_sm70.h:89
cutlass::gemm::warp::MmaVoltaTensorOp::operator()
CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0)
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_tensor_op_sm70.h:200
cutlass::gemm::warp::MmaVoltaTensorOp::mma
Policy::Operator mma
Underlying matrix multiply operator (concept: arch::Mma)
Definition: mma_tensor_op_sm70.h:186
cutlass::gemm::warp::MmaVoltaTensorOp
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_tensor_op_sm70.h:77
static int const kRow
rows of a matrix
Definition: matrix_shape.h:43
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::gemm::warp::MmaVoltaTensorOp::FragmentC
typename IteratorC::Fragment FragmentC
Storage for C tile.
Definition: mma_tensor_op_sm70.h:161
cutlass::gemm::warp::MmaVoltaTensorOp::LayoutC
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_tensor_op_sm70.h:98
cutlass::gemm::warp::MmaVoltaTensorOp::kThreadCount
static int const kThreadCount
Number of threads participating in warp-level matrix product.
Definition: mma_tensor_op_sm70.h:107
cutlass::gemm::warp::MmaVoltaTensorOp::MmaVoltaTensorOp
CUTLASS_DEVICE MmaVoltaTensorOp()
Ctor.
Definition: mma_tensor_op_sm70.h:196
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator
Definition: mma_tensor_op_tile_iterator_sm70.h:1135
cutlass::gemm::warp::MmaVoltaTensorOp::ElementC
ElementC_ ElementC
Data type of accumulator matrix C.
Definition: mma_tensor_op_sm70.h:95
A multiplicand.
cutlass::gemm::warp::MmaVoltaTensorOp::FragmentA
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_tensor_op_sm70.h:132
Basic include for CUTLASS.
[mma_tensor_op_policy.h](mma tensor op__policy_8h.html)
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.
static int const kN
Definition: include/cutlass/gemm/gemm.h:59
Generated by 1.8.11