Back to Cutlass

CUTLASS: epilogue_base.h Source File

docs/epilogue__base_8h_source.html

4.4.222.4 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

epilogue_base.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

33 #pragma once

34

35 #include <assert.h>

36

37 #include "cutlass/cutlass.h"

38 #include "cutlass/matrix_shape.h"

39 #include "cutlass/numeric_types.h"

40 #include "cutlass/array.h"

41 #include "cutlass/layout/vector.h"

42 #include "cutlass/layout/tensor.h"

43 #include "cutlass/tensor_coord.h"

44 #include "cutlass/aligned_buffer.h"

45

46 #include "cutlass/gemm/gemm.h"

47

48 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"

49

51

52 namespace cutlass {

53 namespace epilogue {

54 namespace threadblock {

55

57

59 template <

60typename Shape_,

61typename WarpMmaOperator_,

62int PartitionsK,

63typename AccumulatorFragmentIterator_,

64typename WarpTileIterator_,

65typename Padding_

66 >

67 class EpilogueBase {

68 public:

69

70using Shape = Shape_;

71using WarpMmaOperator = WarpMmaOperator_;

72static int const kPartitionsK = PartitionsK;

73using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;

74using WarpTileIterator = WarpTileIterator_;

75using Padding = Padding_;

76

78using Layout = layout::RowMajor;

79

81using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;

82

84using ElementAccumulator = typename AccumulatorTile::Element;

85

86

88using WarpCount = gemm::GemmShape<

89 Shape::kM / WarpMmaOperator::Shape::kM,

90 Shape::kN / WarpMmaOperator::Shape::kN,

91 kPartitionsK

92 >;

93

94 public:

95

97struct SharedStorage {

98

99//

100// Type definitions

101//

102

104using Element = typename WarpTileIterator::Element;

105

107using TensorRef = typename WarpTileIterator::TensorRef;

108

110using Layout = typename WarpTileIterator::Layout;

111

113using Shape = MatrixShape<

114WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,

115WarpCount::kN * WarpTileIterator::Shape::kColumn

116 >;

117

119using StorageShape = MatrixShape<

120Shape::kRow + Padding::kRow,

121Shape::kColumn + Padding::kColumn

122 >;

123

124//

125// Data members

126//

127

128AlignedBuffer<Element, StorageShape::kCount> storage;

129

130//

131// Methods

132//

133

135 CUTLASS_DEVICE

136Element *data() {

137return storage.data();

138 }

139

141 CUTLASS_DEVICE

142TensorRef reference() {

143return TensorRef(

144 storage.data(),

145Layout::packed({StorageShape::kRow, StorageShape::kColumn}));

146 }

147

148 CUTLASS_DEVICE

149void debug_print() {

150if (threadIdx.x == 0) {

151

152 #pragma unroll 1

153for (int r = 0; r < Shape::kRow; ++r) {

154

155 #pragma unroll 1

156for (int c = 0; c < Shape::kColumn; ++c) {

157

158 printf("%d ", int(storage.data()[r * StorageShape::kColumn + c]));

159 }

160 printf("\n");

161 }

162 }

163 __syncthreads();

164 }

165 };

166

167 protected:

168

169//

170// Data members

171//

172

173SharedStorage &shared_storage_;

174

176WarpTileIterator warp_tile_iterator_;

177

178 public:

179

181 CUTLASS_DEVICE

182EpilogueBase(

183SharedStorage &shared_storage,

184int thread_idx,

185int warp_idx,

186int lane_idx

187 ):

188 shared_storage_(shared_storage),

189 warp_tile_iterator_(shared_storage.reference(), lane_idx) {

190

191// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:

192//

193// _m: the warp's position within the threadblock along the M dimension

194// _n: the warp's position within the threadblock along the N dimension

195// _k: the warp's position within the threadblock along the K dimension

196

197int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);

198int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);

199int warp_m = warp_mn % WarpCount::kM;

200int warp_n = warp_mn / WarpCount::kM;

201

202MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};

203

204 warp_tile_iterator_.add_tile_offset(warp_offset);

205 }

206 };

207

209

210 } // namespace threadblock

211 } // namespace epilogue

212 } // namespace cutlass

213

cutlass::gemm::GemmShape::kM

static int const kM

Definition: include/cutlass/gemm/gemm.h:58

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass

Definition: aligned_buffer.h:35

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::debug_print

