docs/host_2tensor__compare_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
host/tensor_compare.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25 /* \file
26 \brief Defines host-side elementwise operations on TensorView.
27 */
28
29 #pragma once
30
31 // Standard Library includes
32 #include <utility>
33
34 // Cutlass includes
35 #include "cutlass/cutlass.h"
36 #include "cutlass/util/distribution.h"
37 //#include "cutlass/util/type_traits.h"
38 #include "tensor_foreach.h"
39
40 namespace cutlass {
41 namespace reference {
42 namespace host {
43
46
47 namespace detail {
48
49 template <
50typename Element,
51typename Layout>
52 struct TensorEqualsFunc {
53
54//
55// Data members
56//
57
58TensorView<Element, Layout> lhs;
59TensorView<Element, Layout> rhs;
61
63TensorEqualsFunc(): result(true) { }
64
67TensorView<Element, Layout> const &lhs_,
68TensorView<Element, Layout> const &rhs_
69 ) :
70 lhs(lhs_), rhs(rhs_), result(true) { }
71
73void operator()(Coord<Layout::kRank> const &coord) {
74
75 Element lhs_ = lhs.at(coord);
76 Element rhs_ = rhs.at(coord);
77
78if (lhs_ != rhs_) {
79 result = false;
80 }
81 }
82
84operator bool() const {
85return result;
86 }
87 };
88
89 } // namespace detail
90
92
94 template <
95typename Element,
96typename Layout>
97 bool TensorEquals(
98TensorView<Element, Layout> const &lhs,
99TensorView<Element, Layout> const &rhs) {
100
101// Extents must be identical
102if (lhs.extent() != rhs.extent()) {
103return false;
104 }
105
106detail::TensorEqualsFunc<Element, Layout> func(lhs, rhs);
107TensorForEach(
108 lhs.extent(),
109 func
110 );
111
112return bool(func);
113 }
114
117
119 template <
120typename Element,
121typename Layout>
122 bool TensorNotEquals(
123TensorView<Element, Layout> const &lhs,
124TensorView<Element, Layout> const &rhs) {
125
126// Extents must be identical
127if (lhs.extent() != rhs.extent()) {
128return true;
129 }
130
131detail::TensorEqualsFunc<Element, Layout> func(lhs, rhs);
132TensorForEach(
133 lhs.extent(),
134 func
135 );
136
137return !bool(func);
138 }
139
142
143 namespace detail {
144
145 template <
146typename Element,
147typename Layout>
148 struct TensorContainsFunc {
149
150//
151// Data members
152//
153
154TensorView<Element, Layout> view;
157Coord<Layout::kRank> location;
158
159//
160// Methods
161//
162
164TensorContainsFunc(): contains(false) { }
165
168TensorView<Element, Layout> const &view_,
169 Element value_
170 ) :
171 view(view_), value(value_), contains(false) { }
172
174void operator()(Coord<Layout::kRank> const &coord) {
175
176if (view.at(coord) == value) {
177if (!contains) {
178 location = coord;
179 }
180 contains = true;
181 }
182 }
183
185operator bool() const {
186return contains;
187 }
188 };
189
190 } // namespace detail
191
193
195 template <
196typename Element,
197typename Layout>
198 bool TensorContains(
199TensorView<Element, Layout> const & view,
200 Element value) {
201
202detail::TensorContainsFunc<Element, Layout> func(
203 view,
204 value
205 );
206
207TensorForEach(
208 view.extent(),
209 func
210 );
211
212return bool(func);
213 }
214
216
220 template <
221typename Element,
222typename Layout>
223 std::pair<bool, Coord<Layout::kRank> > TensorFind(
224TensorView<Element, Layout> const & view,
225 Element value) {
226
227detail::TensorContainsFunc<Element, Layout> func(
228 view,
229 value
230 );
231
232TensorForEach(
233 view.extent(),
234 func
235 );
236
237return std::make_pair(bool(func), func.location);
238 }
239
242
243 } // namespace host
244 } // namespace reference
245 } // namespace cutlass
cutlass::reference::host::detail::TensorContainsFunc::TensorContainsFunc
TensorContainsFunc()
Ctor.
Definition: host/tensor_compare.h:164
Definition: aligned_buffer.h:35
cutlass::reference::host::detail::TensorContainsFunc
< Layout function
Definition: host/tensor_compare.h:148
cutlass::reference::host::detail::TensorContainsFunc::TensorContainsFunc
TensorContainsFunc(TensorView< Element, Layout > const &view_, Element value_)
Ctor.
Definition: host/tensor_compare.h:167
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
CUTLASS_HOST_DEVICE std::pair< T1, T2 > make_pair(T1 t, T2 u)
Definition: platform.h:232
cutlass::reference::host::detail::TensorContainsFunc::operator()
void operator()(Coord< Layout::kRank > const &coord)
Visits a coordinate.
Definition: host/tensor_compare.h:174
cutlass::reference::host::TensorEquals
bool TensorEquals(TensorView< Element, Layout > const &lhs, TensorView< Element, Layout > const &rhs)
Returns true if two tensor views are equal.
Definition: host/tensor_compare.h:97
cutlass::reference::host::detail::TensorContainsFunc::location
Coord< Layout::kRank > location
Definition: host/tensor_compare.h:157
cutlass::reference::host::detail::TensorContainsFunc::view
TensorView< Element, Layout > view
Definition: host/tensor_compare.h:154
cutlass::reference::host::detail::TensorEqualsFunc::lhs
TensorView< Element, Layout > lhs
Definition: host/tensor_compare.h:58
cutlass::reference::host::detail::TensorEqualsFunc::operator()
void operator()(Coord< Layout::kRank > const &coord)
Visits a coordinate.
Definition: host/tensor_compare.h:73
cutlass::reference::host::TensorContains
bool TensorContains(TensorView< Element, Layout > const &view, Element value)
Returns true if a value is present in a tensor.
Definition: host/tensor_compare.h:198
cutlass::TensorView< Element, Layout >
cutlass::reference::host::detail::TensorContainsFunc::contains
bool contains
Definition: host/tensor_compare.h:156
cutlass::reference::host::detail::TensorEqualsFunc::TensorEqualsFunc
TensorEqualsFunc()
Ctor.
Definition: host/tensor_compare.h:63
This header contains a class to parametrize a statistical distribution function.
cutlass::reference::host::TensorNotEquals
bool TensorNotEquals(TensorView< Element, Layout > const &lhs, TensorView< Element, Layout > const &rhs)
Returns true if two tensor views are NOT equal.
Definition: host/tensor_compare.h:122
cutlass::Coord< Layout::kRank >
cutlass::reference::host::detail::TensorContainsFunc::value
Element value
Definition: host/tensor_compare.h:155
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
cutlass::reference::host::detail::TensorEqualsFunc::TensorEqualsFunc
TensorEqualsFunc(TensorView< Element, Layout > const &lhs_, TensorView< Element, Layout > const &rhs_)
Ctor.
Definition: host/tensor_compare.h:66
cutlass::reference::host::TensorForEach
void TensorForEach(Coord< Rank > extent, Func &func)
Iterates over the index space of a tensor.
Definition: host/tensor_foreach.h:87
cutlass::reference::host::detail::TensorEqualsFunc::result
bool result
Definition: host/tensor_compare.h:60
cutlass::reference::host::TensorFind
std::pair< bool, Coord< Layout::kRank > > TensorFind(TensorView< Element, Layout > const &view, Element value)
< Layout function
Definition: host/tensor_compare.h:223
Basic include for CUTLASS.
cutlass::reference::host::detail::TensorEqualsFunc::rhs
TensorView< Element, Layout > rhs
Definition: host/tensor_compare.h:59
cutlass::reference::host::detail::TensorEqualsFunc
< Layout function
Definition: host/tensor_compare.h:52
Generated by 1.8.11