Back to Cutlass

CUTLASS: gemm.h Source File

docs/include_2cutlass_2gemm_2device_2gemm_8h_source.html

4.4.276.9 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

include/cutlass/gemm/device/gemm.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/numeric_types.h"

33 #include "cutlass/arch/arch.h"

34 #include "cutlass/device_kernel.h"

35

36 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"

37 #include "cutlass/gemm/kernel/gemm.h"

38

39 #include "cutlass/gemm/kernel/default_gemm.h"

40 #include "[cutlass/gemm/device/default_gemm_configuration.h](default gemm configuration_8h.html)"

41

43

44 namespace cutlass {

45 namespace gemm {

46 namespace device {

47

49

113

116

119

122

125

128

131

134

137

140

143

146

149

152

155

159 template <

161typename ElementA_,

163typename LayoutA_,

165typename ElementB_,

167typename LayoutB_,

169typename ElementC_,

171typename LayoutC_,

173typename ElementAccumulator_ = ElementC_,

175typename OperatorClass_ = arch::OpClassSimt,

177typename ArchTag_ = arch::Sm70,

179typename ThreadblockShape_ = typename DefaultGemmConfiguration<

180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

181 ElementAccumulator_>::ThreadblockShape,

183typename WarpShape_ = typename DefaultGemmConfiguration<

184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

185 ElementAccumulator_>::WarpShape,

187typename InstructionShape_ = typename DefaultGemmConfiguration<

188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

189 ElementAccumulator_>::InstructionShape,

191typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<

192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

193 ElementAccumulator_>::EpilogueOutputOp,

195typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,

197int Stages =

198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

199 ElementC_, ElementAccumulator_>::kStages,

201int AlignmentA =

202 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

203 ElementC_, ElementAccumulator_>::kAlignmentA,

205int AlignmentB =

206 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

207 ElementC_, ElementAccumulator_>::kAlignmentB,

209bool SplitKSerial = false,

211typename Operator_ = typename DefaultGemmConfiguration<

212 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

213 ElementAccumulator_>::Operator,

215bool IsBetaZero = false>

216 class Gemm {

217public:

218

219using ElementA = ElementA_;

220using LayoutA = LayoutA_;

221using TensorRefA = TensorRef<ElementA const, LayoutA>;

222using ElementB = ElementB_;

223using LayoutB = LayoutB_;

224using TensorRefB = TensorRef<ElementB const, LayoutB>;

225using ElementC = ElementC_;

226using LayoutC = LayoutC_;

227using TensorRefC = TensorRef<ElementC const, LayoutC>;

228using TensorRefD = TensorRef<ElementC, LayoutC>;

229using ElementAccumulator = ElementAccumulator_;

230using OperatorClass = OperatorClass_;

231using ArchTag = ArchTag_;

232using ThreadblockShape = ThreadblockShape_;

233using WarpShape = WarpShape_;

234using InstructionShape = InstructionShape_;

235using EpilogueOutputOp = EpilogueOutputOp_;

236using ThreadblockSwizzle = ThreadblockSwizzle_;

237using Operator = Operator_;

238static int const kStages = Stages;

239static int const kAlignmentA = AlignmentA;

240static int const kAlignmentB = AlignmentB;

241static int const kAlignmentC = EpilogueOutputOp::kCount;

242static bool const kSplitKSerial = SplitKSerial;

243static bool const kIsBetaZero = IsBetaZero;

244

246using GemmKernel = typename kernel::DefaultGemm<

247ElementA,

248LayoutA,

249kAlignmentA,

250ElementB,

251LayoutB,

252kAlignmentB,

253ElementC,

254LayoutC,

255ElementAccumulator,

256OperatorClass,

257ArchTag,

258ThreadblockShape,

259WarpShape,

260InstructionShape,

261EpilogueOutputOp,

262ThreadblockSwizzle,

263kStages,

264kSplitKSerial,

265Operator,

266 kIsBetaZero

267 >::GemmKernel;

268

270struct Arguments {

271

272//

273// Data members

274//

275

276GemmCoord problem_size;

277TensorRef<ElementA const, LayoutA> ref_A;

278TensorRef<ElementB const, LayoutB> ref_B;

279TensorRef<ElementC const, LayoutC> ref_C;

280TensorRef<ElementC, LayoutC> ref_D;

281typename EpilogueOutputOp::Params epilogue;

282int split_k_slices;

283

284//

285// Methods

286//

287

289CUTLASS_HOST_DEVICE

290Arguments(): problem_size(0, 0, 0), split_k_slices(1) {

291

292 }

293

295CUTLASS_HOST_DEVICE

296Arguments(

297GemmCoord problem_size_,

298TensorRef<ElementA const, LayoutA> ref_A_,

299TensorRef<ElementB const, LayoutB> ref_B_,

300TensorRef<ElementC const, LayoutC> ref_C_,

301TensorRef<ElementC, LayoutC> ref_D_,

302typename EpilogueOutputOp::Params epilogue_ =

303typename EpilogueOutputOp::Params(),

304int split_k_slices = 1

305 ):

306 problem_size(problem_size_),

307 ref_A(ref_A_),

308 ref_B(ref_B_),

309 ref_C(ref_C_),

310 ref_D(ref_D_),

311 epilogue(epilogue_),

312 split_k_slices(split_k_slices) {

313

314 }

315 };

316

317 private:

318

320typename GemmKernel::Params params_;

321

322 public:

323

325Gemm() { }

326

328static Status can_implement(Arguments const &args) {

329

330if (!kSplitKSerial && args.split_k_slices > 1) {

331return Status::kErrorInvalidProblem;

332 }

333

334Status status = GemmKernel::can_implement(

335 args.problem_size,

336 args.ref_A.non_const_ref(),

337 args.ref_B.non_const_ref(),

338 args.ref_C.non_const_ref(),

339 args.ref_D

340 );

341

342if (status != Status::kSuccess) {

343return status;

344 }

345

346return Status::kSuccess;

347 }

348

350static size_t get_workspace_size(Arguments const &args) {

351

352if (kSplitKSerial && args.split_k_slices > 1) {

353

354// Determine grid shape

355 ThreadblockSwizzle threadblock_swizzle;

356

357cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(

358 args.problem_size,

359 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},

360 args.split_k_slices);

361

362return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());

363 }

364

365return 0;

366 }

367

369Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

370

371// Determine grid shape

372 ThreadblockSwizzle threadblock_swizzle;

373

374cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(

375 args.problem_size,

376 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},

377 args.split_k_slices);

378

379if (kSplitKSerial) {

380if (args.split_k_slices > 1) {

381if (!workspace) {

382return Status::kErrorWorkspaceNull;

383 }

384

385size_t bytes = get_workspace_size(args);

386

387 cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);

388

389if (result != cudaSuccess) {

390return Status::kErrorInternal;

391 }

392 }

393 }

