Back to Cutlass

CUTLASS: gemm_splitk_parallel.h Source File

docs/kernel_2gemm__splitk__parallel_8h_source.html

4.4.227.1 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

kernel/gemm_splitk_parallel.h

[Go to the documentation of this file.](kernel_2gemm splitk parallel_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32

33 #include "cutlass/gemm/gemm.h"

34 #include "cutlass/matrix_coord.h"

35

37

38 namespace cutlass {

39 namespace gemm {

40 namespace kernel {

41

43

44 template <

45typename Mma_,

46typename Epilogue_,

47typename ThreadblockSwizzle_

48 >

49 struct GemmSplitKParallel {

50

51using Mma = Mma_;

52using Epilogue = Epilogue_;

53using OutputOp = typename Epilogue::OutputOp;

54using ThreadblockSwizzle = ThreadblockSwizzle_;

55

57using WarpCount = typename Mma::WarpCount;

58static int const kThreadCount = 32 * WarpCount::kCount;

59

60static int const kAlignmentK = Mma::Operator::Shape::kK;

61

63struct Params {

64cutlass::gemm::GemmCoord problem_size;

65cutlass::gemm::GemmCoord grid_tiled_shape;

66typename Mma::IteratorA::Params params_A;

67typename Mma::IteratorA::TensorRef ref_A;

68typename Mma::IteratorB::Params params_B;

69typename Mma::IteratorB::TensorRef ref_B;

70typename Epilogue::OutputTileIterator::Params params_D;

71typename Epilogue::OutputTileIterator::TensorRef ref_D;

72typename OutputOp::Params output_op;

73 int64_t splitk_slice_stride;

74int gemm_k_size;

75

76//

77// Methods

78//

79

80CUTLASS_HOST_DEVICE

81Params() { }

82

83CUTLASS_HOST_DEVICE

84Params(

85cutlass::gemm::GemmCoord const & problem_size,

86cutlass::gemm::GemmCoord const & grid_tiled_shape,

87typename Mma::IteratorA::TensorRef ref_A,

88typename Mma::IteratorB::TensorRef ref_B,

89typename Epilogue::OutputTileIterator::TensorRef ref_D,

90typename OutputOp::Params output_op,

91 int64_t splitk_slice_stride

92 ):

93 problem_size(problem_size),

94 grid_tiled_shape(grid_tiled_shape),

95 params_A(ref_A.layout()),

96 ref_A(ref_A),

97 params_B(ref_B.layout()),

98 ref_B(ref_B),

99 params_D(ref_D.layout()),

100 ref_D(ref_D),

101 output_op(output_op),

102 splitk_slice_stride(splitk_slice_stride) {

103

104int full_gemm_k_iterations = problem_size.k() / Mma::Shape::kK;

105int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();

106

107 gemm_k_size = gemm_k_iterations * Mma::Shape::kK;

108 }

109 };

110

112union SharedStorage {

113typename Mma::SharedStorage main_loop;

114typename Epilogue::SharedStorage epilogue;

115 };

116

117//

118// Methods

119//

120

121CUTLASS_HOST_DEVICE

122GemmSplitKParallel() { }

123

125 CUTLASS_DEVICE

126void operator()(Params const &params, SharedStorage &shared_storage) {

127

128// Compute threadblock location

129ThreadblockSwizzle threadblock_swizzle;

130

131cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset();

132

133// Early exit if CTA is out of range

134if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||

135 params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {

136

137return;

138 }

139

140// Compute initial location in logical coordinates

141cutlass::MatrixCoord tb_offset_A{

142 threadblock_tile_offset.m() * Mma::Shape::kM,

143 threadblock_tile_offset.k() * params.gemm_k_size,

144 };

145

146cutlass::MatrixCoord tb_offset_B{

147 threadblock_tile_offset.k() * params.gemm_k_size,

148 threadblock_tile_offset.n() * Mma::Shape::kN

149 };

150

151// Problem size is a function of threadblock index in the K dimension

152int problem_size_k;

153if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) {

154 problem_size_k = params.problem_size.k();

155 }

156else {

157 problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;

158 }

159

160// Compute threadblock-scoped matrix multiply-add

161int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;

162

163// Compute position within threadblock

164int thread_idx = threadIdx.x;

165

166// Construct iterators to A and B operands

167typename Mma::IteratorA iterator_A(

168 params.params_A,

169 params.ref_A.data(),

170 {params.problem_size.m(), problem_size_k},

171 thread_idx,

172 tb_offset_A);

173

174typename Mma::IteratorB iterator_B(

175 params.params_B,

176 params.ref_B.data(),

177 {problem_size_k, params.problem_size.n()},

178 thread_idx,

179 tb_offset_B);

180

181int warp_idx = threadIdx.x / 32;

182int lane_idx = threadIdx.x % 32;

183

184

185//

186// Main loop

187//

188

189// Construct thread-scoped matrix multiply

190Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

191

192typename Mma::FragmentC accumulators;

193

194 accumulators.clear();

195

196 mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);

197

198//

199// Epilogue

200//

201

202OutputOp output_op(params.output_op);

203

204//

205// Masked tile iterators constructed from members

206//

207

208 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();

209

210//assume identity swizzle

211MatrixCoord threadblock_offset(

212 threadblock_tile_offset.m() * Mma::Shape::kM,

213 threadblock_tile_offset.n() * Mma::Shape::kN

214 );

215

216// Tile iterator writing to output tile

217typename Epilogue::OutputTileIterator iterator_D(

218 params.params_D,

219 params.ref_D.data(),

220 params.problem_size.mn(),

221 thread_idx,

222 threadblock_offset

223 );

224

225 iterator_D.add_pointer_offset(params.splitk_slice_stride * threadblock_tile_offset.k());

226

227// Execute the epilogue

228Epilogue epilogue(

229 shared_storage.epilogue,

230 thread_idx,

231 warp_idx,

232 lane_idx);

233

234// Run efficient epilogue

235 epilogue(output_op, iterator_D, accumulators, iterator_D);

236 }

237 };

238

240

241 } // namespace kernel

242 } // namespace gemm

