docs/linear__combination_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
linear_combination.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/array.h"
34 #include "cutlass/functional.h"
35 #include "cutlass/numeric_conversion.h"
36
38
39 namespace cutlass {
40 namespace epilogue {
41 namespace thread {
42
44
49 template <
50typename ElementOutput_,
51int Count,
52typename ElementAccumulator_ = ElementOutput_,
53typename ElementCompute_ = ElementOutput_,
54FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
55 >
56 class LinearCombination {
57 public:
58
59using ElementOutput = ElementOutput_;
60using ElementAccumulator = ElementAccumulator_;
61using ElementCompute = ElementCompute_;
62
63static int const kCount = Count;
64
65using FragmentOutput = Array<ElementOutput, kCount>;
66using FragmentAccumulator = Array<ElementAccumulator, kCount>;
67using ComputeFragment = Array<ElementCompute, kCount>;
68
69static FloatRoundStyle const kRound = Round;
70
73
76ElementCompute const *alpha_ptr;
77ElementCompute const *beta_ptr;
78
79//
80// Methods
81//
82
85 alpha(ElementCompute(1)),
86 beta(ElementCompute(0)),
87 alpha_ptr(nullptr),
88 beta_ptr(nullptr) { }
89
92ElementCompute alpha,
93ElementCompute beta
94 ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
95
96 }
97
100ElementCompute const *alpha_ptr,
101ElementCompute const *beta_ptr
102 ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
103
104 }
105 };
106
107 private:
108
109//
110// Data members
111//
112
113ElementCompute alpha_;
114ElementCompute beta_;
115
116 public:
117
120LinearCombination(Params const ¶ms) {
121
122 alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
123 beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
124 }
125
128bool is_source_needed() const {
129return beta_ != ElementCompute(0);
130 }
131
134void set_k_partition(int k_partition) {
135if (k_partition) {
136 beta_ = ElementCompute(1);
137 }
138 }
139
143FragmentAccumulator const &accumulator,
144FragmentOutput const &source) const {
145
146// Convert source to interal compute numeric type
147NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
148NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;
149
150ComputeFragment converted_source = source_converter(source);
151ComputeFragment converted_accumulator = accumulator_converter(accumulator);
152
153// Perform binary operations
154
155ComputeFragment intermediate;
156
157multiplies<ComputeFragment> mul_add_source;
158multiply_add<ComputeFragment> mul_add_accumulator;
159
160 intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
161 intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
162
163// Convert to destination numeric type
164NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
165
166return destination_converter(intermediate);
167 }
168 };
169
171
172 } // namespace thread
173 } // namespace epilogue
174 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
Definition: aligned_buffer.h:35
cutlass::epilogue::thread::LinearCombination
Definition: linear_combination.h:56
cutlass::epilogue::thread::LinearCombination::kCount
static int const kCount
Definition: linear_combination.h:63
cutlass::epilogue::thread::LinearCombination::Params::alpha
ElementCompute alpha
scales accumulators
Definition: linear_combination.h:74
cutlass::epilogue::thread::LinearCombination::Params::alpha_ptr
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination.h:76
cutlass::epilogue::thread::LinearCombination::FragmentAccumulator
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination.h:66
cutlass::epilogue::thread::LinearCombination::set_k_partition
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination.h:134
cutlass::epilogue::thread::LinearCombination::ComputeFragment
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination.h:67
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
cutlass::epilogue::thread::LinearCombination::kRound
static FloatRoundStyle const kRound
Definition: linear_combination.h:69
cutlass::epilogue::thread::LinearCombination::ElementAccumulator
ElementAccumulator_ ElementAccumulator
Definition: linear_combination.h:60
cutlass::epilogue::thread::LinearCombination::operator()
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination.h:142
cutlass::epilogue::thread::LinearCombination::ElementOutput
ElementOutput_ ElementOutput
Definition: linear_combination.h:59
Boost-like numeric conversion operator for CUTLASS numeric types.
#define nullptr
nullptr
Definition: platform.h:144
cutlass::epilogue::thread::LinearCombination::Params::Params
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination.h:84
cutlass::epilogue::thread::LinearCombination::LinearCombination
CUTLASS_HOST_DEVICE LinearCombination(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination.h:120
Definition: functional.h:64
cutlass::epilogue::thread::LinearCombination::Params::Params
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta)
Definition: linear_combination.h:91
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
cutlass::epilogue::thread::LinearCombination::FragmentOutput
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination.h:65
cutlass::FloatRoundStyle::round_to_nearest
round to nearest even
cutlass::epilogue::thread::LinearCombination::Params::beta
ElementCompute beta
scales source tensor
Definition: linear_combination.h:75
cutlass::epilogue::thread::LinearCombination::Params::Params
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
Definition: linear_combination.h:99
FloatRoundStyle
Definition: numeric_conversion.h:43
cutlass::epilogue::thread::LinearCombination::ElementCompute
ElementCompute_ ElementCompute
Definition: linear_combination.h:61
cutlass::NumericArrayConverter
Conversion operator for Array.
Definition: numeric_conversion.h:294
Basic include for CUTLASS.
cutlass::epilogue::thread::LinearCombination::Params
Host-constructable parameters structure.
Definition: linear_combination.h:72
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
cutlass::epilogue::thread::LinearCombination::is_source_needed
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination.h:128
cutlass::epilogue::thread::LinearCombination::Params::beta_ptr
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination.h:77
Generated by 1.8.11