Back to Cutlass

CUTLASS: tensor_copy.h Source File

docs/tensor__copy_8h_source.html

4.4.216.7 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

tensor_copy.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

25 /* \file

26 \brief Defines host-side elementwise operations on TensorView.

27 */

28

29 #pragma once

30

31 // Standard Library includes

32 #include <utility>

33

34 // Cutlass includes

35 #include "cutlass/cutlass.h"

36 #include "tensor_foreach.h"

37

38 namespace cutlass {

39 namespace reference {

40 namespace host {

41

43

44 namespace detail {

45

47 template <

48typename DstElement,

49typename SrcElement

50 >

51 struct TrivialConvert {

52

53TrivialConvert() { }

54

55 DstElement operator()(SrcElement src) const {

56return DstElement(src);

57 }

58 };

59

61 template <

62typename DstElement,

63typename DstLayout,

64typename SrcElement,

65typename SrcLayout,

66typename F

67 >

68 struct TensorCopyIf {

69

70using DstTensorView = TensorView<DstElement, DstLayout>;

71using SrcTensorView = TensorView<SrcElement, SrcLayout>;

72

73//

74// Data members

75//

76

77DstTensorView dst;

78SrcTensorView src;

79 F convert;

80

81//

82// Methods

83//

84

85TensorCopyIf() { }

86

87TensorCopyIf(

88DstTensorView const &dst_,

89SrcTensorView const &src_,

90 F const &convert_): dst(dst_), src(src_), convert(convert_) {}

91

93void operator()(Coord<DstLayout::kRank> const &coord) {

94if (dst.contains(coord) && src.contains(coord)) {

95 dst.at(coord) = convert(src.at(coord));

96 }

97 }

98 };

99

100 } // namespace detail

101

103

105 template <

106typename DstElement,

107typename DstLayout,

108typename SrcElement,

109typename SrcLayout,

110typename F

111 >

112 void TensorCopy(

113TensorView<DstElement, DstLayout> dst,

114TensorView<SrcElement, SrcLayout> src,

115 F const &transform) {

116

117using CopyIf = detail::TensorCopyIf<

118 DstElement,

119 DstLayout,

120 SrcElement,

121 SrcLayout,

122 F>;

123

124 CopyIf copy_if(dst, src, transform);

125

126TensorForEach(dst.extent(), copy_if);

127 }

128

129

131

134 template <

135typename DstElement,

136typename DstLayout,

137typename SrcElement,

138typename SrcLayout,

139typename F

140 >

141 void TensorCopy(

142TensorView<DstElement, DstLayout> dst,

143TensorRef<SrcElement, SrcLayout> src,

144 F const &transform) {

145

146using CopyIf = detail::TensorCopyIf<

147 DstElement,

148 DstLayout,

149 SrcElement,

150 SrcLayout,

151 F>;

152

153TensorView<SrcElement, SrcLayout> src_view(src, dst.extent());

154

155 CopyIf copy_if(dst, src_view, transform);

156

157TensorForEach(dst.extent(), copy_if);

158 }

159

162 template <

163typename DstElement,

164typename DstLayout,

165typename SrcElement,

166typename SrcLayout,

167typename F

168 >

169 void TensorCopy(

170TensorRef<DstElement, DstLayout> dst,

171TensorView<SrcElement, SrcLayout> src,

172 F const &transform) {

173

174using CopyIf = detail::TensorCopyIf<

175 DstElement,

176 DstLayout,

177 SrcElement,

178 SrcLayout,

179 F>;

180

181TensorView<DstElement, DstLayout> dst_view(dst, src.extent());

182

183 CopyIf copy_if(dst_view, src, transform);

184

185TensorForEach(src.extent(), copy_if);

186 }

187

189

192 template <

193typename DstElement,

194typename DstLayout,

195typename SrcElement,

196typename SrcLayout

197 >

198 void TensorCopy(

199TensorView<DstElement, DstLayout> dst,

200TensorView<SrcElement, SrcLayout> src) {

201

202detail::TrivialConvert<DstElement, SrcElement> convert;

203

204TensorCopy(dst, src, convert);

205 }

206

208

211 template <

212typename DstElement,

213typename DstLayout,

214typename SrcElement,

215typename SrcLayout,

216typename F

217 >

218 void TensorCopy(

219TensorView<DstElement, DstLayout> dst,

220TensorRef<SrcElement, SrcLayout> src) {

221

222detail::TrivialConvert<DstElement, SrcElement> convert;

223

224TensorCopy(dst, src, convert);

225 }

226

228

231 template <

232typename DstElement,

233typename DstLayout,

234typename SrcElement,

235typename SrcLayout

236 >

237 void TensorCopy(

238TensorRef<DstElement, DstLayout> dst,

239TensorView<SrcElement, SrcLayout> src) {

240

241detail::TrivialConvert<DstElement, SrcElement> convert;

242

243TensorCopy(dst, src, convert);

244 }

245

247

248 } // namespace host

249 } // namespace reference

250 } // namespace cutlass