243 } // namespace cutlass

244

cutlass::gemm::kernel::GemmSplitKParallel::operator()

CUTLASS_DEVICE void operator()(Params const &params, SharedStorage &shared_storage)

Executes one GEMM.

Definition: kernel/gemm_splitk_parallel.h:126

cutlass::gemm::kernel::GemmSplitKParallel::GemmSplitKParallel

CUTLASS_HOST_DEVICE GemmSplitKParallel()

Definition: kernel/gemm_splitk_parallel.h:122

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::kernel::GemmSplitKParallel::Epilogue

Epilogue_ Epilogue

Definition: kernel/gemm_splitk_parallel.h:52

cutlass::gemm::kernel::GemmSplitKParallel::Params::problem_size

cutlass::gemm::GemmCoord problem_size

Definition: kernel/gemm_splitk_parallel.h:64

cutlass::gemm::kernel::GemmSplitKParallel::SharedStorage

Shared memory storage structure.

Definition: kernel/gemm_splitk_parallel.h:112

cutlass::gemm::kernel::GemmSplitKParallel::SharedStorage::epilogue

Epilogue::SharedStorage epilogue

Definition: kernel/gemm_splitk_parallel.h:114

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::gemm::GemmCoord::mn

CUTLASS_HOST_DEVICE Coord< 2 > mn() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:171

cutlass::gemm::kernel::GemmSplitKParallel::Params::grid_tiled_shape

cutlass::gemm::GemmCoord grid_tiled_shape

Definition: kernel/gemm_splitk_parallel.h:65

cutlass::gemm::kernel::GemmSplitKParallel::kThreadCount

static int const kThreadCount

Definition: kernel/gemm_splitk_parallel.h:58

cutlass::gemm::kernel::GemmSplitKParallel::SharedStorage::main_loop

Mma::SharedStorage main_loop

Definition: kernel/gemm_splitk_parallel.h:113

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::gemm::kernel::GemmSplitKParallel::Params

Parameters structure.

Definition: kernel/gemm_splitk_parallel.h:63

cutlass::gemm::kernel::GemmSplitKParallel::WarpCount

typename Mma::WarpCount WarpCount

Warp count (concept: GemmShape)

Definition: kernel/gemm_splitk_parallel.h:57

cutlass::gemm::kernel::GemmSplitKParallel::ThreadblockSwizzle

ThreadblockSwizzle_ ThreadblockSwizzle

Definition: kernel/gemm_splitk_parallel.h:54

cutlass::gemm::kernel::GemmSplitKParallel::Params::Params

CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op, int64_t splitk_slice_stride)

Definition: kernel/gemm_splitk_parallel.h:84

cutlass::gemm::kernel::GemmSplitKParallel::Params::output_op

OutputOp::Params output_op

Definition: kernel/gemm_splitk_parallel.h:72

cutlass::gemm::GemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:145

cutlass::gemm::kernel::GemmSplitKParallel::Params::ref_A

Mma::IteratorA::TensorRef ref_A

Definition: kernel/gemm_splitk_parallel.h:67

cutlass::gemm::kernel::GemmSplitKParallel::Params::ref_B

Mma::IteratorB::TensorRef ref_B

Definition: kernel/gemm_splitk_parallel.h:69

cutlass::gemm::kernel::GemmSplitKParallel::Params::gemm_k_size

int gemm_k_size

Definition: kernel/gemm_splitk_parallel.h:74

cutlass::gemm::kernel::GemmSplitKParallel::Params::ref_D

Epilogue::OutputTileIterator::TensorRef ref_D

Definition: kernel/gemm_splitk_parallel.h:71

cutlass::gemm::kernel::GemmSplitKParallel::Params::Params

CUTLASS_HOST_DEVICE Params()

Definition: kernel/gemm_splitk_parallel.h:81

cutlass::gemm::kernel::GemmSplitKParallel::Params::params_D

Epilogue::OutputTileIterator::Params params_D

Definition: kernel/gemm_splitk_parallel.h:70

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::gemm::kernel::GemmSplitKParallel::Params::params_A

Mma::IteratorA::Params params_A

Definition: kernel/gemm_splitk_parallel.h:66

cutlass::gemm::kernel::GemmSplitKParallel::kAlignmentK

static int const kAlignmentK

Definition: kernel/gemm_splitk_parallel.h:60

matrix_coord.h

Defines a canonical coordinate for rank=2 matrices offering named indices.

cutlass::gemm::kernel::GemmSplitKParallel

Definition: kernel/gemm_splitk_parallel.h:49

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

cutlass::gemm::kernel::GemmSplitKParallel::Mma

Mma_ Mma

Definition: kernel/gemm_splitk_parallel.h:51

cutlass::gemm::kernel::GemmSplitKParallel::Params::params_B

Mma::IteratorB::Params params_B

Definition: kernel/gemm_splitk_parallel.h:68

cutlass::gemm::kernel::GemmSplitKParallel::Params::splitk_slice_stride

int64_t splitk_slice_stride

Definition: kernel/gemm_splitk_parallel.h:73

cutlass.h

Basic include for CUTLASS.

cutlass::MatrixCoord

Definition: matrix_coord.h:39

cutlass::gemm::kernel::GemmSplitKParallel::OutputOp

typename Epilogue::OutputOp OutputOp

Definition: kernel/gemm_splitk_parallel.h:53


Generated by 1.8.11