Back to Cutlass

CUTLASS: gemm.h Source File

docs/include_2cutlass_2gemm_2kernel_2gemm_8h_source.html

4.4.235.6 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

include/cutlass/gemm/kernel/gemm.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

25

30 #pragma once

31

32 #include "cutlass/cutlass.h"

33

34 #include "cutlass/gemm/gemm.h"

35 #include "cutlass/matrix_coord.h"

36 #include "cutlass/semaphore.h"

37

39

40 namespace cutlass {

41 namespace gemm {

42 namespace kernel {

43

45

46 template <

47typename Mma_,

48typename Epilogue_,

49typename ThreadblockSwizzle_,

50bool SplitKSerial

51 >

52 struct Gemm {

53

54using Mma = Mma_;

55using Epilogue = Epilogue_;

56using OutputOp = typename Epilogue::OutputOp;

57using ThreadblockSwizzle = ThreadblockSwizzle_;

58static bool const kSplitKSerial = SplitKSerial;

59

61using WarpCount = typename Mma::WarpCount;

62static int const kThreadCount = 32 * WarpCount::kCount;

63

65struct Params {

66cutlass::gemm::GemmCoord problem_size;

67cutlass::gemm::GemmCoord grid_tiled_shape;

68typename Mma::IteratorA::Params params_A;

69typename Mma::IteratorA::TensorRef ref_A;

70typename Mma::IteratorB::Params params_B;

71typename Mma::IteratorB::TensorRef ref_B;

72typename Epilogue::OutputTileIterator::Params params_C;

73typename Epilogue::OutputTileIterator::TensorRef ref_C;

74typename Epilogue::OutputTileIterator::Params params_D;

75typename Epilogue::OutputTileIterator::TensorRef ref_D;

76typename OutputOp::Params output_op;

77int *semaphore;

78int gemm_k_iterations;

79int gemm_k_size;

80

81//

82// Methods

83//

84

85CUTLASS_HOST_DEVICE

86Params() { }

87

88CUTLASS_HOST_DEVICE

89Params(

90cutlass::gemm::GemmCoord const & problem_size,

91cutlass::gemm::GemmCoord const & grid_tiled_shape,

92typename Mma::IteratorA::TensorRef ref_A,

93typename Mma::IteratorB::TensorRef ref_B,

94typename Epilogue::OutputTileIterator::TensorRef ref_C,

95typename Epilogue::OutputTileIterator::TensorRef ref_D,

96typename OutputOp::Params output_op = typename OutputOp::Params(),

97int *semaphore = nullptr

98 ):

99 problem_size(problem_size),

100 grid_tiled_shape(grid_tiled_shape),

101 params_A(ref_A.layout()),

102 ref_A(ref_A),

103 params_B(ref_B.layout()),

104 ref_B(ref_B),

105 params_C(ref_C.layout()),

106 ref_C(ref_C),

107 params_D(ref_D.layout()),

108 ref_D(ref_D),

109 output_op(output_op),

110 semaphore(semaphore) {

111

112int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;

113int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k();

114

115 gemm_k_size = gemm_k_iterations * Mma::Shape::kK;

116 }

117 };

118

120union SharedStorage {

121typename Mma::SharedStorage main_loop;

122typename Epilogue::SharedStorage epilogue;

123 };

124

125//

126// Methods

127//

128

129CUTLASS_HOST_DEVICE

130Gemm() { }

131

133static Status can_implement(

134cutlass::gemm::GemmCoord const & problem_size,

135typename Mma::IteratorA::TensorRef ref_A,

136typename Mma::IteratorB::TensorRef ref_B,

137typename Epilogue::OutputTileIterator::TensorRef ref_C,

138typename Epilogue::OutputTileIterator::TensorRef ref_D) {

139

140static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;

141static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;

142static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;

143

144if (! TensorRef_aligned(ref_A, kAlignmentA)) {

145return Status::kErrorMisalignedOperand;

146 }

147

148if (! TensorRef_aligned(ref_B, kAlignmentB)) {

149return Status::kErrorMisalignedOperand;

150 }

151

152if (! TensorRef_aligned(ref_C, kAlignmentC)) {

153return Status::kErrorMisalignedOperand;

154 }

155

156if (! TensorRef_aligned(ref_D, kAlignmentC)) {

157return Status::kErrorMisalignedOperand;

158 }

159

160if ((problem_size.m() % kAlignmentA) || (problem_size.k() % kAlignmentA) ||

161 (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) ||

162 (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) {

163

164return Status::kErrorMisalignedOperand;

165 }

166

167return Status::kSuccess;

168 }

169

171 CUTLASS_DEVICE

172void operator()(Params const &params, SharedStorage &shared_storage) {

173

174// Compute threadblock location

175ThreadblockSwizzle threadblock_swizzle;

176

177cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();

178

179// Early exit if CTA is out of range

180if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||

181 params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {

182

183return;

184 }

185

186// Compute initial location in logical coordinates

187cutlass::MatrixCoord tb_offset_A{

188 threadblock_tile_offset.m() * Mma::Shape::kM,

189 threadblock_tile_offset.k() * params.gemm_k_size,

190 };

191

192cutlass::MatrixCoord tb_offset_B{

193 threadblock_tile_offset.k() * params.gemm_k_size,

194 threadblock_tile_offset.n() * Mma::Shape::kN

195 };

196

197// Problem size is a function of threadblock index in the K dimension

198int problem_size_k = min(

199 params.problem_size.k(),

200 (threadblock_tile_offset.k() + 1) * params.gemm_k_size);

201

202// Compute threadblock-scoped matrix multiply-add

203int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;

204

205// Compute position within threadblock

206int thread_idx = threadIdx.x;

207

208// Construct iterators to A and B operands

209typename Mma::IteratorA iterator_A(

210 params.params_A,

211 params.ref_A.data(),

212 {params.problem_size.m(), problem_size_k},

213 thread_idx,

214 tb_offset_A);

215

216typename Mma::IteratorB iterator_B(

217 params.params_B,

218 params.ref_B.data(),

219 {problem_size_k, params.problem_size.n()},

220 thread_idx,

221 tb_offset_B);

222

223int warp_idx = threadIdx.x / 32;

224int lane_idx = threadIdx.x % 32;

225

226//

227// Main loop

228//

229

230// Construct thread-scoped matrix multiply

231Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

232

233typename Mma::FragmentC accumulators;

234

235 accumulators.clear();

236

237if (!kSplitKSerial || gemm_k_iterations > 0) {

238// Compute threadblock-scoped matrix multiply-add

239 mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);

240 }

241

242//

243// Epilogue

244//

245

246OutputOp output_op(params.output_op);

247

248//

249// Masked tile iterators constructed from members

250//

251

252 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();

253

254//assume identity swizzle

255MatrixCoord threadblock_offset(

256 threadblock_tile_offset.m() * Mma::Shape::kM,

257 threadblock_tile_offset.n() * Mma::Shape::kN

258 );

259

260int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();

261

262// Construct the semaphore.

263Semaphore semaphore(params.semaphore + block_idx, thread_idx);

264

265// If performing a reduction via split-K, fetch the initial synchronization

266if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {

267

268// Fetch the synchronization lock initially but do not block.

269 semaphore.fetch();

270

271// Indicate which position in a serial reduction the output operator is currently updating

272 output_op.set_k_partition(threadblock_tile_offset.k());

273 }

274

275// Tile iterator loading from source tensor.

276typename Epilogue::OutputTileIterator iterator_C(

277 params.params_C,

278 params.ref_C.data(),

279 params.problem_size.mn(),

280 thread_idx,

281 threadblock_offset

282 );

283

284// Tile iterator writing to destination tensor.

285typename Epilogue::OutputTileIterator iterator_D(

286 params.params_D,

287 params.ref_D.data(),

288 params.problem_size.mn(),

289 thread_idx,

290 threadblock_offset

291 );

292

293Epilogue epilogue(

294 shared_storage.epilogue,

295 thread_idx,

296 warp_idx,

297 lane_idx);

298

299// Wait on the semaphore - this latency may have been covered by iterator construction

300if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {

301

302// For subsequent threadblocks, the source matrix is held in the 'D' tensor.

303if (threadblock_tile_offset.k()) {

304 iterator_C = iterator_D;

305 }

306

307 semaphore.wait(threadblock_tile_offset.k());

308

309 __threadfence();

310 }

311

312// Execute the epilogue operator to update the destination tensor.

313 epilogue(output_op, iterator_D, accumulators, iterator_C);

314

315//

316// Release the semaphore

317//

318

319if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {

320

321int lock = 0;

322if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {

323

324// The final threadblock resets the semaphore for subsequent grids.

325 lock = 0;

326 }

327else {

328// Otherwise, the semaphore is incremented

329 lock = threadblock_tile_offset.k() + 1;

330 }

331

332 __threadfence();

333 semaphore.release(lock);

334 }

335 }

336 };

337

339

340 } // namespace kernel

341 } // namespace gemm

342 } // namespace cutlass

