docs/tensor__view_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
tensor_view.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
37 #pragma once
38
39 #if !defined(__CUDACC_RTC__)
40 #include <cmath>
41 #endif
42
43 #include "cutlass/cutlass.h"
44 #include "cutlass/tensor_ref.h"
45
46 namespace cutlass {
47
49
50 template <
52typename Element_,
54typename Layout_
55 >
56 class TensorView : public TensorRef<Element_, Layout_> {
57public:
58
60using Base = cutlass::TensorRef<Element_, Layout_>;
61
64
66using ConstTensorRef = typename Base::ConstTensorRef;
67
70
73
75using Reference = Element &;
76
78static int const kRank = Layout::kRank;
79
81using Index = typename Layout::Index;
82
84using LongIndex = typename Layout::LongIndex;
85
87using TensorCoord = typename Layout::TensorCoord;
88
90using Stride = typename Layout::Stride;
91
93using ConstTensorView = TensorView<
94typename platform::remove_const<Element>::type const,
96
98using NonConstTensorView = TensorView<
99typename platform::remove_const<Element>::type,
101
105static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
106
107private:
108
110TensorCoord extent_;
111
112public:
113
114//
115// Methods
116//
117
120TensorView(TensorCoord const &extent = TensorCoord()): extent_(extent) {
121
122 }
123
127Element *ptr,
129TensorCoord const &extent
130 ):
131Base(ptr, layout), extent_(extent) {
132
133 }
134
139TensorCoord const &extent
140 ):
141Base(ref), extent_(extent) {
142
143 }
144
148NonConstTensorView const &view
149 ):
150Base(view), extent_(view.extent_) { }
151
154void reset(Element* ptr, Layout const &layout, TensorCoord size) {
155Base::reset(ptr, layout);
156 this->resize(extent_);
157 }
158
161void resize(TensorCoord extent) {
162 this->extent_ = extent;
163 }
164
167TensorCoord const& extent() const { return extent_; }
168
171Index extent(int dim) const { return extent_.at(dim); }
172
175bool contains(TensorCoord const& coord) const {
177for (int dim = 0; dim < kRank; ++dim) {
178if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) {
179return false;
180 }
181 }
182return true;
183 }
184
188return TensorRef(this->data(), this->layout());
189 }
190
193ConstTensorRef const_ref() const {
194return ConstTensorRef(this->data(), this->layout());
195 }
196
199ConstTensorView const_view() const {
200return ConstTensorView(const_ref(), extent_);
201 }
202
206TensorCoord extent,
207TensorCoord const& location = TensorCoord()
208 ) const {
209
210return TensorView(ref(), extent.clamp(extent_ - location)).add_coord_offset(location);
211 }
212
216return Base::layout().capacity(extent_);
217 }
218
222TensorCoord const& b
223 ) const {
224
225TensorView result(*this);
226 result.add_pointer_offset(this->offset(b));
227return result;
228 }
229
233TensorCoord const& b
234 ) {
235
236 this->add_pointer_offset(this->offset(b));
237return *this;
238 }
239
243TensorCoord const& b
244 ) const {
245
246TensorRef result(*this);
247 result.add_pointer_offset(-this->offset(b));
248return result;
249 }
250
254TensorCoord const& b
255 ) {
256
257 this->add_pointer_offset(-this->offset(b));
258return *this;
259 }
260 };
261
263
265 template <
266typename Element,
267typename Layout
268 >
269 CUTLASS_HOST_DEVICE TensorView<Element, Layout> make_TensorView(
270Element *ptr,
272typename Layout::TensorCoord const &extent) {
273
274return TensorView<Element, Layout>(ptr, layout, extent);
275 }
276
278
279 } // namespace cutlass
cutlass::TensorView< Element, Layout >::operator+=
CUTLASS_HOST_DEVICE TensorView & operator+=(TensorCoord const &b)
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:232
Definition: aligned_buffer.h:35
cutlass::TensorView< Element, Layout >::capacity
CUTLASS_HOST_DEVICE size_t capacity() const
Returns the number of scalar elements needed to store tensor.
Definition: tensor_view.h:215
Defines a structure containing strides, bounds, and a pointer to tensor data.
cutlass::platform::remove_const::type
T type
Definition: platform.h:351
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
static int const kRank
Logical rank of tensor index space.
Definition: tensor_view.h:78
CUTLASS_HOST_DEVICE void resize(TensorCoord extent)
Changes the size of the view without affecting pointer or layout.
Definition: tensor_view.h:161
cutlass::TensorView< Element, Layout >::operator+
CUTLASS_HOST_DEVICE TensorView operator+(TensorCoord const &b) const
Returns a TensorView offset by a given amount.
Definition: tensor_view.h:221
cutlass::TensorView< Element, Layout >::operator-=
CUTLASS_HOST_DEVICE TensorView & operator-=(TensorCoord const &b)
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:253
cutlass::TensorView< Element, Layout >::operator-
CUTLASS_HOST_DEVICE TensorView operator-(TensorCoord const &b) const
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:242
cutlass::TensorView::TensorRef
Base TensorRef
Underlying TensorRef type.
Definition: tensor_view.h:69
cutlass::TensorRef::ConstTensorRef
TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef
TensorRef to constant data.
Definition: tensor_ref.h:179
cutlass::TensorRef::add_coord_offset
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
cutlass::TensorView< Element, Layout >< Element, Layout >::Element
Element Element
Data type of individual access.
Definition: tensor_view.h:72
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::TensorView::ConstTensorView
TensorView< typename platform::remove_const< Element >::type const, Layout > ConstTensorView
TensorView pointing to constant memory.
Definition: tensor_view.h:95
Definition: tensor_view.h:56
cutlass::TensorView< Element, Layout >::reset
CUTLASS_HOST_DEVICE void reset(Element *ptr, Layout const &layout, TensorCoord size)
Updates the pointer and layout object.
Definition: tensor_view.h:154
cutlass::TensorView< Element, Layout >< Element, Layout >::TensorCoord
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_view.h:87
cutlass::TensorView< Element, Layout >< Element, Layout >::Reference
Element & Reference
Reference type to an element.
Definition: tensor_view.h:75
CUTLASS_HOST_DEVICE void reset(Element *ptr=nullptr)
Updates only the pointer.
Definition: tensor_ref.h:235
Definition: tensor_ref.h:146
CUTLASS_HOST_DEVICE TensorRef ref() const
Returns a TensorRef pointing to the first element of the tensor.
Definition: tensor_view.h:187
cutlass::TensorView< Element, Layout >< Element, Layout >::Stride
typename Layout::Stride Stride
Coordinate in storage n-D array.
Definition: tensor_view.h:90
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::TensorView< Element, Layout >::extent
CUTLASS_HOST_DEVICE Index extent(int dim) const
Returns the extent along a particular logical dimension.
Definition: tensor_view.h:171
cutlass::TensorView< Element, Layout >::subview
CUTLASS_HOST_DEVICE TensorView subview(TensorCoord extent, TensorCoord const &location=TensorCoord()) const
Returns a Tensor_view given location and size quantities.
Definition: tensor_view.h:205
cutlass::TensorView< Element, Layout >::TensorView
CUTLASS_HOST_DEVICE TensorView(TensorRef const &ref, TensorCoord const &extent)
Constructs a TensorView object.
Definition: tensor_view.h:137
cutlass::TensorView< Element, Layout >::TensorView
CUTLASS_HOST_DEVICE TensorView(NonConstTensorView const &view)
Converting constructor from TensorRef to non-constant data.
Definition: tensor_view.h:147
cutlass::TensorView< Element, Layout >< Element, Layout >::Index
typename Layout::Index Index
Index type.
Definition: tensor_view.h:81
cutlass::TensorRef< Element_, Layout_ > Base
Base tensor reference.
Definition: tensor_view.h:60
cutlass::TensorView< Element, Layout >::const_view
CUTLASS_HOST_DEVICE ConstTensorView const_view() const
Returns a TensorView to const data.
Definition: tensor_view.h:199
cutlass::TensorView< Element, Layout >< Element, Layout >::ConstTensorRef
typename Base::ConstTensorRef ConstTensorRef
TensorRef pointing to constant memory.
Definition: tensor_view.h:66
CUTLASS_HOST_DEVICE Layout & layout()
Returns the layout object.
Definition: tensor_ref.h:265
cutlass::TensorView< Element, Layout >::TensorView
CUTLASS_HOST_DEVICE TensorView(Element *ptr, Layout const &layout, TensorCoord const &extent)
Constructs a TensorView object.
Definition: tensor_view.h:126
cutlass::TensorView< Element, Layout >::TensorView
CUTLASS_HOST_DEVICE TensorView(TensorCoord const &extent=TensorCoord())
Constructs a TensorView object.
Definition: tensor_view.h:120
CUTLASS_HOST_DEVICE TensorView< Element, Layout > make_TensorView(Element *ptr, Layout const &layout, typename Layout::TensorCoord const &extent)
Constructs a TensorRef, deducing types from arguments.
Definition: tensor_view.h:269
cutlass::TensorView< Element, Layout >::contains
CUTLASS_HOST_DEVICE bool contains(TensorCoord const &coord) const
Determines whether a location is within a tensor.
Definition: tensor_view.h:175
cutlass::TensorRef::add_pointer_offset
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
Basic include for CUTLASS.
cutlass::TensorView< Element, Layout >::const_ref
CUTLASS_HOST_DEVICE ConstTensorRef const_ref() const
Returns a TensorRef pointing to the first element of the tensor.
Definition: tensor_view.h:193
cutlass::TensorView< Element, Layout >< Element, Layout >::LongIndex
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_view.h:84
cutlass::TensorView< Element, Layout >< Element, Layout >::Layout
Layout Layout
Mapping function from logical coordinate to internal n-D array.
Definition: tensor_view.h:63
Generated by 1.8.11