Back to Cutlass

CUTLASS: mma_base.h Source File

docs/mma__base_8h_source.html

4.4.217.9 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

mma_base.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/aligned_buffer.h"

32 #include "cutlass/arch/memory.h"

33 #include "cutlass/array.h"

34 #include "cutlass/cutlass.h"

35 #include "cutlass/gemm/gemm.h"

36 #include "cutlass/matrix_shape.h"

37 #include "cutlass/numeric_types.h"

39

40 namespace cutlass {

41 namespace gemm {

42 namespace threadblock {

43

45

47 template <

49typename Operator_,

51typename SmemPaddingA_,

53typename SmemPaddingB_,

55int PartitionsK = 1>

56 struct MmaPolicy {

58using Operator = Operator_;

59

61using SmemPaddingA = SmemPaddingA_;

62

64using SmemPaddingB = SmemPaddingB_;

65

67static int const kPartitionsK = PartitionsK;

68 };

69

71

74 template <

76typename Shape_,

78typename Policy_,

80int Stages,

82typename Enable = bool>

83 class MmaBase {

84public:

86using Shape = Shape_;

87

89using Policy = Policy_;

90

91//

92// Dependent types

93//

94

96using Operator = typename Policy::Operator;

97

100using WarpGemm = typename Policy::Operator::Shape;

101

103using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,

104 Shape::kN / WarpGemm::kN,

105 Shape::kK / WarpGemm::kK>;

106

108static int const kWarpGemmIterations =

109 (WarpGemm::kK / Operator::Policy::MmaShape::kK);

110

112static int const kStages = Stages;

113

115using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;

116

118using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;

119

120//

121// Nested structs

122//

123

125class SharedStorage {

126public:

127//

128// Type definitions

129//

130

132using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,

133 Shape::kK * kStages +

134 Policy::SmemPaddingA::kColumn>;

135

137using ShapeB =

138MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,

139 Shape::kN + Policy::SmemPaddingB::kColumn>;

140

141public:

142//

143// Data members

144//

145

147AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;

148

150AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;

151

152public:

153

154//

155// Methods

156//

157

159 CUTLASS_DEVICE

160static typename Operator::LayoutA LayoutA() {

161return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});

162 }

163

165CUTLASS_HOST_DEVICE

166static typename Operator::LayoutB LayoutB() {

167return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});

168 }

169

171CUTLASS_HOST_DEVICE

172TensorRefA operand_A_ref() {

173return TensorRefA{operand_A.data(), LayoutA()};

174 }

175

177CUTLASS_HOST_DEVICE

178TensorRefB operand_B_ref() {

179return TensorRefB{operand_B.data(), LayoutB()};

180 }

181 };

182

183protected:

184

185//

186// Data members

187//

188

190typename Operator::IteratorA warp_tile_iterator_A_;

191

193typename Operator::IteratorB warp_tile_iterator_B_;

194

195 public:

196

198 CUTLASS_DEVICE

199MmaBase(

201 SharedStorage &shared_storage,

203int thread_idx,

205int warp_idx,

207int lane_idx

208 ):

209 warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),

210 warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {

211

212 }

213 };

214

216

217 } // namespace threadblock

218 } // namespace gemm

219 } // namespace cutlass

220

cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 1 >::Policy

Policy_ Policy

Definition: mma_base.h:89

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass

Definition: aligned_buffer.h:35

memory.h

Architecture-specific operators on memory.

cutlass::gemm::threadblock::MmaBase::SharedStorage::operand_B

AlignedBuffer< typename Operator::ElementB, ShapeB::kCount > operand_B

Buffer for B operand.

Definition: mma_base.h:150

cutlass::gemm::threadblock::MmaBase::warp_tile_iterator_B_

Operator::IteratorB warp_tile_iterator_B_

Iterator to load a warp-scoped tile of B operand from shared memory.

Definition: mma_base.h:193

cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 1 >::WarpGemm

typename Policy::Operator::Shape WarpGemm

Definition: mma_base.h:100

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::threadblock::MmaBase::SharedStorage

Shared storage object needed by threadblock-scoped GEMM.

Definition: mma_base.h:125

cutlass::gemm::threadblock::MmaBase::Shape

Shape_ Shape

Policy describing tuning details.

Definition: mma_base.h:88

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

cutlass::gemm::threadblock::MmaPolicy::Operator

Operator_ Operator

Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) ...

Definition: mma_base.h:58

cutlass::gemm::threadblock::MmaPolicy::SmemPaddingA

SmemPaddingA_ SmemPaddingA

Padding used for A operand in shared memory.

Definition: mma_base.h:61

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::threadblock::MmaBase::SharedStorage::LayoutB

static CUTLASS_HOST_DEVICE Operator::LayoutB LayoutB()

Returns a layout object for the B matrix.

Definition: mma_base.h:166

cutlass::gemm::threadblock::MmaPolicy

Policy object describing MmaTensorOp.

Definition: mma_base.h:56

cutlass::TensorRef

Definition: tensor_ref.h:146

aligned_buffer.h

AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::AlignedBuffer

Modifies semantics of cutlass::Array<> to provide guaranteed alignment.

Definition: aligned_buffer.h:45

cutlass::gemm::threadblock::MmaBase::SharedStorage::operand_A_ref

CUTLASS_HOST_DEVICE TensorRefA operand_A_ref()

Returns a TensorRef to the A operand.

Definition: mma_base.h:172

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::AlignedBuffer::data

CUTLASS_HOST_DEVICE pointer data()

Definition: aligned_buffer.h:84

cutlass::gemm::threadblock::MmaBase::MmaBase

CUTLASS_DEVICE MmaBase(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)

Construct from tensor references.

Definition: mma_base.h:199

cutlass::gemm::threadblock::MmaBase

Definition: mma_base.h:83

cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 1 >::Operator

typename Policy::Operator Operator

Warp-level Mma.

Definition: mma_base.h:96

cutlass::gemm::threadblock::MmaBase::warp_tile_iterator_A_

Operator::IteratorA warp_tile_iterator_A_

Iterator to load a warp-scoped tile of A operand from shared memory.

Definition: mma_base.h:190

cutlass::gemm::threadblock::MmaBase::SharedStorage::operand_A

AlignedBuffer< typename Operator::ElementA, ShapeA::kCount > operand_A

Buffer for A operand.

Definition: mma_base.h:147

cutlass::gemm::threadblock::MmaBase::SharedStorage::LayoutA

static CUTLASS_DEVICE Operator::LayoutA LayoutA()

Returns a layout object for the A matrix.

Definition: mma_base.h:160

cutlass::gemm::threadblock::MmaPolicy::SmemPaddingB

SmemPaddingB_ SmemPaddingB

Padding used for B operand in shared memory.

Definition: mma_base.h:64

cutlass::gemm::threadblock::MmaPolicy::kPartitionsK

static int const kPartitionsK

Number of partitions of K dimension.

Definition: mma_base.h:67

cutlass::gemm::threadblock::MmaBase::SharedStorage::operand_B_ref

CUTLASS_HOST_DEVICE TensorRefB operand_B_ref()

Returns a TensorRef to the B operand.

Definition: mma_base.h:178

cutlass.h

Basic include for CUTLASS.


Generated by 1.8.11