394else {

395

396if (args.split_k_slices > 1) {

397return Status::kErrorInvalidProblem;

398 }

399 }

400

401// Initialize the Params structure

402 params_ = typename GemmKernel::Params{

403 args.problem_size,

404 grid_shape,

405 args.ref_A.non_const_ref(),

406 args.ref_B.non_const_ref(),

407 args.ref_C.non_const_ref(),

408 args.ref_D,

409 args.epilogue,

410static_cast<int *>(workspace)

411 };

412

413return Status::kSuccess;

414 }

415

417Status update(Arguments const &args, void *workspace = nullptr) {

418

419if (kSplitKSerial && args.split_k_slices > 1) {

420if (!workspace) {

421return Status::kErrorWorkspaceNull;

422 }

423 }

424

425 params_.ref_A.reset(args.ref_A.non_const_ref().data());

426 params_.ref_B.reset(args.ref_B.non_const_ref().data());

427 params_.ref_C.reset(args.ref_C.non_const_ref().data());

428 params_.ref_D.reset(args.ref_D.data());

429 params_.semaphore = static_cast<int *>(workspace);

430

431return Status::kSuccess;

432 }

433

435Status run(cudaStream_t stream = nullptr) {

436

437 ThreadblockSwizzle threadblock_swizzle;

438

439 dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);

440 dim3 block(GemmKernel::kThreadCount, 1, 1);

441

442 cudaError_t result;

443

444int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

445if (smem_size >= (48 << 10)) {

446 result = cudaFuncSetAttribute(Kernel<GemmKernel>,

447 cudaFuncAttributeMaxDynamicSharedMemorySize,

448 smem_size);

449

450if (result != cudaSuccess) {

451return Status::kErrorInternal;

452 }

453

454 result = cudaFuncSetAttribute(

455 Kernel<GemmKernel>,

456 cudaFuncAttributePreferredSharedMemoryCarveout, 100);

457

458if (result != cudaSuccess) {

459return Status::kErrorInternal;

460 }

461 }

