Back to Cutlass

CUTLASS: gemm_splitk_parallel.h Source File

docs/device_2gemm__splitk__parallel_8h_source.html

4.4.287.0 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

device/gemm_splitk_parallel.h

[Go to the documentation of this file.](device_2gemm splitk parallel_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/numeric_types.h"

33 #include "cutlass/arch/arch.h"

34 #include "cutlass/device_kernel.h"

35

36 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"

37 #include "cutlass/gemm/kernel/gemm.h"

38

39 #include "[cutlass/gemm/kernel/default_gemm_splitk_parallel.h](default gemm splitk__parallel_8h.html)"

40 #include "[cutlass/gemm/device/default_gemm_configuration.h](default gemm configuration_8h.html)"

41

42 #include "cutlass/epilogue/thread/conversion_op.h"

43 #include "[cutlass/reduction/kernel/reduce_split_k.h](reduce split k_8h.html)"

44 #include "cutlass/reduction/thread/reduction_operators.h"

45

47

48 namespace cutlass {

49 namespace gemm {

50 namespace device {

51

53

58 template <

60typename ElementA_,

62typename LayoutA_,

64typename ElementB_,

66typename LayoutB_,

68typename ElementC_,

70typename LayoutC_,

72typename ElementAccumulator_ = ElementC_,

74typename OperatorClass_ = arch::OpClassSimt,

76typename ArchTag_ = arch::Sm70,

78typename ThreadblockShape_ = typename DefaultGemmConfiguration<

79 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

80 ElementAccumulator_>::ThreadblockShape,

82typename WarpShape_ = typename DefaultGemmConfiguration<

83 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

84 ElementAccumulator_>::WarpShape,

86typename InstructionShape_ = typename DefaultGemmConfiguration<

87 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

88 ElementAccumulator_>::InstructionShape,

90typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<

91 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

92 ElementAccumulator_>::EpilogueOutputOp,

94typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert<

95 ElementAccumulator_,

96 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

97 ElementAccumulator_,

98 ElementAccumulator_>::EpilogueOutputOp::kCount,

99 ElementAccumulator_>,

101typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd<

102 ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator,

103 EpilogueOutputOp_::kCount>,

105typename ThreadblockSwizzle_ =

106 threadblock::GemmSplitKHorizontalThreadblockSwizzle,

108int Stages =

109 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

110 ElementC_, ElementAccumulator_>::kStages,

112int kAlignmentA =

113 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

114 ElementC_, ElementAccumulator_>::kAlignmentA,

116int kAlignmentB =

117 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

118 ElementC_, ElementAccumulator_>::kAlignmentB,

120typename Operator_ = typename DefaultGemmConfiguration<

121 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

122 ElementAccumulator_>::Operator>

123 class GemmSplitKParallel {

124public:

125

126using ElementA = ElementA_;

127using LayoutA = LayoutA_;

128using ElementB = ElementB_;

129using LayoutB = LayoutB_;

130using ElementC = ElementC_;

131using LayoutC = LayoutC_;

132using ElementAccumulator = ElementAccumulator_;

133using OperatorClass = OperatorClass_;

134using ArchTag = ArchTag_;

135using ThreadblockShape = ThreadblockShape_;

136using WarpShape = WarpShape_;

137using InstructionShape = InstructionShape_;

138using ConvertScaledOp = ConvertScaledOp_;

139using EpilogueOutputOp = EpilogueOutputOp_;

140using ReductionOp = ReductionOp_;

141using ThreadblockSwizzle = ThreadblockSwizzle_;

142using Operator = Operator_;

143static int const kStages = Stages;

144

146using GemmKernel = typename kernel::DefaultGemmSplitKParallel<

147ElementA,

148LayoutA,

149 kAlignmentA,

150ElementB,

151LayoutB,

152 kAlignmentB,

153ElementAccumulator,

154LayoutC,

155ElementAccumulator,

156OperatorClass,

157ArchTag,

158ThreadblockShape,

159WarpShape,

160InstructionShape,

161ConvertScaledOp,

162ThreadblockSwizzle,

163kStages,

164Operator

165 >::GemmKernel;

166

168using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<

169cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,

170EpilogueOutputOp,

171ReductionOp

172 >;

173

174//

175//

176//

177

179struct Arguments {

180

181//

182// Data members

183//

184

185GemmCoord problem_size;

186TensorRef<ElementA const, LayoutA> ref_A;

187TensorRef<ElementB const, LayoutB> ref_B;

188TensorRef<ElementC const, LayoutC> ref_C;

189TensorRef<ElementC, LayoutC> ref_D;

190typename EpilogueOutputOp::Params epilogue;

191int split_k_slices;

192typename ConvertScaledOp::Params convert;

193typename ReductionOp::Params reduction;

194

195//

196// Methods

197//

198

200CUTLASS_HOST_DEVICE

201Arguments() { }

202

204CUTLASS_HOST_DEVICE

205Arguments(

206GemmCoord problem_size_,

207TensorRef<ElementA const, LayoutA> ref_A_,

208TensorRef<ElementB const, LayoutB> ref_B_,

209TensorRef<ElementC const, LayoutC> ref_C_,

210TensorRef<ElementC, LayoutC> ref_D_,

211typename EpilogueOutputOp::Params epilogue_ =

212typename EpilogueOutputOp::Params(),

213int split_k_slices = 1,

214typename ConvertScaledOp::Params convert_ =

215typename ConvertScaledOp::Params(),

216typename ReductionOp::Params reduction_ =

217typename ReductionOp::Params()

218 ):

219 problem_size(problem_size_),

220 ref_A(ref_A_),

221 ref_B(ref_B_),

222 ref_C(ref_C_),

223 ref_D(ref_D_),

224 epilogue(epilogue_),

225 split_k_slices(split_k_slices),

226 convert(convert_),

227 reduction(reduction_) { }

228 };

229

230 private:

231

233typename GemmKernel::Params gemm_params_;

234

236typename ReductionKernel::Params reduction_params_;

237

238 public:

239

241GemmSplitKParallel() { }

242

244static Status can_implement(Arguments const &args) {

245

246// TODO

247

248return Status::kSuccess;

249 }

250

252static size_t get_workspace_size(Arguments const &args) {

253

254// Determine grid shape

255 ThreadblockSwizzle threadblock_swizzle;

256

257cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(

258 args.problem_size,

259 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},

260 args.split_k_slices);

261

262return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) * size_t(args.problem_size.n()) * grid_shape.k();

263 }

264

266Status initialize(Arguments const &args, void *workspace) {

267

268// Determine grid shape

269 ThreadblockSwizzle threadblock_swizzle;

270

271cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(

272 args.problem_size,

273 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},

274 args.split_k_slices);

275

276// Define a reference to the workspace - this is an aligned region in device memory.

277if (!workspace) {

278return Status::kErrorWorkspaceNull;

279 }

280

281TensorRef<ElementAccumulator_, layout::RowMajor> ref_workspace(

282 static_cast<ElementAccumulator_ *>(workspace),

283 args.problem_size.n());

284

285 int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n());

286

287// Initialize the Params structure

288 gemm_params_ = typename GemmKernel::Params{

289 args.problem_size,

290 grid_shape,

291 args.ref_A.non_const_ref(),

292 args.ref_B.non_const_ref(),

293 ref_workspace,

294 args.convert,

295 partition_stride

296 };

297

298 reduction_params_ = typename ReductionKernel::Params(

299 args.problem_size.mn(),

300 grid_shape.k(),

301 partition_stride,

302 ref_workspace,

303 args.ref_D,

304 args.ref_C.non_const_ref(),

305 args.epilogue

306 );

307

308return Status::kSuccess;

309 }

310

312Status update(Arguments const &args, void *workspace = nullptr) {

313

314if (!workspace) {

315return Status::kErrorWorkspaceNull;

316 }

317

318 gemm_params_.ref_A.reset(args.ref_A.data());

319 gemm_params_.ref_B.reset(args.ref_B.data());

320 gemm_params_.ref_D.reset(workspace);

321

322 reduction_params_.ref_D.reset(args.ref_D.data());

323 reduction_params_.ref_C.reset(args.ref_C.data());

324

325return Status::kSuccess;

326 }

327

329Status run(cudaStream_t stream = nullptr) {

330

331//

332// Launch GEMM kernel

333//

334

335 ThreadblockSwizzle threadblock_swizzle;

336

337 dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape);

338 dim3 block(GemmKernel::kThreadCount, 1, 1);

339

340 cudaError_t result;

341

342int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

343if (smem_size >= (48 << 10)) {

344

345 result = cudaFuncSetAttribute(

346 Kernel<GemmKernel>,

347 cudaFuncAttributeMaxDynamicSharedMemorySize,

348 smem_size);

349

350if (result != cudaSuccess) {

351return Status::kErrorInternal;

352 }

353

354 result = cudaFuncSetAttribute(

355 Kernel<GemmKernel>,

356 cudaFuncAttributePreferredSharedMemoryCarveout, 100);

357

358if (result != cudaSuccess) {

359return Status::kErrorInternal;

360 }

361 }

362

363 Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);