cutlass

Definition: aligned_buffer.h:35

cutlass::TensorView::extent

CUTLASS_HOST_DEVICE TensorCoord const & extent() const

Returns the extent of the view (the size along each logical dimension).

Definition: tensor_view.h:167

cutlass::reference::host::TensorCopy

void TensorCopy(TensorView< DstElement, DstLayout > dst, TensorView< SrcElement, SrcLayout > src, F const &transform)

Copies elements from one tensor view into another, satisfying bounds of each tensor.

Definition: tensor_copy.h:112

cutlass::reference::host::detail::TensorCopyIf::TensorCopyIf

TensorCopyIf()

Definition: tensor_copy.h:85

cutlass::reference::host::detail::TrivialConvert::operator()

DstElement operator()(SrcElement src) const

Definition: tensor_copy.h:55

cutlass::reference::host::detail::TensorCopyIf::operator()

void operator()(Coord< DstLayout::kRank > const &coord)

Copies based on destination and source bounds.

Definition: tensor_copy.h:93

cutlass::reference::host::detail::TrivialConvert

Helper to convert between types.

Definition: tensor_copy.h:51

cutlass::reference::host::detail::TensorCopyIf::TensorCopyIf

TensorCopyIf(DstTensorView const &dst_, SrcTensorView const &src_, F const &convert_)

Definition: tensor_copy.h:87

cutlass::TensorView< DstElement, DstLayout >

cutlass::TensorRef< SrcElement, SrcLayout >

cutlass::reference::host::detail::TensorCopyIf

Helper to conditionally copy between tensor views.

Definition: tensor_copy.h:68

cutlass::Coord

Statically-sized array specifying Coords within a tensor.

Definition: coord.h:43

cutlass::reference::host::detail::TensorCopyIf::convert

F convert

Definition: tensor_copy.h:79

cutlass::TensorRef::at

CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const

Returns a reference to the element at a given Coord.

Definition: tensor_ref.h:307

cutlass::reference::host::TensorForEach

void TensorForEach(Coord< Rank > extent, Func &func)

Iterates over the index space of a tensor.

Definition: host/tensor_foreach.h:87

cutlass::TensorView::contains

CUTLASS_HOST_DEVICE bool contains(TensorCoord const &coord) const

Determines whether a location is within a tensor.

Definition: tensor_view.h:175

cutlass::reference::host::detail::TrivialConvert::TrivialConvert

TrivialConvert()

Definition: tensor_copy.h:53

cutlass::reference::host::detail::TensorCopyIf::src

SrcTensorView src

Definition: tensor_copy.h:78

cutlass::reference::host::detail::TensorCopyIf::dst

DstTensorView dst

Definition: tensor_copy.h:77

cutlass.h

Basic include for CUTLASS.

tensor_foreach.h


Generated by 1.8.11