462

463 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

464

465 result = cudaGetLastError();

466

467return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;

468 }

469

471Status operator()(cudaStream_t stream = nullptr) {

472return run(stream);

473 }

474

476Status operator()(

477Arguments const &args,

478void *workspace = nullptr,

479 cudaStream_t stream = nullptr) {

480

481Status status = initialize(args, workspace);

482

483if (status == Status::kSuccess) {

484 status = run(stream);

485 }

486

487return status;

488 }

489 };

490

492

494 template <

496typename ElementA_,

498typename LayoutA_,

500typename ElementB_,

502typename LayoutB_,

504typename ElementC_,

506typename ElementAccumulator_,

508typename OperatorClass_,

510typename ArchTag_,

512typename ThreadblockShape_,

514typename WarpShape_,

516typename InstructionShape_,

518typename EpilogueOutputOp_,

520typename ThreadblockSwizzle_,

522int Stages,

524int AlignmentA,

526int AlignmentB,

528bool SplitKSerial,

530typename Operator_,

532bool IsBetaZero>

[533](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html) class Gemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,

534 layout::ColumnMajor, // partially specialized on LayoutC

535 ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,

536 WarpShape_, InstructionShape_, EpilogueOutputOp_,

537 ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial,

538 Operator_, IsBetaZero> {

539public:

540

[541](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a09db4f8f255d272e7350394d568f4a01)using ElementA = ElementA_;

[542](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5212eb5b3af32e5bc43cc4179bb346ef)using LayoutA = LayoutA_;

[543](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a82da12cbdf6f75499d315ee530f5330e)using TensorRefA = TensorRef<ElementA const, LayoutA>;

[544](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a0f904b72a3ff91f7ae6ad1a91e915b6d)using ElementB = ElementB_;

[545](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1841e0e97e59862c7a92fc8d2ab7c9bc)using LayoutB = LayoutB_;

[546](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ab4bd8a2bb7be0fa2b583cf34b63b62eb)using TensorRefB = TensorRef<ElementB const, LayoutB>;

[547](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ad58a37fecfeb982d20fc209a0df4c1fa)using ElementC = ElementC_;

[548](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#afe4685fea6a4603a7459bbe9923c9cb3)using LayoutC = layout::ColumnMajor;

[549](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aff1ad6d93937a9e4b261eb69322449e7)using TensorRefC = TensorRef<ElementC const, LayoutC>;

[550](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abaa02d78437ae0f42260848d722c134f)using TensorRefD = TensorRef<ElementC, LayoutC>;

[551](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#acbc61142b95f4d33bc0b8518857ab7be)using ElementAccumulator = ElementAccumulator_;

[552](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa816137b589b8bf2204ace73d49b7ded)using OperatorClass = OperatorClass_;

[553](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1b502a4097e745c12d0d628d080ba447)using ArchTag = ArchTag_;

[554](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ab798f409ba80eab4a0140fdf43e768ee)using ThreadblockShape = ThreadblockShape_;

[555](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abdc15293a8b083372e5395049440d01c)using WarpShape = WarpShape_;

[556](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a71b84f983b94b50a48bd0890f1e0ed59)using InstructionShape = InstructionShape_;

[557](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ad1d60f7381ae03803a078a26604bd8be)using EpilogueOutputOp = EpilogueOutputOp_;

[558](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#af55d56acaa01ce303c22d6e9e0b0f895)using ThreadblockSwizzle = ThreadblockSwizzle_;

[559](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a029f48f17ec3fb98067bfacd7e06f3d2)using Operator = Operator_;

[560](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5de0cfa9c3831daebbdc8326c239dd33)static int const kStages = Stages;

[561](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a78660ed036162b8455546ff5718968d0)static int const kAlignmentA = AlignmentA;

[562](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa39221ab9fa4248c613b7222f764072e)static int const kAlignmentB = AlignmentB;

[563](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a0b609010f97cb53cf4d8f1ecb4bb0b79)static bool const kSplitKSerial = SplitKSerial;

[564](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a98d1d07f32f29b29e883775fcd276833)static bool const kIsBetaZero = IsBetaZero;

565

566using UnderlyingOperator = Gemm<

567ElementB,

568typename layout::LayoutTranspose<LayoutB>::type,

569ElementA,

570typename layout::LayoutTranspose<LayoutA>::type,

571ElementC,

572layout::RowMajor,

573ElementAccumulator,

574OperatorClass,

575ArchTag,

576ThreadblockShape,

577WarpShape,

578InstructionShape,

579EpilogueOutputOp,

580ThreadblockSwizzle,

581 Stages,

582kAlignmentB,

583kAlignmentA,

584 SplitKSerial,

585Operator,

586 kIsBetaZero

[587](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa48a0b3645f2ef103cb1b2d41218d865) >;

588

[589](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a)using [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a) = typename UnderlyingOperator::Arguments;

[590](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2bdbc5e737f9bfd1e09a7cfb30e60e29)using [GemmKernel](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2bdbc5e737f9bfd1e09a7cfb30e60e29) = typename UnderlyingOperator::GemmKernel;

[591](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ad74a049e26f4b9224362b4d1c93ca14b)static int const kAlignmentC = UnderlyingOperator::kAlignmentC;

592

[594](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html)struct Arguments {

595

596//

597// Data members

598//

599

[600](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#acd02e86dfff866eade08415e0043ccc3)GemmCoord [problem_size](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#acd02e86dfff866eade08415e0043ccc3);

[601](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a9bdaf3563983efcca649460be169b334)TensorRef<ElementA const, LayoutA> [ref_A](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a9bdaf3563983efcca649460be169b334);

[602](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab77204c1010b17c6643d26a89f41c3d0)TensorRef<ElementB const, LayoutB> [ref_B](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab77204c1010b17c6643d26a89f41c3d0);

[603](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a590b8da88ae9350042838451e3e37a22)TensorRef<ElementC const, LayoutC> [ref_C](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a590b8da88ae9350042838451e3e37a22);

[604](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab1d4d5865786a415f87db1def1b029e7)TensorRef<ElementC, LayoutC> [ref_D](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab1d4d5865786a415f87db1def1b029e7);

[605](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a426f402c08be99849a4477a07f010a5e)typename EpilogueOutputOp::Params [epilogue](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a426f402c08be99849a4477a07f010a5e);

[606](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#aaef8450711318fa1a53fe3cb72b59263)int [split_k_slices](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#aaef8450711318fa1a53fe3cb72b59263);

607

608//

609// Methods

610//

611

613CUTLASS_HOST_DEVICE

[614](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ac6c397a181a52c0dbb39bf3710ee4658)[Arguments](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ac6c397a181a52c0dbb39bf3710ee4658)() { }

615

617CUTLASS_HOST_DEVICE

[618](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a331de1adfdcbea6d0137afe64a4f6f4c)[Arguments](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a331de1adfdcbea6d0137afe64a4f6f4c)(

619GemmCoord problem_size_,

620TensorRef<ElementA const, LayoutA> ref_A_,

621TensorRef<ElementB const, LayoutB> ref_B_,

622TensorRef<ElementC const, LayoutC> ref_C_,

623TensorRef<ElementC, LayoutC> ref_D_,

624typename EpilogueOutputOp::Params epilogue_ =

625typename EpilogueOutputOp::Params(),

626int split_k_slices = 1

627 ):

628 problem_size(problem_size_),

629 ref_A(ref_A_),

630 ref_B(ref_B_),

631 ref_C(ref_C_),

632 ref_D(ref_D_),

633 epilogue(epilogue_),

634 split_k_slices(split_k_slices) { }

635 };

636

637 private:

638

639UnderlyingOperator underlying_operator_;

640

641 public:

642

[644](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abcacf502806db50eb17a6d925aee16d5)[Gemm](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abcacf502806db50eb17a6d925aee16d5)() { }

645

[647](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa9313915a6129f0c43b43ef3698b3ee4)static [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a) [to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa9313915a6129f0c43b43ef3698b3ee4)(Arguments const &args) {

648return [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a)(

649 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},

650 {args.ref_B.data(), args.ref_B.stride(0)},

651 {args.ref_A.data(), args.ref_A.stride(0)},

652 {args.ref_C.data(), args.ref_C.stride(0)},

653 {args.ref_D.data(), args.ref_D.stride(0)},

654 args.epilogue,

655 args.split_k_slices

656 );

657 }

658

[660](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a662bcbcb6164c803ab490c86e69b9ee1)static Status [can_implement](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a662bcbcb6164c803ab490c86e69b9ee1)(Arguments const &args) {

661

662return UnderlyingOperator::can_implement(to_underlying_arguments(args));

663 }

664

[666](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1469133c30fde6b28296e3ff6951e7a4)static size_t [get_workspace_size](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1469133c30fde6b28296e3ff6951e7a4)(Arguments const &args) {

667

668return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));

669 }

670

[672](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7a14474e4238d2fac92ad71c6de087d8)Status [initialize](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7a14474e4238d2fac92ad71c6de087d8)(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

673

674return underlying_operator_.initialize(to_underlying_arguments(args), workspace);

675 }

676

[678](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2b6c5275c173d73cffe8e6b6b1ccf2c1)Status [update](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2b6c5275c173d73cffe8e6b6b1ccf2c1)(Arguments const &args, void *workspace = nullptr) {

679

680return underlying_operator_.update(to_underlying_arguments(args), workspace);

681 }

682

[684](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5f4f93ca97b358b4410f3d0b1e0a6387)Status [run](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5f4f93ca97b358b4410f3d0b1e0a6387)(cudaStream_t stream = nullptr) {

685

686return underlying_operator_.run(stream);

687 }

688

[690](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a384db4125183e504fafc5a946b7ba757)Status [operator()](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a384db4125183e504fafc5a946b7ba757)(cudaStream_t stream = nullptr) {

691return run(stream);

692 }

693

[695](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a6115aa957b3ba8ad9e54b7efeefaacd1)Status [operator()](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a6115aa957b3ba8ad9e54b7efeefaacd1)(

696Arguments const &args,

697void *workspace = nullptr,

698 cudaStream_t stream = nullptr) {

699

700Status status = initialize(args, workspace);

701

702if (status == Status::kSuccess) {

703 status = run(stream);

704 }

705

706return status;

707 }

708 };

709

711

712 } // namespace device

713 } // namespace gemm

