docs/interleaved__epilogue_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
interleaved_epilogue.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
33 #pragma once
34
35 #include <assert.h>
36
37 #include "cutlass/cutlass.h"
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/array.h"
40 #include "cutlass/layout/vector.h"
41 #include "cutlass/layout/tensor.h"
42 #include "cutlass/tensor_coord.h"
43 #include "cutlass/aligned_buffer.h"
44
45 #include "cutlass/gemm/gemm.h"
46
47 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"
48 #include "[cutlass/transform/threadblock/regular_tile_iterator.h](regular tile iterator_8h.html)"
49
50 #include "cutlass/epilogue/threadblock/epilogue_base.h"
51 #include "[cutlass/epilogue/threadblock/predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)"
52
54
55 namespace cutlass {
56 namespace epilogue {
57 namespace threadblock {
58
60
62 template <
64typename Shape_,
66typename WarpMmaOperator_,
68int PartitionsK,
70typename OutputTileIterator_,
72typename AccumulatorFragmentIterator_,
74typename OutputOp_,
76int InterleavedK,
78bool IsBetaZero = false>
79 class InterleavedEpilogue {
80public:
82using WarpMmaOperator = WarpMmaOperator_;
83static int const kPartitionsK = PartitionsK;
84using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
85using OutputTileIterator = OutputTileIterator_;
87
89using Layout = layout::ColumnMajorInterleaved<InterleavedK>;
90
92using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
93
95using ElementAccumulator = typename AccumulatorTile::Element;
96
98using ElementOutput = typename OutputTileIterator::Element;
99
101static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
102
104using TensorRef = typename OutputTileIterator::TensorRef;
105
107using SyncTensorRef =
108typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
109
111using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
112
114using OutputAccessType = Array<typename OutputTileIterator::Element,
115 OutputTileIterator::kElementsPerAccess>;
116
118using AccumulatorAccessType =
119 Array<ElementAccumulator, OutputTileIterator::kElementsPerAccess>;
120
122using WarpCount =
123gemm::GemmShape<Shape::kM / WarpMmaOperator::Shape::kM,
124 Shape::kN / WarpMmaOperator::Shape::kN, kPartitionsK>;
125
126public:
127static_assert(OutputTileIterator::kElementsPerAccess,
128"This must not be zero.");
129
130static_assert(!(OutputTileIterator::Fragment::kElements %
131 OutputTileIterator::kElementsPerAccess),
132"Divisibility");
133
135struct SharedStorage {};
136
137
138public:
140 CUTLASS_DEVICE
142SharedStorage &shared_storage,
143int thread_idx,
144int warp_idx,
145int lane_idx
146 ) {}
147
149 CUTLASS_DEVICE
150void operator()(
151OutputOp const &output_op,
152OutputTileIterator destination_iterator,
153AccumulatorTile const &accumulators,
154OutputTileIterator source_iterator) {
155
156//
157// Predicated tile iterators constructed from members
158//
159
160if (IsBetaZero && output_op.is_source_needed())
161 assert(0);
162
163typename OutputTileIterator::Fragment source_fragment;
164
165if (!IsBetaZero) {
166if (!output_op.is_source_needed()) {
167 source_iterator.clear_mask();
168 }
169 }
170
171 source_fragment.clear();
172
173//
174// Iterator over warp-level accumulator fragment
175//
176
177AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
178
179//
180// Iterate over accumulator tile
181//
182
184for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
185//
186// Load the source
187//
188
189if (!IsBetaZero) {
190 source_iterator.set_iteration_index(iter);
191 source_iterator.load(source_fragment);
192 ++source_iterator;
193 }
194
195//
196// Convert fragment
197//
198
199typename AccumulatorFragmentIterator::Fragment accum_fragment;
200
201 accum_fragment_iterator.load(accum_fragment);
202 ++accum_fragment_iterator;
203
204//
205// Compute the output result
206//
207
208typename OutputTileIterator::Fragment output_fragment;
209 apply_output_operator_(output_op, output_fragment, accum_fragment, source_fragment);
210
211//
212// Store the final result
213//
214
215 destination_iterator.set_iteration_index(iter);
216 destination_iterator.store(output_fragment);
217 ++destination_iterator;
218 }
219 }
220
221private:
223 CUTLASS_DEVICE
224void apply_output_operator_(
225OutputOp const &output_op,
226typename OutputTileIterator::Fragment &output_fragment,
227typename AccumulatorFragmentIterator::Fragment const
228 &aligned_accum_fragment,
229typename OutputTileIterator::Fragment const &source_fragment) {
230OutputAccessType *output_frag_ptr =
231reinterpret_cast<OutputAccessType *>(&output_fragment);
232
233AccumulatorAccessType const *compute_frag_ptr =
234reinterpret_cast<AccumulatorAccessType const *>(
235 &aligned_accum_fragment);
236
237OutputAccessType const *source_frag_ptr =
238reinterpret_cast<OutputAccessType const *>(&source_fragment);
239
240int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
241 OutputTileIterator::kElementsPerAccess;
242
244for (int i = 0; i < kOutputOpIterations; ++i) {
245// Call the output operator
246 output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
247 }
248 }
249 };
250
252
253 } // namespace threadblock
254 } // namespace epilogue
255 } // namespace cutlass
256
cutlass::epilogue::threadblock::InterleavedEpilogue::Shape
Shape_ Shape
Definition: interleaved_epilogue.h:81
Definition: aligned_buffer.h:35
cutlass::epilogue::threadblock::InterleavedEpilogue::ElementAccumulator
typename AccumulatorTile::Element ElementAccumulator
Accumulator element.
Definition: interleaved_epilogue.h:95
[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)
Templates implementing how threads are mapped to a given tile.
cutlass::epilogue::threadblock::InterleavedEpilogue::InterleavedEpilogue
CUTLASS_DEVICE InterleavedEpilogue(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: interleaved_epilogue.h:141
cutlass::epilogue::threadblock::InterleavedEpilogue::AccumulatorTile
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: interleaved_epilogue.h:92
[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)
Epilogue for threadblock scoped GEMMs using Tensor Ops.
cutlass::epilogue::threadblock::InterleavedEpilogue::AccumulatorAccessType
Array< ElementAccumulator, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType
Array type used by output functor.
Definition: interleaved_epilogue.h:119
cutlass::epilogue::threadblock::InterleavedEpilogue
Epilogue operator without splitk.
Definition: interleaved_epilogue.h:79
cutlass::epilogue::threadblock::InterleavedEpilogue::OutputAccessType
Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType
Array type used to output.
Definition: interleaved_epilogue.h:115
cutlass::epilogue::threadblock::InterleavedEpilogue::OutputOp
OutputOp_ OutputOp
Definition: interleaved_epilogue.h:86
Defines common types used for all GEMM-like operators.
cutlass::epilogue::threadblock::InterleavedEpilogue::ConstTensorRef
typename OutputTileIterator::ConstTensorRef ConstTensorRef
Const tensor reference to source tensor.
Definition: interleaved_epilogue.h:111
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
cutlass::epilogue::threadblock::InterleavedEpilogue::TensorRef
typename OutputTileIterator::TensorRef TensorRef
Tensor reference to destination tensor.
Definition: interleaved_epilogue.h:104
cutlass::epilogue::threadblock::InterleavedEpilogue::SharedStorage
Shared storage allocation needed by the epilogue.
Definition: interleaved_epilogue.h:135
cutlass::epilogue::threadblock::InterleavedEpilogue::WarpMmaOperator
WarpMmaOperator_ WarpMmaOperator
Definition: interleaved_epilogue.h:82
cutlass::epilogue::threadblock::InterleavedEpilogue::SyncTensorRef
typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef
Tensor reference to sync tensor.
Definition: interleaved_epilogue.h:108
Definition: tensor_ref.h:146
Defines a canonical coordinate for rank=4 tensors offering named indices.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
cutlass::epilogue::threadblock::InterleavedEpilogue::OutputTileIterator
OutputTileIterator_ OutputTileIterator
Definition: interleaved_epilogue.h:85
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::epilogue::threadblock::InterleavedEpilogue::operator()
CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)
Streams the result to global memory.
Definition: interleaved_epilogue.h:150
cutlass::epilogue::threadblock::InterleavedEpilogue::kElementsPerAccess
static int const kElementsPerAccess
Output access size.
Definition: interleaved_epilogue.h:101
Defines layout functions used for rank=1 vectors.
[regular_tile_iterator.h](regular tile iterator_8h.html)
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Epilogue for threadblock scoped GEMMs using Tensor Ops.
cutlass::layout::ColumnMajorInterleaved
Definition: layout/matrix.h:343
cutlass::epilogue::threadblock::InterleavedEpilogue::kPartitionsK
static int const kPartitionsK
Definition: interleaved_epilogue.h:83
cutlass::epilogue::threadblock::InterleavedEpilogue::AccumulatorFragmentIterator
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: interleaved_epilogue.h:84
cutlass::epilogue::threadblock::InterleavedEpilogue::ElementOutput
typename OutputTileIterator::Element ElementOutput
Output element.
Definition: interleaved_epilogue.h:98
Basic include for CUTLASS.
Generated by 1.8.11