Back to Cutlass

CUTLASS: default_epilogue_wmma_tensor_op.h Source File

docs/default__epilogue__wmma__tensor__op_8h_source.html

4.4.219.8 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

default_epilogue_wmma_tensor_op.h

[Go to the documentation of this file.](default epilogue wmma tensor op_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

33 #pragma once

34

35 #include "cutlass/cutlass.h"

36 #include "cutlass/numeric_types.h"

37 #include "cutlass/array.h"

38

39 #include "cutlass/gemm/gemm.h"

40

41 #include "cutlass/epilogue/thread/linear_combination.h"

42 #include "cutlass/epilogue/thread/conversion_op.h"

43 #include "cutlass/epilogue/thread/reduction_op.h"

44

45 #include "[cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h](regular tile iterator pitch linear_8h.html)"

46

47 #include "[cutlass/epilogue/warp/fragment_iterator_wmma_tensor_op.h](fragment iterator wmma tensor op_8h.html)"

48 #include "[cutlass/epilogue/warp/tile_iterator_wmma_tensor_op.h](tile iterator wmma tensor op_8h.html)"

49 #include "[cutlass/epilogue/threadblock/default_thread_map_wmma_tensor_op.h](default thread map wmma tensor__op_8h.html)"

50 #include "[cutlass/epilogue/threadblock/predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)"

51 #include "[cutlass/epilogue/threadblock/shared_load_iterator.h](shared load iterator_8h.html)"

52

53 #include "cutlass/epilogue/threadblock/epilogue.h"

54

56

57 namespace cutlass {

58 namespace epilogue {

59 namespace threadblock {

60

62

64 template <

65typename Shape_,

66typename WarpMmaTensorOp_,

67int PartitionsK,

68typename OutputOp_,

69int ElementsPerAccess

70 >

71 struct DefaultEpilogueWmmaTensorOp {

72

73using Shape = Shape_;

74using WarpMmaTensorOp = WarpMmaTensorOp_;

75static int const kPartitionsK = PartitionsK;

76using OutputOp = OutputOp_;

77static int const kElementsPerAccess = ElementsPerAccess;

78

79using ElementOutput = typename OutputOp::ElementOutput;

80using LayoutC = typename WarpMmaTensorOp::LayoutC;

81using ElementAccumulator = typename WarpMmaTensorOp::ElementC;

82

83//

84// Thread map

85//

86

87using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapWmmaTensorOp<

88Shape,

89typename WarpMmaTensorOp::Shape,

90typename WarpMmaTensorOp::Policy::Operator::Shape,

91kPartitionsK,

92ElementOutput,

93 kElementsPerAccess

94 >::Type;

95

96using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator<

97OutputTileThreadMap,

98 ElementOutput

99 >;

100

101using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp<

102typename WarpMmaTensorOp::Shape,

103typename WarpMmaTensorOp::Policy::Operator::Shape,

104typename WarpMmaTensorOp::Policy::Operator::ElementC,

105typename WarpMmaTensorOp::Policy::Operator::FragmentC,

106LayoutC

107 >;

108

109using WarpTileIterator = cutlass::epilogue::warp::TileIteratorWmmaTensorOp<

110typename WarpMmaTensorOp::Shape,

111typename WarpMmaTensorOp::Policy::Operator::Shape,

112typename WarpMmaTensorOp::Policy::Operator::FragmentC,

113LayoutC

114 >;

115

116using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIterator<

117typename OutputTileThreadMap::CompactedThreadMap,

118ElementAccumulator

119 >;

120

122using Padding = typename WarpTileIterator::Padding;

123

124//

125// Define the epilogue

126//

127using Epilogue = cutlass::epilogue::threadblock::Epilogue<

128Shape,

129WarpMmaTensorOp,

130kPartitionsK,

131OutputTileIterator,

132AccumulatorFragmentIterator,

133WarpTileIterator,

134SharedLoadIterator,

135OutputOp,

136Padding

137 >;

138 };

139

140

142

143 } // namespace threadblock

144 } // namespace epilogue

145 } // namespace cutlass

146

[regular_tile_iterator_pitch_linear.h](regular tile iterator pitch linear_8h.html)

Templates implementing loading of tiles from pitch-linear rank=2 tensors.

cutlass

Definition: aligned_buffer.h:35

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::Shape

Shape_ Shape

Definition: default_epilogue_wmma_tensor_op.h:73

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::OutputTileIterator

cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, ElementOutput > OutputTileIterator

Definition: default_epilogue_wmma_tensor_op.h:99

[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)

Epilogue for threadblock scoped GEMMs using Tensor Ops.

gemm.h

Defines common types used for all GEMM-like operators.

conversion_op.h

Functor performing conversion operations used by epilogues.

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::LayoutC

typename WarpMmaTensorOp::LayoutC LayoutC

Definition: default_epilogue_wmma_tensor_op.h:80

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::AccumulatorFragmentIterator

cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp< typename WarpMmaTensorOp::Shape, typename WarpMmaTensorOp::Policy::Operator::Shape, typename WarpMmaTensorOp::Policy::Operator::ElementC, typename WarpMmaTensorOp::Policy::Operator::FragmentC, LayoutC > AccumulatorFragmentIterator

Definition: default_epilogue_wmma_tensor_op.h:107

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::OutputOp

OutputOp_ OutputOp

Definition: default_epilogue_wmma_tensor_op.h:76

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::ElementOutput

typename OutputOp::ElementOutput ElementOutput

Definition: default_epilogue_wmma_tensor_op.h:79

linear_combination.h

Functor performing linear combination operations used by epilogues.

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::kPartitionsK

static int const kPartitionsK

Definition: default_epilogue_wmma_tensor_op.h:75

[tile_iterator_wmma_tensor_op.h](tile iterator wmma tensor op_8h.html)

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::Padding

typename WarpTileIterator::Padding Padding

Hard-coded padding elements added.

Definition: default_epilogue_wmma_tensor_op.h:122

[shared_load_iterator.h](shared load iterator_8h.html)

Epilogue for threadblock scoped GEMMs using Tensor Ops.

cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp

Definition: fragment_iterator_wmma_tensor_op.h:63

cutlass::epilogue::threadblock::DefaultThreadMapWmmaTensorOp

Defines the optimal thread map for Wmma TensorOp accumulator layouts.

Definition: default_thread_map_wmma_tensor_op.h:53

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::OutputTileThreadMap

typename cutlass::epilogue::threadblock::DefaultThreadMapWmmaTensorOp< Shape, typename WarpMmaTensorOp::Shape, typename WarpMmaTensorOp::Policy::Operator::Shape, kPartitionsK, ElementOutput, kElementsPerAccess >::Type OutputTileThreadMap

Definition: default_epilogue_wmma_tensor_op.h:94

cutlass::epilogue::threadblock::Epilogue

Epilogue operator without splitk.

Definition: epilogue.h:74

[fragment_iterator_wmma_tensor_op.h](fragment iterator wmma tensor op_8h.html)

This defines a "fragment" iterator for visiting the fragments of an accumulator tile that participate...

epilogue.h

Epilogue for threadblock scoped GEMMs using Tensor Ops.

cutlass::epilogue::threadblock::PredicatedTileIterator

Definition: epilogue/threadblock/predicated_tile_iterator.h:65

[default_thread_map_wmma_tensor_op.h](default thread map wmma tensor__op_8h.html)

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp

Defines sensible defaults for epilogues for WMMA TensorOps.

Definition: default_epilogue_wmma_tensor_op.h:71

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::ElementAccumulator

typename WarpMmaTensorOp::ElementC ElementAccumulator

Definition: default_epilogue_wmma_tensor_op.h:81

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::kElementsPerAccess

static int const kElementsPerAccess

Definition: default_epilogue_wmma_tensor_op.h:77

cutlass::epilogue::warp::TileIteratorWmmaTensorOp

Template for reading and writing tiles of accumulators to shared memory.

Definition: tile_iterator_wmma_tensor_op.h:56

cutlass::epilogue::threadblock::SharedLoadIterator

Definition: shared_load_iterator.h:61

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::SharedLoadIterator

cutlass::epilogue::threadblock::SharedLoadIterator< typename OutputTileThreadMap::CompactedThreadMap, ElementAccumulator > SharedLoadIterator

Definition: default_epilogue_wmma_tensor_op.h:119

reduction_op.h

Functor performing reduction operations used by epilogues.

cutlass.h

Basic include for CUTLASS.

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::WarpMmaTensorOp

WarpMmaTensorOp_ WarpMmaTensorOp

Definition: default_epilogue_wmma_tensor_op.h:74

cutlass::epilogue::threadblock::DefaultEpilogueWmmaTensorOp::WarpTileIterator

cutlass::epilogue::warp::TileIteratorWmmaTensorOp< typename WarpMmaTensorOp::Shape, typename WarpMmaTensorOp::Policy::Operator::Shape, typename WarpMmaTensorOp::Policy::Operator::FragmentC, LayoutC > WarpTileIterator

Definition: default_epilogue_wmma_tensor_op.h:114


Generated by 1.8.11