364

365 result = cudaGetLastError();

366

367if (result != cudaSuccess) {

368return Status::kErrorInternal;

369 }

370

371//

372// Launch reduction kernel

373//

374

375 block = ReductionKernel::block_shape();

376 grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn());

377

378 Kernel<ReductionKernel><<< grid, block, 0, stream >>>(reduction_params_);

379

380 result = cudaGetLastError();

381

382if (result != cudaSuccess) {

383return Status::kErrorInternal;

384 }

385

386return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;

387 }

388

390Status operator()(cudaStream_t stream = nullptr) {

391return run(stream);

392 }

393

395Status operator()(

396Arguments const &args,

397void *workspace = nullptr,

398 cudaStream_t stream = nullptr) {

399

400Status status = initialize(args, workspace);

401

402if (status == Status::kSuccess) {

403 status = run(stream);

404 }

405

406return status;

407 }

408 };

409

411

413 template <

415typename ElementA_,

417typename LayoutA_,

419typename ElementB_,

421typename LayoutB_,

423typename ElementC_,

425typename ElementAccumulator_,

427typename OperatorClass_,

429typename ArchTag_,

431typename ThreadblockShape_,

433typename WarpShape_,

435typename InstructionShape_,

437typename EpilogueOutputOp_,

439typename ConvertScaledOp_,

