Back to Cutlass

CUTLASS: tensor_view_io.h Source File

docs/tensor__view__io_8h_source.html

4.4.29.3 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

tensor_view_io.h

[Go to the documentation of this file.](tensor view io_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

25 #pragma once

26

27 #include "cutlass/core_io.h"

28 #include "cutlass/tensor_view.h"

29

30 namespace cutlass {

31

33

34 namespace detail {

35

37 template <

38typename Element,

39typename Layout

40 >

41 inline std::ostream & TensorView_WriteLeastSignificantRank(

42 std::ostream& out,

43TensorView<Element, Layout> const& view,

44Coord<Layout::kRank> const &start_coord,

45int rank,

46 std::streamsize width) {

47

48for (int idx = 0; idx < view.extent(rank); ++idx) {

49

50Coord<Layout::kRank> coord(start_coord);

51 coord[rank] = idx;

52

53if (idx) {

54 out.width(0);

55 out << ", ";

56 }

57if (idx || coord) {

58 out.width(width);

59 }

60 out << ScalarIO<Element>(view.at(coord));

61 }

62

63return out;

64 }

65

67 template <

68typename Element,

69typename Layout

70 >

71 inline std::ostream & TensorView_WriteRank(

72 std::ostream& out,

73TensorView<Element, Layout> const& view,

74Coord<Layout::kRank> const &start_coord,

75int rank,

76 std::streamsize width) {

77

78// If called on the least significant rank, write the result as a row

79if (rank + 1 == Layout::kRank) {

80return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width);

81 }

82

83// Otherwise, write a sequence of rows and newlines

84for (int idx = 0; idx < view.extent(rank); ++idx) {

85

86Coord<Layout::kRank> coord(start_coord);

87 coord[rank] = idx;

88

89if (rank + 2 == Layout::kRank) {

90// Write least significant ranks asa matrix with rows delimited by ";\n"

91 out << (idx ? ";\n" : "");

92TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width);

93 }

94else {

95// Higher ranks are separated by newlines

96 out << (idx ? "\n" : "");

97TensorView_WriteRank(out, view, coord, rank + 1, width);

98 }

99 }

100

101return out;

102 }

103

104 } // namespace detail

105

107

109 template <

110typename Element,

111typename Layout

112 >

113 inline std::ostream& TensorViewWrite(

114 std::ostream& out,

115TensorView<Element, Layout> const& view) {

116

117// Prints a TensorView according to the following conventions:

118// - least significant rank is printed as rows separated by ";\n"

119// - all greater ranks are delimited with newlines

120//

121// The result is effectively a whitespace-delimited series of 2D matrices.

122

123return detail::TensorView_WriteRank(out, view, Coord<Layout::kRank>(), 0, out.width());

124 }

125

127 template <

128typename Element,

129typename Layout

130 >

131 inline std::ostream& operator<<(

132 std::ostream& out,

133TensorView<Element, Layout> const& view) {

134

135// Prints a TensorView according to the following conventions:

136// - least significant rank is printed as rows separated by ";\n"

137// - all greater ranks are delimited with newlines

138//

139// The result is effectively a whitespace-delimited series of 2D matrices.

140

141return TensorViewWrite(out, view);

142 }

143

145

146 } // namespace cutlass

core_io.h

Helpers for printing cutlass/core objects.

cutlass

Definition: aligned_buffer.h:35

cutlass::TensorView::extent

CUTLASS_HOST_DEVICE TensorCoord const & extent() const

Returns the extent of the view (the size along each logical dimension).

Definition: tensor_view.h:167

tensor_view.h

Defines a structure containing strides and a pointer to tensor data.

cutlass::TensorViewWrite

std::ostream & TensorViewWrite(std::ostream &out, TensorView< Element, Layout > const &view)

Prints human-readable representation of a TensorView to an ostream.

Definition: tensor_view_io.h:113

cutlass::detail::TensorView_WriteLeastSignificantRank

std::ostream & TensorView_WriteLeastSignificantRank(std::ostream &out, TensorView< Element, Layout > const &view, Coord< Layout::kRank > const &start_coord, int rank, std::streamsize width)

Helper to write the least significant rank of a TensorView.

Definition: tensor_view_io.h:41

cutlass::TensorView< Element, Layout >

cutlass::Coord< Layout::kRank >

cutlass::TensorRef::at

CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const

Returns a reference to the element at a given Coord.

Definition: tensor_ref.h:307

cutlass::operator<<

std::ostream & operator<<(std::ostream &out, complex< T > const &z)

Definition: complex.h:291

cutlass::detail::TensorView_WriteRank

std::ostream & TensorView_WriteRank(std::ostream &out, TensorView< Element, Layout > const &view, Coord< Layout::kRank > const &start_coord, int rank, std::streamsize width)

Helper to write a rank of a TensorView.

Definition: tensor_view_io.h:71


Generated by 1.8.11