docs/reduce_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
reduce.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/array.h"
34 #include "cutlass/half.h"
35 #include "cutlass/functional.h"
36
37 namespace cutlass {
38 namespace reduction {
39 namespace thread {
40
42 template <typename Op, typename T>
44
46
48 template <typename T>
49 struct Reduce< plus<T>, T > {
50
52 T operator()(T lhs, T const &rhs) const {
53plus<T> _op;
54return _op(lhs, rhs);
55 }
56 };
57
59
61 template <typename T, int N>
62 struct Reduce < plus<T>, Array<T, N>> {
63
65 Array<T, 1> operator()(Array<T, N> const &in) const {
66
67 Array<T, 1> result;
68Reduce< plus<T>, T > scalar_reduce;
69 result.clear();
70
72for (auto i = 0; i < N; ++i) {
73 result[0] = scalar_reduce(result[0], in[i]);
74 }
75
76return result;
77 }
78 };
79
81
83 template <int N>
[84](structcutlass_1_1reduction_1_1thread_1_1Reduce_3_01plus_3_01half t_01_4_00_01Array_3_01half t_00_01N_01_4_01_4.html) struct Reduce < plus<half_t>, Array<half_t, N> > {
85
[87](structcutlass_1_1reduction_1_1thread_1_1Reduce_3_01plus_3_01half t_01_4_00_01Array_3_01half t_00_01N_01_4_01_4.html#afe8938d7e9d806511480831668ef1563) Array<half_t, 1> [operator()](structcutlass_1_1reduction_1_1thread_1_1Reduce_3_01plus_3_01half t_01_4_00_01Array_3_01half t_00_01N_01_4_01_4.html#afe8938d7e9d806511480831668ef1563)(Array<half_t, N> const &input) {
88
89 Array<half_t, 1> result;
90
91// If there is only 1 element - there is nothing to reduce
92if( N ==1 ){
93
94 result[0] = input.front();
95
96 } else {
97
98 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
99
100 __half result_d;
101 Array<half_t, 1> const *in_ptr_half = reinterpret_cast<Array<half_t, 1> const *>(&input);
102 Array<half_t, 2> const *in_ptr_half2 = reinterpret_cast<Array<half_t, 2> const *>(&input);
103 __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2);
104
105// Set initial result = first half2, in case N==2
106 __half2 tmp_result = x_in_half2[0];
107
109for (int i = 1; i < N/2; ++i) {
110
111 tmp_result = __hadd2(x_in_half2[i], tmp_result);
112
113 }
114
115 result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
116
117// One final step is needed for odd "N" (to add the (N-1)th element)
118if( N%2 ){
119
120 __half last_element;
121 Array<half_t, 1> tmp_last;
122 Array<half_t, 1> *tmp_last_ptr = &tmp_last;
123 tmp_last_ptr[0] = in_ptr_half[N-1];
124 last_element = reinterpret_cast<__half const &>(tmp_last);
125
126 result_d = __hadd(result_d, last_element);
127
128 }
129
130 Array<half_t, 1> *result_ptr = &result;
131 *result_ptr = reinterpret_cast<Array<half_t, 1> &>(result_d);
132
133 #else
134
135Reduce< plus<half_t>, half_t > scalar_reduce;
136 result.clear();
137
139for (auto i = 0; i < N; ++i) {
140
141 result[0] = scalar_reduce(result[0], input[i]);
142
143 }
144
145 #endif
146 }
147
148return result;
149
150 }
151 };
152
153
155
157 template <int N>
[158](structcutlass_1_1reduction_1_1thread_1_1Reduce_3_01plus_3_01half t_01_4_00_01AlignedArray_3_01half t_00_01N_01_4_01_4.html) struct Reduce < plus<half_t>, AlignedArray<half_t, N> > {
159
[161](structcutlass_1_1reduction_1_1thread_1_1Reduce_3_01plus_3_01half t_01_4_00_01AlignedArray_3_01half t_00_01N_01_4_01_4.html#a37155e3dcc591e896df2d80c8daff4d4) Array<half_t, 1> [operator()](structcutlass_1_1reduction_1_1thread_1_1Reduce_3_01plus_3_01half t_01_4_00_01AlignedArray_3_01half t_00_01N_01_4_01_4.html#a37155e3dcc591e896df2d80c8daff4d4)(AlignedArray<half_t, N> const &input) {
162
163 Array<half_t, 1> result;
164
165// If there is only 1 element - there is nothing to reduce
166if( N ==1 ){
167
168 result[0] = input.front();
169
170 } else {
171
172 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
173
174 __half result_d;
175AlignedArray<half_t, 1> const *in_ptr_half = reinterpret_cast<AlignedArray<half_t, 1> const *>(&input);
176AlignedArray<half_t, 2> const *in_ptr_half2 = reinterpret_cast<AlignedArray<half_t, 2> const *>(&input);
177 __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2);
178
179// Set initial result = first half2, in case N==2
180 __half2 tmp_result = x_in_half2[0];
181
183for (int i = 1; i < N/2; ++i) {
184
185 tmp_result = __hadd2(x_in_half2[i], tmp_result);
186
187 }
188
189 result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
190
191// One final step is needed for odd "N" (to add the (N-1)th element)
192if( N%2 ){
193
194 __half last_element;
195AlignedArray<half_t, 1> tmp_last;
196AlignedArray<half_t, 1> *tmp_last_ptr = &tmp_last;
197 tmp_last_ptr[0] = in_ptr_half[N-1];
198 last_element = reinterpret_cast<__half const &>(tmp_last);
199
200 result_d = __hadd(result_d, last_element);
201
202 }
203
204 Array<half_t, 1> *result_ptr = &result;
205 *result_ptr = reinterpret_cast<Array<half_t, 1> &>(result_d);
206
207 #else
208
209Reduce< plus<half_t>, half_t > scalar_reduce;
210 result.clear();
211
213for (auto i = 0; i < N; ++i) {
214
215 result[0] = scalar_reduce(result[0], input[i]);
216
217 }
218
219 #endif
220 }
221
222return result;
223
224 }
225 };
226 }
227 }
228 }
Definition: aligned_buffer.h:35
Defines a class for using IEEE half-precision floating-point types in host or device code...
Aligned array type.
Definition: array.h:511
IEEE half-precision floating-point type.
Definition: half.h:126
[cutlass::reduction::thread::Reduce< plus< half_t >, AlignedArray< half_t, N > >::operator()](structcutlass_1_1reduction_1_1thread_1_1Reduce_3_01plus_3_01half t_01_4_00_01AlignedArray_3_01half t_00_01N_01_4_01_4.html#a37155e3dcc591e896df2d80c8daff4d4)
CUTLASS_HOST_DEVICE Array< half_t, 1 > operator()(AlignedArray< half_t, N > const &input)
Definition: reduce.h:161
Definition: functional.h:46
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::reduction::thread::Reduce< plus< T >, Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, 1 > operator()(Array< T, N > const &in) const
Definition: reduce.h:65
[cutlass::reduction::thread::Reduce< plus< half_t >, Array< half_t, N > >::operator()](structcutlass_1_1reduction_1_1thread_1_1Reduce_3_01plus_3_01half t_01_4_00_01Array_3_01half t_00_01N_01_4_01_4.html#afe8938d7e9d806511480831668ef1563)
CUTLASS_HOST_DEVICE Array< half_t, 1 > operator()(Array< half_t, N > const &input)
Definition: reduce.h:87
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Basic include for CUTLASS.
cutlass::reduction::thread::Reduce< plus< T >, T >::operator()
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: reduce.h:52
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
cutlass::reduction::thread::Reduce
Structure to compute the thread level reduction.
Definition: reduce.h:43
Generated by 1.8.11