docs/include_2cutlass_2gemm_2kernel_2gemm_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
include/cutlass/gemm/kernel/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25
30 #pragma once
31
32 #include "cutlass/cutlass.h"
33
34 #include "cutlass/gemm/gemm.h"
35 #include "cutlass/matrix_coord.h"
36 #include "cutlass/semaphore.h"
37
39
40 namespace cutlass {
41 namespace gemm {
42 namespace kernel {
43
45
46 template <
47typename Mma_,
48typename Epilogue_,
49typename ThreadblockSwizzle_,
50bool SplitKSerial
51 >
53
56using OutputOp = typename Epilogue::OutputOp;
57using ThreadblockSwizzle = ThreadblockSwizzle_;
58static bool const kSplitKSerial = SplitKSerial;
59
61using WarpCount = typename Mma::WarpCount;
62static int const kThreadCount = 32 * WarpCount::kCount;
63
66cutlass::gemm::GemmCoord problem_size;
67cutlass::gemm::GemmCoord grid_tiled_shape;
68typename Mma::IteratorA::Params params_A;
69typename Mma::IteratorA::TensorRef ref_A;
70typename Mma::IteratorB::Params params_B;
71typename Mma::IteratorB::TensorRef ref_B;
72typename Epilogue::OutputTileIterator::Params params_C;
73typename Epilogue::OutputTileIterator::TensorRef ref_C;
74typename Epilogue::OutputTileIterator::Params params_D;
75typename Epilogue::OutputTileIterator::TensorRef ref_D;
76typename OutputOp::Params output_op;
78int gemm_k_iterations;
79int gemm_k_size;
80
81//
82// Methods
83//
84
87
90cutlass::gemm::GemmCoord const & problem_size,
91cutlass::gemm::GemmCoord const & grid_tiled_shape,
92typename Mma::IteratorA::TensorRef ref_A,
93typename Mma::IteratorB::TensorRef ref_B,
94typename Epilogue::OutputTileIterator::TensorRef ref_C,
95typename Epilogue::OutputTileIterator::TensorRef ref_D,
96typename OutputOp::Params output_op = typename OutputOp::Params(),
97int *semaphore = nullptr
98 ):
99 problem_size(problem_size),
100 grid_tiled_shape(grid_tiled_shape),
101 params_A(ref_A.layout()),
102 ref_A(ref_A),
103 params_B(ref_B.layout()),
104 ref_B(ref_B),
105 params_C(ref_C.layout()),
106 ref_C(ref_C),
107 params_D(ref_D.layout()),
108 ref_D(ref_D),
109 output_op(output_op),
110 semaphore(semaphore) {
111
112int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
113int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();
114
115 gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
116 }
117 };
118
120union SharedStorage {
121typename Mma::SharedStorage main_loop;
122typename Epilogue::SharedStorage epilogue;
123 };
124
125//
126// Methods
127//
128
131
133static Status can_implement(
134cutlass::gemm::GemmCoord const & problem_size,
135typename Mma::IteratorA::TensorRef ref_A,
136typename Mma::IteratorB::TensorRef ref_B,
137typename Epilogue::OutputTileIterator::TensorRef ref_C,
138typename Epilogue::OutputTileIterator::TensorRef ref_D) {
139
140static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
141static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
142static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
143
144if (! TensorRef_aligned(ref_A, kAlignmentA)) {
145return Status::kErrorMisalignedOperand;
146 }
147
148if (! TensorRef_aligned(ref_B, kAlignmentB)) {
149return Status::kErrorMisalignedOperand;
150 }
151
152if (! TensorRef_aligned(ref_C, kAlignmentC)) {
153return Status::kErrorMisalignedOperand;
154 }
155
156if (! TensorRef_aligned(ref_D, kAlignmentC)) {
157return Status::kErrorMisalignedOperand;
158 }
159
160if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||
161 (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||
162 (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {
163
164return Status::kErrorMisalignedOperand;
165 }
166
167return Status::kSuccess;
168 }
169
171 CUTLASS_DEVICE
172void operator()(Params const ¶ms, SharedStorage &shared_storage) {
173
174// Compute threadblock location
175ThreadblockSwizzle threadblock_swizzle;
176
177cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
178
179// Early exit if CTA is out of range
180if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
181 params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
182
183return;
184 }
185
186// Compute initial location in logical coordinates
187cutlass::MatrixCoord tb_offset_A{
188 threadblock_tile_offset.m() * Mma::Shape::kM,
189 threadblock_tile_offset.k() * params.gemm_k_size,
190 };
191
192cutlass::MatrixCoord tb_offset_B{
193 threadblock_tile_offset.k() * params.gemm_k_size,
194 threadblock_tile_offset.n() * Mma::Shape::kN
195 };
196
197// Problem size is a function of threadblock index in the K dimension
198int problem_size_k = min(
199 params.problem_size.k(),
200 (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
201
202// Compute threadblock-scoped matrix multiply-add
203int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
204
205// Compute position within threadblock
206int thread_idx = threadIdx.x;
207
208// Construct iterators to A and B operands
209typename Mma::IteratorA iterator_A(
210 params.params_A,
211 params.ref_A.data(),
212 {params.problem_size.m(), problem_size_k},
213 thread_idx,
214 tb_offset_A);
215
216typename Mma::IteratorB iterator_B(
217 params.params_B,
218 params.ref_B.data(),
219 {problem_size_k, params.problem_size.n()},
220 thread_idx,
221 tb_offset_B);
222
223int warp_idx = threadIdx.x / 32;
224int lane_idx = threadIdx.x % 32;
225
226//
227// Main loop
228//
229
230// Construct thread-scoped matrix multiply
231Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
232
233typename Mma::FragmentC accumulators;
234
235 accumulators.clear();
236
237if (!kSplitKSerial || gemm_k_iterations > 0) {
238// Compute threadblock-scoped matrix multiply-add
239 mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
240 }
241
242//
243// Epilogue
244//
245
246OutputOp output_op(params.output_op);
247
248//
249// Masked tile iterators constructed from members
250//
251
252 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
253
254//assume identity swizzle
255MatrixCoord threadblock_offset(
256 threadblock_tile_offset.m() * Mma::Shape::kM,
257 threadblock_tile_offset.n() * Mma::Shape::kN
258 );
259
260int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
261
262// Construct the semaphore.
263Semaphore semaphore(params.semaphore + block_idx, thread_idx);
264
265// If performing a reduction via split-K, fetch the initial synchronization
266if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
267
268// Fetch the synchronization lock initially but do not block.
269 semaphore.fetch();
270
271// Indicate which position in a serial reduction the output operator is currently updating
272 output_op.set_k_partition(threadblock_tile_offset.k());
273 }
274
275// Tile iterator loading from source tensor.
276typename Epilogue::OutputTileIterator iterator_C(
277 params.params_C,
278 params.ref_C.data(),
279 params.problem_size.mn(),
280 thread_idx,
281 threadblock_offset
282 );
283
284// Tile iterator writing to destination tensor.
285typename Epilogue::OutputTileIterator iterator_D(
286 params.params_D,
287 params.ref_D.data(),
288 params.problem_size.mn(),
289 thread_idx,
290 threadblock_offset
291 );
292
293Epilogue epilogue(
294 shared_storage.epilogue,
295 thread_idx,
296 warp_idx,
297 lane_idx);
298
299// Wait on the semaphore - this latency may have been covered by iterator construction
300if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
301
302// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
303if (threadblock_tile_offset.k()) {
304 iterator_C = iterator_D;
305 }
306
307 semaphore.wait(threadblock_tile_offset.k());
308
309 __threadfence();
310 }
311
312// Execute the epilogue operator to update the destination tensor.
313 epilogue(output_op, iterator_D, accumulators, iterator_C);
314
315//
316// Release the semaphore
317//
318
319if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
320
321int lock = 0;
322if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
323
324// The final threadblock resets the semaphore for subsequent grids.
325 lock = 0;
326 }
327else {
328// Otherwise, the semaphore is incremented
329 lock = threadblock_tile_offset.k() + 1;
330 }
331
332 __threadfence();
333 semaphore.release(lock);
334 }
335 }
336 };
337
339
340 } // namespace kernel
341 } // namespace gemm
342 } // namespace cutlass
343
cutlass::gemm::kernel::Gemm::Params::ref_C
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: include/cutlass/gemm/kernel/gemm.h:73
Definition: aligned_buffer.h:35
cutlass::gemm::kernel::Gemm::SharedStorage::epilogue
Epilogue::SharedStorage epilogue
Definition: include/cutlass/gemm/kernel/gemm.h:122
cutlass::gemm::kernel::Gemm::Params::params_D
Epilogue::OutputTileIterator::Params params_D
Definition: include/cutlass/gemm/kernel/gemm.h:74
cutlass::gemm::kernel::Gemm::Params::params_A
Mma::IteratorA::Params params_A
Definition: include/cutlass/gemm/kernel/gemm.h:68
cutlass::gemm::kernel::Gemm::Epilogue
Epilogue_ Epilogue
Definition: include/cutlass/gemm/kernel/gemm.h:55
cutlass::gemm::kernel::Gemm::Params::params_B
Mma::IteratorB::Params params_B
Definition: include/cutlass/gemm/kernel/gemm.h:70
cutlass::gemm::kernel::Gemm::Params::Params
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op=typename OutputOp::Params(), int *semaphore=nullptr)
Definition: include/cutlass/gemm/kernel/gemm.h:89
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
cutlass::gemm::kernel::Gemm::Params::params_C
Epilogue::OutputTileIterator::Params params_C
Definition: include/cutlass/gemm/kernel/gemm.h:72
cutlass::gemm::kernel::Gemm::kThreadCount
static int const kThreadCount
Definition: include/cutlass/gemm/kernel/gemm.h:62
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void fetch()
Permit fetching the synchronization mechanism early.
Definition: semaphore.h:68
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::kernel::Gemm::Params::grid_tiled_shape
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: include/cutlass/gemm/kernel/gemm.h:67
cutlass::gemm::kernel::Gemm::Params::gemm_k_iterations
int gemm_k_iterations
Definition: include/cutlass/gemm/kernel/gemm.h:78
cutlass::gemm::kernel::Gemm::Params::ref_B
Mma::IteratorB::TensorRef ref_B
Definition: include/cutlass/gemm/kernel/gemm.h:71
cutlass::gemm::kernel::Gemm::can_implement
static Status can_implement(cutlass::gemm::GemmCoord const &problem_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D)
Determines whether kernel satisfies alignment.
Definition: include/cutlass/gemm/kernel/gemm.h:133
cutlass::gemm::kernel::Gemm::Gemm
CUTLASS_HOST_DEVICE Gemm()
Definition: include/cutlass/gemm/kernel/gemm.h:130
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
cutlass::gemm::kernel::Gemm::kSplitKSerial
static bool const kSplitKSerial
Definition: include/cutlass/gemm/kernel/gemm.h:58
cutlass::gemm::kernel::Gemm::OutputOp
typename Epilogue::OutputOp OutputOp
Definition: include/cutlass/gemm/kernel/gemm.h:56
cutlass::gemm::kernel::Gemm::Params
Parameters structure.
Definition: include/cutlass/gemm/kernel/gemm.h:65
cutlass::gemm::kernel::Gemm::Params::output_op
OutputOp::Params output_op
Definition: include/cutlass/gemm/kernel/gemm.h:76
cutlass::Status::kErrorMisalignedOperand
operands fail alignment requirements.
cutlass::gemm::kernel::Gemm::SharedStorage
Shared memory storage structure.
Definition: include/cutlass/gemm/kernel/gemm.h:120
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE constexpr const T & min(const T &a, const T &b)
std::min
Definition: platform.h:183
cutlass::gemm::kernel::Gemm::Params::gemm_k_size
int gemm_k_size
Definition: include/cutlass/gemm/kernel/gemm.h:79
cutlass::gemm::kernel::Gemm::Params::semaphore
int * semaphore
Definition: include/cutlass/gemm/kernel/gemm.h:77
cutlass::gemm::kernel::Gemm::operator()
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM.
Definition: include/cutlass/gemm/kernel/gemm.h:172
CTA-wide semaphore for inter-CTA synchronization.
Definition: semaphore.h:48
Implementation of a CTA-wide semaphore for inter-CTA synchronization.
Defines a canonical coordinate for rank=2 matrices offering named indices.
CUTLASS_DEVICE void release(int status=0)
Updates the lock with the given result.
Definition: semaphore.h:98
cutlass::gemm::kernel::Gemm::Params::problem_size
cutlass::gemm::GemmCoord problem_size
Definition: include/cutlass/gemm/kernel/gemm.h:66
cutlass::gemm::kernel::Gemm::ThreadblockSwizzle
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: include/cutlass/gemm/kernel/gemm.h:57
Definition: include/cutlass/gemm/kernel/gemm.h:52
cutlass::gemm::kernel::Gemm::Params::ref_A
Mma::IteratorA::TensorRef ref_A
Definition: include/cutlass/gemm/kernel/gemm.h:69
bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)
Definition: tensor_ref.h:382
CUTLASS_DEVICE void wait(int status=0)
Waits until the semaphore is equal to the given value.
Definition: semaphore.h:81
Operation was successful.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
cutlass::gemm::kernel::Gemm::Mma
Mma_ Mma
Definition: include/cutlass/gemm/kernel/gemm.h:54
cutlass::gemm::kernel::Gemm::WarpCount
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: include/cutlass/gemm/kernel/gemm.h:61
Basic include for CUTLASS.
Definition: matrix_coord.h:39
cutlass::gemm::kernel::Gemm::Params::Params
CUTLASS_HOST_DEVICE Params()
Definition: include/cutlass/gemm/kernel/gemm.h:86
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
cutlass::gemm::kernel::Gemm::SharedStorage::main_loop
Mma::SharedStorage main_loop
Definition: include/cutlass/gemm/kernel/gemm.h:121
cutlass::gemm::kernel::Gemm::Params::ref_D
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: include/cutlass/gemm/kernel/gemm.h:75
Generated by 1.8.11