441typename ReductionOp_,

443typename ThreadblockSwizzle_,

445int Stages, int kAlignmentA, int kAlignmentB,

447typename Operator_>

[448](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html) class GemmSplitKParallel<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,

449 layout::ColumnMajor, ElementAccumulator_,

450 OperatorClass_, ArchTag_, ThreadblockShape_,

451 WarpShape_, InstructionShape_, EpilogueOutputOp_,

452 ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_,

453 Stages, kAlignmentA, kAlignmentB, Operator_> {

454public:

455

[456](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a05fc35f2f2fc3c329eccb6af24981caf)using ElementA = ElementA_;

[457](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ab0f19b729484a5d7e384af1a310f3f8c)using LayoutA = LayoutA_;

[458](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#af5f036e046e05c2a19cfd99673f9835c)using ElementB = ElementB_;

[459](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad3783855d4101f59892e1af5024288ff)using LayoutB = LayoutB_;

[460](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#adbe8a410fe634ab05b8cf69356b79b26)using [ElementC](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#adbe8a410fe634ab05b8cf69356b79b26) = ElementC_;

[461](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a33d738b2e304c974a9b77be0b176fb59)using LayoutC = layout::ColumnMajor;

[462](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a37df0372c002340106a6f1651348084e)using ElementAccumulator = ElementAccumulator_;

[463](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a498ec1c01a7bfd6f2e401450991ed8be)using OperatorClass = OperatorClass_;

[464](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ab83912e2e116c176d3f733ccdee06a1b)using ArchTag = ArchTag_;

[465](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#add39b0bee00309be7dfca383dbda0cab)using ThreadblockShape = ThreadblockShape_;

[466](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a93db32fb628949381ff8d18b2a765624)using WarpShape = WarpShape_;

[467](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a8c6c83e045a18b7a3c004e039509576e)using InstructionShape = InstructionShape_;

[468](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aa69d9364cc5247ea353608d5c0600fe7)using ConvertScaledOp = ConvertScaledOp_;

[469](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#af5a360d190ca3e8a9df879eaf8e65dd9)using EpilogueOutputOp = EpilogueOutputOp_;

[470](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a2d8d3a504dd8807ed09e25f37a658783)using [ReductionOp](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a2d8d3a504dd8807ed09e25f37a658783) = ReductionOp_;

[471](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a1f04e5294e4238442cb23666564db958)using ThreadblockSwizzle = ThreadblockSwizzle_;

[472](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a696cd49441ddb490d32a374135731c68)using [Operator](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a696cd49441ddb490d32a374135731c68) = Operator_;

[473](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a87961d33bf1aff6a6cbb5a6bc022493e)static int const kStages = Stages;

474

475using UnderlyingOperator = GemmSplitKParallel<

476ElementB,

477typename layout::LayoutTranspose<LayoutB>::type,

478ElementA,

479typename layout::LayoutTranspose<LayoutA>::type,

480ElementC,

481layout::RowMajor,

482ElementAccumulator,

483OperatorClass,

484ArchTag,

485ThreadblockShape,

486WarpShape,

487InstructionShape,

488EpilogueOutputOp,

489ConvertScaledOp,

490ReductionOp,

491ThreadblockSwizzle,

492 Stages,

493 kAlignmentA,

494 kAlignmentB,

495[Operator](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a696cd49441ddb490d32a374135731c68)

[496](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#adb3cad6256057addcd5cf96f469fd679) >;

497

[498](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d)using [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d) = typename UnderlyingOperator::Arguments;

[499](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a949dbf8f84e6350649a171bf3b45478a)using [GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a949dbf8f84e6350649a171bf3b45478a) = typename UnderlyingOperator::GemmKernel;

[500](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aa69c465611c07990cdc79605c16b04ff)using [ReductionKernel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aa69c465611c07990cdc79605c16b04ff) = typename UnderlyingOperator::ReductionKernel;

501

[503](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html)struct Arguments {

504

505//

506// Data members

507//

508

[509](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#adee4f1a66aa6b6cb0400f6159ec52eb9)GemmCoord [problem_size](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#adee4f1a66aa6b6cb0400f6159ec52eb9);

[510](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a8b10e75e5d6cd348dacc085f5264ee95)TensorRef<ElementA const, LayoutA> [ref_A](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a8b10e75e5d6cd348dacc085f5264ee95);

[511](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a9a22df7c4d515a48e03fd6f16e074217)TensorRef<ElementB const, LayoutB> [ref_B](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a9a22df7c4d515a48e03fd6f16e074217);

[512](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a2aabb13f196a087b77245c67c8664b7b)TensorRef<ElementC const, LayoutC> [ref_C](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a2aabb13f196a087b77245c67c8664b7b);

[513](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a850da307d8741296e515add0f716eaf9)TensorRef<ElementC, LayoutC> [ref_D](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a850da307d8741296e515add0f716eaf9);

[514](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a04818e67f94c5440ac6c367798e17fc2)typename EpilogueOutputOp::Params [epilogue](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a04818e67f94c5440ac6c367798e17fc2);

[515](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#aff78ac3c99bb15cf8a7d7a1ece736cd1)int [split_k_slices](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#aff78ac3c99bb15cf8a7d7a1ece736cd1);

[516](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a48ced96adaf371f03c1c9a50db9f50f2)typename ConvertScaledOp::Params [convert](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a48ced96adaf371f03c1c9a50db9f50f2);

[517](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a63048fa3419753d96a60eaee28f6cfe4)typename ReductionOp::Params [reduction](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a63048fa3419753d96a60eaee28f6cfe4);

518

519//

520// Methods

521//

522

524CUTLASS_HOST_DEVICE

[525](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#acf6c5b216c0c82f0c7797627d651743f)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#acf6c5b216c0c82f0c7797627d651743f)() { }

526

528CUTLASS_HOST_DEVICE

[529](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a37c45d8dc800de6a631b8a096704559a)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a37c45d8dc800de6a631b8a096704559a)(

530GemmCoord problem_size_,

531TensorRef<ElementA const, LayoutA> ref_A_,

532TensorRef<ElementB const, LayoutB> ref_B_,

533TensorRef<ElementC const, LayoutC> ref_C_,

534TensorRef<ElementC, LayoutC> ref_D_,

535typename EpilogueOutputOp::Params epilogue_ =

536typename EpilogueOutputOp::Params(),

537int split_k_slices = 1,

538typename ConvertScaledOp::Params convert_ =

539typename ConvertScaledOp::Params(),

540typename ReductionOp::Params reduction_ =

541typename ReductionOp::Params()

542 ):

543 problem_size(problem_size_),

544 ref_A(ref_A_),

545 ref_B(ref_B_),

546 ref_C(ref_C_),

547 ref_D(ref_D_),

548 epilogue(epilogue_),

549 split_k_slices(split_k_slices),

550 convert(convert_),

551 reduction(reduction_) { }

552 };

553

554 private:

555

557UnderlyingOperator underlying_operator_;

558

559 public:

560

[562](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad0c614a548bcade989eb25633b45bb0f)[GemmSplitKParallel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad0c614a548bcade989eb25633b45bb0f)() { }

563

[565](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a535863339cab9879474e31f2fd543804)static [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d) [to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a535863339cab9879474e31f2fd543804)(Arguments const &args) {

566return [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d)(

567 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},

568 {args.ref_B.data(), args.ref_B.stride(0)},

569 {args.ref_A.data(), args.ref_A.stride(0)},

570 {args.ref_C.data(), args.ref_C.stride(0)},

571 {args.ref_D.data(), args.ref_D.stride(0)},

572 args.epilogue,

573 args.split_k_slices,

574 args.convert,

575 args.reduction

576 );

577 }

578

[580](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a465591fbfde2a9aa6330d9adcbf82bd6)static Status [can_implement](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a465591fbfde2a9aa6330d9adcbf82bd6)(Arguments const &args) {

581

582return UnderlyingOperator::can_implement(to_underlying_arguments(args));

583 }

584

[586](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9c70e23eef0a15d849b5b0ebadfcdd)static size_t [get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9c70e23eef0a15d849b5b0ebadfcdd)(Arguments const &args) {

587

588return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));

589 }

590

[592](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9f086305f76d7f885bf032f3d2c7c9)Status [initialize](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9f086305f76d7f885bf032f3d2c7c9)(Arguments const &args, void *workspace) {

593

594return underlying_operator_.initialize(to_underlying_arguments(args), workspace);

595 }

596

[598](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a44facae3996ed3da5fdb4398e469b773)Status [update](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a44facae3996ed3da5fdb4398e469b773)(Arguments const &args, void *workspace = nullptr) {

599

600return underlying_operator_.update(to_underlying_arguments(args), workspace);

601 }

602

[604](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a1de7cf5d8bad27b3ff6c803dbc572077)Status [run](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a1de7cf5d8bad27b3ff6c803dbc572077)(cudaStream_t stream = nullptr) {

605

606return underlying_operator_.run(stream);

607 }

608

[610](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a72f5de19ad97e08241157d5106f2f66a)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a72f5de19ad97e08241157d5106f2f66a)(cudaStream_t stream = nullptr) {

611return run(stream);

612 }

613

[615](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad6d811ca346ce6467a291497edc85623)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad6d811ca346ce6467a291497edc85623)(

616Arguments const &args,

617void *workspace = nullptr,

618 cudaStream_t stream = nullptr) {

619

620Status status = initialize(args, workspace);

621

622if (status == Status::kSuccess) {

623 status = run(stream);

624 }

625

626return status;

627 }

628 };

629

631

632 } // namespace device

633 } // namespace gemm

