docs/kernel_2gemm__splitk__parallel_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
kernel/gemm_splitk_parallel.h
[Go to the documentation of this file.](kernel_2gemm splitk parallel_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32
33 #include "cutlass/gemm/gemm.h"
34 #include "cutlass/matrix_coord.h"
35
37
38 namespace cutlass {
39 namespace gemm {
40 namespace kernel {
41
43
44 template <
45typename Mma_,
46typename Epilogue_,
47typename ThreadblockSwizzle_
48 >
49 struct GemmSplitKParallel {
50
53using OutputOp = typename Epilogue::OutputOp;
54using ThreadblockSwizzle = ThreadblockSwizzle_;
55
57using WarpCount = typename Mma::WarpCount;
58static int const kThreadCount = 32 * WarpCount::kCount;
59
60static int const kAlignmentK = Mma::Operator::Shape::kK;
61
64cutlass::gemm::GemmCoord problem_size;
65cutlass::gemm::GemmCoord grid_tiled_shape;
66typename Mma::IteratorA::Params params_A;
67typename Mma::IteratorA::TensorRef ref_A;
68typename Mma::IteratorB::Params params_B;
69typename Mma::IteratorB::TensorRef ref_B;
70typename Epilogue::OutputTileIterator::Params params_D;
71typename Epilogue::OutputTileIterator::TensorRef ref_D;
72typename OutputOp::Params output_op;
73 int64_t splitk_slice_stride;
74int gemm_k_size;
75
76//
77// Methods
78//
79
82
85cutlass::gemm::GemmCoord const & problem_size,
86cutlass::gemm::GemmCoord const & grid_tiled_shape,
87typename Mma::IteratorA::TensorRef ref_A,
88typename Mma::IteratorB::TensorRef ref_B,
89typename Epilogue::OutputTileIterator::TensorRef ref_D,
90typename OutputOp::Params output_op,
91 int64_t splitk_slice_stride
92 ):
93 problem_size(problem_size),
94 grid_tiled_shape(grid_tiled_shape),
95 params_A(ref_A.layout()),
96 ref_A(ref_A),
97 params_B(ref_B.layout()),
98 ref_B(ref_B),
99 params_D(ref_D.layout()),
100 ref_D(ref_D),
101 output_op(output_op),
102 splitk_slice_stride(splitk_slice_stride) {
103
104int full_gemm_k_iterations = problem_size.k() / Mma::Shape::kK;
105int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
106
107 gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
108 }
109 };
110
112union SharedStorage {
113typename Mma::SharedStorage main_loop;
114typename Epilogue::SharedStorage epilogue;
115 };
116
117//
118// Methods
119//
120
122GemmSplitKParallel() { }
123
125 CUTLASS_DEVICE
126void operator()(Params const ¶ms, SharedStorage &shared_storage) {
127
128// Compute threadblock location
129ThreadblockSwizzle threadblock_swizzle;
130
131cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
132
133// Early exit if CTA is out of range
134if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
135 params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
136
137return;
138 }
139
140// Compute initial location in logical coordinates
141cutlass::MatrixCoord tb_offset_A{
142 threadblock_tile_offset.m() * Mma::Shape::kM,
143 threadblock_tile_offset.k() * params.gemm_k_size,
144 };
145
146cutlass::MatrixCoord tb_offset_B{
147 threadblock_tile_offset.k() * params.gemm_k_size,
148 threadblock_tile_offset.n() * Mma::Shape::kN
149 };
150
151// Problem size is a function of threadblock index in the K dimension
152int problem_size_k;
153if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) {
154 problem_size_k = params.problem_size.k();
155 }
156else {
157 problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
158 }
159
160// Compute threadblock-scoped matrix multiply-add
161int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
162
163// Compute position within threadblock
164int thread_idx = threadIdx.x;
165
166// Construct iterators to A and B operands
167typename Mma::IteratorA iterator_A(
168 params.params_A,
169 params.ref_A.data(),
170 {params.problem_size.m(), problem_size_k},
171 thread_idx,
172 tb_offset_A);
173
174typename Mma::IteratorB iterator_B(
175 params.params_B,
176 params.ref_B.data(),
177 {problem_size_k, params.problem_size.n()},
178 thread_idx,
179 tb_offset_B);
180
181int warp_idx = threadIdx.x / 32;
182int lane_idx = threadIdx.x % 32;
183
184
185//
186// Main loop
187//
188
189// Construct thread-scoped matrix multiply
190Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
191
192typename Mma::FragmentC accumulators;
193
194 accumulators.clear();
195
196 mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
197
198//
199// Epilogue
200//
201
202OutputOp output_op(params.output_op);
203
204//
205// Masked tile iterators constructed from members
206//
207
208 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
209
210//assume identity swizzle
211MatrixCoord threadblock_offset(
212 threadblock_tile_offset.m() * Mma::Shape::kM,
213 threadblock_tile_offset.n() * Mma::Shape::kN
214 );
215
216// Tile iterator writing to output tile
217typename Epilogue::OutputTileIterator iterator_D(
218 params.params_D,
219 params.ref_D.data(),
220 params.problem_size.mn(),
221 thread_idx,
222 threadblock_offset
223 );
224
225 iterator_D.add_pointer_offset(params.splitk_slice_stride * threadblock_tile_offset.k());
226
227// Execute the epilogue
228Epilogue epilogue(
229 shared_storage.epilogue,
230 thread_idx,
231 warp_idx,
232 lane_idx);
233
234// Run efficient epilogue
235 epilogue(output_op, iterator_D, accumulators, iterator_D);
236 }
237 };
238
240
241 } // namespace kernel
242 } // namespace gemm
243 } // namespace cutlass
244
cutlass::gemm::kernel::GemmSplitKParallel::operator()
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM.
Definition: kernel/gemm_splitk_parallel.h:126
cutlass::gemm::kernel::GemmSplitKParallel::GemmSplitKParallel
CUTLASS_HOST_DEVICE GemmSplitKParallel()
Definition: kernel/gemm_splitk_parallel.h:122
Definition: aligned_buffer.h:35
cutlass::gemm::kernel::GemmSplitKParallel::Epilogue
Epilogue_ Epilogue
Definition: kernel/gemm_splitk_parallel.h:52
cutlass::gemm::kernel::GemmSplitKParallel::Params::problem_size
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_splitk_parallel.h:64
cutlass::gemm::kernel::GemmSplitKParallel::SharedStorage
Shared memory storage structure.
Definition: kernel/gemm_splitk_parallel.h:112
cutlass::gemm::kernel::GemmSplitKParallel::SharedStorage::epilogue
Epilogue::SharedStorage epilogue
Definition: kernel/gemm_splitk_parallel.h:114
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
cutlass::gemm::kernel::GemmSplitKParallel::Params::grid_tiled_shape
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_splitk_parallel.h:65
cutlass::gemm::kernel::GemmSplitKParallel::kThreadCount
static int const kThreadCount
Definition: kernel/gemm_splitk_parallel.h:58
cutlass::gemm::kernel::GemmSplitKParallel::SharedStorage::main_loop
Mma::SharedStorage main_loop
Definition: kernel/gemm_splitk_parallel.h:113
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::kernel::GemmSplitKParallel::Params
Parameters structure.
Definition: kernel/gemm_splitk_parallel.h:63
cutlass::gemm::kernel::GemmSplitKParallel::WarpCount
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: kernel/gemm_splitk_parallel.h:57
cutlass::gemm::kernel::GemmSplitKParallel::ThreadblockSwizzle
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: kernel/gemm_splitk_parallel.h:54
cutlass::gemm::kernel::GemmSplitKParallel::Params::Params
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op, int64_t splitk_slice_stride)
Definition: kernel/gemm_splitk_parallel.h:84
cutlass::gemm::kernel::GemmSplitKParallel::Params::output_op
OutputOp::Params output_op
Definition: kernel/gemm_splitk_parallel.h:72
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
cutlass::gemm::kernel::GemmSplitKParallel::Params::ref_A
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_splitk_parallel.h:67
cutlass::gemm::kernel::GemmSplitKParallel::Params::ref_B
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_splitk_parallel.h:69
cutlass::gemm::kernel::GemmSplitKParallel::Params::gemm_k_size
int gemm_k_size
Definition: kernel/gemm_splitk_parallel.h:74
cutlass::gemm::kernel::GemmSplitKParallel::Params::ref_D
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_splitk_parallel.h:71
cutlass::gemm::kernel::GemmSplitKParallel::Params::Params
CUTLASS_HOST_DEVICE Params()
Definition: kernel/gemm_splitk_parallel.h:81
cutlass::gemm::kernel::GemmSplitKParallel::Params::params_D
Epilogue::OutputTileIterator::Params params_D
Definition: kernel/gemm_splitk_parallel.h:70
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::gemm::kernel::GemmSplitKParallel::Params::params_A
Mma::IteratorA::Params params_A
Definition: kernel/gemm_splitk_parallel.h:66
cutlass::gemm::kernel::GemmSplitKParallel::kAlignmentK
static int const kAlignmentK
Definition: kernel/gemm_splitk_parallel.h:60
Defines a canonical coordinate for rank=2 matrices offering named indices.
cutlass::gemm::kernel::GemmSplitKParallel
Definition: kernel/gemm_splitk_parallel.h:49
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
cutlass::gemm::kernel::GemmSplitKParallel::Mma
Mma_ Mma
Definition: kernel/gemm_splitk_parallel.h:51
cutlass::gemm::kernel::GemmSplitKParallel::Params::params_B
Mma::IteratorB::Params params_B
Definition: kernel/gemm_splitk_parallel.h:68
cutlass::gemm::kernel::GemmSplitKParallel::Params::splitk_slice_stride
int64_t splitk_slice_stride
Definition: kernel/gemm_splitk_parallel.h:73
Basic include for CUTLASS.
Definition: matrix_coord.h:39
cutlass::gemm::kernel::GemmSplitKParallel::OutputOp
typename Epilogue::OutputOp OutputOp
Definition: kernel/gemm_splitk_parallel.h:53
Generated by 1.8.11