714 } // namespace cutlass

715

cutlass::gemm::kernel::DefaultGemm

Definition: default_gemm.h:116

cutlass::gemm::device::Gemm::kStages

static int const kStages

Definition: include/cutlass/gemm/device/gemm.h:238

cutlass::gemm::device::Gemm::Arguments::problem_size

GemmCoord problem_size

Definition: include/cutlass/gemm/device/gemm.h:276

cutlass

Definition: aligned_buffer.h:35

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::split_k_slices](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#aaef8450711318fa1a53fe3cb72b59263)

int split_k_slices

Definition: include/cutlass/gemm/device/gemm.h:606

cutlass::Status::kErrorInvalidProblem

Specified problem size is not supported by operator.

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ElementA

ElementB ElementA

Definition: include/cutlass/gemm/device/gemm.h:219

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_A](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a9bdaf3563983efcca649460be169b334)

TensorRef< ElementA const, LayoutA > ref_A

Definition: include/cutlass/gemm/device/gemm.h:601

cutlass::gemm::device::Gemm::get_workspace_size

static size_t get_workspace_size(Arguments const &args)

Gets the workspace size.

Definition: include/cutlass/gemm/device/gemm.h:350

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ThreadblockSwizzle

ThreadblockSwizzle ThreadblockSwizzle

