Back to Cutlass

CUTLASS: tensor_compare.h Source File

docs/host_2tensor__compare_8h_source.html

4.4.219.6 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

host/tensor_compare.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

25 /* \file

26 \brief Defines host-side elementwise operations on TensorView.

27 */

28

29 #pragma once

30

31 // Standard Library includes

32 #include <utility>

33

34 // Cutlass includes

35 #include "cutlass/cutlass.h"

36 #include "cutlass/util/distribution.h"

37 //#include "cutlass/util/type_traits.h"

38 #include "tensor_foreach.h"

39

40 namespace cutlass {

41 namespace reference {

42 namespace host {

43

46

47 namespace detail {

48

49 template <

50typename Element,

51typename Layout>

52 struct TensorEqualsFunc {

53

54//

55// Data members

56//

57

58TensorView<Element, Layout> lhs;

59TensorView<Element, Layout> rhs;

60bool result;

61

63TensorEqualsFunc(): result(true) { }

64

66TensorEqualsFunc(

67TensorView<Element, Layout> const &lhs_,

68TensorView<Element, Layout> const &rhs_

69 ) :

70 lhs(lhs_), rhs(rhs_), result(true) { }

71

73void operator()(Coord<Layout::kRank> const &coord) {

74

75 Element lhs_ = lhs.at(coord);

76 Element rhs_ = rhs.at(coord);

77

78if (lhs_ != rhs_) {

79 result = false;

80 }

81 }

82

84operator bool() const {

85return result;

86 }

87 };

88

89 } // namespace detail

90

92

94 template <

95typename Element,

96typename Layout>

97 bool TensorEquals(

98TensorView<Element, Layout> const &lhs,

99TensorView<Element, Layout> const &rhs) {

100

101// Extents must be identical

102if (lhs.extent() != rhs.extent()) {

103return false;

104 }

105

106detail::TensorEqualsFunc<Element, Layout> func(lhs, rhs);

107TensorForEach(

108 lhs.extent(),

109 func

110 );

111

112return bool(func);

113 }

114

117

119 template <

120typename Element,

121typename Layout>

122 bool TensorNotEquals(

123TensorView<Element, Layout> const &lhs,

124TensorView<Element, Layout> const &rhs) {

125

126// Extents must be identical

127if (lhs.extent() != rhs.extent()) {

128return true;

129 }

130

131detail::TensorEqualsFunc<Element, Layout> func(lhs, rhs);

132TensorForEach(

133 lhs.extent(),

134 func

135 );

136

137return !bool(func);

138 }

139

142

143 namespace detail {

144

145 template <

146typename Element,

147typename Layout>

148 struct TensorContainsFunc {

149

150//

151// Data members

152//

153

154TensorView<Element, Layout> view;

155 Element value;

156bool contains;

157Coord<Layout::kRank> location;

158

159//

160// Methods

161//

162

164TensorContainsFunc(): contains(false) { }

165

167TensorContainsFunc(

168TensorView<Element, Layout> const &view_,

169 Element value_

170 ) :

171 view(view_), value(value_), contains(false) { }

172

174void operator()(Coord<Layout::kRank> const &coord) {

175

176if (view.at(coord) == value) {

177if (!contains) {

178 location = coord;

179 }

180 contains = true;

181 }

182 }

183

185operator bool() const {

186return contains;

187 }

188 };

189

190 } // namespace detail

191

193

195 template <

196typename Element,

197typename Layout>

198 bool TensorContains(

199TensorView<Element, Layout> const & view,

200 Element value) {

201

202detail::TensorContainsFunc<Element, Layout> func(

203 view,

204 value

205 );

206

207TensorForEach(

208 view.extent(),

209 func

210 );

211

212return bool(func);

213 }

214

216

220 template <

221typename Element,

222typename Layout>

223 std::pair<bool, Coord<Layout::kRank> > TensorFind(

224TensorView<Element, Layout> const & view,

225 Element value) {

226

227detail::TensorContainsFunc<Element, Layout> func(

228 view,

229 value

230 );

231

232TensorForEach(

233 view.extent(),

234 func

235 );

236

237return std::make_pair(bool(func), func.location);

238 }

239

242

243 } // namespace host

244 } // namespace reference

245 } // namespace cutlass