634 } // namespace cutlass

635

cutlass::epilogue::thread::Convert

Definition: conversion_op.h:53

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::WarpShape

WarpShape WarpShape

Definition: device/gemm_splitk_parallel.h:136

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a949dbf8f84e6350649a171bf3b45478a)

typename UnderlyingOperator::GemmKernel GemmKernel

Definition: device/gemm_splitk_parallel.h:499

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_D](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a850da307d8741296e515add0f716eaf9)

TensorRef< ElementC, LayoutC > ref_D

Definition: device/gemm_splitk_parallel.h:513

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::LayoutB

typename layout::LayoutTranspose< LayoutA >::type LayoutB

Definition: device/gemm_splitk_parallel.h:129

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_C](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a2aabb13f196a087b77245c67c8664b7b)

TensorRef< ElementC const, LayoutC > ref_C

Definition: device/gemm_splitk_parallel.h:512

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::Operator

Operator Operator

Definition: device/gemm_splitk_parallel.h:142

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ReductionKernel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aa69c465611c07990cdc79605c16b04ff)

typename UnderlyingOperator::ReductionKernel ReductionKernel

Definition: device/gemm_splitk_parallel.h:500

cutlass

Definition: aligned_buffer.h:35

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::problem_size](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#adee4f1a66aa6b6cb0400f6159ec52eb9)

GemmCoord problem_size

Definition: device/gemm_splitk_parallel.h:509

cutlass::gemm::kernel::DefaultGemmSplitKParallel

Definition: default_gemm_splitk_parallel.h:88

cutlass::gemm::device::GemmSplitKParallel::operator()

Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_splitk_parallel.h:395

cutlass::gemm::device::GemmSplitKParallel::kStages

static int const kStages

Definition: device/gemm_splitk_parallel.h:143

cutlass::reduction::kernel::ReduceSplitK::block_shape

static CUTLASS_HOST_DEVICE dim3 block_shape()

Determines the threadblock shape.

Definition: reduce_split_k.h:138

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ElementC](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#adbe8a410fe634ab05b8cf69356b79b26)

ElementC_ ElementC

Definition: device/gemm_splitk_parallel.h:460

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_A](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a8b10e75e5d6cd348dacc085f5264ee95)

