docs/tensor__view__io_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
tensor_view_io.h
[Go to the documentation of this file.](tensor view io_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25 #pragma once
26
27 #include "cutlass/core_io.h"
28 #include "cutlass/tensor_view.h"
29
30 namespace cutlass {
31
33
34 namespace detail {
35
37 template <
38typename Element,
39typename Layout
40 >
41 inline std::ostream & TensorView_WriteLeastSignificantRank(
42 std::ostream& out,
43TensorView<Element, Layout> const& view,
44Coord<Layout::kRank> const &start_coord,
45int rank,
46 std::streamsize width) {
47
48for (int idx = 0; idx < view.extent(rank); ++idx) {
49
50Coord<Layout::kRank> coord(start_coord);
51 coord[rank] = idx;
52
53if (idx) {
54 out.width(0);
55 out << ", ";
56 }
57if (idx || coord) {
58 out.width(width);
59 }
60 out << ScalarIO<Element>(view.at(coord));
61 }
62
63return out;
64 }
65
67 template <
68typename Element,
69typename Layout
70 >
71 inline std::ostream & TensorView_WriteRank(
72 std::ostream& out,
73TensorView<Element, Layout> const& view,
74Coord<Layout::kRank> const &start_coord,
75int rank,
76 std::streamsize width) {
77
78// If called on the least significant rank, write the result as a row
79if (rank + 1 == Layout::kRank) {
80return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width);
81 }
82
83// Otherwise, write a sequence of rows and newlines
84for (int idx = 0; idx < view.extent(rank); ++idx) {
85
86Coord<Layout::kRank> coord(start_coord);
87 coord[rank] = idx;
88
89if (rank + 2 == Layout::kRank) {
90// Write least significant ranks asa matrix with rows delimited by ";\n"
91 out << (idx ? ";\n" : "");
92TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width);
93 }
94else {
95// Higher ranks are separated by newlines
96 out << (idx ? "\n" : "");
97TensorView_WriteRank(out, view, coord, rank + 1, width);
98 }
99 }
100
101return out;
102 }
103
104 } // namespace detail
105
107
109 template <
110typename Element,
111typename Layout
112 >
113 inline std::ostream& TensorViewWrite(
114 std::ostream& out,
115TensorView<Element, Layout> const& view) {
116
117// Prints a TensorView according to the following conventions:
118// - least significant rank is printed as rows separated by ";\n"
119// - all greater ranks are delimited with newlines
120//
121// The result is effectively a whitespace-delimited series of 2D matrices.
122
123return detail::TensorView_WriteRank(out, view, Coord<Layout::kRank>(), 0, out.width());
124 }
125
127 template <
128typename Element,
129typename Layout
130 >
131 inline std::ostream& operator<<(
132 std::ostream& out,
133TensorView<Element, Layout> const& view) {
134
135// Prints a TensorView according to the following conventions:
136// - least significant rank is printed as rows separated by ";\n"
137// - all greater ranks are delimited with newlines
138//
139// The result is effectively a whitespace-delimited series of 2D matrices.
140
141return TensorViewWrite(out, view);
142 }
143
145
146 } // namespace cutlass
Helpers for printing cutlass/core objects.
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
Defines a structure containing strides and a pointer to tensor data.
std::ostream & TensorViewWrite(std::ostream &out, TensorView< Element, Layout > const &view)
Prints human-readable representation of a TensorView to an ostream.
Definition: tensor_view_io.h:113
cutlass::detail::TensorView_WriteLeastSignificantRank
std::ostream & TensorView_WriteLeastSignificantRank(std::ostream &out, TensorView< Element, Layout > const &view, Coord< Layout::kRank > const &start_coord, int rank, std::streamsize width)
Helper to write the least significant rank of a TensorView.
Definition: tensor_view_io.h:41
cutlass::TensorView< Element, Layout >
cutlass::Coord< Layout::kRank >
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
std::ostream & operator<<(std::ostream &out, complex< T > const &z)
Definition: complex.h:291
cutlass::detail::TensorView_WriteRank
std::ostream & TensorView_WriteRank(std::ostream &out, TensorView< Element, Layout > const &view, Coord< Layout::kRank > const &start_coord, int rank, std::streamsize width)
Helper to write a rank of a TensorView.
Definition: tensor_view_io.h:71
Generated by 1.8.11