Back to Cutlass

CUTLASS: mma_pipelined.h Source File

docs/mma__pipelined_8h_source.html

4.4.226.9 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

mma_pipelined.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/array.h"

33 #include "cutlass/aligned_buffer.h"

34 #include "cutlass/numeric_conversion.h"

35

36 #include "cutlass/numeric_types.h"

37 #include "cutlass/matrix_shape.h"

38

39 #include "cutlass/gemm/gemm.h"

40 #include "cutlass/gemm/threadblock/mma_base.h"

41

43

44 namespace cutlass {

45 namespace gemm {

46 namespace threadblock {

47

49

51 template <

53typename Shape_,

55// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)

56typename IteratorA_,

59typename SmemIteratorA_,

61// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)

62typename IteratorB_,

65typename SmemIteratorB_,

67typename ElementC_,

69typename LayoutC_,

71typename Policy_,

73typename TransformA_ = NumericArrayConverter<

74typename SmemIteratorA_::Element,

75typename IteratorA_::Element,

76 IteratorA_::Fragment::kElements>,

79typename TransformB_ = NumericArrayConverter<

80typename SmemIteratorB_::Element,

81typename IteratorB_::Element,

82 IteratorB_::Fragment::kElements>,

84typename Enable = bool

85 >

86 class MmaPipelined : public MmaBase<Shape_, Policy_, 2> {

87 public:

88

90using Base = MmaBase<Shape_, Policy_, 2>;

91

92using Shape = Shape_;

93using IteratorA = IteratorA_;

94using IteratorB = IteratorB_;

95using ElementC = ElementC_;

96using LayoutC = LayoutC_;

97using Policy = Policy_;

98

99using SmemIteratorA = SmemIteratorA_;

100using SmemIteratorB = SmemIteratorB_;

101

102using TransformA = TransformA_;

103using TransformB = TransformB_;

104

105//

106// Dependent types

107//

108

110using FragmentA = typename IteratorA::Fragment;

111

113using FragmentB = typename IteratorB::Fragment;

114

116using FragmentC = typename Policy::Operator::FragmentC;

117

119using Operator = typename Policy::Operator;

120

121// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)

122static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");

123

124 private:

125

126using WarpFragmentA = typename Operator::FragmentA;

127using WarpFragmentB = typename Operator::FragmentB;

128

129 protected:

130

132SmemIteratorA smem_iterator_A_;

133

135SmemIteratorB smem_iterator_B_;

136

137 public:

138

140 CUTLASS_DEVICE

141MmaPipelined(

142typename Base::SharedStorage &shared_storage,

143int thread_idx,

144int warp_idx,

145int lane_idx

146 ):

147Base(shared_storage, thread_idx, warp_idx, lane_idx),

148 smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),

149 smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {

150

151// Compute warp location within threadblock tile by mapping the warp_id to

152// three coordinates:

153// _m: the warp's position within the threadblock along the M dimension

154// _n: the warp's position within the threadblock along the N dimension

155// _k: the warp's position within the threadblock along the K dimension

156

157int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);

158int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);

159

160int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;

161int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;

162

163// Add per-warp offsets in units of warp-level tiles

164 this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});

165 this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});

166 }

167

169 CUTLASS_DEVICE

