docs/mma__tensor__op__wmma_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_tensor_op_wmma.h
[Go to the documentation of this file.](mma tensor op__wmma_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
30 #pragma once
31
32 #include "cutlass/cutlass.h"
33 #include "cutlass/arch/wmma.h"
34
35 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
36
37 #include "cutlass/wmma_array.h"
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/matrix_shape.h"
40
41 #include "cutlass/arch/memory_sm75.h"
42 #include "cutlass/arch/mma_sm75.h"
43 #include "cutlass/gemm/gemm.h"
44 #include "cutlass/gemm/warp/mma.h"
45
46 #include "[cutlass/gemm/warp/mma_tensor_op_policy.h](mma tensor op__policy_8h.html)"
47
48 #include "[cutlass/gemm/warp/mma_tensor_op_tile_iterator_wmma.h](mma tensor op tile iterator__wmma_8h.html)"
49
51
52 namespace cutlass {
53 namespace gemm {
54 namespace warp {
55
57
59 template <
61typename Shape_,
63typename ElementA_,
65typename LayoutA_,
67typename ElementB_,
69typename LayoutB_,
71typename ElementC_,
73typename LayoutC_,
75typename Policy_,
77int PartitionsK_ = 1,
79int PartitionsN_ = 1,
81typename Enable = bool
82 >
83 class MmaTensorOpWmma {
84 public:
86using Shape = Shape_;
87
89using ElementA = ElementA_;
90
92using LayoutA = LayoutA_;
93
95using ElementB = ElementB_;
96
98using LayoutB = LayoutB_;
99
101using ElementC = ElementC_;
102
104using LayoutC = LayoutC_;
105
107using Policy = Policy_;
108
110using OperatorClass = arch::OpClassTensorOp;
111
113static int const kThreadCount = 32;
114
116static int const kPartitionsK = PartitionsK_;
117
119static int const kPartitionsN = PartitionsN_;
120
121 public:
122
124using IteratorA = MmaTensorOpWmmaMultiplicandTileIterator<
125 MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
126 Policy::OpDelta::kRow, kThreadCount, Policy>;
127
129using FragmentA = typename IteratorA::Fragment;
130
132using IteratorB = MmaTensorOpWmmaMultiplicandTileIterator<
133 MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
134 Policy::OpDelta::kRow, kThreadCount, Policy>;
135
137using FragmentB = typename IteratorB::Fragment;
138
140using IteratorC = MmaTensorOpWmmaAccumulatorTileIterator<
141 MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
142typename Policy::OpDelta, Policy>;
143
145using FragmentC = typename IteratorC::Fragment;
146
147 private:
148
149static_assert(
150 !(Shape::kM % Policy::Operator::Shape::kM) &&
151 !(Shape::kN % Policy::Operator::Shape::kN),
152"Shape of warp-level Wmma must be divisible by operator shape (wmma native size)");
153
155using WmmaIterations = MatrixShape<
156 Shape::kM / Policy::Operator::Shape::kM,
157 (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?
158 Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :
159 1
160 >;
161
162 public:
163
165typename Policy::Operator wmma;
166
167 public:
168
169//
170// Methods
171//
172
174 CUTLASS_DEVICE
175 MmaTensorOpWmma() {}
176
178 CUTLASS_DEVICE
179void operator()(
180 FragmentC &D,
181 FragmentA const &A,
182 FragmentB const &B,
183 FragmentC const &C,
184int const &partitionN_idx = 0) const {
185
187for (int n = 0; n < WmmaIterations::kColumn; ++n) {
189for (int m = 0; m < WmmaIterations::kRow; ++m) {
190
191// accumulate wmma mma
192 wmma(D[m * WmmaIterations::kColumn + n], A[m], B[n], C[m * WmmaIterations::kColumn + n]);
193 }
194 }
195 }
196
197 };
198
200
201 } // namespace warp
202 } // namespace gemm
203 } // namespace cutlass
204
205 #endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
206
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
Architecture-specific operators on memory added for SM75.
Defines common types used for all GEMM-like operators.
[mma_tensor_op_tile_iterator_wmma.h](mma tensor op tile iterator__wmma_8h.html)
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Defines a Shape template for matrix tiles.
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
Matrix multiply for SM75.
A multiplicand.
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
[mma_tensor_op_policy.h](mma tensor op__policy_8h.html)
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.
Generated by 1.8.11