docs/gemm_2threadblock_2threadblock__swizzle_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
gemm/threadblock/threadblock_swizzle.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
30 #pragma once
31
32 #include "cutlass/cutlass.h"
33
34 #include "cutlass/gemm/gemm.h"
35
37
38 namespace cutlass {
39 namespace gemm {
40 namespace threadblock {
41
43
45 CUTLASS_DEVICE
46 int RematerializeThreadIdxX() {
47return threadIdx.x;
48 }
49
51 CUTLASS_DEVICE
52 int RematerializeThreadIdxY() {
53return threadIdx.y;
54 }
55
57 CUTLASS_DEVICE
58 int RematerializeThreadIdxZ() {
59return threadIdx.z;
60 }
61
63 CUTLASS_DEVICE
64 int RematerializeBlockIdxX() {
65return blockIdx.x;
66 }
67
69 CUTLASS_DEVICE
70 int RematerializeBlockIdxY() {
71return blockIdx.y;
72 }
73
75 CUTLASS_DEVICE
76 int RematerializeBlockIdxZ() {
77return blockIdx.z;
78 }
79
81 CUTLASS_DEVICE
82 int RematerializeBlockDimX() {
83return blockDim.x;
84 }
85
87 CUTLASS_DEVICE
88 int RematerializeBlockDimY() {
89return blockDim.y;
90 }
91
93 CUTLASS_DEVICE
94 int RematerializeBlockDimZ() {
95return blockDim.z;
96 }
97
99
101 struct GemmIdentityThreadblockSwizzle {
102
104GemmIdentityThreadblockSwizzle() { }
105
107
111GemmCoord problem_size,
112GemmCoord tile_size,
113int split_k_slices) const {
114
115return GemmCoord(
116 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
117 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
118 split_k_slices);
119 }
120
123 dim3 get_grid_shape(GemmCoord tiled_shape) const {
124return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());
125 }
126
128 CUTLASS_DEVICE
129GemmCoord get_tile_offset() const {
130
131int block_idx_x = RematerializeBlockIdxX();
132int block_idx_y = RematerializeBlockIdxY();
133
134return GemmCoord{
135 (block_idx_x / kTile),
136 (block_idx_y * kTile) + (block_idx_x % kTile),
138 };
139 }
140 };
141
143
145 struct GemmHorizontalThreadblockSwizzle {
146
148GemmHorizontalThreadblockSwizzle() { }
149
153GemmCoord problem_size,
154GemmCoord tile_size,
155int split_k_slices) const {
156
157return GemmCoord(
158 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
159 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
160 split_k_slices);
161 }
162
165 dim3 get_grid_shape(GemmCoord tiled_shape) const {
166return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k());
167 }
168
170 CUTLASS_DEVICE
171GemmCoord get_tile_offset() const {
172return GemmCoord{
173RematerializeBlockIdxY(),
174RematerializeBlockIdxX(),
176 };
177 }
178 };
179
181
183 struct GemmBatchedIdentityThreadblockSwizzle {
184
188GemmCoord problem_size,
189int batch_count,
190GemmCoord tile_size) const {
191
192return GemmCoord(
193 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
194 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
195 batch_count % (1 << 16));
196 }
197
200 dim3 get_grid_shape(GemmCoord tiled_shape) const {
201return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
202 }
203
205 CUTLASS_DEVICE
206GemmCoord get_tile_offset() const {
207return GemmCoord{
208RematerializeBlockIdxX(),
209RematerializeBlockIdxY(),
210 0
211 };
212 }
213
215 CUTLASS_DEVICE
216int get_batch_idx() const {
217return RematerializeBlockIdxZ();
218 }
219 };
220
222
224 struct GemmSplitKIdentityThreadblockSwizzle {
225
229GemmCoord problem_size,
230GemmCoord tile_size,
231int partitions) const {
232
233return GemmCoord(
234 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
235 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
236 partitions);
237 }
238
241 dim3 get_grid_shape(GemmCoord tiled_shape) const {
242return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());
243 }
244
245
247 CUTLASS_DEVICE
248GemmCoord get_tile_offset() const {
249return GemmCoord{
250RematerializeBlockIdxX(),
251RematerializeBlockIdxY(),
253 };
254 }
255 };
256
258
260 struct GemmSplitKHorizontalThreadblockSwizzle {
261
265GemmCoord problem_size,
266GemmCoord tile_size,
267int partitions) const {
268
269return GemmCoord(
270 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),
271 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
272 partitions);
273 }
274
277 dim3 get_grid_shape(GemmCoord tiled_shape) const {
278return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k());
279 }
280
281
283 CUTLASS_DEVICE
284GemmCoord get_tile_offset() const {
285return GemmCoord{
286RematerializeBlockIdxY(),
287RematerializeBlockIdxX(),
289 };
290 }
291 };
292
294
296 struct GemvBatchedStridedThreadblockDefaultSwizzle {
297
300BatchedGemmCoord get_tiled_shape(
301BatchedGemmCoord problem_size,
302BatchedGemmCoord tile_size) const {
303
304return BatchedGemmCoord(
305 1, // M is always 1
306 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),
307 (problem_size.k() + tile_size.k() - 1) / tile_size.k(),
308 (problem_size.batch() + tile_size.batch() - 1) / tile_size.batch());
309 }
310
313 dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const {
314return dim3(tiled_shape.n(), tiled_shape.batch(), tiled_shape.k());
315 }
316
318 CUTLASS_DEVICE
319BatchedGemmCoord get_tile_offset() const {
320return BatchedGemmCoord{
321 0, // M is always 1
322RematerializeBlockIdxX(),
323RematerializeBlockIdxZ(),
324RematerializeBlockIdxY(),
325 };
326 }
327
329 CUTLASS_DEVICE
330int get_batch_tile_idx() const {
331return RematerializeBlockIdxY();
332 }
333
335 CUTLASS_DEVICE
336int get_batch_idx() const {
337return RematerializeBlockDimY()*RematerializeBlockIdxY() + RematerializeThreadIdxY();
338 }
339 };
340
342
343 } // namespace threadblock
344 } // namespace gemm
345 } // namespace cutlass
346
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::kTile
int const kTile
Definition: gemm/threadblock/threadblock_swizzle.h:106
Definition: aligned_buffer.h:35
cutlass::gemm::threadblock::RematerializeThreadIdxY
CUTLASS_DEVICE int RematerializeThreadIdxY()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:52
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle::get_grid_shape
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:200
Definition: include/cutlass/gemm/gemm.h:94
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle::get_tile_offset
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:206
cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle::get_grid_shape
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:241
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle::get_tiled_shape
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int partitions) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:264
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::threadblock::RematerializeBlockDimX
CUTLASS_DEVICE int RematerializeBlockDimX()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:82
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::get_grid_shape
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:123
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle::get_tiled_shape
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int split_k_slices) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:152
cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle::get_grid_shape
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:165
cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_grid_shape
CUTLASS_HOST_DEVICE dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:313
cutlass::gemm::threadblock::RematerializeThreadIdxX
CUTLASS_DEVICE int RematerializeThreadIdxX()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:46
cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle
Threadblock swizzling function for GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:145
cutlass::gemm::BatchedGemmCoord
Definition: include/cutlass/gemm/gemm.h:260
cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle::GemmHorizontalThreadblockSwizzle
CUTLASS_HOST_DEVICE GemmHorizontalThreadblockSwizzle()
Definition: gemm/threadblock/threadblock_swizzle.h:148
cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_batch_tile_idx
CUTLASS_DEVICE int get_batch_tile_idx() const
Gets the batch tile index.
Definition: gemm/threadblock/threadblock_swizzle.h:330
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle::get_grid_shape
CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const
Computes CUDA grid dimensions given a size in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:277
cutlass::gemm::threadblock::RematerializeThreadIdxZ
CUTLASS_DEVICE int RematerializeThreadIdxZ()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:58
cutlass::gemm::BatchedGemmCoord::batch
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: include/cutlass/gemm/gemm.h:322
cutlass::gemm::threadblock::RematerializeBlockDimY
CUTLASS_DEVICE int RematerializeBlockDimY()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:88
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::GemmIdentityThreadblockSwizzle
CUTLASS_HOST_DEVICE GemmIdentityThreadblockSwizzle()
Definition: gemm/threadblock/threadblock_swizzle.h:104
cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_batch_idx
CUTLASS_DEVICE int get_batch_idx() const
Gets the absolute batch index.
Definition: gemm/threadblock/threadblock_swizzle.h:336
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::get_tiled_shape
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int split_k_slices) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:110
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::gemm::BatchedGemmCoord::k
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:314
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::get_tile_offset
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:129
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle
Threadblock swizzling function for split-K GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:260
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle::get_tiled_shape
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, int batch_count, GemmCoord tile_size) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:187
cutlass::gemm::threadblock::RematerializeBlockIdxY
CUTLASS_DEVICE int RematerializeBlockIdxY()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:70
cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle::get_tiled_shape
CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int partitions) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:228
cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_tiled_shape
CUTLASS_HOST_DEVICE BatchedGemmCoord get_tiled_shape(BatchedGemmCoord problem_size, BatchedGemmCoord tile_size) const
Returns the shape of the problem in units of logical tiles.
Definition: gemm/threadblock/threadblock_swizzle.h:300
cutlass::gemm::threadblock::RematerializeBlockDimZ
CUTLASS_DEVICE int RematerializeBlockDimZ()
Helper to rematerialize block Dim. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:94
cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_tile_offset
CUTLASS_DEVICE BatchedGemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:319
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle
Threadblock swizzling function for GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:101
cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle
Threadblock swizzling function for batched GEMVs.
Definition: gemm/threadblock/threadblock_swizzle.h:296
cutlass::gemm::BatchedGemmCoord::n
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:306
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle::get_batch_idx
CUTLASS_DEVICE int get_batch_idx() const
Gets the batch index.
Definition: gemm/threadblock/threadblock_swizzle.h:216
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
cutlass::gemm::threadblock::RematerializeBlockIdxZ
CUTLASS_DEVICE int RematerializeBlockIdxZ()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:76
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle
Threadblock swizzling function for batched GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:183
cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle::get_tile_offset
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:284
cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle::get_tile_offset
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:171
cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle
Threadblock swizzling function for split-K GEMMs.
Definition: gemm/threadblock/threadblock_swizzle.h:224
Basic include for CUTLASS.
cutlass::gemm::threadblock::RematerializeBlockIdxX
CUTLASS_DEVICE int RematerializeBlockIdxX()
Helper to rematerialize block Idx. Reduces register liveness.
Definition: gemm/threadblock/threadblock_swizzle.h:64
cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle::get_tile_offset
CUTLASS_DEVICE GemmCoord get_tile_offset() const
Obtains the threadblock offset (in units of threadblock-scoped tiles)
Definition: gemm/threadblock/threadblock_swizzle.h:248
Generated by 1.8.11