Definition: include/cutlass/gemm/device/gemm.h:236

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ac6c397a181a52c0dbb39bf3710ee4658)

CUTLASS_HOST_DEVICE Arguments()

Default ctor.

Definition: include/cutlass/gemm/device/gemm.h:614

cutlass::gemm::device::Gemm::can_implement

static Status can_implement(Arguments const &args)

Determines whether the GEMM can execute the given problem.

Definition: include/cutlass/gemm/device/gemm.h:328

cutlass::gemm::device::Gemm::Arguments::Arguments

CUTLASS_HOST_DEVICE Arguments()

Default ctor.

Definition: include/cutlass/gemm/device/gemm.h:290

cutlass::gemm::device::Gemm::Arguments::Arguments

CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)

Constructs an Arguments structure.

Definition: include/cutlass/gemm/device/gemm.h:296

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::get_workspace_size](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1469133c30fde6b28296e3ff6951e7a4)

static size_t get_workspace_size(Arguments const &args)

Gets the workspace size.

Definition: include/cutlass/gemm/device/gemm.h:666

cutlass::gemm::device::Gemm

Definition: include/cutlass/gemm/device/gemm.h:216

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::gemm::device::Gemm::kSplitKSerial

static bool const kSplitKSerial

Definition: include/cutlass/gemm/device/gemm.h:242

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::update](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2b6c5275c173d73cffe8e6b6b1ccf2c1)

