Back to Cutlass

CUTLASS: direct_epilogue_tensor_op.h Source File

docs/direct__epilogue__tensor__op_8h_source.html

4.4.224.3 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

direct_epilogue_tensor_op.h

[Go to the documentation of this file.](direct epilogue tensor__op_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

30 #pragma once

31

32 #include "cutlass/cutlass.h"

33 #include "cutlass/numeric_types.h"

34 #include "cutlass/array.h"

35

36 #include "cutlass/gemm/gemm.h"

37

39

40 namespace cutlass {

41 namespace epilogue {

42 namespace threadblock {

43

45

47 template <

48typename Shape_,

49typename Operator_,

50int PartitionsK,

51typename Element_,

52typename OutputOp_,

53typename ConvertOp_

54 >

55 class DirectEpilogueTensorOp {

56 public:

57

58using Shape = Shape_;

59using Operator = Operator_;

60

62using WarpCount = gemm::GemmShape<

63 Shape::kM / Operator::Shape::kM,

64 Shape::kN / Operator::Shape::kN,

65 PartitionsK,

66 >;

67

68static_assert(PartitionsK == 1,

69"Direct epilogue cannot be used with when the threadblock tile is partitioned along the K dimension.");

70

72using FragmentC = typename Operator::FragmentC;

73

75using Element = Element_;

76

78using Layout = layout::RowMajor;

79

81using OutputOp = OutputOp_;

82

84using ConvertOp = ConvertOp_;

85

87using TensorRef = TensorRef<Element, Layout::kRank, Layout>;

88

89 public:

90

92struct Params {

93

94//

95// Data members

96//

97

98TensorRef destination_ref;

99TensorRef source_ref;

100

101typename OutputOp::Params output_op;

102typename ConvertOp::Params convert_op;

103

104//

105// Methods

106//

107

109CUTLASS_HOST_DEVICE

110Params(

111TensorRef destination_ref_,

112TensorRef source_ref_,

113typename OutputOp::Params output_op_,

114typename ConvertOp::Params convert_op_

115 ):

116 destination_ref(destination_ref_),

117 source_ref(source_ref_),

118 output_op(output_op_),

119 convert_op(convert_op_) {

120

121 }

122

124CUTLASS_HOST_DEVICE

125Params(

126TensorRef destination_ref_,

127TensorRef source_ref_,

128typename OutputOp::Params output_op_

129 ):

130Params(

131 destination_ref,

132 source_ref,

133 output_op,

134ConvertOp::Params()

135 ) { }

136 };

137

139struct SharedStorage { };

140

141 private:

142

143OutputOp output_op;

144ConvertOp convert_op;

145

146TensorRef destination_ref_;

147TensorRef source_ref_;

148

149MatrixCoord warp_origin_;

150

151 public:

152

154 CUTLASS_DEVICE

155DirectEpilogueTensorOp(

156Params const &params,

157SharedStorage &shared_storage,

158int thread_idx,

159int warp_idx,

160int lane_idx

161 ):

162 output_op(params.output_op),

163 convert_op(params.convert_op),

164 destination_ref_(params.destination_ref),

165 source_ref_(params.source_ref) {

166

167

168// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:

169//

170// _m: the warp's position within the threadblock along the M dimension

171// _n: the warp's position within the threadblock along the N dimension

172// _k: the warp's position within the threadblock along the K dimension

173

174int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);

175int warp_m = warp_mn % WarpCount::kM;

176int warp_n = warp_mn / WarpCount::kM;

177

178 warp_origin_ = MatrixCoord{

179 warp_m * Operator::Shape::kM,

180 warp_n * Operator::Shape::kN

181 };

182

183 destination_ref_.add_coord_offset(warp_origin_);

184 source_ref_.add_coord_offset(warp_origin_);

185 }

186

188 CUTLASS_DEVICE

189void operator()(

190gemm::GemmCoord problem_size,

191gemm::GemmCoord tb_tile_coord,

192FragmentC const &accumulators) {

193

194MatrixCoord thread_origin =

195MatrixCoord{tb_tile_coord.m() * Shape::kM, tb_tile_coord.n() * Shape::kN} + warp_origin_;

196

198using MmaIterations = MatrixShape<

199 Operator::Shape::kM / Operator::Policy::Operator::Shape::kM,

200 Operator::Shape::kN / Operator::Policy::Operator::Shape::kN

201 >;

202

203// Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire

204// shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements

205// of that row. The accumulators within one row are assumed to be consecutive.

206int const kElementsPerAccess = Operator::Policy::Operator::Shape::kN / 4;

207int const kRowsPerTile = 8;

208int const kAccumulatorRows = Operator::Policy::Operator::Shape::kM / kRowsPerTile;

209

210CUTLASS_PRAGMA_UNROLL

211for (int mma_n = 0; mma_n < MmaIterations::kN; ++mma_n) {

212CUTLASS_PRAGMA_UNROLL

213for (int mma_m = 0; mma_m < MmaIterations::kM; ++mma_m) {

214

215int mma_accum_start = kAccumulatorRows * kElementsPerAccess *

216 (mma_m * MmaIterations::kN + mma_n);

217

218CUTLASS_PRAGMA_UNROLL

219for (int row = 0; row < kAccumulatorRows; ++row) {

220CUTLASS_PRAGMA_UNROLL

221for (int col = 0; col < kElementsPerAccess; ++col) {

222

223int accum_m = mma_m * Operator::Policy::Operator::Shape::kM + row * kRowsPerTile;

224int accum_n = mma_n * Operator::Policy::Operator::Shape::kN + col;

225int idx = mma_accum_start + row * kElementsPerAccess + col;

226

227MatrixCoord accum_coord = MatrixCoord{accum_m, accum_n};

228

229MatrixCoord thread_coord = thread_origin + accum_coord;

230

231if (thread_coord < MatrixCoord{problem_size.m(), problem_size.n()}) {

232

233typename ConvertOp::result_type converted_accum = convert_op(accumulators[idx]);

234

235typename OutputOp::result_type output = output_op(converted_accum, source_ref_.at(accum_coord));

236

237 destination_ref_.at(accum_coord) = output;

238 }

239 }

240 }

241 }

242 }

243 }

244 };

245

247

248 } // namespace threadblock

249 } // namespace epilogue

250 } // namespace cutlass

