docs/tensor__copy_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
tensor_copy.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25 /* \file
26 \brief Defines host-side elementwise operations on TensorView.
27 */
28
29 #pragma once
30
31 // Standard Library includes
32 #include <utility>
33
34 // Cutlass includes
35 #include "cutlass/cutlass.h"
36 #include "tensor_foreach.h"
37
38 namespace cutlass {
39 namespace reference {
40 namespace host {
41
43
44 namespace detail {
45
47 template <
48typename DstElement,
49typename SrcElement
50 >
51 struct TrivialConvert {
52
53TrivialConvert() { }
54
55 DstElement operator()(SrcElement src) const {
56return DstElement(src);
57 }
58 };
59
61 template <
62typename DstElement,
63typename DstLayout,
64typename SrcElement,
65typename SrcLayout,
66typename F
67 >
68 struct TensorCopyIf {
69
70using DstTensorView = TensorView<DstElement, DstLayout>;
71using SrcTensorView = TensorView<SrcElement, SrcLayout>;
72
73//
74// Data members
75//
76
80
81//
82// Methods
83//
84
85TensorCopyIf() { }
86
88DstTensorView const &dst_,
89SrcTensorView const &src_,
90 F const &convert_): dst(dst_), src(src_), convert(convert_) {}
91
93void operator()(Coord<DstLayout::kRank> const &coord) {
94if (dst.contains(coord) && src.contains(coord)) {
95 dst.at(coord) = convert(src.at(coord));
96 }
97 }
98 };
99
100 } // namespace detail
101
103
105 template <
106typename DstElement,
107typename DstLayout,
108typename SrcElement,
109typename SrcLayout,
110typename F
111 >
112 void TensorCopy(
113TensorView<DstElement, DstLayout> dst,
114TensorView<SrcElement, SrcLayout> src,
115 F const &transform) {
116
117using CopyIf = detail::TensorCopyIf<
118 DstElement,
119 DstLayout,
120 SrcElement,
121 SrcLayout,
122 F>;
123
124 CopyIf copy_if(dst, src, transform);
125
126TensorForEach(dst.extent(), copy_if);
127 }
128
129
131
134 template <
135typename DstElement,
136typename DstLayout,
137typename SrcElement,
138typename SrcLayout,
139typename F
140 >
141 void TensorCopy(
142TensorView<DstElement, DstLayout> dst,
143TensorRef<SrcElement, SrcLayout> src,
144 F const &transform) {
145
146using CopyIf = detail::TensorCopyIf<
147 DstElement,
148 DstLayout,
149 SrcElement,
150 SrcLayout,
151 F>;
152
153TensorView<SrcElement, SrcLayout> src_view(src, dst.extent());
154
155 CopyIf copy_if(dst, src_view, transform);
156
157TensorForEach(dst.extent(), copy_if);
158 }
159
162 template <
163typename DstElement,
164typename DstLayout,
165typename SrcElement,
166typename SrcLayout,
167typename F
168 >
169 void TensorCopy(
170TensorRef<DstElement, DstLayout> dst,
171TensorView<SrcElement, SrcLayout> src,
172 F const &transform) {
173
174using CopyIf = detail::TensorCopyIf<
175 DstElement,
176 DstLayout,
177 SrcElement,
178 SrcLayout,
179 F>;
180
181TensorView<DstElement, DstLayout> dst_view(dst, src.extent());
182
183 CopyIf copy_if(dst_view, src, transform);
184
185TensorForEach(src.extent(), copy_if);
186 }
187
189
192 template <
193typename DstElement,
194typename DstLayout,
195typename SrcElement,
196typename SrcLayout
197 >
198 void TensorCopy(
199TensorView<DstElement, DstLayout> dst,
200TensorView<SrcElement, SrcLayout> src) {
201
202detail::TrivialConvert<DstElement, SrcElement> convert;
203
204TensorCopy(dst, src, convert);
205 }
206
208
211 template <
212typename DstElement,
213typename DstLayout,
214typename SrcElement,
215typename SrcLayout,
216typename F
217 >
218 void TensorCopy(
219TensorView<DstElement, DstLayout> dst,
220TensorRef<SrcElement, SrcLayout> src) {
221
222detail::TrivialConvert<DstElement, SrcElement> convert;
223
224TensorCopy(dst, src, convert);
225 }
226
228
231 template <
232typename DstElement,
233typename DstLayout,
234typename SrcElement,
235typename SrcLayout
236 >
237 void TensorCopy(
238TensorRef<DstElement, DstLayout> dst,
239TensorView<SrcElement, SrcLayout> src) {
240
241detail::TrivialConvert<DstElement, SrcElement> convert;
242
243TensorCopy(dst, src, convert);
244 }
245
247
248 } // namespace host
249 } // namespace reference
250 } // namespace cutlass
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
cutlass::reference::host::TensorCopy
void TensorCopy(TensorView< DstElement, DstLayout > dst, TensorView< SrcElement, SrcLayout > src, F const &transform)
Copies elements from one tensor view into another, satisfying bounds of each tensor.
Definition: tensor_copy.h:112
cutlass::reference::host::detail::TensorCopyIf::TensorCopyIf
TensorCopyIf()
Definition: tensor_copy.h:85
cutlass::reference::host::detail::TrivialConvert::operator()
DstElement operator()(SrcElement src) const
Definition: tensor_copy.h:55
cutlass::reference::host::detail::TensorCopyIf::operator()
void operator()(Coord< DstLayout::kRank > const &coord)
Copies based on destination and source bounds.
Definition: tensor_copy.h:93
cutlass::reference::host::detail::TrivialConvert
Helper to convert between types.
Definition: tensor_copy.h:51
cutlass::reference::host::detail::TensorCopyIf::TensorCopyIf
TensorCopyIf(DstTensorView const &dst_, SrcTensorView const &src_, F const &convert_)
Definition: tensor_copy.h:87
cutlass::TensorView< DstElement, DstLayout >
cutlass::TensorRef< SrcElement, SrcLayout >
cutlass::reference::host::detail::TensorCopyIf
Helper to conditionally copy between tensor views.
Definition: tensor_copy.h:68
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:43
cutlass::reference::host::detail::TensorCopyIf::convert
F convert
Definition: tensor_copy.h:79
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
cutlass::reference::host::TensorForEach
void TensorForEach(Coord< Rank > extent, Func &func)
Iterates over the index space of a tensor.
Definition: host/tensor_foreach.h:87
CUTLASS_HOST_DEVICE bool contains(TensorCoord const &coord) const
Determines whether a location is within a tensor.
Definition: tensor_view.h:175
cutlass::reference::host::detail::TrivialConvert::TrivialConvert
TrivialConvert()
Definition: tensor_copy.h:53
cutlass::reference::host::detail::TensorCopyIf::src
SrcTensorView src
Definition: tensor_copy.h:78
cutlass::reference::host::detail::TensorCopyIf::dst
DstTensorView dst
Definition: tensor_copy.h:77
Basic include for CUTLASS.
Generated by 1.8.11