Back to Cutlass

CUTLASS: mma_simt.h Source File

docs/mma__simt_8h_source.html

4.4.221.0 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

mma_simt.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/array.h"

33 #include "cutlass/numeric_types.h"

34 #include "cutlass/matrix_shape.h"

35 #include "cutlass/gemm/gemm.h"

36 #include "cutlass/gemm/warp/mma.h"

37

38 #include "cutlass/gemm/thread/mma.h"

39

40 #include "[cutlass/gemm/warp/mma_simt_tile_iterator.h](mma simt tile__iterator_8h.html)"

41 #include "[cutlass/gemm/warp/mma_simt_policy.h](mma simt policy_8h.html)"

42

44

45 namespace cutlass {

46 namespace gemm {

47 namespace warp {

48

50

52 template <

54typename Shape_,

56typename ElementA_,

58typename LayoutA_,

60typename ElementB_,

62typename LayoutB_,

64typename ElementC_,

66typename LayoutC_,

68typename Policy_,

70int PartitionsK = 1,

72typename Enable = bool

73 >

74 class MmaSimt {

75 public:

77using Shape = Shape_;

78

80using ElementA = ElementA_;

81

83using LayoutA = LayoutA_;

84

86using ElementB = ElementB_;

87

89using LayoutB = LayoutB_;

90

92using ElementC = ElementC_;

93

95using LayoutC = LayoutC_;

96

98using Policy = Policy_;

99

101using OperatorClass = arch::OpClassSimt;

102

103using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value,

104layout::ColumnMajor,

105typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value,

106layout::RowMajor,

107LayoutA>::type

108 >::type;

109

110using ThreadLayoutB = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutB >::value,

111 layout::ColumnMajor,

112typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutB >::value,

113 layout::RowMajor,

114LayoutB>::type

115 >::type;

116

117static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value ||

118platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) &&

119platform::is_same< ElementA, int8_t >::value &&

120platform::is_same< ElementB, int8_t >::value;

121

122using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type;

123

125using ThreadMma = thread::Mma<

126GemmShape<

127 Shape::kM / Policy::WarpShape::kRow,

128 Shape::kN / Policy::WarpShape::kColumn,

129 Policy::LaneMmaShape::kK>,

130ElementA,

131ThreadLayoutA,

132ElementB,

133ThreadLayoutB,

134ElementC,

135LayoutC,

136 arch::OpMultiplyAdd,

137dp4a_type

138 >;

139

140 public:

141

143using IteratorA = MmaSimtTileIterator<

144MatrixShape<Shape::kM, Policy::LaneMmaShape::kK>,

145Operand::kA,

146ElementA,

147LayoutA,

148Policy,

149 PartitionsK,

150 Shape::kK

151 >;

152

154using FragmentA = typename IteratorA::Fragment;

155

157using IteratorB = MmaSimtTileIterator<

158MatrixShape<Policy::LaneMmaShape::kK, Shape::kN>,

159Operand::kB,

160ElementB,

161LayoutB,

162Policy,

163 PartitionsK,

164 Shape::kK

165 >;

166

168using FragmentB = typename IteratorB::Fragment;

169

171using IteratorC = MmaSimtTileIterator<

172MatrixShape<Shape::kM, Shape::kN>,

173Operand::kC,

174ElementC,

175LayoutC,

176 Policy

177 >;

178

180using FragmentC = typename ThreadMma::FragmentC;

181

182 public:

183

184//

185// Methods

186//

187

189 CUTLASS_DEVICE

190MmaSimt() {}

191

193 CUTLASS_DEVICE

194void operator()(

195FragmentC &d,

196FragmentA const &a,

197FragmentB const &b,

198FragmentC const &c, int group_idx = 0) const {

199

200ThreadMma mma;

201

202 mma(d, a, b, c);

203 }

204 };

205

207

208 } // namespace warp

209 } // namespace gemm

210 } // namespace cutlass

[mma_simt_policy.h](mma simt policy_8h.html)

Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass::gemm::warp::MmaSimt::ElementC

ElementC_ ElementC

Data type of accumulator matrix C.

Definition: mma_simt.h:92

cutlass

Definition: aligned_buffer.h:35

constexpr

#define constexpr