343

cutlass::gemm::kernel::Gemm::Params::ref_C

Epilogue::OutputTileIterator::TensorRef ref_C

Definition: include/cutlass/gemm/kernel/gemm.h:73

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::kernel::Gemm::SharedStorage::epilogue

Epilogue::SharedStorage epilogue

Definition: include/cutlass/gemm/kernel/gemm.h:122

cutlass::gemm::kernel::Gemm::Params::params_D

Epilogue::OutputTileIterator::Params params_D

Definition: include/cutlass/gemm/kernel/gemm.h:74

cutlass::gemm::kernel::Gemm::Params::params_A

Mma::IteratorA::Params params_A

Definition: include/cutlass/gemm/kernel/gemm.h:68

cutlass::gemm::kernel::Gemm::Epilogue

Epilogue_ Epilogue

Definition: include/cutlass/gemm/kernel/gemm.h:55

cutlass::gemm::kernel::Gemm::Params::params_B

Mma::IteratorB::Params params_B

Definition: include/cutlass/gemm/kernel/gemm.h:70

cutlass::gemm::kernel::Gemm::Params::Params

CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op=typename OutputOp::Params(), int *semaphore=nullptr)

Definition: include/cutlass/gemm/kernel/gemm.h:89

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::gemm::GemmCoord::mn

