docs/output__tile__thread__map_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
output_tile_thread_map.h
[Go to the documentation of this file.](output tile thread__map_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
31 #pragma once
32
33 #include "cutlass/cutlass.h"
34 #include "cutlass/numeric_types.h"
35 #include "cutlass/array.h"
36 #include "cutlass/layout/matrix.h"
37 #include "cutlass/matrix_shape.h"
38 #include "cutlass/tensor_ref.h"
39 #include "cutlass/fast_math.h"
40
42
43 namespace cutlass {
44 namespace epilogue {
45 namespace threadblock {
46
48
50 template <
51int Column,
52int Row,
53int Group,
54int Cluster,
55int Tile
56 >
57 struct OutputTileShape {
58static int const kColumn = Column;
59static int const kRow = Row;
60static int const kGroup = Group;
61static int const kCluster = Cluster;
62static int const kTile = Tile;
63
64static int const kCount = kColumn * kRow * kGroup * kCluster * kTile;
65 };
66
68
69 template <
70typename ThreadMap_,
71typename Shape_,
72typename Iterations_,
73typename Delta_,
74typename Count_
75 >
76 struct OutputTileThreadMap {
77
79using ThreadMap = ThreadMap_;
80
82static int const kThreads = ThreadMap::kThreads;
83
85static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
86
89
91using Iterations = Iterations_;
92
95
98
101static MatrixCoord initial_offset(int thread_idx) {
102
103using Index = typename layout::PitchLinearCoord::Index;
104
105layout::PitchLinearCoord coord = ThreadMap::initial_offset(thread_idx);
106
107 Index cluster = coord.strided() / (Shape::kGroup * Shape::kRow);
108 Index cluster_residual = coord.strided() % (Shape::kGroup * Shape::kRow);
109
110 Index group = cluster_residual / (Shape::kRow);
111 Index row = cluster_residual % (Shape::kRow);
112
113return MatrixCoord{
114 row + group * Shape::kRow * Count::kRow
115 + cluster * Shape::kGroup * Count::kGroup * Shape::kRow * Count::kRow,
116 coord.contiguous()
117 };
118 }
119 };
120
122
123 namespace detail {
124
126 template <
127typename Shape,
128int WarpsRemaining,
129int ElementsPerAccess,
130int ElementSize,
131bool Is2dTile
132 >
133 struct RowArrangement;
134
136 template <
137typename Shape,
138int WarpsRemaining,
139int ElementsPerAccess,
140int ElementSize
141 >
142 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {
143static int const kWarpSize = 32;
144static int const kElementsPerAccess = ElementsPerAccess;
145static int const kElementSize = ElementSize;
146
147static int const kIterationsRow = 1;
148static int const kDeltaRow = 1;
149static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;
150static int const kDeltaColumn = kWarpSize * kElementsPerAccess;
151
152static int const kAccessWidth = kWarpSize;
153static int const kAccessRows = 1;
154static int const kWarpPartitionsRow = 1;
155static int const kWarpPartitionsColumn = WarpsRemaining;
156 };
157
159 template <
160typename Shape,
161int WarpsRemaining,
162int ElementsPerAccess,
163int ElementSize
164 >
165 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {
166
167static int const kMemoryAccessSize = 128;
168static int const kWarpSize = 32;
169
170static int const kElementsPerAccess = ElementsPerAccess;
171static int const kElementSize = ElementSize;
172
173struct Detail {
174static int const kShapeRow = Shape::kRow / WarpsRemaining;
175static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;
176
177static int const kTargetMemoryAccessWidth =
178 kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);
179
180static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;
181 };
182
183static int const kAccessWidth =
184 (Detail::kTargetAccessRows > Detail::kShapeRow ?
185 kWarpSize / Detail::kShapeRow
186 : const_min(
187 Detail::kShapeWidth,
188const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))
189 ));
190
191static int const kAccessRows =
192 (Detail::kTargetAccessRows > Detail::kShapeRow ?
193 Detail::kShapeRow
194 : const_min(Shape::kRow, kWarpSize / kAccessWidth));
195
196static int const kIterationsRow = Detail::kShapeRow / kAccessRows;
197static int const kDeltaRow = kAccessRows;
198
199static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;
200static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;
201
202static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access");
203static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" );
204static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" );
205
206static int const kWarpPartitionsRow = 1;
207static int const kWarpPartitionsColumn = 1;
208 };
209
210 }
211
213
221 template <
222typename Shape_,
223typename Count_,
224int Threads,
225int ElementsPerAccess,
226int ElementSize
227 >
228 struct OutputTileOptimalThreadMap {
229
232
233static int const kWarpSize = 32;
234static int const kThreads = Threads;
235static int const kWarpCount = kThreads / kWarpSize;
236
237static int const kElementsPerAccess = ElementsPerAccess;
238static int const kElementSize = ElementSize;
239
240//
241// Metaprogram computation
242//
243
245
246// Clusters
247static int const kIterationsCluster =
248 ((Shape::kCluster > kWarpCount) ?
249 Shape::kCluster / kWarpCount
250 : 1);
251
252static int const kDeltaCluster =
253 ((Shape::kCluster > kWarpCount) ?
254 Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster
255 : 1);
256
257static int const kCompactedDeltaCluster =
258 ((Shape::kCluster > kWarpCount) ?
259 Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster
260 : 1);
261
262static int const kWarpPartitionsCluster =
263 ((Shape::kCluster > kWarpCount) ?
264 kWarpCount
265 : kWarpCount / Shape::kCluster);
266
267static int const kWarpsRemainingForGroups =
268 ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);
269
270// Groups
271static int const kIterationsGroup =
272 ((Shape::kGroup > kWarpsRemainingForGroups) ?
273 Shape::kGroup / kWarpsRemainingForGroups
274 : 1);
275
276static int const kDeltaGroup =
277 ((Shape::kGroup > kWarpsRemainingForGroups) ?
278 Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup
279 : 1);
280
281static int const kCompactedDeltaGroup =
282 ((Shape::kGroup > kWarpsRemainingForGroups) ?
283 Shape::kRow * Shape::kGroup / kIterationsGroup
284 : 1);
285
286static int const kWarpPartitionsGroup =
287 ((Shape::kGroup > kWarpsRemainingForGroups) ?
288 1
289 : kWarpsRemainingForGroups / Shape::kGroup);
290
291static int const kWarpsRemainingForRows =
292 ((Shape::kGroup > kWarpsRemainingForGroups) ?
293 1
294 : kWarpsRemainingForGroups / Shape::kGroup);
295
296// Rows
297using RowArrangement = detail::RowArrangement<
298Shape,
299 kWarpsRemainingForRows,
300 kElementsPerAccess,
301 kElementSize,
302 (Shape::kRow > kWarpsRemainingForRows)
303 >;
304
305// Warp partitions
306using WarpPartitions = OutputTileShape<
307 RowArrangement::kWarpPartitionsColumn,
308 RowArrangement::kWarpPartitionsRow,
309 kWarpPartitionsGroup,
310 kWarpPartitionsCluster,
311 1>;
312
313static int const kAccessWidth = RowArrangement::kAccessWidth;
314static int const kAccessRows = RowArrangement::kAccessRows;
315 };
316
317//
318// Output
319//
320
321using Iterations = OutputTileShape<
322 Detail::RowArrangement::kIterationsColumn,
323 Detail::RowArrangement::kIterationsRow,
324 Detail::kIterationsGroup,
325 Detail::kIterationsCluster,
326 1>;
327
328using Delta = OutputTileShape<
329 Detail::RowArrangement::kDeltaColumn,
330 Detail::RowArrangement::kDeltaRow,
331 Detail::kDeltaGroup,
332 Detail::kDeltaCluster,
333 1>;
334
337static MatrixCoord initial_offset(int thread_idx) {
338
339int warp_idx = thread_idx / kWarpSize;
340int lane_idx = thread_idx % kWarpSize;
341
342// Compute warp location
343int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
344int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
345
346int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
347int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
348
349int row_idx = residual_group / Detail::WarpPartitions::kRow;
350int col_idx = residual_group % Detail::WarpPartitions::kRow;
351
352// Compute per-lane offset
353int lane_row_offset = lane_idx / Detail::kAccessWidth;
354int lane_col_offset = lane_idx % Detail::kAccessWidth;
355
356// Compute coordinate in output space
357int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;
358int group_offset = group_idx * Shape::kRow * Count::kRow;
359int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
360int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
361
362return MatrixCoord(
363 cluster_offset + group_offset + row_offset + lane_row_offset,
364 (column_offset + lane_col_offset) * kElementsPerAccess
365 );
366 }
367
369struct CompactedThreadMap {
370
371
373
374using Iterations = OutputTileShape<
375 Detail::RowArrangement::kIterationsColumn,
376 Detail::RowArrangement::kIterationsRow,
377 Detail::kIterationsGroup,
378 Detail::kIterationsCluster,
379 1>;
380
381using Delta = OutputTileShape<
382 Detail::RowArrangement::kDeltaColumn,
383 Detail::RowArrangement::kDeltaRow,
384 Detail::kCompactedDeltaGroup,
385 Detail::kCompactedDeltaCluster,
386 1>;
387
389static int const kElementsPerAccess = ElementsPerAccess;
390
392static int const kThreads = Threads;
393
396static MatrixCoord initial_offset(int thread_idx) {
397
398int warp_idx = thread_idx / kWarpSize;
399int lane_idx = thread_idx % kWarpSize;
400
401// Compute warp location
402int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;
403int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;
404
405int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;
406int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;
407
408int row_idx = residual_group / Detail::WarpPartitions::kRow;
409int col_idx = residual_group % Detail::WarpPartitions::kRow;
410
411// Compute per-lane offset
412int lane_row_offset = lane_idx / Detail::kAccessWidth;
413int lane_col_offset = lane_idx % Detail::kAccessWidth;
414
415// Compute coordinate in output space
416int cluster_offset = cluster_idx * Shape::kRow * Shape::kGroup;
417int group_offset = group_idx * Shape::kRow;
418int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;
419int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;
420
421MatrixCoord coord(
422 cluster_offset + group_offset + row_offset + lane_row_offset,
423 (column_offset + lane_col_offset) * kElementsPerAccess
424 );
425
426return coord;
427 }
428 };
429 };
430
432
440 template <typename WarpCount_, typename MmaCount_, int Threads,
441int ElementsPerAccess, int ElementSize>
442 struct InterleavedOutputTileThreadMap {
443using WarpCount = WarpCount_;
444using MmaCount = MmaCount_;
445
446static int const kWarpSize = 32;
447static int const kThreads = Threads;
448static int const kWarpCount = kThreads / kWarpSize;
449
450static int const kElementsPerAccess = ElementsPerAccess;
451static int const kElementSize = ElementSize;
452
453//
454// Metaprogram computation
455//
456
458
459//
460// Output
461//
462
463using Iterations = MmaCount;
464
465using Delta = layout::PitchLinearShape<kWarpSize * kElementsPerAccess, 1>;
466
469static layout::PitchLinearCoord initial_offset(int thread_idx) {
470int warp_idx = thread_idx / kWarpSize;
471int lane_idx = thread_idx % kWarpSize;
472
473// Compute warp location
474layout::PitchLinearCoord warp_footprint{
475 Delta::kContiguous * Iterations::kContiguous,
476 Delta::kStrided * Iterations::kStrided};
477
478layout::PitchLinearCoord warp_offset{warp_idx % WarpCount::kContiguous,
479 warp_idx / WarpCount::kContiguous};
480
481// Compute per-lane offset
482layout::PitchLinearCoord thread_offset_in_warp{
483 lane_idx * kElementsPerAccess, 0};
484
485layout::PitchLinearCoord thread_offset_in_threadblock_tile =
486 warp_footprint * warp_offset + thread_offset_in_warp;
487
488return thread_offset_in_threadblock_tile;
489 }
490 };
491
493
494 } // namespace threadblock
495 } // namespace epilogue
496 } // namespace cutlass
cutlass::layout::PitchLinearCoord::Index
int Index
Integer-valued index.
Definition: pitch_linear.h:56
cutlass::epilogue::threadblock::OutputTileThreadMap::ThreadMap
ThreadMap_ ThreadMap
Conventional thread map (concept: ThreadMap)
Definition: output_tile_thread_map.h:79
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap
Definition: output_tile_thread_map.h:228
Definition: aligned_buffer.h:35
cutlass::layout::PitchLinearCoord
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::Count
Count_ Count
Definition: output_tile_thread_map.h:231
cutlass::epilogue::threadblock::OutputTileShape::kGroup
static int const kGroup
Definition: output_tile_thread_map.h:60
cutlass::epilogue::threadblock::OutputTileShape
Tuple defining point in output tile.
Definition: output_tile_thread_map.h:57
cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::WarpCount
WarpCount_ WarpCount
Definition: output_tile_thread_map.h:443
cutlass::epilogue::threadblock::OutputTileThreadMap::Iterations
Iterations_ Iterations
Iterations performed by each thread.
Definition: output_tile_thread_map.h:91
cutlass::epilogue::threadblock::OutputTileShape::kColumn
static int const kColumn
Definition: output_tile_thread_map.h:58
cutlass::epilogue::threadblock::detail::RowArrangement
RowArrangement determines how one or more warps cover a region of consecutive rows.
Definition: output_tile_thread_map.h:133
cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap
Definition: output_tile_thread_map.h:442
cutlass::layout::PitchLinearShape
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::CompactedThreadMap
Compacted thread map in which the 4D region is contiguous.
Definition: output_tile_thread_map.h:369
cutlass::epilogue::threadblock::OutputTileThreadMap::Count
Count_ Count
Number of iterator iterations.
Definition: output_tile_thread_map.h:97
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::CompactedThreadMap::initial_offset
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Function to compute each thread's initial offset.
Definition: output_tile_thread_map.h:396
Defines a Shape template for matrix tiles.
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::Shape
Shape_ Shape
Definition: output_tile_thread_map.h:230
cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::initial_offset
static CUTLASS_HOST_DEVICE layout::PitchLinearCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:469
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::Detail::RowArrangement
detail::RowArrangement< Shape, kWarpsRemainingForRows, kElementsPerAccess, kElementSize,(Shape::kRow > kWarpsRemainingForRows) > RowArrangement
Definition: output_tile_thread_map.h:303
cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::Iterations
MmaCount Iterations
Definition: output_tile_thread_map.h:463
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
cutlass::layout::PitchLinearCoord::contiguous
CUTLASS_HOST_DEVICE Index const & contiguous() const
Returns the contiguous dimension.
Definition: pitch_linear.h:89
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::CompactedThreadMap::Shape
Shape_ Shape
Definition: output_tile_thread_map.h:372
cutlass::epilogue::threadblock::OutputTileThreadMap::Delta
Delta_ Delta
Delta between accesses.
Definition: output_tile_thread_map.h:94
cutlass::epilogue::threadblock::OutputTileThreadMap::initial_offset
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:101
cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::Detail
Definition: output_tile_thread_map.h:457
cutlass::epilogue::threadblock::OutputTileShape::kRow
static int const kRow
Definition: output_tile_thread_map.h:59
Defines layout functions used by TensorRef and derived classes.
Math utilities.
cutlass::epilogue::threadblock::OutputTileThreadMap
Definition: output_tile_thread_map.h:76
cutlass::epilogue::threadblock::OutputTileThreadMap::Shape
Shape_ Shape
Shape of the tile.
Definition: output_tile_thread_map.h:88
cutlass::epilogue::threadblock::OutputTileShape::kTile
static int const kTile
Definition: output_tile_thread_map.h:62
cutlass::epilogue::threadblock::OutputTileShape::kCount
static int const kCount
Definition: output_tile_thread_map.h:64
cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::MmaCount
MmaCount_ MmaCount
Definition: output_tile_thread_map.h:444
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::initial_offset
static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)
Initial offset function.
Definition: output_tile_thread_map.h:337
CUTLASS_HOST_DEVICE constexpr int const_min(int a, int b)
Definition: fast_math.h:219
Basic include for CUTLASS.
Definition: matrix_coord.h:39
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::Detail
Definition: output_tile_thread_map.h:244
cutlass::epilogue::threadblock::OutputTileShape::kCluster
static int const kCluster
Definition: output_tile_thread_map.h:61
cutlass::layout::PitchLinearCoord::strided
CUTLASS_HOST_DEVICE Index const & strided() const
Returns the column of the coordinate.
Definition: pitch_linear.h:97
Generated by 1.8.11