170void operator()(

171int gemm_k_iterations,

172FragmentC &accum,

173IteratorA iterator_A,

174IteratorB iterator_B,

175FragmentC const &src_accum,

176TransformA transform_A = TransformA(),

177TransformB transform_B = TransformB()) {

178

179//

180// Prologue

181//

182

183// Perform accumulation in the 'd' output operand

184 accum = src_accum;

185

186FragmentA tb_frag_A;

187FragmentB tb_frag_B;

188

189 tb_frag_A.clear();

190 tb_frag_B.clear();

191

192// The last kblock is loaded in the prolog

193 iterator_A.load(tb_frag_A);

194 iterator_B.load(tb_frag_B);

195

196 ++iterator_A;

197 ++iterator_B;

198

199 this->smem_iterator_A_.store(transform_A(tb_frag_A));

200 this->smem_iterator_B_.store(transform_B(tb_frag_B));

201

202 ++this->smem_iterator_A_;

203 ++this->smem_iterator_B_;

204

205 __syncthreads();

206

207// Pair of fragments used to overlap shared memory loads and math instructions

208 WarpFragmentA warp_frag_A[2];

209 WarpFragmentB warp_frag_B[2];

210

211 this->warp_tile_iterator_A_.set_kgroup_index(0);

212 this->warp_tile_iterator_B_.set_kgroup_index(0);

213

214 this->warp_tile_iterator_A_.load(warp_frag_A[0]);

215 this->warp_tile_iterator_B_.load(warp_frag_B[0]);

216

217 ++this->warp_tile_iterator_A_;

218 ++this->warp_tile_iterator_B_;

219

220Operator warp_mma;

221

222int smem_write_stage_idx = 1;

223

224// Avoid reading out of bounds

225if (gemm_k_iterations <= 1) {

226 iterator_A.clear_mask();

227 iterator_B.clear_mask();

228 }

229

230// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing

231// shared memory loads (which have the tightest latency requirement).

232

233//

234// Mainloop

235//

236

237// Note: The main loop does not support Base::kWarpGemmIterations == 2.

238CUTLASS_GEMM_LOOP

239for (; gemm_k_iterations > 0; --gemm_k_iterations) {

240//

241// Loop over GEMM K dimension

242//

243

244CUTLASS_PRAGMA_UNROLL

245for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {

246

247// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group

248// as the case may be.

249

250if (warp_mma_k == Base::kWarpGemmIterations - 1) {

251

252// Write fragments to shared memory

253 this->smem_iterator_A_.store(transform_A(tb_frag_A));

254

255 this->smem_iterator_B_.store(transform_B(tb_frag_B));

256

257 __syncthreads();

258

259 ++this->smem_iterator_B_;

260 ++this->smem_iterator_A_;

261

262// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory

263if (smem_write_stage_idx == 1) {

264 this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});

265 this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});

266 }

267else {

268 this->warp_tile_iterator_A_.add_tile_offset(

269 {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});

270 this->warp_tile_iterator_B_.add_tile_offset(

271 {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,

272 0});

273 }

274

275 smem_write_stage_idx ^= 1;

276 }

277

278 this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);

279 this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);

280

281 this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);

282 this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);

283

284 ++this->warp_tile_iterator_A_;

285 ++this->warp_tile_iterator_B_;

286

287if (warp_mma_k == 0) {

288

289 iterator_A.load(tb_frag_A);

290 iterator_B.load(tb_frag_B);

291

292 ++iterator_A;

293 ++iterator_B;

294

295// Avoid reading out of bounds if this was the last loop iteration

296if (gemm_k_iterations <= 2) {

297 iterator_A.clear_mask();

298 iterator_B.clear_mask();

299 }

300 }

301

302 warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);

303 }

304 }

305

306 }

307 };

308

310

311 } // namespace threadblock

312 } // namespace gemm

313 } // namespace cutlass

cutlass::gemm::GemmShape::kM

static int const kM

Definition: include/cutlass/gemm/gemm.h:58

cutlass::gemm::threadblock::MmaPipelined::LayoutC

LayoutC_ LayoutC

Layout of accumulator matrix.

Definition: mma_pipelined.h:96

cutlass::gemm::threadblock::MmaPipelined::TransformB

TransformB_ TransformB

Definition: mma_pipelined.h:103

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::threadblock::MmaPipelined::Policy

Policy_ Policy

Policy describing tuning details.

Definition: mma_pipelined.h:97

cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 2 >::warp_tile_iterator_B_

Operator::IteratorB warp_tile_iterator_B_

Iterator to load a warp-scoped tile of B operand from shared memory.

Definition: mma_base.h:193

cutlass::gemm::threadblock::MmaPipelined

Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.

Definition: mma_pipelined.h:86

cutlass::gemm::threadblock::MmaPipelined::IteratorB