CUTLASS_DEVICE void debug_print()

Definition: epilogue_base.h:149

cutlass::MatrixShape::kColumn

static int const kColumn

columns of a matrix

Definition: matrix_shape.h:44

cutlass::epilogue::threadblock::EpilogueBase::warp_tile_iterator_

WarpTileIterator warp_tile_iterator_

Stores a warp's fragment of accumulators to SMEM.

Definition: epilogue_base.h:176

cutlass::epilogue::threadblock::EpilogueBase::shared_storage_

SharedStorage & shared_storage_

Definition: epilogue_base.h:173

[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)

Templates implementing how threads are mapped to a given tile.

cutlass::epilogue::threadblock::EpilogueBase::WarpMmaOperator

WarpMmaOperator_ WarpMmaOperator

Definition: epilogue_base.h:71

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage

Shared storage allocation needed by the epilogue.

Definition: epilogue_base.h:97

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::data

CUTLASS_DEVICE Element * data()

Returns a pointer to the shared memory buffer.

Definition: epilogue_base.h:136

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::epilogue::threadblock::EpilogueBase::ElementAccumulator

typename AccumulatorTile::Element ElementAccumulator

Accumulator element.

Definition: epilogue_base.h:84

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::TensorRef

typename WarpTileIterator::TensorRef TensorRef

Tensor reference to shared memory allocation.

Definition: epilogue_base.h:107

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

cutlass::gemm::GemmShape::kK

static int const kK

Definition: include/cutlass/gemm/gemm.h:60

tensor.h

Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::epilogue::threadblock::EpilogueBase::kPartitionsK

static int const kPartitionsK

Definition: epilogue_base.h:72

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::Element

typename WarpTileIterator::Element Element

Element type of shared memory.

Definition: epilogue_base.h:104

tensor_coord.h

Defines a canonical coordinate for rank=4 tensors offering named indices.

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::storage

AlignedBuffer< Element, StorageShape::kCount > storage

Definition: epilogue_base.h:128

cutlass::MatrixShape::kRow

static int const kRow

rows of a matrix

Definition: matrix_shape.h:43

aligned_buffer.h

AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::AlignedBuffer

Modifies semantics of cutlass::Array<> to provide guaranteed alignment.

Definition: aligned_buffer.h:45

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::AlignedBuffer::data

CUTLASS_HOST_DEVICE pointer data()

Definition: aligned_buffer.h:84

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::Layout

typename WarpTileIterator::Layout Layout

Layout of shared memory allocation.

Definition: epilogue_base.h:110

cutlass::epilogue::threadblock::EpilogueBase::AccumulatorFragmentIterator

AccumulatorFragmentIterator_ AccumulatorFragmentIterator

Definition: epilogue_base.h:73

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

vector.h

Defines layout functions used for rank=1 vectors.

cutlass::epilogue::threadblock::EpilogueBase::Shape

Shape_ Shape

Definition: epilogue_base.h:70

cutlass::epilogue::threadblock::EpilogueBase

Base class for epilogues defining warp-level.

Definition: epilogue_base.h:67

cutlass::layout::RowMajor::packed

static CUTLASS_HOST_DEVICE RowMajor packed(MatrixCoord const &extent)

Helper returns a layout to a tightly packed tensor.

Definition: layout/matrix.h:93

cutlass::epilogue::threadblock::EpilogueBase::Padding

Padding_ Padding

Definition: epilogue_base.h:75

cutlass::epilogue::threadblock::EpilogueBase::EpilogueBase

CUTLASS_DEVICE EpilogueBase(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)

Constructor.

Definition: epilogue_base.h:182

cutlass::epilogue::threadblock::EpilogueBase::WarpTileIterator

WarpTileIterator_ WarpTileIterator

Definition: epilogue_base.h:74

cutlass::epilogue::threadblock::EpilogueBase::AccumulatorTile

typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile

The complete warp-level accumulator tile.

Definition: epilogue_base.h:81

cutlass.h

Basic include for CUTLASS.

cutlass::MatrixCoord

Definition: matrix_coord.h:39

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::reference

CUTLASS_DEVICE TensorRef reference()

Returns a tensor reference to the shared memory buffer.

Definition: epilogue_base.h:142

cutlass::gemm::GemmShape::kN

static int const kN

Definition: include/cutlass/gemm/gemm.h:59


Generated by 1.8.11