Back to Cutlass

CUTLASS: interleaved_epilogue.h Source File

docs/interleaved__epilogue_8h_source.html

4.4.222.1 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

interleaved_epilogue.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

33 #pragma once

34

35 #include <assert.h>

36

37 #include "cutlass/cutlass.h"

38 #include "cutlass/numeric_types.h"

39 #include "cutlass/array.h"

40 #include "cutlass/layout/vector.h"

41 #include "cutlass/layout/tensor.h"

42 #include "cutlass/tensor_coord.h"

43 #include "cutlass/aligned_buffer.h"

44

45 #include "cutlass/gemm/gemm.h"

46

47 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"

48 #include "[cutlass/transform/threadblock/regular_tile_iterator.h](regular tile iterator_8h.html)"

49

50 #include "cutlass/epilogue/threadblock/epilogue_base.h"

51 #include "[cutlass/epilogue/threadblock/predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)"

52

54

55 namespace cutlass {

56 namespace epilogue {

57 namespace threadblock {

58

60

62 template <

64typename Shape_,

66typename WarpMmaOperator_,

68int PartitionsK,

70typename OutputTileIterator_,

72typename AccumulatorFragmentIterator_,

74typename OutputOp_,

76int InterleavedK,

78bool IsBetaZero = false>

79 class InterleavedEpilogue {

80public:

81using Shape = Shape_;

82using WarpMmaOperator = WarpMmaOperator_;

83static int const kPartitionsK = PartitionsK;

84using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;

85using OutputTileIterator = OutputTileIterator_;

86using OutputOp = OutputOp_;

87

89using Layout = layout::ColumnMajorInterleaved<InterleavedK>;

90

92using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;

93

95using ElementAccumulator = typename AccumulatorTile::Element;

96

98using ElementOutput = typename OutputTileIterator::Element;

99

101static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;

102

104using TensorRef = typename OutputTileIterator::TensorRef;

105

107using SyncTensorRef =

108typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;

109

111using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;

112

114using OutputAccessType = Array<typename OutputTileIterator::Element,

115 OutputTileIterator::kElementsPerAccess>;

116

118using AccumulatorAccessType =

119 Array<ElementAccumulator, OutputTileIterator::kElementsPerAccess>;

120

122using WarpCount =

123gemm::GemmShape<Shape::kM / WarpMmaOperator::Shape::kM,

124 Shape::kN / WarpMmaOperator::Shape::kN, kPartitionsK>;

125

126public:

127static_assert(OutputTileIterator::kElementsPerAccess,

128"This must not be zero.");

129

130static_assert(!(OutputTileIterator::Fragment::kElements %

131 OutputTileIterator::kElementsPerAccess),

132"Divisibility");

133

135struct SharedStorage {};

136

137

138public:

140 CUTLASS_DEVICE

141InterleavedEpilogue(

142SharedStorage &shared_storage,

143int thread_idx,

144int warp_idx,

145int lane_idx

146 ) {}

147

149 CUTLASS_DEVICE

150void operator()(

151OutputOp const &output_op,

152OutputTileIterator destination_iterator,

153AccumulatorTile const &accumulators,

154OutputTileIterator source_iterator) {

155

156//

157// Predicated tile iterators constructed from members

158//

159

160if (IsBetaZero && output_op.is_source_needed())

161 assert(0);

162

163typename OutputTileIterator::Fragment source_fragment;

164

165if (!IsBetaZero) {

166if (!output_op.is_source_needed()) {

167 source_iterator.clear_mask();

168 }

169 }

170

171 source_fragment.clear();

172

173//

174// Iterator over warp-level accumulator fragment

175//

176

177AccumulatorFragmentIterator accum_fragment_iterator(accumulators);

178

179//

180// Iterate over accumulator tile

181//

182

183CUTLASS_PRAGMA_UNROLL

184for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {

185//

186// Load the source

187//

188

189if (!IsBetaZero) {

190 source_iterator.set_iteration_index(iter);

191 source_iterator.load(source_fragment);

192 ++source_iterator;

193 }

194

195//

196// Convert fragment

197//

198

199typename AccumulatorFragmentIterator::Fragment accum_fragment;

200

201 accum_fragment_iterator.load(accum_fragment);

202 ++accum_fragment_iterator;

203

204//

205// Compute the output result

206//

207

208typename OutputTileIterator::Fragment output_fragment;

209 apply_output_operator_(output_op, output_fragment, accum_fragment, source_fragment);

210

211//

212// Store the final result

213//

214

215 destination_iterator.set_iteration_index(iter);

216 destination_iterator.store(output_fragment);

217 ++destination_iterator;

218 }

219 }

220

221private:

223 CUTLASS_DEVICE

224void apply_output_operator_(

225OutputOp const &output_op,

226typename OutputTileIterator::Fragment &output_fragment,

227typename AccumulatorFragmentIterator::Fragment const

228 &aligned_accum_fragment,

229typename OutputTileIterator::Fragment const &source_fragment) {

230OutputAccessType *output_frag_ptr =

231reinterpret_cast<OutputAccessType *>(&output_fragment);

232

233AccumulatorAccessType const *compute_frag_ptr =

234reinterpret_cast<AccumulatorAccessType const *>(

235 &aligned_accum_fragment);

236

237OutputAccessType const *source_frag_ptr =

238reinterpret_cast<OutputAccessType const *>(&source_fragment);

239

240int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /

241 OutputTileIterator::kElementsPerAccess;

242

243CUTLASS_PRAGMA_UNROLL

244for (int i = 0; i < kOutputOpIterations; ++i) {

245// Call the output operator

246 output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);

247 }

248 }

249 };

250

252

253 } // namespace threadblock

254 } // namespace epilogue