cutlass::reference::host::detail::TensorContainsFunc::TensorContainsFunc

TensorContainsFunc()

Ctor.

Definition: host/tensor_compare.h:164

cutlass

Definition: aligned_buffer.h:35

cutlass::reference::host::detail::TensorContainsFunc

< Layout function

Definition: host/tensor_compare.h:148

cutlass::reference::host::detail::TensorContainsFunc::TensorContainsFunc

TensorContainsFunc(TensorView< Element, Layout > const &view_, Element value_)

Ctor.

Definition: host/tensor_compare.h:167

cutlass::TensorView::extent

CUTLASS_HOST_DEVICE TensorCoord const & extent() const

Returns the extent of the view (the size along each logical dimension).

Definition: tensor_view.h:167

cutlass::platform::make_pair

CUTLASS_HOST_DEVICE std::pair< T1, T2 > make_pair(T1 t, T2 u)

Definition: platform.h:232

cutlass::reference::host::detail::TensorContainsFunc::operator()

void operator()(Coord< Layout::kRank > const &coord)

Visits a coordinate.

Definition: host/tensor_compare.h:174

cutlass::reference::host::TensorEquals

bool TensorEquals(TensorView< Element, Layout > const &lhs, TensorView< Element, Layout > const &rhs)

Returns true if two tensor views are equal.

Definition: host/tensor_compare.h:97

cutlass::reference::host::detail::TensorContainsFunc::location

Coord< Layout::kRank > location

Definition: host/tensor_compare.h:157

cutlass::reference::host::detail::TensorContainsFunc::view

TensorView< Element, Layout > view

Definition: host/tensor_compare.h:154

cutlass::reference::host::detail::TensorEqualsFunc::lhs

TensorView< Element, Layout > lhs

Definition: host/tensor_compare.h:58

cutlass::reference::host::detail::TensorEqualsFunc::operator()

void operator()(Coord< Layout::kRank > const &coord)

Visits a coordinate.

Definition: host/tensor_compare.h:73

cutlass::reference::host::TensorContains

bool TensorContains(TensorView< Element, Layout > const &view, Element value)

Returns true if a value is present in a tensor.

Definition: host/tensor_compare.h:198

cutlass::TensorView< Element, Layout >

cutlass::reference::host::detail::TensorContainsFunc::contains

bool contains

Definition: host/tensor_compare.h:156

cutlass::reference::host::detail::TensorEqualsFunc::TensorEqualsFunc

TensorEqualsFunc()

Ctor.

Definition: host/tensor_compare.h:63

distribution.h

This header contains a class to parametrize a statistical distribution function.

cutlass::reference::host::TensorNotEquals

bool TensorNotEquals(TensorView< Element, Layout > const &lhs, TensorView< Element, Layout > const &rhs)

Returns true if two tensor views are NOT equal.

Definition: host/tensor_compare.h:122

cutlass::Coord< Layout::kRank >

cutlass::reference::host::detail::TensorContainsFunc::value

Element value

Definition: host/tensor_compare.h:155

cutlass::TensorRef::at

CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const

Returns a reference to the element at a given Coord.

Definition: tensor_ref.h:307

cutlass::reference::host::detail::TensorEqualsFunc::TensorEqualsFunc

TensorEqualsFunc(TensorView< Element, Layout > const &lhs_, TensorView< Element, Layout > const &rhs_)

Ctor.

Definition: host/tensor_compare.h:66

cutlass::reference::host::TensorForEach

void TensorForEach(Coord< Rank > extent, Func &func)

Iterates over the index space of a tensor.

Definition: host/tensor_foreach.h:87

cutlass::reference::host::detail::TensorEqualsFunc::result

bool result

Definition: host/tensor_compare.h:60

cutlass::reference::host::TensorFind

std::pair< bool, Coord< Layout::kRank > > TensorFind(TensorView< Element, Layout > const &view, Element value)

< Layout function

Definition: host/tensor_compare.h:223

cutlass.h

Basic include for CUTLASS.

cutlass::reference::host::detail::TensorEqualsFunc::rhs

TensorView< Element, Layout > rhs

Definition: host/tensor_compare.h:59

cutlass::reference::host::detail::TensorEqualsFunc

< Layout function

Definition: host/tensor_compare.h:52

tensor_foreach.h


Generated by 1.8.11