Status update(Arguments const &args, void *workspace=nullptr)

Lightweight update given a subset of arguments.

Definition: include/cutlass/gemm/device/gemm.h:678

cutlass::gemm::device::Gemm::Arguments::epilogue

EpilogueOutputOp::Params epilogue

Definition: include/cutlass/gemm/device/gemm.h:281

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::epilogue](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a426f402c08be99849a4477a07f010a5e)

EpilogueOutputOp::Params epilogue

Definition: include/cutlass/gemm/device/gemm.h:605

cutlass::gemm::device::Gemm::Arguments::ref_A

TensorRef< ElementA const, LayoutA > ref_A

Definition: include/cutlass/gemm/device/gemm.h:277

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Gemm](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abcacf502806db50eb17a6d925aee16d5)

Gemm()

Constructs the GEMM.

Definition: include/cutlass/gemm/device/gemm.h:644

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::can_implement](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a662bcbcb6164c803ab490c86e69b9ee1)

static Status can_implement(Arguments const &args)

Determines whether the GEMM can execute the given problem.

Definition: include/cutlass/gemm/device/gemm.h:660

default_gemm.h

Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::operator()](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a384db4125183e504fafc5a946b7ba757)

Status operator()(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm.h:690

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::InstructionShape

InstructionShape InstructionShape

Definition: include/cutlass/gemm/device/gemm.h:234

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::Operator

Operator Operator

Definition: include/cutlass/gemm/device/gemm.h:237

cutlass::gemm::device::Gemm::update

Status update(Arguments const &args, void *workspace=nullptr)

Lightweight update given a subset of arguments.

Definition: include/cutlass/gemm/device/gemm.h:417

cutlass::gemm::device::Gemm::Arguments::split_k_slices

int split_k_slices

Definition: include/cutlass/gemm/device/gemm.h:282

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ElementC

ElementC ElementC

Definition: include/cutlass/gemm/device/gemm.h:225

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::run](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5f4f93ca97b358b4410f3d0b1e0a6387)

Status run(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm.h:684

cutlass::gemm::device::Gemm::operator()

Status operator()(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm.h:471

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::OperatorClass

OperatorClass OperatorClass

Definition: include/cutlass/gemm/device/gemm.h:230

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a)

typename UnderlyingOperator::Arguments UnderlyingArguments

Definition: include/cutlass/gemm/device/gemm.h:589

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ElementAccumulator

ElementAccumulator ElementAccumulator

Definition: include/cutlass/gemm/device/gemm.h:229

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::GemmKernel

typename kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kSplitKSerial, Operator, kIsBetaZero >::GemmKernel GemmKernel

Define the kernel.

Definition: include/cutlass/gemm/device/gemm.h:267

cutlass::gemm::device::Gemm::kAlignmentB

static int const kAlignmentB

Definition: include/cutlass/gemm/device/gemm.h:240

cutlass::layout::LayoutTranspose

Defines transposes of matrix layouts.

Definition: layout/matrix.h:921

cutlass::gemm::device::Gemm::kAlignmentA

static int const kAlignmentA

Definition: include/cutlass/gemm/device/gemm.h:239

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ThreadblockShape

ThreadblockShape ThreadblockShape

Definition: include/cutlass/gemm/device/gemm.h:232

cutlass::gemm::device::Gemm::Arguments::ref_C

TensorRef< ElementC const, LayoutC > ref_C

Definition: include/cutlass/gemm/device/gemm.h:279

cutlass::TensorRef< ElementA const, LayoutA >

cutlass::gemm::device::Gemm::Arguments::ref_B

TensorRef< ElementB const, LayoutB > ref_B

Definition: include/cutlass/gemm/device/gemm.h:278

cutlass::Status::kErrorInternal

An error within CUTLASS occurred.

device_kernel.h

Template for generic CUTLASS kernel.

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::problem_size](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#acd02e86dfff866eade08415e0043ccc3)

GemmCoord problem_size

Definition: include/cutlass/gemm/device/gemm.h:600

cutlass::gemm::device::Gemm::operator()

Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm.h:476

cutlass::gemm::device::Gemm::Gemm

Gemm()

Constructs the GEMM.

Definition: include/cutlass/gemm/device/gemm.h:325

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a331de1adfdcbea6d0137afe64a4f6f4c)

CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)

Constructs an Arguments structure.

Definition: include/cutlass/gemm/device/gemm.h:618

cutlass::gemm::device::Gemm::LayoutC

LayoutC_ LayoutC

Definition: include/cutlass/gemm/device/gemm.h:226

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::LayoutB

typename layout::LayoutTranspose< LayoutA >::type LayoutB

Definition: include/cutlass/gemm/device/gemm.h:223

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_C](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a590b8da88ae9350042838451e3e37a22)

TensorRef< ElementC const, LayoutC > ref_C

Definition: include/cutlass/gemm/device/gemm.h:603

cutlass::gemm::device::Gemm::Arguments

Argument structure.

Definition: include/cutlass/gemm/device/gemm.h:270

[default_gemm_configuration.h](default gemm configuration_8h.html)

Definitions for GEMM structures.

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::initialize](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7a14474e4238d2fac92ad71c6de087d8)

Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Initializes GEMM state from arguments.

Definition: include/cutlass/gemm/device/gemm.h:672

cutlass::gemm::device::Gemm::run

Status run(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm.h:435

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::EpilogueOutputOp

EpilogueOutputOp EpilogueOutputOp

Definition: include/cutlass/gemm/device/gemm.h:235

cutlass::gemm::device::Gemm::initialize

Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Initializes GEMM state from arguments.

Definition: include/cutlass/gemm/device/gemm.h:369

cutlass::Status::kErrorWorkspaceNull

The given workspace is null when it is required to be non-null.

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::WarpShape

WarpShape WarpShape

Definition: include/cutlass/gemm/device/gemm.h:233

cutlass::Status::kSuccess

Operation was successful.

cutlass::gemm::device::Gemm::kAlignmentC

static int const kAlignmentC

Definition: include/cutlass/gemm/device/gemm.h:241

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa9313915a6129f0c43b43ef3698b3ee4)

static UnderlyingArguments to_underlying_arguments(Arguments const &args)

Helper to construct a transposed equivalent for the underying GEMM operator.

Definition: include/cutlass/gemm/device/gemm.h:647

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ElementB

ElementA ElementB

Definition: include/cutlass/gemm/device/gemm.h:222

threadblock_swizzle.h

Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...

arch.h

Defines tags for architecture-specific configurations.

cutlass::gemm::device::Gemm::Arguments::ref_D

TensorRef< ElementC, LayoutC > ref_D

Definition: include/cutlass/gemm/device/gemm.h:280

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_D](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab1d4d5865786a415f87db1def1b029e7)

TensorRef< ElementC, LayoutC > ref_D

Definition: include/cutlass/gemm/device/gemm.h:604

cutlass::gemm::device::Gemm::kIsBetaZero

static bool const kIsBetaZero

Definition: include/cutlass/gemm/device/gemm.h:243

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::GemmKernel](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2bdbc5e737f9bfd1e09a7cfb30e60e29)

typename UnderlyingOperator::GemmKernel GemmKernel

Definition: include/cutlass/gemm/device/gemm.h:590

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ArchTag

ArchTag ArchTag

Definition: include/cutlass/gemm/device/gemm.h:231

cutlass.h

Basic include for CUTLASS.

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::operator()](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a6115aa957b3ba8ad9e54b7efeefaacd1)

Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm.h:695

cutlass::Status

Status

Status code returned by CUTLASS operations.

Definition: cutlass.h:39

gemm.h

Template for a pipelined GEMM kernel. Does not compute batching or support split-K.

[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_B](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab77204c1010b17c6643d26a89f41c3d0)

TensorRef< ElementB const, LayoutB > ref_B

Definition: include/cutlass/gemm/device/gemm.h:602

cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::LayoutA

typename layout::LayoutTranspose< LayoutB >::type LayoutA

Definition: include/cutlass/gemm/device/gemm.h:220


Generated by 1.8.11