255 } // namespace cutlass

256

cutlass::epilogue::threadblock::InterleavedEpilogue::Shape

Shape_ Shape

Definition: interleaved_epilogue.h:81

cutlass

Definition: aligned_buffer.h:35

cutlass::epilogue::threadblock::InterleavedEpilogue::ElementAccumulator

typename AccumulatorTile::Element ElementAccumulator

Accumulator element.

Definition: interleaved_epilogue.h:95

[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)

Templates implementing how threads are mapped to a given tile.

cutlass::epilogue::threadblock::InterleavedEpilogue::InterleavedEpilogue

CUTLASS_DEVICE InterleavedEpilogue(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)

Constructor.

Definition: interleaved_epilogue.h:141

cutlass::epilogue::threadblock::InterleavedEpilogue::AccumulatorTile

typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile

The complete warp-level accumulator tile.

Definition: interleaved_epilogue.h:92

[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)

Epilogue for threadblock scoped GEMMs using Tensor Ops.

cutlass::epilogue::threadblock::InterleavedEpilogue::AccumulatorAccessType

Array< ElementAccumulator, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType

Array type used by output functor.

Definition: interleaved_epilogue.h:119

cutlass::epilogue::threadblock::InterleavedEpilogue

Epilogue operator without splitk.

Definition: interleaved_epilogue.h:79

cutlass::epilogue::threadblock::InterleavedEpilogue::OutputAccessType

Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType

Array type used to output.

Definition: interleaved_epilogue.h:115

cutlass::epilogue::threadblock::InterleavedEpilogue::OutputOp

OutputOp_ OutputOp

Definition: interleaved_epilogue.h:86

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::epilogue::threadblock::InterleavedEpilogue::ConstTensorRef

typename OutputTileIterator::ConstTensorRef ConstTensorRef

Const tensor reference to source tensor.

Definition: interleaved_epilogue.h:111

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

tensor.h

Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...

cutlass::epilogue::threadblock::InterleavedEpilogue::TensorRef

typename OutputTileIterator::TensorRef TensorRef

Tensor reference to destination tensor.

Definition: interleaved_epilogue.h:104

cutlass::epilogue::threadblock::InterleavedEpilogue::SharedStorage

Shared storage allocation needed by the epilogue.

Definition: interleaved_epilogue.h:135

cutlass::epilogue::threadblock::InterleavedEpilogue::WarpMmaOperator

WarpMmaOperator_ WarpMmaOperator

Definition: interleaved_epilogue.h:82

cutlass::epilogue::threadblock::InterleavedEpilogue::SyncTensorRef

typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef

Tensor reference to sync tensor.

Definition: interleaved_epilogue.h:108

cutlass::TensorRef

Definition: tensor_ref.h:146

tensor_coord.h

Defines a canonical coordinate for rank=4 tensors offering named indices.

aligned_buffer.h

AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...

cutlass::epilogue::threadblock::InterleavedEpilogue::OutputTileIterator

OutputTileIterator_ OutputTileIterator

Definition: interleaved_epilogue.h:85

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::epilogue::threadblock::InterleavedEpilogue::operator()

CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)

Streams the result to global memory.

Definition: interleaved_epilogue.h:150

cutlass::epilogue::threadblock::InterleavedEpilogue::kElementsPerAccess

static int const kElementsPerAccess

Output access size.

Definition: interleaved_epilogue.h:101

vector.h

Defines layout functions used for rank=1 vectors.

[regular_tile_iterator.h](regular tile iterator_8h.html)

Templates implementing storing of tiles from pitch-linear rank=2 tensors.

epilogue_base.h

Epilogue for threadblock scoped GEMMs using Tensor Ops.

cutlass::layout::ColumnMajorInterleaved

Definition: layout/matrix.h:343

cutlass::epilogue::threadblock::InterleavedEpilogue::kPartitionsK

static int const kPartitionsK

Definition: interleaved_epilogue.h:83

cutlass::epilogue::threadblock::InterleavedEpilogue::AccumulatorFragmentIterator

AccumulatorFragmentIterator_ AccumulatorFragmentIterator

Definition: interleaved_epilogue.h:84

cutlass::epilogue::threadblock::InterleavedEpilogue::ElementOutput

typename OutputTileIterator::Element ElementOutput

Output element.

Definition: interleaved_epilogue.h:98

cutlass.h

Basic include for CUTLASS.


Generated by 1.8.11