docs/mma__pipelined_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_pipelined.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/array.h"
33 #include "cutlass/aligned_buffer.h"
34 #include "cutlass/numeric_conversion.h"
35
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/matrix_shape.h"
38
39 #include "cutlass/gemm/gemm.h"
40 #include "cutlass/gemm/threadblock/mma_base.h"
41
43
44 namespace cutlass {
45 namespace gemm {
46 namespace threadblock {
47
49
51 template <
53typename Shape_,
55// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
56typename IteratorA_,
59typename SmemIteratorA_,
61// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
62typename IteratorB_,
65typename SmemIteratorB_,
67typename ElementC_,
69typename LayoutC_,
71typename Policy_,
73typename TransformA_ = NumericArrayConverter<
74typename SmemIteratorA_::Element,
75typename IteratorA_::Element,
76 IteratorA_::Fragment::kElements>,
79typename TransformB_ = NumericArrayConverter<
80typename SmemIteratorB_::Element,
81typename IteratorB_::Element,
82 IteratorB_::Fragment::kElements>,
84typename Enable = bool
85 >
86 class MmaPipelined : public MmaBase<Shape_, Policy_, 2> {
87 public:
88
90using Base = MmaBase<Shape_, Policy_, 2>;
91
93using IteratorA = IteratorA_;
94using IteratorB = IteratorB_;
98
99using SmemIteratorA = SmemIteratorA_;
100using SmemIteratorB = SmemIteratorB_;
101
102using TransformA = TransformA_;
103using TransformB = TransformB_;
104
105//
106// Dependent types
107//
108
110using FragmentA = typename IteratorA::Fragment;
111
113using FragmentB = typename IteratorB::Fragment;
114
116using FragmentC = typename Policy::Operator::FragmentC;
117
119using Operator = typename Policy::Operator;
120
121// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
122static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
123
124 private:
125
126using WarpFragmentA = typename Operator::FragmentA;
127using WarpFragmentB = typename Operator::FragmentB;
128
129 protected:
130
132SmemIteratorA smem_iterator_A_;
133
135SmemIteratorB smem_iterator_B_;
136
137 public:
138
140 CUTLASS_DEVICE
142typename Base::SharedStorage &shared_storage,
143int thread_idx,
144int warp_idx,
145int lane_idx
146 ):
147Base(shared_storage, thread_idx, warp_idx, lane_idx),
148 smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
149 smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
150
151// Compute warp location within threadblock tile by mapping the warp_id to
152// three coordinates:
153// _m: the warp's position within the threadblock along the M dimension
154// _n: the warp's position within the threadblock along the N dimension
155// _k: the warp's position within the threadblock along the K dimension
156
157int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
158int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
159
160int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
161int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
162
163// Add per-warp offsets in units of warp-level tiles
164 this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
165 this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
166 }
167
169 CUTLASS_DEVICE
170void operator()(
171int gemm_k_iterations,
172FragmentC &accum,
173IteratorA iterator_A,
174IteratorB iterator_B,
175FragmentC const &src_accum,
176TransformA transform_A = TransformA(),
177TransformB transform_B = TransformB()) {
178
179//
180// Prologue
181//
182
183// Perform accumulation in the 'd' output operand
184 accum = src_accum;
185
186FragmentA tb_frag_A;
187FragmentB tb_frag_B;
188
189 tb_frag_A.clear();
190 tb_frag_B.clear();
191
192// The last kblock is loaded in the prolog
193 iterator_A.load(tb_frag_A);
194 iterator_B.load(tb_frag_B);
195
196 ++iterator_A;
197 ++iterator_B;
198
199 this->smem_iterator_A_.store(transform_A(tb_frag_A));
200 this->smem_iterator_B_.store(transform_B(tb_frag_B));
201
202 ++this->smem_iterator_A_;
203 ++this->smem_iterator_B_;
204
205 __syncthreads();
206
207// Pair of fragments used to overlap shared memory loads and math instructions
208 WarpFragmentA warp_frag_A[2];
209 WarpFragmentB warp_frag_B[2];
210
211 this->warp_tile_iterator_A_.set_kgroup_index(0);
212 this->warp_tile_iterator_B_.set_kgroup_index(0);
213
214 this->warp_tile_iterator_A_.load(warp_frag_A[0]);
215 this->warp_tile_iterator_B_.load(warp_frag_B[0]);
216
217 ++this->warp_tile_iterator_A_;
218 ++this->warp_tile_iterator_B_;
219
220Operator warp_mma;
221
222int smem_write_stage_idx = 1;
223
224// Avoid reading out of bounds
225if (gemm_k_iterations <= 1) {
226 iterator_A.clear_mask();
227 iterator_B.clear_mask();
228 }
229
230// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
231// shared memory loads (which have the tightest latency requirement).
232
233//
234// Mainloop
235//
236
237// Note: The main loop does not support Base::kWarpGemmIterations == 2.
239for (; gemm_k_iterations > 0; --gemm_k_iterations) {
240//
241// Loop over GEMM K dimension
242//
243
245for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
246
247// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
248// as the case may be.
249
250if (warp_mma_k == Base::kWarpGemmIterations - 1) {
251
252// Write fragments to shared memory
253 this->smem_iterator_A_.store(transform_A(tb_frag_A));
254
255 this->smem_iterator_B_.store(transform_B(tb_frag_B));
256
257 __syncthreads();
258
259 ++this->smem_iterator_B_;
260 ++this->smem_iterator_A_;
261
262// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
263if (smem_write_stage_idx == 1) {
264 this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
265 this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
266 }
267else {
268 this->warp_tile_iterator_A_.add_tile_offset(
269 {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
270 this->warp_tile_iterator_B_.add_tile_offset(
271 {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
272 0});
273 }
274
275 smem_write_stage_idx ^= 1;
276 }
277
278 this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
279 this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
280
281 this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
282 this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
283
284 ++this->warp_tile_iterator_A_;
285 ++this->warp_tile_iterator_B_;
286
287if (warp_mma_k == 0) {
288
289 iterator_A.load(tb_frag_A);
290 iterator_B.load(tb_frag_B);
291
292 ++iterator_A;
293 ++iterator_B;
294
295// Avoid reading out of bounds if this was the last loop iteration
296if (gemm_k_iterations <= 2) {
297 iterator_A.clear_mask();
298 iterator_B.clear_mask();
299 }
300 }
301
302 warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
303 }
304 }
305
306 }
307 };
308
310
311 } // namespace threadblock
312 } // namespace gemm
313 } // namespace cutlass
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
cutlass::gemm::threadblock::MmaPipelined::LayoutC
LayoutC_ LayoutC
Layout of accumulator matrix.
Definition: mma_pipelined.h:96
cutlass::gemm::threadblock::MmaPipelined::TransformB
TransformB_ TransformB
Definition: mma_pipelined.h:103
Definition: aligned_buffer.h:35
cutlass::gemm::threadblock::MmaPipelined::Policy
Policy_ Policy
Policy describing tuning details.
Definition: mma_pipelined.h:97
cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 2 >::warp_tile_iterator_B_
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory.
Definition: mma_base.h:193
cutlass::gemm::threadblock::MmaPipelined
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_pipelined.h:86
cutlass::gemm::threadblock::MmaPipelined::IteratorB
IteratorB_ IteratorB
Iterates over tiles of B operand in global memory.
Definition: mma_pipelined.h:94
Defines common types used for all GEMM-like operators.
cutlass::gemm::threadblock::MmaPipelined::operator()
CUTLASS_DEVICE void operator()(int gemm_k_iterations, FragmentC &accum, IteratorA iterator_A, IteratorB iterator_B, FragmentC const &src_accum, TransformA transform_A=TransformA(), TransformB transform_B=TransformB())
Perform a threadblock-scoped matrix multiply-accumulate.
Definition: mma_pipelined.h:170
cutlass::gemm::threadblock::MmaPipelined::IteratorA
IteratorA_ IteratorA
Iterates over tiles of A operand in global memory.
Definition: mma_pipelined.h:93
cutlass::gemm::threadblock::MmaPipelined::FragmentB
typename IteratorB::Fragment FragmentB
Fragment of operand B loaded from global memory.
Definition: mma_pipelined.h:113
cutlass::gemm::threadblock::MmaPipelined::SmemIteratorA
SmemIteratorA_ SmemIteratorA
Definition: mma_pipelined.h:99
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Boost-like numeric conversion operator for CUTLASS numeric types.
Defines a Shape template for matrix tiles.
cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 2 >::kWarpGemmIterations
static int const kWarpGemmIterations
Number of warp-level GEMM oeprations.
Definition: mma_base.h:108
Template for a double-buffered threadblock-scoped GEMM kernel.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
cutlass::gemm::threadblock::MmaPipelined::Shape
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: mma_pipelined.h:92
cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 2 >::kStages
static int const kStages
Number of stages.
Definition: mma_base.h:112
cutlass::gemm::threadblock::MmaPipelined::FragmentA
typename IteratorA::Fragment FragmentA
Fragment of operand A loaded from global memory.
Definition: mma_pipelined.h:110
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::gemm::threadblock::MmaBase
Definition: mma_base.h:83
cutlass::gemm::threadblock::MmaPipelined::FragmentC
typename Policy::Operator::FragmentC FragmentC
Fragment of accumulator tile.
Definition: mma_pipelined.h:116
cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 2 >::warp_tile_iterator_A_
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory.
Definition: mma_base.h:190
#define CUTLASS_GEMM_LOOP
Definition: cutlass.h:112
cutlass::gemm::threadblock::MmaPipelined::ElementC
ElementC_ ElementC
Data type of accumulator matrix.
Definition: mma_pipelined.h:95
cutlass::gemm::threadblock::MmaPipelined::smem_iterator_A_
SmemIteratorA smem_iterator_A_
Iterator to write threadblock-scoped tile of A operand to shared memory.
Definition: mma_pipelined.h:132
cutlass::gemm::threadblock::MmaPipelined::SmemIteratorB
SmemIteratorB_ SmemIteratorB
Definition: mma_pipelined.h:100
cutlass::gemm::threadblock::MmaPipelined::smem_iterator_B_
SmemIteratorB smem_iterator_B_
Iterator to write threadblock-scoped tile of B operand to shared memory.
Definition: mma_pipelined.h:135
cutlass::gemm::threadblock::MmaPipelined::MmaPipelined
CUTLASS_DEVICE MmaPipelined(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references.
Definition: mma_pipelined.h:141
Basic include for CUTLASS.
cutlass::gemm::threadblock::MmaPipelined::TransformA
TransformA_ TransformA
Definition: mma_pipelined.h:102
cutlass::gemm::threadblock::MmaPipelined::Operator
typename Policy::Operator Operator
Warp-level Mma.
Definition: mma_pipelined.h:119
static int const kN
Definition: include/cutlass/gemm/gemm.h:59
Generated by 1.8.11