Back to Cutlass

CUTLASS: gemm_batched.h Source File

docs/kernel_2gemm__batched_8h_source.html

4.4.228.7 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

kernel/gemm_batched.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32

33 #include "cutlass/gemm/gemm.h"

34 #include "cutlass/matrix_coord.h"

35

37

38 namespace cutlass {

39 namespace gemm {

40 namespace kernel {

41

43

44 template <

45typename Mma_,

46typename Epilogue_,

47typename ThreadblockSwizzle_

48 >

49 struct GemmBatched {

50

51using Mma = Mma_;

52using Epilogue = Epilogue_;

53using OutputOp = typename Epilogue::OutputOp;

54using ThreadblockSwizzle = ThreadblockSwizzle_;

55

57using WarpCount = typename Mma::WarpCount;

58static int const kThreadCount = 32 * WarpCount::kCount;

59

61struct Params {

62cutlass::gemm::GemmCoord problem_size;

63cutlass::gemm::GemmCoord grid_tiled_shape;

64typename Mma::IteratorA::Params params_A;

65typename Mma::IteratorA::TensorRef ref_A;

66 int64_t stride_A;

67typename Mma::IteratorB::Params params_B;

68typename Mma::IteratorB::TensorRef ref_B;

69 int64_t stride_B;

70typename Epilogue::OutputTileIterator::Params params_C;

71typename Epilogue::OutputTileIterator::TensorRef ref_C;

72 int64_t stride_C;

73typename Epilogue::OutputTileIterator::Params params_D;

74typename Epilogue::OutputTileIterator::TensorRef ref_D;

75 int64_t stride_D;

76typename OutputOp::Params epilogue;

77int batch_count;

78int gemm_k_iterations;

79

80//

81// Methods

82//

83

84CUTLASS_HOST_DEVICE

85Params() { }

86

87CUTLASS_HOST_DEVICE

88Params(

89cutlass::gemm::GemmCoord const & problem_size_,

90cutlass::gemm::GemmCoord const & grid_tiled_shape_,

91typename Mma::IteratorA::TensorRef ref_A_,

92 int64_t stride_A_,

93typename Mma::IteratorB::TensorRef ref_B_,

94 int64_t stride_B_,

95typename Epilogue::OutputTileIterator::TensorRef ref_C_,

96 int64_t stride_C_,

97typename Epilogue::OutputTileIterator::TensorRef ref_D_,

98 int64_t stride_D_,

99typename OutputOp::Params epilogue_,

100int batch_count_

101 ):

102 problem_size(problem_size_),

103 grid_tiled_shape(grid_tiled_shape_),

104 params_A(ref_A_.layout()),

105 ref_A(ref_A_),

106 stride_A(stride_A_),

107 params_B(ref_B_.layout()),

108 ref_B(ref_B_),

109 stride_B(stride_B_),

110 params_C(ref_C_.layout()),

111 ref_C(ref_C_),

112 stride_C(stride_C_),

113 params_D(ref_D_.layout()),

114 ref_D(ref_D_),

115 stride_D(stride_D_),

116 epilogue(epilogue_),

117 batch_count(batch_count_),

118 gemm_k_iterations((problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK) {

119

120 }

121 };

122

124union SharedStorage {

125typename Mma::SharedStorage main_loop;

126typename Epilogue::SharedStorage epilogue;

127 };

128

129//

130// Methods

131//

132

133CUTLASS_HOST_DEVICE

134GemmBatched() { }

135

137 CUTLASS_DEVICE

138void operator()(Params const &params, SharedStorage &shared_storage) {

139

140// Compute threadblock location

141ThreadblockSwizzle threadblock_swizzle;

142

143cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();

144

145// Early exit if CTA is out of range

146if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||

147 params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {

148

149return;

150 }

151

152

153// Each CTA handles multiple batch indices to accommodate limited range of CUDA grid's Z dimension

154for (int batch_idx = threadblock_swizzle.get_batch_idx();

155 batch_idx < params.batch_count;

156 batch_idx += gridDim.z) {

157

158// Compute initial location in logical coordinates

159cutlass::MatrixCoord tb_offset_A{

160 threadblock_tile_offset.m() * Mma::Shape::kM,

161 0

162 };

163

164cutlass::MatrixCoord tb_offset_B{

165 0,

166 threadblock_tile_offset.n() * Mma::Shape::kN

167 };

168

169// Compute position within threadblock

170int thread_idx = threadIdx.x;

171

172// Construct iterators to A and B operands

173typename Mma::IteratorA iterator_A(

174 params.params_A,

175 params.ref_A.data(),

176 params.problem_size.mk(),

177 thread_idx,

178 tb_offset_A);

179

180 iterator_A.add_pointer_offset(params.stride_A * batch_idx);

181

182typename Mma::IteratorB iterator_B(

183 params.params_B,

184 params.ref_B.data(),

185 params.problem_size.kn(),

186 thread_idx,

187 tb_offset_B);

188

189 iterator_B.add_pointer_offset(params.stride_B * batch_idx);

190

191

192//

193// Main loop

194//

195

196// Construct thread-scoped matrix multiply

197int warp_idx = threadIdx.x / 32;

198int lane_idx = threadIdx.x % 32;

199

200Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

201

202typename Mma::FragmentC accumulators;

203

204 accumulators.clear();

205

206

207// Compute threadblock-scoped matrix multiply-add

208 mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);

209

210//

211// Epilogue

212//

213

214OutputOp output_op(params.epilogue);

215

216//

217// Masked tile iterators constructed from members

218//

219

220 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();

221

222//assume identity swizzle

223MatrixCoord threadblock_offset(

224 threadblock_tile_offset.m() * Mma::Shape::kM,

225 threadblock_tile_offset.n() * Mma::Shape::kN

226 );

227

228// Tile iterator writing to output tile

229typename Epilogue::OutputTileIterator iterator_C(

230 params.params_C,

231 params.ref_C.data(),

232 params.problem_size.mn(),

233 thread_idx,

234 threadblock_offset

235 );

236

237 iterator_C.add_pointer_offset(params.stride_C * batch_idx);

238

239// Tile iterator writing to output tile

240typename Epilogue::OutputTileIterator iterator_D(

241 params.params_D,

242 params.ref_D.data(),

243 params.problem_size.mn(),

244 thread_idx,

245 threadblock_offset

246 );

247

248 iterator_D.add_pointer_offset(params.stride_D * batch_idx);

249

250Epilogue epilogue(

251 shared_storage.epilogue,

252 thread_idx,

253 warp_idx,

254 lane_idx);

255

256// run efficient epilogue

257epilogue(output_op, iterator_D, accumulators, iterator_C);

258 }

259 }

260 };

261

263

264 } // namespace kernel

265 } // namespace gemm

266 } // namespace cutlass

