docs/linear__combination__clamp_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
linear_combination_clamp.h
[Go to the documentation of this file.](linear combination clamp_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
30 #pragma once
31
32 #include "cutlass/cutlass.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/array.h"
35 #include "cutlass/functional.h"
36 #include "cutlass/numeric_conversion.h"
37
39
40 namespace cutlass {
41 namespace epilogue {
42 namespace thread {
43
45
51 template <
52typename ElementOutput_,
53int Count,
54typename ElementAccumulator_ = ElementOutput_,
55typename ElementCompute_ = ElementOutput_,
56FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
57 >
58 class LinearCombinationClamp {
59 public:
60
61using ElementOutput = ElementOutput_;
62using ElementAccumulator = ElementAccumulator_;
63using ElementCompute = ElementCompute_;
64
65static int const kCount = Count;
66
67using FragmentOutput = Array<ElementOutput, kCount>;
68using FragmentAccumulator = Array<ElementAccumulator, kCount>;
69using ComputeFragment = Array<ElementCompute, kCount>;
70
71static FloatRoundStyle const kRound = Round;
72
75
78ElementCompute const *alpha_ptr;
79ElementCompute const *beta_ptr;
80
81//
82// Methods
83//
84
87 alpha(ElementCompute(1)),
88 beta(ElementCompute(0)),
89 alpha_ptr(nullptr),
90 beta_ptr(nullptr) { }
91
94ElementCompute alpha,
95ElementCompute beta
96 ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
97
98 }
99
102ElementCompute const *alpha_ptr,
103ElementCompute const *beta_ptr
104 ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
105
106 }
107 };
108
109 private:
110
111//
112// Data members
113//
114
115ElementCompute alpha_;
116ElementCompute beta_;
117
118 public:
119
122LinearCombinationClamp(Params const ¶ms) {
123
124 alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
125 beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
126 }
127
130bool is_source_needed() const {
131return beta_ != ElementCompute(0);
132 }
133
136void set_k_partition(int k_partition) {
137if (k_partition) {
138 beta_ = ElementCompute(1);
139 }
140 }
141
145FragmentAccumulator const &accumulator,
146FragmentOutput const &source,
147ElementCompute uniform = ElementCompute(0)) const {
148
149// Convert source to interal compute numeric type
150NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
151NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
152
153ComputeFragment converted_source = source_converter(source);
154ComputeFragment converted_accumulator = accumulator_converter(accumulator);
155
156// Perform binary operations
157
158ComputeFragment intermediate;
159
160multiplies<ComputeFragment> mul_add_source;
161multiply_add<ComputeFragment> mul_add_accumulator;
162
163minimum<ComputeFragment> min_accumulator;
164maximum<ComputeFragment> max_accumulator;
165
166 intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
167 intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
168
170ElementCompute const kClamp = ElementCompute(1 << (sizeof_bits<ElementOutput>::value - 1));
171
172 intermediate = max_accumulator(intermediate, -kClamp);
173 intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
174
175// Convert to destination numeric type
176NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
177
178return destination_converter(intermediate);
179 }
180
181 };
182
184
185 // Conditional guards to enable partial specialization for packed integers
186 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
187
193 template <
194typename ElementOutput_,
195int Count,
196FloatRoundStyle Round
197 >
198 class LinearCombinationClamp<ElementOutput_, Count, int, float, Round> {
199 public:
200
201using ElementOutput = ElementOutput_;
202using ElementAccumulator = int;
203using ElementCompute = float;
204
205static int const kCount = Count;
206
207using FragmentOutput = Array<ElementOutput, kCount>;
208using FragmentAccumulator = Array<ElementAccumulator, kCount>;
209using ComputeFragment = Array<ElementCompute, kCount>;
210
211static FloatRoundStyle const kRound = Round;
212
214struct Params {
215
216ElementCompute alpha;
217ElementCompute beta;
218ElementCompute const *alpha_ptr;
219ElementCompute const *beta_ptr;
220
221//
222// Methods
223//
224
226Params():
227alpha(ElementCompute(1)),
228beta(ElementCompute(0)),
229alpha_ptr(nullptr),
230beta_ptr(nullptr) { }
231
233Params(
234ElementCompute alpha,
235ElementCompute beta
236 ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
237
238 }
239
241Params(
242ElementCompute const *alpha_ptr,
243ElementCompute const *beta_ptr
244 ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
245
246 }
247 };
248
249 private:
250
251//
252// Data members
253//
254
255ElementCompute alpha_;
256ElementCompute beta_;
257
258 public:
259
262LinearCombinationClamp(Params const ¶ms) {
263
264 alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
265 beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
266 }
267
270bool is_source_needed() const {
271return beta_ != ElementCompute(0);
272 }
273
276void set_k_partition(int k_partition) {
277if (k_partition) {
278 beta_ = ElementCompute(1);
279 }
280 }
281
285FragmentAccumulator const &accumulator,
286FragmentOutput const &source,
287ElementCompute uniform = ElementCompute(0)) const {
288
289// Convert source to interal compute numeric type
290NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
291NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
292
293ComputeFragment converted_source = source_converter(source);
294ComputeFragment converted_accumulator = accumulator_converter(accumulator);
295
296// Compute linear scaling in floating point
297ComputeFragment intermediate;
298
299multiplies<ComputeFragment> mul_add_source;
300multiply_add<ComputeFragment> mul_add_accumulator;
301
302// Float min-max
303 intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
304 intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
305
306// Convert floats back to INT
307FragmentAccumulator scaled_accumulator;
308
310for (int i = 0; i < kCount; ++i) {
311 scaled_accumulator[i] = static_cast<int>(intermediate[i]);
312 }
313
314// Convert to destination numeric type
315NumericArrayConverter<ElementOutput, int, kCount, Round> destination_converter;
316
317return destination_converter(scaled_accumulator);
318 }
319 };
320
321 #endif // Conditional guards to enable partial specialization for packed integers
322
324
325 } // namespace thread
326 } // namespace epilogue
327 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
cutlass::epilogue::thread::LinearCombinationClamp::ElementCompute
ElementCompute_ ElementCompute
Definition: linear_combination_clamp.h:63
Definition: aligned_buffer.h:35
cutlass::epilogue::thread::LinearCombinationClamp::Params::beta
ElementCompute beta
scales source tensor
Definition: linear_combination_clamp.h:77
cutlass::epilogue::thread::LinearCombinationClamp::Params::Params
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
Definition: linear_combination_clamp.h:101
cutlass::epilogue::thread::LinearCombinationClamp::Params::Params
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta)
Definition: linear_combination_clamp.h:93
cutlass::epilogue::thread::LinearCombinationClamp
Definition: linear_combination_clamp.h:58
Definition: functional.h:298
Definition: functional.h:235
cutlass::epilogue::thread::LinearCombinationClamp::kCount
static int const kCount
Definition: linear_combination_clamp.h:65
cutlass::epilogue::thread::LinearCombinationClamp::operator()
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination_clamp.h:144
cutlass::epilogue::thread::LinearCombinationClamp::Params::Params
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_clamp.h:86
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Boost-like numeric conversion operator for CUTLASS numeric types.
Defines the size of an element in bits.
Definition: numeric_types.h:42
#define nullptr
nullptr
Definition: platform.h:144
cutlass::epilogue::thread::LinearCombinationClamp::LinearCombinationClamp
CUTLASS_HOST_DEVICE LinearCombinationClamp(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination_clamp.h:122
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
cutlass::epilogue::thread::LinearCombinationClamp::ComputeFragment
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_clamp.h:69
cutlass::epilogue::thread::LinearCombinationClamp::ElementOutput
ElementOutput_ ElementOutput
Definition: linear_combination_clamp.h:61
cutlass::epilogue::thread::LinearCombinationClamp::FragmentOutput
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_clamp.h:67
cutlass::epilogue::thread::LinearCombinationClamp::ElementAccumulator
ElementAccumulator_ ElementAccumulator
Definition: linear_combination_clamp.h:62
cutlass::FloatRoundStyle::round_to_nearest
round to nearest even
cutlass::epilogue::thread::LinearCombinationClamp::Params::beta_ptr
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination_clamp.h:79
cutlass::epilogue::thread::LinearCombinationClamp::set_k_partition
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination_clamp.h:136
FloatRoundStyle
Definition: numeric_conversion.h:43
cutlass::NumericArrayConverter
Conversion operator for Array.
Definition: numeric_conversion.h:294
cutlass::epilogue::thread::LinearCombinationClamp::Params
Host-constructable parameters structure.
Definition: linear_combination_clamp.h:74
cutlass::epilogue::thread::LinearCombinationClamp::kRound
static FloatRoundStyle const kRound
Definition: linear_combination_clamp.h:71
cutlass::epilogue::thread::LinearCombinationClamp::is_source_needed
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination_clamp.h:130
Basic include for CUTLASS.
cutlass::epilogue::thread::LinearCombinationClamp::Params::alpha_ptr
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination_clamp.h:78
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
cutlass::epilogue::thread::LinearCombinationClamp::Params::alpha
ElementCompute alpha
scales accumulators
Definition: linear_combination_clamp.h:76
cutlass::epilogue::thread::LinearCombinationClamp::FragmentAccumulator
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_clamp.h:68
Generated by 1.8.11