CUTLASS_HOST_DEVICE Coord< 2 > mn() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:171

cutlass::gemm::kernel::Gemm::Params::params_C

Epilogue::OutputTileIterator::Params params_C

Definition: include/cutlass/gemm/kernel/gemm.h:72

cutlass::gemm::kernel::Gemm::kThreadCount

static int const kThreadCount

Definition: include/cutlass/gemm/kernel/gemm.h:62

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::Semaphore::fetch

CUTLASS_DEVICE void fetch()

Permit fetching the synchronization mechanism early.

Definition: semaphore.h:68

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::gemm::kernel::Gemm::Params::grid_tiled_shape

cutlass::gemm::GemmCoord grid_tiled_shape

Definition: include/cutlass/gemm/kernel/gemm.h:67

cutlass::gemm::kernel::Gemm::Params::gemm_k_iterations

int gemm_k_iterations

Definition: include/cutlass/gemm/kernel/gemm.h:78

cutlass::gemm::kernel::Gemm::Params::ref_B

Mma::IteratorB::TensorRef ref_B

Definition: include/cutlass/gemm/kernel/gemm.h:71

cutlass::gemm::kernel::Gemm::can_implement

static Status can_implement(cutlass::gemm::GemmCoord const &problem_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D)

Determines whether kernel satisfies alignment.

Definition: include/cutlass/gemm/kernel/gemm.h:133

cutlass::gemm::kernel::Gemm::Gemm

CUTLASS_HOST_DEVICE Gemm()

Definition: include/cutlass/gemm/kernel/gemm.h:130

cutlass::gemm::GemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:145

cutlass::gemm::kernel::Gemm::kSplitKSerial

static bool const kSplitKSerial

Definition: include/cutlass/gemm/kernel/gemm.h:58