TensorRef< ElementA const, LayoutA > ref_A

Definition: device/gemm_splitk_parallel.h:510

reduction_operators.h

Kernel performing a reduction over densely packed tensors in global memory.

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::epilogue](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a04818e67f94c5440ac6c367798e17fc2)

EpilogueOutputOp::Params epilogue

Definition: device/gemm_splitk_parallel.h:514

conversion_op.h

Functor performing conversion operations used by epilogues.

cutlass::gemm::device::GemmSplitKParallel::Arguments::split_k_slices

int split_k_slices

Definition: device/gemm_splitk_parallel.h:191

cutlass::gemm::device::GemmSplitKParallel::Arguments::reduction

ReductionOp::Params reduction

Definition: device/gemm_splitk_parallel.h:193

cutlass::reduction::thread::ReduceAdd

Mixed-precision reduction.

Definition: reduction_operators.h:50

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::run](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a1de7cf5d8bad27b3ff6c803dbc572077)

Status run(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_splitk_parallel.h:604

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::update](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a44facae3996ed3da5fdb4398e469b773)

Status update(Arguments const &args, void *workspace=nullptr)

Lightweight update given a subset of arguments.

Definition: device/gemm_splitk_parallel.h:598

cutlass::reduction::kernel::ReduceSplitK::Params

