docs/epilogue_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
epilogue.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
33 #pragma once
34
35 #include <assert.h>
36
37 #include "cutlass/cutlass.h"
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/array.h"
40 #include "cutlass/layout/vector.h"
41 #include "cutlass/layout/tensor.h"
42 #include "cutlass/tensor_coord.h"
43 #include "cutlass/aligned_buffer.h"
44 #include "cutlass/functional.h"
45
46 #include "cutlass/gemm/gemm.h"
47
48 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"
49 #include "[cutlass/transform/threadblock/regular_tile_iterator.h](regular tile iterator_8h.html)"
50
51 #include "cutlass/epilogue/threadblock/epilogue_base.h"
52 #include "[cutlass/epilogue/threadblock/predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)"
53
55
56 namespace cutlass {
57 namespace epilogue {
58 namespace threadblock {
59
61
63 template <
64typename Shape_,
65typename WarpMmaOperator_,
66int PartitionsK,
67typename OutputTileIterator_,
68typename AccumulatorFragmentIterator_,
69typename WarpTileIterator_,
70typename SharedLoadIterator_,
71typename OutputOp_,
72typename Padding_
73 >
75public EpilogueBase<
76 Shape_,
77 WarpMmaOperator_,
78 PartitionsK,
79 AccumulatorFragmentIterator_,
80 WarpTileIterator_,
81 Padding_> {
82
83 public:
84
85using Base = EpilogueBase<
86 Shape_,
87 WarpMmaOperator_,
88 PartitionsK,
89 AccumulatorFragmentIterator_,
90 WarpTileIterator_,
91 Padding_>;
92
94using WarpMmaOperator = WarpMmaOperator_;
95static int const kPartitionsK = PartitionsK;
96using OutputTileIterator = OutputTileIterator_;
97using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
98using WarpTileIterator = WarpTileIterator_;
99using SharedLoadIterator = SharedLoadIterator_;
100using OutputOp = OutputOp_;
102
104using Layout = layout::RowMajor;
105using LongIndex = typename Layout::LongIndex;
106
108using AccumulatorTile = typename Base::AccumulatorTile;
109
111using ElementAccumulator = typename WarpTileIterator::Element;
112
113
115using ElementOutput = typename OutputTileIterator::Element;
116
118static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
119
121using TensorRef = typename OutputTileIterator::TensorRef;
122
124using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
125
127using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
128
130using OutputAccessType = Array<
131typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
132
134using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
135
137using WarpCount = typename Base::WarpCount;
138
139 public:
140
141
142static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
143"Mismatch between shared load iterator and output tile iterator.");
144
145static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");
146
147static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
148"Divisibility");
149
150 private:
151
153SharedLoadIterator shared_load_iterator_;
154
155 public:
156
158 CUTLASS_DEVICE
160typename Base::SharedStorage &shared_storage,
161int thread_idx,
162int warp_idx,
163int lane_idx
164 ):
165Base(shared_storage, thread_idx, warp_idx, lane_idx),
166 shared_load_iterator_(shared_storage.reference(), thread_idx) { }
167
169 CUTLASS_DEVICE
170void operator()(
171OutputOp const &output_op,
172OutputTileIterator destination_iterator,
173AccumulatorTile const &accumulators,
174OutputTileIterator source_iterator) {
175
176
177typename OutputTileIterator::Fragment source_fragment;
178
179if (!output_op.is_source_needed()) {
180 source_iterator.clear_mask();
181 }
182
183 source_fragment.clear();
184
185//
186// Iterator over warp-level accumulator fragment
187//
188
189AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
190
191//
192// Iterate over accumulator tile
193//
194
196for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
197
198//
199// Load the source
200//
201
202 source_iterator.load(source_fragment);
203 ++source_iterator;
204
205//
206// Convert and store fragment
207//
208
209 __syncthreads();
210
211typename AccumulatorFragmentIterator::Fragment accum_fragment;
212
213 accum_fragment_iterator.load(accum_fragment);
214 ++accum_fragment_iterator;
215
216 this->warp_tile_iterator_.store(accum_fragment);
217
218 __syncthreads();
219
220//
221// Load fragments from shared memory
222//
223
224typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];
225
226 shared_load_iterator_.load(aligned_accum_fragment[0]);
227
228// If the number of k-slices is > 1 - perform a reduction amongst the k-slices
229if (kPartitionsK > 1)
230 {
231plus <typename SharedLoadIterator::Fragment> add_fragments;
232const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
233
235for ( int i = 1; i < kPartitionsK; ++i) {
236 shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
237 shared_load_iterator_.load(aligned_accum_fragment[i]);
238 aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
239 }
240
241 shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
242 }
243
244//
245// Compute the output result
246//
247
248typename OutputTileIterator::Fragment output_fragment;
249
250 apply_output_operator_(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
251
252
253//
254// Store the final result
255//
256
257 destination_iterator.store(output_fragment);
258 ++destination_iterator;
259
260 }
261 }
262
263 private:
264
266 CUTLASS_DEVICE
267void apply_output_operator_(
268typename OutputTileIterator::Fragment &output_fragment,
269OutputOp const &output_op,
270typename SharedLoadIterator::Fragment const &aligned_accum_fragment,
271typename OutputTileIterator::Fragment const &source_fragment) {
272
273OutputAccessType *output_frag_ptr =
274reinterpret_cast<OutputAccessType *>(&output_fragment);
275
276AccumulatorAccessType const *compute_frag_ptr =
277reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);
278
279OutputAccessType const *source_frag_ptr =
280reinterpret_cast<OutputAccessType const *>(&source_fragment);
281
282int const kOutputOpIterations =
283 OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
284
286for (int i = 0; i < kOutputOpIterations; ++i) {
287
288// Call the output operator
289 output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
290 }
291 }
292 };
293
295
296 } // namespace threadblock
297 } // namespace epilogue
298 } // namespace cutlass
299
cutlass::layout::RowMajor::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
cutlass::epilogue::threadblock::Epilogue::LongIndex
typename Layout::LongIndex LongIndex
Definition: epilogue.h:105
cutlass::epilogue::threadblock::Epilogue::WarpCount
typename Base::WarpCount WarpCount
Number of warps.
Definition: epilogue.h:137
cutlass::epilogue::threadblock::SharedLoadIterator::Fragment
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: shared_load_iterator.h:91
Definition: aligned_buffer.h:35
cutlass::epilogue::threadblock::EpilogueBase::warp_tile_iterator_
WarpTileIterator warp_tile_iterator_
Stores a warp's fragment of accumulators to SMEM.
Definition: epilogue_base.h:176
[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)
Templates implementing how threads are mapped to a given tile.
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage
Shared storage allocation needed by the epilogue.
Definition: epilogue_base.h:97
cutlass::epilogue::threadblock::Epilogue::operator()
CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)
Streams the result to global memory.
Definition: epilogue.h:170
cutlass::epilogue::threadblock::Epilogue::OutputTileIterator
OutputTileIterator_ OutputTileIterator
Definition: epilogue.h:96
[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Defines common types used for all GEMM-like operators.
cutlass::epilogue::threadblock::Epilogue::Epilogue
CUTLASS_DEVICE Epilogue(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: epilogue.h:159
cutlass::epilogue::threadblock::Epilogue::Shape
Shape_ Shape
Definition: epilogue.h:93
cutlass::epilogue::threadblock::Epilogue::TensorRef
typename OutputTileIterator::TensorRef TensorRef
Tensor reference to destination tensor.
Definition: epilogue.h:121
cutlass::epilogue::threadblock::EpilogueBase::WarpCount
gemm::GemmShape< Shape::kM/WarpMmaOperator::Shape::kM, Shape::kN/WarpMmaOperator::Shape::kN, kPartitionsK > WarpCount
Number of warps.
Definition: epilogue_base.h:92
Definition: functional.h:46
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
cutlass::epilogue::threadblock::Epilogue::kPartitionsK
static int const kPartitionsK
Definition: epilogue.h:95
cutlass::epilogue::threadblock::Epilogue::OutputOp
OutputOp_ OutputOp
Definition: epilogue.h:100
Definition: tensor_ref.h:146
cutlass::epilogue::threadblock::Epilogue::Padding
Padding_ Padding
Definition: epilogue.h:101
Defines a canonical coordinate for rank=4 tensors offering named indices.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
cutlass::epilogue::threadblock::Epilogue::AccumulatorFragmentIterator
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: epilogue.h:97
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::epilogue::threadblock::Epilogue::ConstTensorRef
typename OutputTileIterator::ConstTensorRef ConstTensorRef
Const tensor reference to source tensor.
Definition: epilogue.h:127
cutlass::epilogue::threadblock::Epilogue::WarpTileIterator
WarpTileIterator_ WarpTileIterator
Definition: epilogue.h:98
cutlass::epilogue::threadblock::Epilogue::SharedLoadIterator
SharedLoadIterator_ SharedLoadIterator
Definition: epilogue.h:99
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
cutlass::epilogue::threadblock::Epilogue
Epilogue operator without splitk.
Definition: epilogue.h:74
cutlass::epilogue::threadblock::Epilogue::ElementAccumulator
typename WarpTileIterator::Element ElementAccumulator
Accumulator element.
Definition: epilogue.h:111
Defines layout functions used for rank=1 vectors.
[regular_tile_iterator.h](regular tile iterator_8h.html)
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Epilogue for threadblock scoped GEMMs using Tensor Ops.
cutlass::epilogue::threadblock::EpilogueBase
Base class for epilogues defining warp-level.
Definition: epilogue_base.h:67
cutlass::epilogue::threadblock::Epilogue::WarpMmaOperator
WarpMmaOperator_ WarpMmaOperator
Definition: epilogue.h:94
cutlass::epilogue::threadblock::Epilogue::AccumulatorTile
typename Base::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: epilogue.h:108
cutlass::epilogue::threadblock::Epilogue::OutputAccessType
Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType
Array type used to output.
Definition: epilogue.h:131
cutlass::epilogue::threadblock::Epilogue::kElementsPerAccess
static int const kElementsPerAccess
Output access size.
Definition: epilogue.h:118
cutlass::epilogue::threadblock::EpilogueBase::AccumulatorTile
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: epilogue_base.h:81
cutlass::epilogue::threadblock::Epilogue::ElementOutput
typename OutputTileIterator::Element ElementOutput
Output element.
Definition: epilogue.h:115
Basic include for CUTLASS.
cutlass::epilogue::threadblock::Epilogue::SyncTensorRef
typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef
Tensor reference to sync tensor.
Definition: epilogue.h:124
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
cutlass::epilogue::threadblock::Epilogue::AccumulatorAccessType
Array< typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType
Array type used by output functor.
Definition: epilogue.h:134
Generated by 1.8.11