Back to Cutlass

CUTLASS: epilogue.h Source File

docs/epilogue_8h_source.html

4.4.227.3 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

epilogue.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

33 #pragma once

34

35 #include <assert.h>

36

37 #include "cutlass/cutlass.h"

38 #include "cutlass/numeric_types.h"

39 #include "cutlass/array.h"

40 #include "cutlass/layout/vector.h"

41 #include "cutlass/layout/tensor.h"

42 #include "cutlass/tensor_coord.h"

43 #include "cutlass/aligned_buffer.h"

44 #include "cutlass/functional.h"

45

46 #include "cutlass/gemm/gemm.h"

47

48 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"

49 #include "[cutlass/transform/threadblock/regular_tile_iterator.h](regular tile iterator_8h.html)"

50

51 #include "cutlass/epilogue/threadblock/epilogue_base.h"

52 #include "[cutlass/epilogue/threadblock/predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)"

53

55

56 namespace cutlass {

57 namespace epilogue {

58 namespace threadblock {

59

61

63 template <

64typename Shape_,

65typename WarpMmaOperator_,

66int PartitionsK,

67typename OutputTileIterator_,

68typename AccumulatorFragmentIterator_,

69typename WarpTileIterator_,

70typename SharedLoadIterator_,

71typename OutputOp_,

72typename Padding_

73 >

74 class Epilogue :

75public EpilogueBase<

76 Shape_,

77 WarpMmaOperator_,

78 PartitionsK,

79 AccumulatorFragmentIterator_,

80 WarpTileIterator_,

81 Padding_> {

82

83 public:

84

85using Base = EpilogueBase<

86 Shape_,

87 WarpMmaOperator_,

88 PartitionsK,

89 AccumulatorFragmentIterator_,

90 WarpTileIterator_,

91 Padding_>;

92

93using Shape = Shape_;

94using WarpMmaOperator = WarpMmaOperator_;

95static int const kPartitionsK = PartitionsK;

96using OutputTileIterator = OutputTileIterator_;

97using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;

98using WarpTileIterator = WarpTileIterator_;

99using SharedLoadIterator = SharedLoadIterator_;

100using OutputOp = OutputOp_;

101using Padding = Padding_;

102

104using Layout = layout::RowMajor;

105using LongIndex = typename Layout::LongIndex;

106

108using AccumulatorTile = typename Base::AccumulatorTile;

109

111using ElementAccumulator = typename WarpTileIterator::Element;

112

113

115using ElementOutput = typename OutputTileIterator::Element;

116

118static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;

119

121using TensorRef = typename OutputTileIterator::TensorRef;

122

124using SyncTensorRef = typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;

125

127using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;

128

130using OutputAccessType = Array<

131typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;

132

134using AccumulatorAccessType = Array<typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess>;

135

137using WarpCount = typename Base::WarpCount;

138

139 public:

140

141

142static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,

143"Mismatch between shared load iterator and output tile iterator.");

144

145static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero.");

146

147static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),

148"Divisibility");

149

150 private:

151

153SharedLoadIterator shared_load_iterator_;

154

155 public:

156

158 CUTLASS_DEVICE

159Epilogue(

160typename Base::SharedStorage &shared_storage,

161int thread_idx,

162int warp_idx,

163int lane_idx

164 ):

165Base(shared_storage, thread_idx, warp_idx, lane_idx),

166 shared_load_iterator_(shared_storage.reference(), thread_idx) { }

167

169 CUTLASS_DEVICE

170void operator()(

171OutputOp const &output_op,

172OutputTileIterator destination_iterator,

173AccumulatorTile const &accumulators,

174OutputTileIterator source_iterator) {

175

176

177typename OutputTileIterator::Fragment source_fragment;

178

179if (!output_op.is_source_needed()) {

180 source_iterator.clear_mask();

181 }

182

183 source_fragment.clear();

184

185//

186// Iterator over warp-level accumulator fragment

187//

188

189AccumulatorFragmentIterator accum_fragment_iterator(accumulators);

190

191//

192// Iterate over accumulator tile

193//

194

195CUTLASS_PRAGMA_UNROLL

196for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {

197

198//

199// Load the source

200//

201

202 source_iterator.load(source_fragment);

203 ++source_iterator;

204

205//

206// Convert and store fragment

207//

208

209 __syncthreads();

210

211typename AccumulatorFragmentIterator::Fragment accum_fragment;

212

213 accum_fragment_iterator.load(accum_fragment);

214 ++accum_fragment_iterator;

215

216 this->warp_tile_iterator_.store(accum_fragment);

217

218 __syncthreads();

219

220//

221// Load fragments from shared memory

222//

223

224typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK];

225

226 shared_load_iterator_.load(aligned_accum_fragment[0]);

227

228// If the number of k-slices is > 1 - perform a reduction amongst the k-slices

229if (kPartitionsK > 1)

230 {

231plus <typename SharedLoadIterator::Fragment> add_fragments;

232const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;

233

234CUTLASS_PRAGMA_UNROLL

235for ( int i = 1; i < kPartitionsK; ++i) {

236 shared_load_iterator_.add_tile_offset({tile_row_offset , 0});

237 shared_load_iterator_.load(aligned_accum_fragment[i]);

238 aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);

239 }

240

241 shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});