Params structure.

Definition: reduce_split_k.h:80

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::InstructionShape

InstructionShape InstructionShape

Definition: device/gemm_splitk_parallel.h:137

cutlass::gemm::device::GemmSplitKParallel::Arguments::Arguments

CUTLASS_HOST_DEVICE Arguments()

Default ctor.

Definition: device/gemm_splitk_parallel.h:201

cutlass::gemm::GemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:145

cutlass::gemm::device::GemmSplitKParallel::Arguments::convert

ConvertScaledOp::Params convert

Definition: device/gemm_splitk_parallel.h:192

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::GemmSplitKParallel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad0c614a548bcade989eb25633b45bb0f)

GemmSplitKParallel()

Constructs the GEMM.

Definition: device/gemm_splitk_parallel.h:562

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementC

ElementC ElementC

Definition: device/gemm_splitk_parallel.h:130

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ThreadblockShape

ThreadblockShape ThreadblockShape

Definition: device/gemm_splitk_parallel.h:135

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_B](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a9a22df7c4d515a48e03fd6f16e074217)

TensorRef< ElementB const, LayoutB > ref_B

Definition: device/gemm_splitk_parallel.h:511

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::EpilogueOutputOp

EpilogueOutputOp EpilogueOutputOp

Definition: device/gemm_splitk_parallel.h:139

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9c70e23eef0a15d849b5b0ebadfcdd)

