docs/kernel_2gemm__batched_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
kernel/gemm_batched.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32
33 #include "cutlass/gemm/gemm.h"
34 #include "cutlass/matrix_coord.h"
35
37
38 namespace cutlass {
39 namespace gemm {
40 namespace kernel {
41
43
44 template <
45typename Mma_,
46typename Epilogue_,
47typename ThreadblockSwizzle_
48 >
49 struct GemmBatched {
50
53using OutputOp = typename Epilogue::OutputOp;
54using ThreadblockSwizzle = ThreadblockSwizzle_;
55
57using WarpCount = typename Mma::WarpCount;
58static int const kThreadCount = 32 * WarpCount::kCount;
59
62cutlass::gemm::GemmCoord problem_size;
63cutlass::gemm::GemmCoord grid_tiled_shape;
64typename Mma::IteratorA::Params params_A;
65typename Mma::IteratorA::TensorRef ref_A;
67typename Mma::IteratorB::Params params_B;
68typename Mma::IteratorB::TensorRef ref_B;
70typename Epilogue::OutputTileIterator::Params params_C;
71typename Epilogue::OutputTileIterator::TensorRef ref_C;
73typename Epilogue::OutputTileIterator::Params params_D;
74typename Epilogue::OutputTileIterator::TensorRef ref_D;
76typename OutputOp::Params epilogue;
77int batch_count;
78int gemm_k_iterations;
79
80//
81// Methods
82//
83
86
89cutlass::gemm::GemmCoord const & problem_size_,
90cutlass::gemm::GemmCoord const & grid_tiled_shape_,
91typename Mma::IteratorA::TensorRef ref_A_,
92 int64_t stride_A_,
93typename Mma::IteratorB::TensorRef ref_B_,
94 int64_t stride_B_,
95typename Epilogue::OutputTileIterator::TensorRef ref_C_,
96 int64_t stride_C_,
97typename Epilogue::OutputTileIterator::TensorRef ref_D_,
98 int64_t stride_D_,
99typename OutputOp::Params epilogue_,
100int batch_count_
101 ):
102 problem_size(problem_size_),
103 grid_tiled_shape(grid_tiled_shape_),
104 params_A(ref_A_.layout()),
105 ref_A(ref_A_),
106 stride_A(stride_A_),
107 params_B(ref_B_.layout()),
108 ref_B(ref_B_),
109 stride_B(stride_B_),
110 params_C(ref_C_.layout()),
111 ref_C(ref_C_),
112 stride_C(stride_C_),
113 params_D(ref_D_.layout()),
114 ref_D(ref_D_),
115 stride_D(stride_D_),
116 epilogue(epilogue_),
117 batch_count(batch_count_),
118 gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) {
119
120 }
121 };
122
124union SharedStorage {
125typename Mma::SharedStorage main_loop;
126typename Epilogue::SharedStorage epilogue;
127 };
128
129//
130// Methods
131//
132
134GemmBatched() { }
135
137 CUTLASS_DEVICE
138void operator()(Params const ¶ms, SharedStorage &shared_storage) {
139
140// Compute threadblock location
141ThreadblockSwizzle threadblock_swizzle;
142
143cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
144
145// Early exit if CTA is out of range
146if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
147 params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
148
149return;
150 }
151
152
153// Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension
154for (int batch_idx = threadblock_swizzle.get_batch_idx();
155 batch_idx < params.batch_count;
156 batch_idx += gridDim.z) {
157
158// Compute initial location in logical coordinates
159cutlass::MatrixCoord tb_offset_A{
160 threadblock_tile_offset.m() * Mma::Shape::kM,
161 0
162 };
163
164cutlass::MatrixCoord tb_offset_B{
165 0,
166 threadblock_tile_offset.n() * Mma::Shape::kN
167 };
168
169// Compute position within threadblock
170int thread_idx = threadIdx.x;
171
172// Construct iterators to A and B operands
173typename Mma::IteratorA iterator_A(
174 params.params_A,
175 params.ref_A.data(),
176 params.problem_size.mk(),
177 thread_idx,
178 tb_offset_A);
179
180 iterator_A.add_pointer_offset(params.stride_A * batch_idx);
181
182typename Mma::IteratorB iterator_B(
183 params.params_B,
184 params.ref_B.data(),
185 params.problem_size.kn(),
186 thread_idx,
187 tb_offset_B);
188
189 iterator_B.add_pointer_offset(params.stride_B * batch_idx);
190
191
192//
193// Main loop
194//
195
196// Construct thread-scoped matrix multiply
197int warp_idx = threadIdx.x / 32;
198int lane_idx = threadIdx.x % 32;
199
200Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
201
202typename Mma::FragmentC accumulators;
203
204 accumulators.clear();
205
206
207// Compute threadblock-scoped matrix multiply-add
208 mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
209
210//
211// Epilogue
212//
213
214OutputOp output_op(params.epilogue);
215
216//
217// Masked tile iterators constructed from members
218//
219
220 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
221
222//assume identity swizzle
223MatrixCoord threadblock_offset(
224 threadblock_tile_offset.m() * Mma::Shape::kM,
225 threadblock_tile_offset.n() * Mma::Shape::kN
226 );
227
228// Tile iterator writing to output tile
229typename Epilogue::OutputTileIterator iterator_C(
230 params.params_C,
231 params.ref_C.data(),
232 params.problem_size.mn(),
233 thread_idx,
234 threadblock_offset
235 );
236
237 iterator_C.add_pointer_offset(params.stride_C * batch_idx);
238
239// Tile iterator writing to output tile
240typename Epilogue::OutputTileIterator iterator_D(
241 params.params_D,
242 params.ref_D.data(),
243 params.problem_size.mn(),
244 thread_idx,
245 threadblock_offset
246 );
247
248 iterator_D.add_pointer_offset(params.stride_D * batch_idx);
249
251 shared_storage.epilogue,
252 thread_idx,
253 warp_idx,
254 lane_idx);
255
256// run efficient epilogue
257epilogue(output_op, iterator_D, accumulators, iterator_C);
258 }
259 }
260 };
261
263
264 } // namespace kernel
265 } // namespace gemm
266 } // namespace cutlass
267
cutlass::gemm::kernel::GemmBatched::operator()
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM.
Definition: kernel/gemm_batched.h:138
Definition: aligned_buffer.h:35
cutlass::gemm::kernel::GemmBatched::Params::Params
CUTLASS_HOST_DEVICE Params()
Definition: kernel/gemm_batched.h:85
cutlass::gemm::kernel::GemmBatched::OutputOp
typename Epilogue::OutputOp OutputOp
Definition: kernel/gemm_batched.h:53
cutlass::gemm::kernel::GemmBatched::Params::ref_D
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_batched.h:74
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
Defines common types used for all GEMM-like operators.
cutlass::gemm::kernel::GemmBatched::Params::ref_B
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_batched.h:68
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::kernel::GemmBatched::Params::gemm_k_iterations
int gemm_k_iterations
Definition: kernel/gemm_batched.h:78
cutlass::gemm::kernel::GemmBatched::Params::ref_C
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: kernel/gemm_batched.h:71
cutlass::gemm::kernel::GemmBatched::GemmBatched
CUTLASS_HOST_DEVICE GemmBatched()
Definition: kernel/gemm_batched.h:134
cutlass::gemm::kernel::GemmBatched::Epilogue
Epilogue_ Epilogue
Definition: kernel/gemm_batched.h:52
cutlass::gemm::kernel::GemmBatched::SharedStorage
Shared memory storage structure.
Definition: kernel/gemm_batched.h:124
cutlass::gemm::kernel::GemmBatched::Params::grid_tiled_shape
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_batched.h:63
cutlass::gemm::kernel::GemmBatched::SharedStorage::main_loop
Mma::SharedStorage main_loop
Definition: kernel/gemm_batched.h:125
cutlass::gemm::kernel::GemmBatched::kThreadCount
static int const kThreadCount
Definition: kernel/gemm_batched.h:58
cutlass::gemm::kernel::GemmBatched::Params
Parameters structure.
Definition: kernel/gemm_batched.h:61
cutlass::gemm::kernel::GemmBatched::WarpCount
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: kernel/gemm_batched.h:57
cutlass::gemm::kernel::GemmBatched::Params::params_D
Epilogue::OutputTileIterator::Params params_D
Definition: kernel/gemm_batched.h:73
cutlass::gemm::kernel::GemmBatched::Params::params_C
Epilogue::OutputTileIterator::Params params_C
Definition: kernel/gemm_batched.h:70
cutlass::gemm::kernel::GemmBatched::Params::epilogue
OutputOp::Params epilogue
Definition: kernel/gemm_batched.h:76
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::gemm::kernel::GemmBatched::Params::stride_C
int64_t stride_C
Definition: kernel/gemm_batched.h:72
CUTLASS_HOST_DEVICE Coord< 2 > mk() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:177
cutlass::gemm::kernel::GemmBatched::Mma
Mma_ Mma
Definition: kernel/gemm_batched.h:51
cutlass::gemm::kernel::GemmBatched::Params::problem_size
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_batched.h:62
cutlass::gemm::kernel::GemmBatched::Params::params_A
Mma::IteratorA::Params params_A
Definition: kernel/gemm_batched.h:64
Defines a canonical coordinate for rank=2 matrices offering named indices.
cutlass::gemm::kernel::GemmBatched::Params::batch_count
int batch_count
Definition: kernel/gemm_batched.h:77
cutlass::gemm::kernel::GemmBatched::Params::params_B
Mma::IteratorB::Params params_B
Definition: kernel/gemm_batched.h:67
CUTLASS_HOST_DEVICE Coord< 2 > kn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:195
cutlass::gemm::kernel::GemmBatched::Params::stride_B
int64_t stride_B
Definition: kernel/gemm_batched.h:69
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
cutlass::gemm::kernel::GemmBatched::Params::ref_A
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_batched.h:65
cutlass::gemm::kernel::GemmBatched::Params::Params
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size_, cutlass::gemm::GemmCoord const &grid_tiled_shape_, typename Mma::IteratorA::TensorRef ref_A_, int64_t stride_A_, typename Mma::IteratorB::TensorRef ref_B_, int64_t stride_B_, typename Epilogue::OutputTileIterator::TensorRef ref_C_, int64_t stride_C_, typename Epilogue::OutputTileIterator::TensorRef ref_D_, int64_t stride_D_, typename OutputOp::Params epilogue_, int batch_count_)
Definition: kernel/gemm_batched.h:88
cutlass::gemm::kernel::GemmBatched::Params::stride_A
int64_t stride_A
Definition: kernel/gemm_batched.h:66
cutlass::gemm::kernel::GemmBatched
Definition: kernel/gemm_batched.h:49
cutlass::gemm::kernel::GemmBatched::SharedStorage::epilogue
Epilogue::SharedStorage epilogue
Definition: kernel/gemm_batched.h:126
cutlass::gemm::kernel::GemmBatched::Params::stride_D
int64_t stride_D
Definition: kernel/gemm_batched.h:75
cutlass::gemm::kernel::GemmBatched::ThreadblockSwizzle
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: kernel/gemm_batched.h:54
Basic include for CUTLASS.
Definition: matrix_coord.h:39
Generated by 1.8.11