IteratorB_ IteratorB

Iterates over tiles of B operand in global memory.

Definition: mma_pipelined.h:94

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::threadblock::MmaPipelined::operator()

CUTLASS_DEVICE void operator()(int gemm_k_iterations, FragmentC &accum, IteratorA iterator_A, IteratorB iterator_B, FragmentC const &src_accum, TransformA transform_A=TransformA(), TransformB transform_B=TransformB())

Perform a threadblock-scoped matrix multiply-accumulate.

Definition: mma_pipelined.h:170

cutlass::gemm::threadblock::MmaPipelined::IteratorA

IteratorA_ IteratorA

Iterates over tiles of A operand in global memory.

Definition: mma_pipelined.h:93

cutlass::gemm::threadblock::MmaPipelined::FragmentB

typename IteratorB::Fragment FragmentB

Fragment of operand B loaded from global memory.

Definition: mma_pipelined.h:113

cutlass::gemm::threadblock::MmaPipelined::SmemIteratorA

SmemIteratorA_ SmemIteratorA

Definition: mma_pipelined.h:99

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

numeric_conversion.h

Boost-like numeric conversion operator for CUTLASS numeric types.

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 2 >::kWarpGemmIterations

static int const kWarpGemmIterations

Number of warp-level GEMM oeprations.

Definition: mma_base.h:108

mma_base.h

Template for a double-buffered threadblock-scoped GEMM kernel.

aligned_buffer.h

AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...

cutlass::gemm::threadblock::MmaPipelined::Shape

Shape_ Shape

Size of the Gemm problem - concept: gemm::GemmShape<>

Definition: mma_pipelined.h:92

cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 2 >::kStages

static int const kStages

Number of stages.

Definition: mma_base.h:112

cutlass::gemm::threadblock::MmaPipelined::FragmentA

typename IteratorA::Fragment FragmentA

Fragment of operand A loaded from global memory.

Definition: mma_pipelined.h:110

numeric_types.h

Top-level include for all CUTLASS numeric types.

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::gemm::threadblock::MmaBase

Definition: mma_base.h:83

cutlass::gemm::threadblock::MmaPipelined::FragmentC

typename Policy::Operator::FragmentC FragmentC

Fragment of accumulator tile.

Definition: mma_pipelined.h:116

cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 2 >::warp_tile_iterator_A_

Operator::IteratorA warp_tile_iterator_A_

Iterator to load a warp-scoped tile of A operand from shared memory.

Definition: mma_base.h:190

CUTLASS_GEMM_LOOP

#define CUTLASS_GEMM_LOOP

Definition: cutlass.h:112

cutlass::gemm::threadblock::MmaPipelined::ElementC

ElementC_ ElementC

Data type of accumulator matrix.

Definition: mma_pipelined.h:95

cutlass::gemm::threadblock::MmaPipelined::smem_iterator_A_

SmemIteratorA smem_iterator_A_

Iterator to write threadblock-scoped tile of A operand to shared memory.

Definition: mma_pipelined.h:132

cutlass::gemm::threadblock::MmaPipelined::SmemIteratorB

SmemIteratorB_ SmemIteratorB

Definition: mma_pipelined.h:100

cutlass::gemm::threadblock::MmaPipelined::smem_iterator_B_

SmemIteratorB smem_iterator_B_

Iterator to write threadblock-scoped tile of B operand to shared memory.

Definition: mma_pipelined.h:135

cutlass::gemm::threadblock::MmaPipelined::MmaPipelined

CUTLASS_DEVICE MmaPipelined(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)

Construct from tensor references.

Definition: mma_pipelined.h:141

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::threadblock::MmaPipelined::TransformA

TransformA_ TransformA

Definition: mma_pipelined.h:102

cutlass::gemm::threadblock::MmaPipelined::Operator

typename Policy::Operator Operator

Warp-level Mma.

Definition: mma_pipelined.h:119

cutlass::gemm::GemmShape::kN

static int const kN

Definition: include/cutlass/gemm/gemm.h:59


Generated by 1.8.11