static size_t get_workspace_size(Arguments const &args)

Gets the workspace size.

Definition: device/gemm_splitk_parallel.h:586

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ArchTag

ArchTag ArchTag

Definition: device/gemm_splitk_parallel.h:134

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad6d811ca346ce6467a291497edc85623)

Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_splitk_parallel.h:615

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a72f5de19ad97e08241157d5106f2f66a)

Status operator()(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_splitk_parallel.h:610

cutlass::gemm::device::GemmSplitKParallel::can_implement

static Status can_implement(Arguments const &args)

Determines whether the GEMM can execute the given problem.

Definition: device/gemm_splitk_parallel.h:244

cutlass::gemm::device::GemmSplitKParallel::Arguments::ref_D

TensorRef< ElementC, LayoutC > ref_D

Definition: device/gemm_splitk_parallel.h:189

cutlass::gemm::device::GemmSplitKParallel::GemmSplitKParallel

GemmSplitKParallel()

Constructs the GEMM.

Definition: device/gemm_splitk_parallel.h:241

cutlass::layout::LayoutTranspose

Defines transposes of matrix layouts.

Definition: layout/matrix.h:921

cutlass::gemm::device::GemmSplitKParallel::Arguments::problem_size

GemmCoord problem_size

Definition: device/gemm_splitk_parallel.h:185

cutlass::TensorRef< ElementA const, LayoutA >

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ReductionOp](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a2d8d3a504dd8807ed09e25f37a658783)

ReductionOp_ ReductionOp

Definition: device/gemm_splitk_parallel.h:470

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::GemmKernel

typename kernel::DefaultGemmSplitKParallel< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, ConvertScaledOp, ThreadblockSwizzle, kStages, Operator >::GemmKernel GemmKernel

GEMM kernel.

Definition: device/gemm_splitk_parallel.h:165

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::reduction](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a63048fa3419753d96a60eaee28f6cfe4)

ReductionOp::Params reduction

Definition: device/gemm_splitk_parallel.h:517

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::LayoutA

typename layout::LayoutTranspose< LayoutB >::type LayoutA

Definition: device/gemm_splitk_parallel.h:127

cutlass::gemm::device::GemmSplitKParallel::operator()

Status operator()(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_splitk_parallel.h:390

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ReductionOp

ReductionOp ReductionOp

Definition: device/gemm_splitk_parallel.h:140

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementB

ElementA ElementB

Definition: device/gemm_splitk_parallel.h:128

cutlass::Status::kErrorInternal

An error within CUTLASS occurred.

device_kernel.h

Template for generic CUTLASS kernel.

[reduce_split_k.h](reduce split k_8h.html)

Kernel performing a reduction over densely packed tensors in global memory.

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d)

typename UnderlyingOperator::Arguments UnderlyingArguments

Definition: device/gemm_splitk_parallel.h:498

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Operator](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a696cd49441ddb490d32a374135731c68)

Operator_ Operator

Definition: device/gemm_splitk_parallel.h:472

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::convert](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a48ced96adaf371f03c1c9a50db9f50f2)

ConvertScaledOp::Params convert

Definition: device/gemm_splitk_parallel.h:516

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::reduction::kernel::ReduceSplitK

Definition: reduce_split_k.h:55

cutlass::reduction::kernel::ReduceSplitK::grid_shape

static CUTLASS_HOST_DEVICE dim3 grid_shape(cutlass::MatrixCoord problem_size)

Computes the grid size given a chosen threadblock shape.

Definition: reduce_split_k.h:128

[default_gemm_configuration.h](default gemm configuration_8h.html)

Definitions for GEMM structures.

cutlass::TensorRef::non_const_ref

CUTLASS_HOST_DEVICE NonConstTensorRef non_const_ref() const

Definition: tensor_ref.h:229

cutlass::gemm::device::GemmSplitKParallel::get_workspace_size

static size_t get_workspace_size(Arguments const &args)

Gets the workspace size.

Definition: device/gemm_splitk_parallel.h:252

cutlass::gemm::device::GemmSplitKParallel