cutlass::gemm::kernel::Gemm::OutputOp

typename Epilogue::OutputOp OutputOp

Definition: include/cutlass/gemm/kernel/gemm.h:56

cutlass::gemm::kernel::Gemm::Params

Parameters structure.

Definition: include/cutlass/gemm/kernel/gemm.h:65

cutlass::gemm::kernel::Gemm::Params::output_op

OutputOp::Params output_op

Definition: include/cutlass/gemm/kernel/gemm.h:76

cutlass::Status::kErrorMisalignedOperand

operands fail alignment requirements.

cutlass::gemm::kernel::Gemm::SharedStorage

Shared memory storage structure.

Definition: include/cutlass/gemm/kernel/gemm.h:120

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::platform::min

CUTLASS_HOST_DEVICE constexpr const T & min(const T &a, const T &b)

std::min

Definition: platform.h:183

cutlass::gemm::kernel::Gemm::Params::gemm_k_size

int gemm_k_size

Definition: include/cutlass/gemm/kernel/gemm.h:79

cutlass::gemm::kernel::Gemm::Params::semaphore

int * semaphore

Definition: include/cutlass/gemm/kernel/gemm.h:77

cutlass::gemm::kernel::Gemm::operator()

CUTLASS_DEVICE void operator()(Params const &params, SharedStorage &shared_storage)

Executes one GEMM.

Definition: include/cutlass/gemm/kernel/gemm.h:172

cutlass::Semaphore

CTA-wide semaphore for inter-CTA synchronization.

Definition: semaphore.h:48

semaphore.h

Implementation of a CTA-wide semaphore for inter-CTA synchronization.

matrix_coord.h

Defines a canonical coordinate for rank=2 matrices offering named indices.

cutlass::Semaphore::release

CUTLASS_DEVICE void release(int status=0)

Updates the lock with the given result.

Definition: semaphore.h:98

cutlass::gemm::kernel::Gemm::Params::problem_size

cutlass::gemm::GemmCoord problem_size

Definition: include/cutlass/gemm/kernel/gemm.h:66

cutlass::gemm::kernel::Gemm::ThreadblockSwizzle

ThreadblockSwizzle_ ThreadblockSwizzle

Definition: include/cutlass/gemm/kernel/gemm.h:57

cutlass::gemm::kernel::Gemm

Definition: include/cutlass/gemm/kernel/gemm.h:52

cutlass::gemm::kernel::Gemm::Params::ref_A

Mma::IteratorA::TensorRef ref_A

Definition: include/cutlass/gemm/kernel/gemm.h:69

cutlass::TensorRef_aligned

bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)

Definition: tensor_ref.h:382

cutlass::Semaphore::wait

CUTLASS_DEVICE void wait(int status=0)

Waits until the semaphore is equal to the given value.

Definition: semaphore.h:81

cutlass::Status::kSuccess

Operation was successful.

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

cutlass::gemm::kernel::Gemm::Mma

Mma_ Mma

Definition: include/cutlass/gemm/kernel/gemm.h:54

cutlass::gemm::kernel::Gemm::WarpCount

typename Mma::WarpCount WarpCount

Warp count (concept: GemmShape)

Definition: include/cutlass/gemm/kernel/gemm.h:61

cutlass.h

Basic include for CUTLASS.

cutlass::MatrixCoord

Definition: matrix_coord.h:39

cutlass::gemm::kernel::Gemm::Params::Params

CUTLASS_HOST_DEVICE Params()

Definition: include/cutlass/gemm/kernel/gemm.h:86

cutlass::Status

Status

Status code returned by CUTLASS operations.

Definition: cutlass.h:39

cutlass::gemm::kernel::Gemm::SharedStorage::main_loop

Mma::SharedStorage main_loop

Definition: include/cutlass/gemm/kernel/gemm.h:121

cutlass::gemm::kernel::Gemm::Params::ref_D

Epilogue::OutputTileIterator::TensorRef ref_D

Definition: include/cutlass/gemm/kernel/gemm.h:75


Generated by 1.8.11