251

cutlass::epilogue::threadblock::DirectEpilogueTensorOp

Epilogue operator.

Definition: direct_epilogue_tensor_op.h:55

cutlass::gemm::GemmShape::kM

static int const kM

Definition: include/cutlass/gemm/gemm.h:58

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::Params

CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_)

Constructs a Params object.

Definition: direct_epilogue_tensor_op.h:125

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params

Parameters structure for host-constructible state.

Definition: direct_epilogue_tensor_op.h:92

cutlass

Definition: aligned_buffer.h:35

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::operator()

CUTLASS_DEVICE void operator()(gemm::GemmCoord problem_size, gemm::GemmCoord tb_tile_coord, FragmentC const &accumulators)

Streams the result to global memory.

Definition: direct_epilogue_tensor_op.h:189

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::destination_ref

TensorRef destination_ref

Definition: direct_epilogue_tensor_op.h:98

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::TensorRef

TensorRef< Element, Layout::kRank, Layout > TensorRef

Reference to source and destination tensors.

Definition: direct_epilogue_tensor_op.h:87

cutlass::TensorRef::add_coord_offset

CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)

Adds an offset to each pointer.

Definition: tensor_ref.h:326

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::source_ref

TensorRef source_ref

Definition: direct_epilogue_tensor_op.h:99

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::DirectEpilogueTensorOp

CUTLASS_DEVICE DirectEpilogueTensorOp(Params const &params, SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)

Constructor.

Definition: direct_epilogue_tensor_op.h:155

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::output_op

OutputOp::Params output_op

Definition: direct_epilogue_tensor_op.h:101

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::convert_op

ConvertOp::Params convert_op

Definition: direct_epilogue_tensor_op.h:102

cutlass::TensorRef< Element, Layout::kRank, Layout >

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::Params

CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_, typename ConvertOp::Params convert_op_)

Constructs a Params object.

Definition: direct_epilogue_tensor_op.h:110

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::SharedStorage

Shared storage allocation needed by the epilogue.

Definition: direct_epilogue_tensor_op.h:139

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::FragmentC

typename Operator::FragmentC FragmentC

Accumulator tile is really the warp-scoped tile.

Definition: direct_epilogue_tensor_op.h:72

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Operator

Operator_ Operator

Definition: direct_epilogue_tensor_op.h:59

cutlass::TensorRef::at

CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const

Returns a reference to the element at a given Coord.

Definition: tensor_ref.h:307

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::OutputOp

OutputOp_ OutputOp

Function operator computing final output.

Definition: direct_epilogue_tensor_op.h:81

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::ConvertOp

ConvertOp_ ConvertOp

Conversion operator to shared memory.

Definition: direct_epilogue_tensor_op.h:84

cutlass.h

Basic include for CUTLASS.

cutlass::MatrixCoord

Definition: matrix_coord.h:39

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Element

Element_ Element

Data type of output tensor.

Definition: direct_epilogue_tensor_op.h:75

cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Shape

Shape_ Shape

Definition: direct_epilogue_tensor_op.h:58

cutlass::gemm::GemmShape::kN

static int const kN

Definition: include/cutlass/gemm/gemm.h:59


Generated by 1.8.11