docs/host_2tensor__elementwise_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
host/tensor_elementwise.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25 /* \file
26 \brief Defines host-side elementwise operations on TensorView.
27 */
28
29 #pragma once
30
31 // Cutlass includes
32 #include "cutlass/cutlass.h"
33 #include "cutlass/functional.h"
34
35 #include "tensor_foreach.h"
36
37 namespace cutlass {
38 namespace reference {
39 namespace host {
40
43
44 namespace detail {
45
47
49 template <
50typename ElementA,
51typename LayoutA,
52typename ElementB,
53typename LayoutB,
54typename ElementD,
55typename LayoutD,
56typename BinaryFunc>
57 struct TensorFuncBinaryOp {
58
59//
60// Data members
61//
62
64TensorView<ElementD, LayoutD> view_d;
65TensorRef<ElementA, LayoutA> ref_a;
66TensorRef<ElementB, LayoutB> ref_b;
68
69//
70// Methods
71//
72
74TensorFuncBinaryOp() { }
75
78TensorView<ElementD, LayoutD> const & view_d_,
79TensorRef<ElementA, LayoutA> const & ref_a_,
80TensorRef<ElementB, LayoutB> const & ref_b_,
81 BinaryFunc func = BinaryFunc()
82 ):
83 view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { }
84
86void operator()(Coord<LayoutD::kRank> const &coord) const {
88 ElementD(view_a.at(coord)),
89 ElementD(view_b.at(coord))
90 );
91 }
92 };
93
94 } // namespace detail
95
98
100 template <
101typename ElementD,
102typename LayoutD,
103typename ElementA,
104typename LayoutA,
105typename ElementB,
106typename LayoutB
107 >
109TensorView<ElementD, LayoutD> d,
110TensorRef<ElementA, LayoutA> a,
111TensorRef<ElementB, LayoutB> b
112 ) {
113
114detail::TensorFuncBinaryOp<
115 ElementD,
116 LayoutD,
117 ElementA,
118 LayoutA,
119 ElementB,
120 LayoutB,
122 > func(d, a, b);
123
124TensorForEach(
125 d.extent(),
126func);
127 }
128
130 template <
131typename ElementD,
132typename LayoutD,
133typename ElementA,
134typename LayoutA
135 >
137TensorView<ElementD, LayoutD> d,
138TensorRef<ElementA, LayoutA> a
139 ) {
140TensorAdd(d, d, a);
141 }
142
144
146 template <
147typename ElementD,
148typename LayoutD,
149typename ElementA,
150typename LayoutA,
151typename ElementB,
152typename LayoutB
153 >
155TensorView<ElementD, LayoutD> d,
156TensorRef<ElementA, LayoutA> a,
157TensorRef<ElementB, LayoutB> b
158 ) {
159
160detail::TensorFuncBinaryOp<
161 ElementD,
162 LayoutD,
163 ElementA,
164 LayoutA,
165 ElementB,
166 LayoutB,
168 > func(d, a, b);
169
170TensorForEach(
171 d.extent(),
172func);
173 }
174
176 template <
177typename ElementD,
178typename LayoutD,
179typename ElementA,
180typename LayoutA,
181typename ElementB,
182typename LayoutB
183 >
185TensorView<ElementD, LayoutD> d,
186TensorRef<ElementA, LayoutA> a
187 ) {
188
189TensorSub(d, d, a);
190 }
191
193
195 template <
196typename ElementD,
197typename LayoutD,
198typename ElementA,
199typename LayoutA,
200typename ElementB,
201typename LayoutB
202 >
204TensorView<ElementD, LayoutD> d,
205TensorRef<ElementA, LayoutA> a,
206TensorRef<ElementB, LayoutB> b
207 ) {
208
209detail::TensorFuncBinaryOp<
210 ElementD,
211 LayoutD,
212 ElementA,
213 LayoutA,
214 ElementB,
215 LayoutB,
216cutlass::multiplies<ElementD>
217 > func(d, a, b);
218
219TensorForEach(
220 d.extent(),
221func);
222 }
223
225 template <
226typename ElementD,
227typename LayoutD,
228typename ElementA,
229typename LayoutA
230 >
232TensorView<ElementD, LayoutD> d,
233TensorRef<ElementA, LayoutA> a
234 ) {
235TensorMul(d, d, a);
236 }
237
239
241 template <
242typename ElementD,
243typename LayoutD,
244typename ElementA,
245typename LayoutA,
246typename ElementB,
247typename LayoutB
248 >
250TensorView<ElementD, LayoutD> d,
251TensorRef<ElementA, LayoutA> a,
252TensorRef<ElementB, LayoutB> b
253 ) {
254
255detail::TensorFuncBinaryOp<
256 ElementD,
257 LayoutD,
258 ElementA,
259 LayoutA,
260 ElementB,
261 LayoutB,
263 > func(d, a, b);
264
265TensorForEach(
266 d.extent(),
267func);
268 }
269
271 template <
272typename ElementD,
273typename LayoutD,
274typename ElementA,
275typename LayoutA
276 >
278TensorView<ElementD, LayoutD> d,
279TensorRef<ElementA, LayoutA> a
280 ) {
281TensorMul(d, d, a);
282 }
283
284
286
288 template <
289typename ElementD,
290typename LayoutD,
291typename ElementA,
292typename LayoutA,
293typename ElementB,
294typename LayoutB
295 >
296 void TensorModulus(
297TensorView<ElementD, LayoutD> d,
298TensorRef<ElementA, LayoutA> a,
299TensorRef<ElementB, LayoutB> b
300 ) {
301
302detail::TensorFuncBinaryOp<
303 ElementD,
304 LayoutD,
305 ElementA,
306 LayoutA,
307 ElementB,
308 LayoutB,
309 cutlass::modulus<ElementD>
310 > func(d, a, b);
311
312TensorForEach(
313 d.extent(),
314func);
315 }
316
318 template <
319typename ElementD,
320typename LayoutD,
321typename ElementA,
322typename LayoutA
323 >
324 void TensorModulus(
325TensorView<ElementD, LayoutD> d,
326TensorRef<ElementA, LayoutA> a
327 ) {
328TensorMul(d, d, a);
329 }
330
332
333 } // namespace host
334 } // namespace reference
335 } // namespace cutlass
cutlass::reference::host::detail::TensorFuncBinaryOp::operator()
void operator()(Coord< LayoutD::kRank > const &coord) const
Equality check.
Definition: host/tensor_elementwise.h:86
Definition: aligned_buffer.h:35
cutlass::reference::host::detail::TensorFuncBinaryOp
Helper to apply a binary operator in place.
Definition: host/tensor_elementwise.h:57
cutlass::reference::host::TensorAdd
void TensorAdd(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Adds two tensors and stores in the destination tensor: d = a + b.
Definition: host/tensor_elementwise.h:108
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
cutlass::reference::host::TensorDiv
void TensorDiv(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Divides two tensors and stores in the destination tensor: d = a ./ b.
Definition: host/tensor_elementwise.h:249
cutlass::reference::host::detail::TensorFuncBinaryOp::func
BinaryFunc func
Definition: host/tensor_elementwise.h:67
Definition: functional.h:46
cutlass::reference::host::detail::TensorFuncBinaryOp::TensorFuncBinaryOp
TensorFuncBinaryOp(TensorView< ElementD, LayoutD > const &view_d_, TensorRef< ElementA, LayoutA > const &ref_a_, TensorRef< ElementB, LayoutB > const &ref_b_, BinaryFunc func=BinaryFunc())
Constructor.
Definition: host/tensor_elementwise.h:77
cutlass::TensorView< ElementD, LayoutD >
cutlass::TensorRef< ElementA, LayoutA >
cutlass::reference::host::detail::TensorFuncBinaryOp::ref_b
TensorRef< ElementB, LayoutB > ref_b
Definition: host/tensor_elementwise.h:66
Definition: functional.h:64
cutlass::reference::host::TensorSub
void TensorSub(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Subtracts two tensors and stores in the destination tensor: d = a - b.
Definition: host/tensor_elementwise.h:154
cutlass::reference::host::detail::TensorFuncBinaryOp::view_d
TensorView< ElementD, LayoutD > view_d
View of left-hand-side tensor.
Definition: host/tensor_elementwise.h:64
Definition: functional.h:73
cutlass::reference::host::TensorMul
void TensorMul(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Multiplies two tensors and stores in the destination tensor: d = a .* b.
Definition: host/tensor_elementwise.h:203
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:43
cutlass::reference::host::TensorModulus
void TensorModulus(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Divides two tensors and stores in the destination tensor: d = a ./ b.
Definition: host/tensor_elementwise.h:296
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
cutlass::reference::host::TensorForEach
void TensorForEach(Coord< Rank > extent, Func &func)
Iterates over the index space of a tensor.
Definition: host/tensor_foreach.h:87
Definition: functional.h:55
cutlass::reference::host::detail::TensorFuncBinaryOp::TensorFuncBinaryOp
TensorFuncBinaryOp()
Constructor.
Definition: host/tensor_elementwise.h:74
Basic include for CUTLASS.
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
cutlass::reference::host::detail::TensorFuncBinaryOp::ref_a
TensorRef< ElementA, LayoutA > ref_a
Definition: host/tensor_elementwise.h:65
Generated by 1.8.11