267

cutlass::gemm::kernel::GemmBatched::operator()

CUTLASS_DEVICE void operator()(Params const &params, SharedStorage &shared_storage)

Executes one GEMM.

Definition: kernel/gemm_batched.h:138

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::kernel::GemmBatched::Params::Params

CUTLASS_HOST_DEVICE Params()

Definition: kernel/gemm_batched.h:85

cutlass::gemm::kernel::GemmBatched::OutputOp

typename Epilogue::OutputOp OutputOp

Definition: kernel/gemm_batched.h:53

cutlass::gemm::kernel::GemmBatched::Params::ref_D

Epilogue::OutputTileIterator::TensorRef ref_D

Definition: kernel/gemm_batched.h:74

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::gemm::GemmCoord::mn

CUTLASS_HOST_DEVICE Coord< 2 > mn() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:171

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::kernel::GemmBatched::Params::ref_B

Mma::IteratorB::TensorRef ref_B

Definition: kernel/gemm_batched.h:68

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::gemm::kernel::GemmBatched::Params::gemm_k_iterations

int gemm_k_iterations

Definition: kernel/gemm_batched.h:78

cutlass::gemm::kernel::GemmBatched::Params::ref_C

Epilogue::OutputTileIterator::TensorRef ref_C

Definition: kernel/gemm_batched.h:71

cutlass::gemm::kernel::GemmBatched::GemmBatched

CUTLASS_HOST_DEVICE GemmBatched()

Definition: kernel/gemm_batched.h:134

cutlass::gemm::kernel::GemmBatched::Epilogue

Epilogue_ Epilogue

Definition: kernel/gemm_batched.h:52

cutlass::gemm::kernel::GemmBatched::SharedStorage

Shared memory storage structure.

Definition: kernel/gemm_batched.h:124