Definition: device/gemm_splitk_parallel.h:123

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::gemm::device::GemmSplitKParallel::Arguments::ref_C

TensorRef< ElementC const, LayoutC > ref_C

Definition: device/gemm_splitk_parallel.h:188

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::initialize](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9f086305f76d7f885bf032f3d2c7c9)

Status initialize(Arguments const &args, void *workspace)

Initializes GEMM state from arguments.

Definition: device/gemm_splitk_parallel.h:592

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ConvertScaledOp

ConvertScaledOp ConvertScaledOp

Definition: device/gemm_splitk_parallel.h:138

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::can_implement](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a465591fbfde2a9aa6330d9adcbf82bd6)

static Status can_implement(Arguments const &args)

Determines whether the GEMM can execute the given problem.

Definition: device/gemm_splitk_parallel.h:580

cutlass::gemm::device::GemmSplitKParallel::Arguments::Arguments

CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())

Constructs an Arguments structure.

Definition: device/gemm_splitk_parallel.h:205

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementAccumulator

ElementAccumulator ElementAccumulator

Definition: device/gemm_splitk_parallel.h:132

cutlass::Status::kErrorWorkspaceNull

The given workspace is null when it is required to be non-null.

cutlass::Status::kSuccess

Operation was successful.

cutlass::gemm::device::GemmSplitKParallel::Arguments::ref_B

TensorRef< ElementB const, LayoutB > ref_B

Definition: device/gemm_splitk_parallel.h:187

threadblock_swizzle.h

Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...

arch.h

Defines tags for architecture-specific configurations.

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ThreadblockSwizzle

ThreadblockSwizzle ThreadblockSwizzle

Definition: device/gemm_splitk_parallel.h:141

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::split_k_slices](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#aff78ac3c99bb15cf8a7d7a1ece736cd1)

int split_k_slices

Definition: device/gemm_splitk_parallel.h:515

cutlass::gemm::device::GemmSplitKParallel::update

Status update(Arguments const &args, void *workspace=nullptr)

Lightweight update given a subset of arguments.

Definition: device/gemm_splitk_parallel.h:312

cutlass::gemm::device::GemmSplitKParallel::run

Status run(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_splitk_parallel.h:329

cutlass::gemm::device::GemmSplitKParallel::Arguments::epilogue

EpilogueOutputOp::Params epilogue

Definition: device/gemm_splitk_parallel.h:190

[default_gemm_splitk_parallel.h](default gemm splitk__parallel_8h.html)

Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...

cutlass::gemm::device::GemmSplitKParallel::Arguments

Argument structure.

Definition: device/gemm_splitk_parallel.h:179

cutlass::gemm::device::GemmSplitKParallel::LayoutC

LayoutC_ LayoutC

Definition: device/gemm_splitk_parallel.h:131

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a37c45d8dc800de6a631b8a096704559a)

CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())

Constructs an Arguments structure.

Definition: device/gemm_splitk_parallel.h:529

cutlass::gemm::device::GemmSplitKParallel::initialize

Status initialize(Arguments const &args, void *workspace)

Initializes GEMM state from arguments.

Definition: device/gemm_splitk_parallel.h:266

cutlass::gemm::device::GemmSplitKParallel::Arguments::ref_A

TensorRef< ElementA const, LayoutA > ref_A

Definition: device/gemm_splitk_parallel.h:186

cutlass.h

Basic include for CUTLASS.

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#acf6c5b216c0c82f0c7797627d651743f)

CUTLASS_HOST_DEVICE Arguments()

Default ctor.

Definition: device/gemm_splitk_parallel.h:525

[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a535863339cab9879474e31f2fd543804)

static UnderlyingArguments to_underlying_arguments(Arguments const &args)

Helper to construct a transposed equivalent for the underying GEMM operator.

Definition: device/gemm_splitk_parallel.h:565

cutlass::Status

Status

Status code returned by CUTLASS operations.

Definition: cutlass.h:39

gemm.h

Template for a pipelined GEMM kernel. Does not compute batching or support split-K.

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementA

ElementB ElementA

Definition: device/gemm_splitk_parallel.h:126

cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::OperatorClass

OperatorClass OperatorClass

Definition: device/gemm_splitk_parallel.h:133


Generated by 1.8.11