242 }

243

244//

245// Compute the output result

246//

247

248typename OutputTileIterator::Fragment output_fragment;

249

250 apply_output_operator_(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);

251

252

253//

254// Store the final result

255//

256

257 destination_iterator.store(output_fragment);

258 ++destination_iterator;

259

260 }

261 }

262

263 private:

264

266 CUTLASS_DEVICE

267void apply_output_operator_(

268typename OutputTileIterator::Fragment &output_fragment,

269OutputOp const &output_op,

270typename SharedLoadIterator::Fragment const &aligned_accum_fragment,

271typename OutputTileIterator::Fragment const &source_fragment) {

272

273OutputAccessType *output_frag_ptr =

274reinterpret_cast<OutputAccessType *>(&output_fragment);

275

276AccumulatorAccessType const *compute_frag_ptr =

277reinterpret_cast<AccumulatorAccessType const *>(&aligned_accum_fragment);

278

279OutputAccessType const *source_frag_ptr =

280reinterpret_cast<OutputAccessType const *>(&source_fragment);

281

282int const kOutputOpIterations =

283 OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;

284

285CUTLASS_PRAGMA_UNROLL

286for (int i = 0; i < kOutputOpIterations; ++i) {

287

288// Call the output operator

289 output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);

290 }

291 }

292 };

293

295

296 } // namespace threadblock

297 } // namespace epilogue

298 } // namespace cutlass

299

cutlass::layout::RowMajor::LongIndex

int64_t LongIndex

Long index type used for offsets.

Definition: layout/matrix.h:62

cutlass::epilogue::threadblock::Epilogue::LongIndex

typename Layout::LongIndex LongIndex

Definition: epilogue.h:105

cutlass::epilogue::threadblock::Epilogue::WarpCount

typename Base::WarpCount WarpCount

Number of warps.

Definition: epilogue.h:137

cutlass::epilogue::threadblock::SharedLoadIterator::Fragment

Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment

Fragment object.

Definition: shared_load_iterator.h:91

cutlass

Definition: aligned_buffer.h:35

cutlass::epilogue::threadblock::EpilogueBase::warp_tile_iterator_

WarpTileIterator warp_tile_iterator_

Stores a warp's fragment of accumulators to SMEM.

Definition: epilogue_base.h:176

[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)

Templates implementing how threads are mapped to a given tile.

cutlass::epilogue::threadblock::EpilogueBase::SharedStorage

Shared storage allocation needed by the epilogue.

Definition: epilogue_base.h:97

cutlass::epilogue::threadblock::Epilogue::operator()

CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)

Streams the result to global memory.

Definition: epilogue.h:170

cutlass::epilogue::threadblock::Epilogue::OutputTileIterator

OutputTileIterator_ OutputTileIterator

Definition: epilogue.h:96

[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)

Epilogue for threadblock scoped GEMMs using Tensor Ops.

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::epilogue::threadblock::Epilogue::Epilogue

CUTLASS_DEVICE Epilogue(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)

Constructor.

Definition: epilogue.h:159