Definition: platform.h:137

[mma_simt_tile_iterator.h](mma simt tile__iterator_8h.html)

Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...

cutlass::platform::conditional::type

T type

Definition: platform.h:326

cutlass::platform::is_same

std::is_same (false specialization)

Definition: platform.h:394

cutlass::gemm::warp::MmaSimt::FragmentC

typename ThreadMma::FragmentC FragmentC

Storage for C tile.

Definition: mma_simt.h:180

cutlass::gemm::warp::MmaSimt::Shape

Shape_ Shape

Shape of warp-level matrix operation (concept: GemmShape)

Definition: mma_simt.h:77

cutlass::gemm::warp::MmaSimt

Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.

Definition: mma_simt.h:74

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::warp::MmaSimt::operator()

CUTLASS_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c, int group_idx=0) const

Performs a warp-level matrix multiply-accumulate operation.

Definition: mma_simt.h:194

cutlass::gemm::warp::MmaSimt::use_dp4a

static constexpr bool use_dp4a

Definition: mma_simt.h:117

cutlass::gemm::warp::MmaSimt::LayoutC

LayoutC_ LayoutC

Layout of accumulator matrix C.

Definition: mma_simt.h:95

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

cutlass::gemm::Operand::kC

B multiplicand.

mma.h

Templates exposing architecture support for warp-level multiply-add operations.

cutlass::gemm::warp::MmaSimtTileIterator

Definition: mma_simt_tile_iterator.h:69

cutlass::gemm::Operand::kA

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::warp::MmaSimt::OperatorClass

arch::OpClassSimt OperatorClass

Indicates class of matrix operator.

Definition: mma_simt.h:101

cutlass::gemm::warp::MmaSimt::ThreadLayoutB

typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved< 4 >, LayoutB >::value, layout::ColumnMajor, typename platform::conditional< platform::is_same< layout::RowMajorInterleaved< 4 >, LayoutB >::value, layout::RowMajor, LayoutB >::type >::type ThreadLayoutB

Definition: mma_simt.h:115

cutlass::gemm::warp::MmaSimt::LayoutA

LayoutA_ LayoutA

Layout of multiplicand A.

Definition: mma_simt.h:83

mma.h

Templates exposing architecture support for warp-level multiply-add operations.

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::platform::conditional

std::conditional (true specialization)

Definition: platform.h:325

cutlass::gemm::warp::MmaSimt::Policy

Policy_ Policy

Shape of the warp in units of thread (concept: MmaLanePolicySimt)

Definition: mma_simt.h:98

cutlass::gemm::warp::MmaSimt::FragmentA

typename IteratorA::Fragment FragmentA

Storage for A tile.

Definition: mma_simt.h:154

cutlass::gemm::warp::MmaSimt::dp4a_type

typename platform::conditional< use_dp4a, int8_t, bool >::type dp4a_type

Definition: mma_simt.h:122

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::gemm::thread::Mma

Structure to compute the matrix product.

Definition: gemm/thread/mma.h:66

cutlass::gemm::warp::MmaSimt::ElementA

ElementA_ ElementA

Data type of multiplicand A.

Definition: mma_simt.h:80

cutlass::gemm::Operand::kB

A multiplicand.

cutlass::gemm::warp::MmaSimt::ThreadLayoutA

typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved< 4 >, LayoutA >::value, layout::ColumnMajor, typename platform::conditional< platform::is_same< layout::RowMajorInterleaved< 4 >, LayoutA >::value, layout::RowMajor, LayoutA >::type >::type ThreadLayoutA

Definition: mma_simt.h:108

cutlass::gemm::warp::MmaSimt::ElementB

ElementB_ ElementB

Data type of multiplicand B.

Definition: mma_simt.h:86

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::warp::MmaSimt::LayoutB

LayoutB_ LayoutB

Layout of multiplicand B.

Definition: mma_simt.h:89

cutlass::gemm::warp::MmaSimt::MmaSimt

CUTLASS_DEVICE MmaSimt()

Ctor.

Definition: mma_simt.h:190

cutlass::gemm::warp::MmaSimt::FragmentB

typename IteratorB::Fragment FragmentB

Storage for B tile.

Definition: mma_simt.h:168


Generated by 1.8.11