cutlass::gemm::kernel::GemmBatched::Params::grid_tiled_shape

cutlass::gemm::GemmCoord grid_tiled_shape

Definition: kernel/gemm_batched.h:63

cutlass::gemm::kernel::GemmBatched::SharedStorage::main_loop

Mma::SharedStorage main_loop

Definition: kernel/gemm_batched.h:125

cutlass::gemm::kernel::GemmBatched::kThreadCount

static int const kThreadCount

Definition: kernel/gemm_batched.h:58

cutlass::gemm::kernel::GemmBatched::Params

Parameters structure.

Definition: kernel/gemm_batched.h:61

cutlass::gemm::kernel::GemmBatched::WarpCount

typename Mma::WarpCount WarpCount

Warp count (concept: GemmShape)

Definition: kernel/gemm_batched.h:57

cutlass::gemm::kernel::GemmBatched::Params::params_D

Epilogue::OutputTileIterator::Params params_D

Definition: kernel/gemm_batched.h:73

cutlass::gemm::kernel::GemmBatched::Params::params_C

Epilogue::OutputTileIterator::Params params_C

Definition: kernel/gemm_batched.h:70

cutlass::gemm::kernel::GemmBatched::Params::epilogue

OutputOp::Params epilogue

Definition: kernel/gemm_batched.h:76

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::gemm::kernel::GemmBatched::Params::stride_C

int64_t stride_C

Definition: kernel/gemm_batched.h:72

cutlass::gemm::GemmCoord::mk

CUTLASS_HOST_DEVICE Coord< 2 > mk() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:177

cutlass::gemm::kernel::GemmBatched::Mma

Mma_ Mma

Definition: kernel/gemm_batched.h:51

cutlass::gemm::kernel::GemmBatched::Params::problem_size

cutlass::gemm::GemmCoord problem_size

Definition: kernel/gemm_batched.h:62

cutlass::gemm::kernel::GemmBatched::Params::params_A

Mma::IteratorA::Params params_A

Definition: kernel/gemm_batched.h:64

matrix_coord.h

Defines a canonical coordinate for rank=2 matrices offering named indices.

cutlass::gemm::kernel::GemmBatched::Params::batch_count

int batch_count

Definition: kernel/gemm_batched.h:77

cutlass::gemm::kernel::GemmBatched::Params::params_B

Mma::IteratorB::Params params_B

Definition: kernel/gemm_batched.h:67

cutlass::gemm::GemmCoord::kn

CUTLASS_HOST_DEVICE Coord< 2 > kn() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:195

cutlass::gemm::kernel::GemmBatched::Params::stride_B

int64_t stride_B

Definition: kernel/gemm_batched.h:69

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

cutlass::gemm::kernel::GemmBatched::Params::ref_A

Mma::IteratorA::TensorRef ref_A

Definition: kernel/gemm_batched.h:65

cutlass::gemm::kernel::GemmBatched::Params::Params

CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size_, cutlass::gemm::GemmCoord const &grid_tiled_shape_, typename Mma::IteratorA::TensorRef ref_A_, int64_t stride_A_, typename Mma::IteratorB::TensorRef ref_B_, int64_t stride_B_, typename Epilogue::OutputTileIterator::TensorRef ref_C_, int64_t stride_C_, typename Epilogue::OutputTileIterator::TensorRef ref_D_, int64_t stride_D_, typename OutputOp::Params epilogue_, int batch_count_)

Definition: kernel/gemm_batched.h:88

cutlass::gemm::kernel::GemmBatched::Params::stride_A

int64_t stride_A

Definition: kernel/gemm_batched.h:66

cutlass::gemm::kernel::GemmBatched

Definition: kernel/gemm_batched.h:49

cutlass::gemm::kernel::GemmBatched::SharedStorage::epilogue

Epilogue::SharedStorage epilogue

Definition: kernel/gemm_batched.h:126

cutlass::gemm::kernel::GemmBatched::Params::stride_D

int64_t stride_D

Definition: kernel/gemm_batched.h:75

cutlass::gemm::kernel::GemmBatched::ThreadblockSwizzle

ThreadblockSwizzle_ ThreadblockSwizzle

Definition: kernel/gemm_batched.h:54

cutlass.h

Basic include for CUTLASS.

cutlass::MatrixCoord

Definition: matrix_coord.h:39


Generated by 1.8.11