cutlass::epilogue::threadblock::Epilogue::Shape

Shape_ Shape

Definition: epilogue.h:93

cutlass::epilogue::threadblock::Epilogue::TensorRef

typename OutputTileIterator::TensorRef TensorRef

Tensor reference to destination tensor.

Definition: epilogue.h:121

cutlass::epilogue::threadblock::EpilogueBase::WarpCount

gemm::GemmShape< Shape::kM/WarpMmaOperator::Shape::kM, Shape::kN/WarpMmaOperator::Shape::kN, kPartitionsK > WarpCount

Number of warps.

Definition: epilogue_base.h:92

cutlass::plus

Definition: functional.h:46

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

tensor.h

Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...

cutlass::epilogue::threadblock::Epilogue::kPartitionsK

static int const kPartitionsK

Definition: epilogue.h:95

cutlass::epilogue::threadblock::Epilogue::OutputOp

OutputOp_ OutputOp

Definition: epilogue.h:100

cutlass::TensorRef

Definition: tensor_ref.h:146

cutlass::epilogue::threadblock::Epilogue::Padding

Padding_ Padding

Definition: epilogue.h:101

tensor_coord.h

Defines a canonical coordinate for rank=4 tensors offering named indices.

aligned_buffer.h

AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...

cutlass::epilogue::threadblock::Epilogue::AccumulatorFragmentIterator

AccumulatorFragmentIterator_ AccumulatorFragmentIterator

Definition: epilogue.h:97

numeric_types.h

Top-level include for all CUTLASS numeric types.

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::epilogue::threadblock::Epilogue::ConstTensorRef

typename OutputTileIterator::ConstTensorRef ConstTensorRef

Const tensor reference to source tensor.

Definition: epilogue.h:127

cutlass::epilogue::threadblock::Epilogue::WarpTileIterator

WarpTileIterator_ WarpTileIterator

Definition: epilogue.h:98

cutlass::epilogue::threadblock::Epilogue::SharedLoadIterator

SharedLoadIterator_ SharedLoadIterator

Definition: epilogue.h:99

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::epilogue::threadblock::Epilogue

Epilogue operator without splitk.

Definition: epilogue.h:74

cutlass::epilogue::threadblock::Epilogue::ElementAccumulator

typename WarpTileIterator::Element ElementAccumulator

Accumulator element.

Definition: epilogue.h:111

vector.h

Defines layout functions used for rank=1 vectors.

[regular_tile_iterator.h](regular tile iterator_8h.html)

Templates implementing storing of tiles from pitch-linear rank=2 tensors.

epilogue_base.h

Epilogue for threadblock scoped GEMMs using Tensor Ops.

cutlass::epilogue::threadblock::EpilogueBase

Base class for epilogues defining warp-level.

Definition: epilogue_base.h:67

cutlass::epilogue::threadblock::Epilogue::WarpMmaOperator

WarpMmaOperator_ WarpMmaOperator

Definition: epilogue.h:94

cutlass::epilogue::threadblock::Epilogue::AccumulatorTile

typename Base::AccumulatorTile AccumulatorTile

The complete warp-level accumulator tile.

Definition: epilogue.h:108

cutlass::epilogue::threadblock::Epilogue::OutputAccessType

Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType

Array type used to output.

Definition: epilogue.h:131

cutlass::epilogue::threadblock::Epilogue::kElementsPerAccess

static int const kElementsPerAccess

Output access size.

Definition: epilogue.h:118

cutlass::epilogue::threadblock::EpilogueBase::AccumulatorTile

typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile

The complete warp-level accumulator tile.

Definition: epilogue_base.h:81

cutlass::epilogue::threadblock::Epilogue::ElementOutput

typename OutputTileIterator::Element ElementOutput

Output element.

Definition: epilogue.h:115

cutlass.h

Basic include for CUTLASS.

cutlass::epilogue::threadblock::Epilogue::SyncTensorRef

typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef

Tensor reference to sync tensor.

Definition: epilogue.h:124

functional.h

Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...

cutlass::epilogue::threadblock::Epilogue::AccumulatorAccessType

Array< typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType

Array type used by output functor.

Definition: epilogue.h:134


Generated by 1.8.11