Back to Cutlass

CUTLASS: gemm_batched.h Source File

docs/device_2gemm__batched_8h_source.html

4.4.293.8 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

device/gemm_batched.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/numeric_types.h"

33 #include "cutlass/arch/arch.h"

34 #include "cutlass/device_kernel.h"

35

36 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"

37 #include "cutlass/gemm/kernel/gemm_batched.h"

38

39 #include "cutlass/gemm/kernel/default_gemm.h"

40 #include "[cutlass/gemm/device/default_gemm_configuration.h](default gemm configuration_8h.html)"

41

43

44 namespace cutlass {

45 namespace gemm {

46 namespace device {

47

49

113

116

119

122

125

128

131

134

137

140

143

146

149

152

155

159 template <

161typename ElementA_,

163typename LayoutA_,

165typename ElementB_,

167typename LayoutB_,

169typename ElementC_,

171typename LayoutC_,

173typename ElementAccumulator_ = ElementC_,

175typename OperatorClass_ = arch::OpClassSimt,

177typename ArchTag_ = arch::Sm70,

179typename ThreadblockShape_ = typename DefaultGemmConfiguration<

180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

181 ElementAccumulator_>::ThreadblockShape,

183typename WarpShape_ = typename DefaultGemmConfiguration<

184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

185 ElementAccumulator_>::WarpShape,

187typename InstructionShape_ = typename DefaultGemmConfiguration<

188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

189 ElementAccumulator_>::InstructionShape,

191typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<

192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

193 ElementAccumulator_>::EpilogueOutputOp,

195typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle,

197int Stages =

198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

199 ElementC_, ElementAccumulator_>::kStages,

201int AlignmentA =

202 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

203 ElementC_, ElementAccumulator_>::kAlignmentA,

205int AlignmentB =

206 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

207 ElementC_, ElementAccumulator_>::kAlignmentB,

209typename Operator_ = typename DefaultGemmConfiguration<

210 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

211 ElementAccumulator_>::Operator

212 >

213 class GemmBatched {

214public:

215

216using ElementA = ElementA_;

217using LayoutA = LayoutA_;

218using TensorRefA = TensorRef<ElementA const, LayoutA>;

219using ElementB = ElementB_;

220using LayoutB = LayoutB_;

221using TensorRefB = TensorRef<ElementB const, LayoutB>;

222using ElementC = ElementC_;

223using LayoutC = LayoutC_;

224using TensorRefC = TensorRef<ElementC const, LayoutC>;

225using TensorRefD = TensorRef<ElementC, LayoutC>;

226using ElementAccumulator = ElementAccumulator_;

227using OperatorClass = OperatorClass_;

228using ArchTag = ArchTag_;

229using ThreadblockShape = ThreadblockShape_;

230using WarpShape = WarpShape_;

231using InstructionShape = InstructionShape_;

232using EpilogueOutputOp = EpilogueOutputOp_;

233using ThreadblockSwizzle = ThreadblockSwizzle_;

234static int const kStages = Stages;

235static int const kAlignmentA = AlignmentA;

236static int const kAlignmentB = AlignmentB;

237static int const kAlignmentC = EpilogueOutputOp::kCount;

238using Operator = Operator_;

239

241using DefaultGemmKernel = typename kernel::DefaultGemm<

242ElementA,

243LayoutA,

244kAlignmentA,

245ElementB,

246LayoutB,

247kAlignmentB,

248ElementC,

249LayoutC,

250ElementAccumulator,

251OperatorClass,

252ArchTag,

253ThreadblockShape,

254WarpShape,

255InstructionShape,

256EpilogueOutputOp,

257ThreadblockSwizzle,

258kStages,

259false,

260Operator,

261false

262 >::GemmKernel;

263

264using GemmKernel = kernel::GemmBatched<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle>;

265

267struct Arguments {

268

269//

270// Data members

271//

272

273GemmCoord problem_size;

274TensorRef<ElementA const, LayoutA> ref_A;

275 int64_t stride_A;

276TensorRef<ElementB const, LayoutB> ref_B;

277 int64_t stride_B;

278TensorRef<ElementC const, LayoutC> ref_C;

279 int64_t stride_C;

280TensorRef<ElementC, LayoutC> ref_D;

281 int64_t stride_D;

282typename EpilogueOutputOp::Params epilogue;

283int batch_count;

284

285//

286// Methods

287//

288

290CUTLASS_HOST_DEVICE

291Arguments() { }

292

294CUTLASS_HOST_DEVICE

295Arguments(

296GemmCoord problem_size_,

297TensorRef<ElementA const, LayoutA> ref_A_,

298 int64_t stride_A_,

299TensorRef<ElementB const, LayoutB> ref_B_,

300 int64_t stride_B_,

301TensorRef<ElementC const, LayoutC> ref_C_,

302 int64_t stride_C_,

303TensorRef<ElementC, LayoutC> ref_D_,

304 int64_t stride_D_,

305typename EpilogueOutputOp::Params epilogue_,

306int batch_count_

307 ):

308 problem_size(problem_size_),

309 ref_A(ref_A_),

310 stride_A(stride_A_),

311 ref_B(ref_B_),

312 stride_B(stride_B_),

313 ref_C(ref_C_),

314 stride_C(stride_C_),

315 ref_D(ref_D_),

316 stride_D(stride_D_),

317 epilogue(epilogue_),

318 batch_count(batch_count_) { }

319 };

320

321 private:

322

324typename GemmKernel::Params params_;

325

326 public:

327

329GemmBatched() { }

330

332static Status can_implement(Arguments const &args) {

333

334if (! TensorRef_aligned(args.ref_A, kAlignmentA) || (args.stride_A % kAlignmentA)) {

335return Status::kErrorMisalignedOperand;

336 }

337

338if (! TensorRef_aligned(args.ref_B, kAlignmentB) || (args.stride_B % kAlignmentB)) {

339return Status::kErrorMisalignedOperand;

340 }

341

342if (! TensorRef_aligned(args.ref_C, kAlignmentC) || (args.stride_C % kAlignmentC)) {

343return Status::kErrorMisalignedOperand;

344 }

345

346if (! TensorRef_aligned(args.ref_D, kAlignmentC) || (args.stride_D % kAlignmentC)) {

347return Status::kErrorMisalignedOperand;

348 }

349

350if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||

351 (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||

352 (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {

353

354return Status::kErrorMisalignedOperand;

355 }

356

357return Status::kSuccess;

358 }

359

361static size_t get_workspace_size(Arguments const &args) {

362return 0;

363 }

364

366Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

367

368// Determine grid shape

369 ThreadblockSwizzle threadblock_swizzle;

370

371cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(

372 args.problem_size,

373 args.batch_count,

374 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});

375

376// Initialize the Params structure

377 params_ = typename GemmKernel::Params{

378 args.problem_size,

379 grid_shape,

380 args.ref_A.non_const_ref(),

381 args.stride_A,

382 args.ref_B.non_const_ref(),

383 args.stride_B,

384 args.ref_C.non_const_ref(),

385 args.stride_C,

386 args.ref_D,

387 args.stride_D,

388 args.epilogue,

389 args.batch_count

390 };

391

392return Status::kSuccess;

393 }

394

396Status update(Arguments const &args, void *workspace = nullptr) {

397

398 params_.ref_A.reset(args.ref_A.non_const_ref().data());

399 params_.ref_B.reset(args.ref_B.non_const_ref().data());

400 params_.ref_C.reset(args.ref_C.non_const_ref().data());

401 params_.ref_D.reset(args.ref_D.data());

402

403return Status::kSuccess;

404 }

405

407Status run(cudaStream_t stream = nullptr) {

408

409 ThreadblockSwizzle threadblock_swizzle;

410

411 dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);

412 dim3 block(GemmKernel::kThreadCount, 1, 1);

413

414 cudaError_t result;

415

416int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

417if (smem_size >= (48 << 10)) {

418 result = cudaFuncSetAttribute(Kernel<GemmKernel>,

419 cudaFuncAttributeMaxDynamicSharedMemorySize,

420 smem_size);

421

422if (result != cudaSuccess) {

423return Status::kErrorInternal;

424 }

425

426 result = cudaFuncSetAttribute(

427 Kernel<GemmKernel>,

428 cudaFuncAttributePreferredSharedMemoryCarveout, 100);

429

430if (result != cudaSuccess) {

431return Status::kErrorInternal;

432 }

433 }

434

435 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

436

437 result = cudaGetLastError();

438

439return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;

440 }

441

443Status operator()(cudaStream_t stream = nullptr) {

444return run(stream);

445 }

446

448Status operator()(

449Arguments const &args,

450void *workspace = nullptr,

451 cudaStream_t stream = nullptr) {

452

453Status status = initialize(args, workspace);

454

455if (status == Status::kSuccess) {

456 status = run(stream);

457 }

458

459return status;

460 }

461 };

462

464

466 template <

468typename ElementA_,

470typename LayoutA_,

472typename ElementB_,

474typename LayoutB_,

476typename ElementC_,

478typename ElementAccumulator_,

480typename OperatorClass_,

482typename ArchTag_,

484typename ThreadblockShape_,

486typename WarpShape_,

488typename InstructionShape_,

490typename EpilogueOutputOp_,

492typename ThreadblockSwizzle_,

494int Stages,

496int AlignmentA,

498int AlignmentB,

499typename Operator_

500 >

[501](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html) class GemmBatched<

502 ElementA_,

503 LayoutA_,

504 ElementB_,

505 LayoutB_,

506 ElementC_,

507 layout::ColumnMajor,

508 ElementAccumulator_,

509 OperatorClass_,

510 ArchTag_,

511 ThreadblockShape_,

512 WarpShape_,

513 InstructionShape_,

514 EpilogueOutputOp_,

515 ThreadblockSwizzle_,

516 Stages,

517 AlignmentA,

518 AlignmentB,

519 Operator_

520 > {

521 public:

522

[523](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a52b9261576b5633e901719f7c21d3369)using [ElementA](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a52b9261576b5633e901719f7c21d3369) = ElementA_;

[524](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af623ca54d9554cdfafc09af7a22cdd62)using [LayoutA](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af623ca54d9554cdfafc09af7a22cdd62) = LayoutA_;

[525](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a18266ad32200d3a72aba6e17a6297a3a)using TensorRefA = TensorRef<ElementA const, LayoutA>;

[526](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3fd5c64783f88a7533801fef7d1375ad)using [ElementB](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3fd5c64783f88a7533801fef7d1375ad) = ElementB_;

[527](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a4aaaa6ca0e4b9f983fe37b4105fd058f)using [LayoutB](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a4aaaa6ca0e4b9f983fe37b4105fd058f) = LayoutB_;

[528](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a5595a5e74a0fb536794edf94cd5c7b7f)using TensorRefB = TensorRef<ElementB const, LayoutB>;

[529](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#aef19ab5158e41856723852b3e307cc5d)using [ElementC](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#aef19ab5158e41856723852b3e307cc5d) = ElementC_;

[530](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#aed31a68c08cbfe9bf32d788be3f41679)using LayoutC = layout::ColumnMajor;

[531](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a04e1ec5b0634d45b9ae6811c0ea9f528)using TensorRefC = TensorRef<ElementC const, LayoutC>;

[532](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#acd52c5c939493b3446af9682a2f7793c)using TensorRefD = TensorRef<ElementC, LayoutC>;

[533](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae7f006ea8bc324d31de9dfbebc1b9327)using [ElementAccumulator](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae7f006ea8bc324d31de9dfbebc1b9327) = ElementAccumulator_;

[534](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a37600c0bf3570bc4b21c26b2b64fc54a)using [OperatorClass](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a37600c0bf3570bc4b21c26b2b64fc54a) = OperatorClass_;

[535](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a681b145a9701109f9d72059bb874895b)using [ArchTag](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a681b145a9701109f9d72059bb874895b) = ArchTag_;

[536](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a657e50fb03ea4d16f7b904920d9aa000)using [ThreadblockShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a657e50fb03ea4d16f7b904920d9aa000) = ThreadblockShape_;

[537](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3760f803bd2b31b3fdf47741caa950fa)using [WarpShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3760f803bd2b31b3fdf47741caa950fa) = WarpShape_;

[538](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae073edad6dd4447d7f99c94f4cd0c1c8)using [InstructionShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae073edad6dd4447d7f99c94f4cd0c1c8) = InstructionShape_;

[539](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a494be150d3b809a4ecf66df682481905)using [EpilogueOutputOp](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a494be150d3b809a4ecf66df682481905) = EpilogueOutputOp_;

[540](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af8b282788223086b80fbb097b22459ec)using [ThreadblockSwizzle](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af8b282788223086b80fbb097b22459ec) = ThreadblockSwizzle_;

[541](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab7f6a87909a3c2d45de71367a0d6eae3)static int const kStages = Stages;

542

[543](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a4b924723475dcef72e0130ce1bb43956)static int const kAlignmentA = AlignmentA;

[544](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a8f5d41976058b08562aa1819687d79a2)static int const kAlignmentB = AlignmentB;

[545](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a79d27ed8dc23cc975f287ec0f041ddf9)static int const kAlignmentC = EpilogueOutputOp::kCount;

[546](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a5a77d26d895197ff5224dac759e05766)static bool const kSplitKSerial = false;

547

548//

549using UnderlyingOperator = GemmBatched<

550ElementB,

551typename layout::LayoutTranspose<LayoutB>::type,

552ElementA,

553typename layout::LayoutTranspose<LayoutA>::type,

554ElementC,

555layout::RowMajor,

556ElementAccumulator,

557OperatorClass,

558ArchTag,

559ThreadblockShape,

560WarpShape,

561InstructionShape,

562EpilogueOutputOp,

563ThreadblockSwizzle,

564 Stages,

565kAlignmentB,

566 kAlignmentA

[567](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a55141da9e85b0c3556e531a2a6c19126) >;

568

[569](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992)using [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992) = typename UnderlyingOperator::Arguments;

[570](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3947c9b192bec2fad631334f31632353)using [GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3947c9b192bec2fad631334f31632353) = typename UnderlyingOperator::GemmKernel;

571

[573](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html)struct Arguments {

574

575//

576// Data members

577//

578

[579](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ad0469cc3e961d21e212d026bccf6fe1a)GemmCoord [problem_size](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ad0469cc3e961d21e212d026bccf6fe1a);

[580](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a1727630fc0525724df28a75ccf2580b9)TensorRef<ElementA const, LayoutA> [ref_A](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a1727630fc0525724df28a75ccf2580b9);

[581](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ac8830c9ed0e0a8bd7aa2aa4382550a2f) int64_t [stride_A](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ac8830c9ed0e0a8bd7aa2aa4382550a2f);

[582](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ad7d2b82b83d7503b9f920ce3bdcdffa5)TensorRef<ElementB const, LayoutB> [ref_B](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ad7d2b82b83d7503b9f920ce3bdcdffa5);

[583](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a302101a4e5c00c843b3c525ddb94c117) int64_t [stride_B](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a302101a4e5c00c843b3c525ddb94c117);

[584](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#aa9e30e41627595590421d8b53941b2b2)TensorRef<ElementC const, LayoutC> [ref_C](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#aa9e30e41627595590421d8b53941b2b2);

[585](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a9f8a044d7b7439192dfe2bf488558ed3) int64_t [stride_C](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a9f8a044d7b7439192dfe2bf488558ed3);

[586](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a17c4e381e91229a8ef15b18ee5ec073d)TensorRef<ElementC, LayoutC> [ref_D](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a17c4e381e91229a8ef15b18ee5ec073d);

[587](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ac181dba327e605b6cde9de5c7f176e7c) int64_t [stride_D](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ac181dba327e605b6cde9de5c7f176e7c);

[588](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#af9c2fa1e0cc0456197c2cc0840c89982)typename EpilogueOutputOp::Params [epilogue](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#af9c2fa1e0cc0456197c2cc0840c89982);

[589](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#adb66f3083f56c15578b139b7935452b5)int [batch_count](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#adb66f3083f56c15578b139b7935452b5);

590

591//

592// Methods

593//

594

596CUTLASS_HOST_DEVICE

[597](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ae86daa985279c77e57e682b64a68d330)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ae86daa985279c77e57e682b64a68d330)() { }

598

600CUTLASS_HOST_DEVICE

[601](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a2129a4dccbd73f8c0f26b08ce5a5cb28)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a2129a4dccbd73f8c0f26b08ce5a5cb28)(

602GemmCoord problem_size_,

603TensorRef<ElementA const, LayoutA> ref_A_,

604 int64_t stride_A_,

605TensorRef<ElementB const, LayoutB> ref_B_,

606 int64_t stride_B_,

607TensorRef<ElementC const, LayoutC> ref_C_,

608 int64_t stride_C_,

609TensorRef<ElementC, LayoutC> ref_D_,

610 int64_t stride_D_,

611typename EpilogueOutputOp::Params epilogue_,

612int batch_count_

613 ):

614 problem_size(problem_size_),

615 ref_A(ref_A_),

616 stride_A(stride_A_),

617 ref_B(ref_B_),

618 stride_B(stride_B_),

619 ref_C(ref_C_),

620 stride_C(stride_C_),

621 ref_D(ref_D_),

622 stride_D(stride_D_),

623 epilogue(epilogue_),

624 batch_count(batch_count_) { }

625 };

626

627 private:

628

629UnderlyingOperator underlying_operator_;

630

631 public:

632

[634](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a75922fd7bcd77fbc714cd87681f692bf)[GemmBatched](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a75922fd7bcd77fbc714cd87681f692bf)() { }

635

[637](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ac4ef1ac1e0876aaee5bff50dc09fe8a9)static [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992) [to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ac4ef1ac1e0876aaee5bff50dc09fe8a9)(Arguments const &args) {

638return [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992)(

639 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},

640 {args.ref_B.data(), args.ref_B.stride(0)},

641 args.stride_B,

642 {args.ref_A.data(), args.ref_A.stride(0)},

643 args.stride_A,

644 {args.ref_C.data(), args.ref_C.stride(0)},

645 args.stride_C,

646 {args.ref_D.data(), args.ref_D.stride(0)},

647 args.stride_D,

648 args.epilogue,

649 args.batch_count

650 );

651 }

652

[654](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abbd82c0f989a9d07e5e222db96386701)static Status [can_implement](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abbd82c0f989a9d07e5e222db96386701)(Arguments const &args) {

655

656return UnderlyingOperator::can_implement(to_underlying_arguments(args));

657 }

658

[660](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3687659e826ba7f38bb060ad6020a739)static size_t [get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3687659e826ba7f38bb060ad6020a739)(Arguments const &args) {

661

662return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));

663 }

664

[666](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a428d8b1c4ac36040145a59d8e4cff3d2)Status [initialize](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a428d8b1c4ac36040145a59d8e4cff3d2)(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

667

668return underlying_operator_.initialize(to_underlying_arguments(args), workspace);

669 }

670

[672](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a9f0c7054068175c1891e4820857603c3)Status [update](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a9f0c7054068175c1891e4820857603c3)(Arguments const &args, void *workspace = nullptr) {

673

674return underlying_operator_.update(to_underlying_arguments(args), workspace);

675 }

676

[678](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abcae3d15f1ec2ee7ae93690c82fbee8a)Status [run](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abcae3d15f1ec2ee7ae93690c82fbee8a)(cudaStream_t stream = nullptr) {

679

680return underlying_operator_.run(stream);

681 }

682

[684](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a00805989734182945f982cab23a5dca8)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a00805989734182945f982cab23a5dca8)(cudaStream_t stream = nullptr) {

685return run(stream);

686 }

687

[689](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a53ca4db66d0d2c96d9036d8eb7c6072b)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a53ca4db66d0d2c96d9036d8eb7c6072b)(

690Arguments const &args,

691void *workspace = nullptr,

692 cudaStream_t stream = nullptr) {

693

694Status status = initialize(args, workspace);

695

696if (status == Status::kSuccess) {

697 status = run(stream);

698 }

699

700return status;

701 }

702

703 };

704

706

707 } // namespace device

708 } // namespace gemm

709 } // namespace cutlass

710

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementC](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#aef19ab5158e41856723852b3e307cc5d)

ElementC_ ElementC

Definition: device/gemm_batched.h:529

cutlass::gemm::kernel::DefaultGemm

Definition: default_gemm.h:116

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ThreadblockSwizzle](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af8b282788223086b80fbb097b22459ec)

ThreadblockSwizzle_ ThreadblockSwizzle

Definition: device/gemm_batched.h:540

cutlass::gemm::device::GemmBatched::kAlignmentB

static int const kAlignmentB

Definition: device/gemm_batched.h:236

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_A](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a1727630fc0525724df28a75ccf2580b9)

TensorRef< ElementA const, LayoutA > ref_A

Definition: device/gemm_batched.h:580

cutlass::gemm::device::GemmBatched::Arguments::ref_D

TensorRef< ElementC, LayoutC > ref_D

Definition: device/gemm_batched.h:280

cutlass::gemm::device::GemmBatched::Arguments::problem_size

GemmCoord problem_size

Definition: device/gemm_batched.h:273

cutlass

Definition: aligned_buffer.h:35

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementB](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3fd5c64783f88a7533801fef7d1375ad)

ElementB_ ElementB

Definition: device/gemm_batched.h:526

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a53ca4db66d0d2c96d9036d8eb7c6072b)

Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_batched.h:689

cutlass::gemm::device::GemmBatched::Arguments::stride_D

int64_t stride_D

Definition: device/gemm_batched.h:281

cutlass::gemm::kernel::GemmBatched::Params::ref_D

Epilogue::OutputTileIterator::TensorRef ref_D

Definition: kernel/gemm_batched.h:74

cutlass::gemm::device::GemmBatched::Arguments::Arguments

CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, int64_t stride_A_, TensorRef< ElementB const, LayoutB > ref_B_, int64_t stride_B_, TensorRef< ElementC const, LayoutC > ref_C_, int64_t stride_C_, TensorRef< ElementC, LayoutC > ref_D_, int64_t stride_D_, typename EpilogueOutputOp::Params epilogue_, int batch_count_)

Constructs an Arguments structure.

Definition: device/gemm_batched.h:295

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::LayoutB](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a4aaaa6ca0e4b9f983fe37b4105fd058f)

LayoutB_ LayoutB

Definition: device/gemm_batched.h:527

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::DefaultGemmKernel

typename kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, false, Operator, false >::GemmKernel DefaultGemmKernel

Define the kernel.

Definition: device/gemm_batched.h:262

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::can_implement](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abbd82c0f989a9d07e5e222db96386701)

static Status can_implement(Arguments const &args)

Determines whether the GEMM can execute the given problem.

Definition: device/gemm_batched.h:654

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ThreadblockShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a657e50fb03ea4d16f7b904920d9aa000)

ThreadblockShape_ ThreadblockShape

Definition: device/gemm_batched.h:536

cutlass::gemm::device::GemmBatched::operator()

Status operator()(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_batched.h:443

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::gemm::device::GemmBatched::Arguments

Argument structure.

Definition: device/gemm_batched.h:267

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::Operator

typename DefaultGemmConfiguration< OperatorClass, ArchTag, ElementB, ElementA, ElementC,ElementAccumulator >::Operator Operator

Definition: device/gemm_batched.h:238

cutlass::gemm::kernel::GemmBatched::Params::ref_B

Mma::IteratorB::TensorRef ref_B

Definition: kernel/gemm_batched.h:68

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_C](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a9f8a044d7b7439192dfe2bf488558ed3)

int64_t stride_C

Definition: device/gemm_batched.h:585

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::OperatorClass](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a37600c0bf3570bc4b21c26b2b64fc54a)

OperatorClass_ OperatorClass

Definition: device/gemm_batched.h:534

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::LayoutB

typename layout::LayoutTranspose< LayoutA >::type LayoutB

Definition: device/gemm_batched.h:220

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ArchTag](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a681b145a9701109f9d72059bb874895b)

ArchTag_ ArchTag

Definition: device/gemm_batched.h:535

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3947c9b192bec2fad631334f31632353)

typename UnderlyingOperator::GemmKernel GemmKernel

Definition: device/gemm_batched.h:570

cutlass::gemm::device::GemmBatched::get_workspace_size

static size_t get_workspace_size(Arguments const &args)

Gets the workspace size.

Definition: device/gemm_batched.h:361

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_D](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ac181dba327e605b6cde9de5c7f176e7c)

int64_t stride_D

Definition: device/gemm_batched.h:587

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::WarpShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3760f803bd2b31b3fdf47741caa950fa)

WarpShape_ WarpShape

Definition: device/gemm_batched.h:537

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::epilogue](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#af9c2fa1e0cc0456197c2cc0840c89982)

EpilogueOutputOp::Params epilogue

Definition: device/gemm_batched.h:588

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::update](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a9f0c7054068175c1891e4820857603c3)

Status update(Arguments const &args, void *workspace=nullptr)

Lightweight update given a subset of arguments.

Definition: device/gemm_batched.h:672

cutlass::gemm::device::GemmBatched::Arguments::stride_A

int64_t stride_A

Definition: device/gemm_batched.h:275

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::InstructionShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae073edad6dd4447d7f99c94f4cd0c1c8)

InstructionShape_ InstructionShape

Definition: device/gemm_batched.h:538

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::initialize](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a428d8b1c4ac36040145a59d8e4cff3d2)

Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Initializes GEMM state from arguments.

Definition: device/gemm_batched.h:666

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::run](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abcae3d15f1ec2ee7ae93690c82fbee8a)

Status run(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_batched.h:678

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementAccumulator](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae7f006ea8bc324d31de9dfbebc1b9327)

ElementAccumulator_ ElementAccumulator

Definition: device/gemm_batched.h:533

cutlass::gemm::kernel::GemmBatched::Params::ref_C

Epilogue::OutputTileIterator::TensorRef ref_C

Definition: kernel/gemm_batched.h:71

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::OperatorClass

OperatorClass OperatorClass

Definition: device/gemm_batched.h:227

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_B](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a302101a4e5c00c843b3c525ddb94c117)

int64_t stride_B

Definition: device/gemm_batched.h:583

default_gemm.h

Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...

cutlass::gemm::device::GemmBatched::kAlignmentC

static int const kAlignmentC

Definition: device/gemm_batched.h:237

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_D](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a17c4e381e91229a8ef15b18ee5ec073d)

TensorRef< ElementC, LayoutC > ref_D

Definition: device/gemm_batched.h:586

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ThreadblockShape

ThreadblockShape ThreadblockShape

Definition: device/gemm_batched.h:229

cutlass::gemm::kernel::GemmBatched::SharedStorage

Shared memory storage structure.

Definition: kernel/gemm_batched.h:124

cutlass::gemm::device::GemmBatched::Arguments::Arguments

CUTLASS_HOST_DEVICE Arguments()

Default ctor.

Definition: device/gemm_batched.h:291

cutlass::gemm::kernel::GemmBatched::Params::grid_tiled_shape

cutlass::gemm::GemmCoord grid_tiled_shape

Definition: kernel/gemm_batched.h:63

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::InstructionShape

InstructionShape InstructionShape

Definition: device/gemm_batched.h:231

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::GemmBatched](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a75922fd7bcd77fbc714cd87681f692bf)

GemmBatched()

Constructs the GEMM.

Definition: device/gemm_batched.h:634

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::problem_size](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ad0469cc3e961d21e212d026bccf6fe1a)

GemmCoord problem_size

Definition: device/gemm_batched.h:579

cutlass::gemm::device::GemmBatched::kStages

static int const kStages

Definition: device/gemm_batched.h:234

cutlass::gemm::device::GemmBatched::update

Status update(Arguments const &args, void *workspace=nullptr)

Lightweight update given a subset of arguments.

Definition: device/gemm_batched.h:396

cutlass::gemm::kernel::GemmBatched::kThreadCount

static int const kThreadCount

Definition: kernel/gemm_batched.h:58

cutlass::gemm::device::GemmBatched::can_implement

static Status can_implement(Arguments const &args)

Determines whether the GEMM can execute the given problem.

Definition: device/gemm_batched.h:332

cutlass::gemm::device::GemmBatched::Arguments::stride_C

int64_t stride_C

Definition: device/gemm_batched.h:279

cutlass::gemm::kernel::GemmBatched::Params

Parameters structure.

Definition: kernel/gemm_batched.h:61

cutlass::layout::LayoutTranspose

Defines transposes of matrix layouts.

Definition: layout/matrix.h:921

cutlass::Status::kErrorMisalignedOperand

operands fail alignment requirements.

cutlass::TensorRef< ElementA const, LayoutA >

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ElementC

ElementC ElementC

Definition: device/gemm_batched.h:222

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992)

typename UnderlyingOperator::Arguments UnderlyingArguments

Definition: device/gemm_batched.h:569

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ac4ef1ac1e0876aaee5bff50dc09fe8a9)

static UnderlyingArguments to_underlying_arguments(Arguments const &args)

Helper to construct a transposed equivalent for the underying GEMM operator.

Definition: device/gemm_batched.h:637

cutlass::Status::kErrorInternal

An error within CUTLASS occurred.

cutlass::gemm::device::GemmBatched::Arguments::ref_B

TensorRef< ElementB const, LayoutB > ref_B

Definition: device/gemm_batched.h:276

cutlass::gemm::device::GemmBatched::kAlignmentA

static int const kAlignmentA

Definition: device/gemm_batched.h:235

device_kernel.h

Template for generic CUTLASS kernel.

cutlass::gemm::device::GemmBatched::initialize

Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Initializes GEMM state from arguments.

Definition: device/gemm_batched.h:366

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ThreadblockSwizzle

ThreadblockSwizzle ThreadblockSwizzle

Definition: device/gemm_batched.h:233

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::EpilogueOutputOp](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a494be150d3b809a4ecf66df682481905)

EpilogueOutputOp_ EpilogueOutputOp

Definition: device/gemm_batched.h:539

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::gemm::device::GemmBatched::GemmBatched

GemmBatched()

Constructs the GEMM.

Definition: device/gemm_batched.h:329

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::EpilogueOutputOp

EpilogueOutputOp EpilogueOutputOp

Definition: device/gemm_batched.h:232

cutlass::gemm::device::GemmBatched::Arguments::batch_count

int batch_count

Definition: device/gemm_batched.h:283

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a00805989734182945f982cab23a5dca8)

Status operator()(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_batched.h:684

[default_gemm_configuration.h](default gemm configuration_8h.html)

Definitions for GEMM structures.

gemm_batched.h

Template for a pipelined GEMM kernel. Does not compute batching or support split-K.

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementA](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a52b9261576b5633e901719f7c21d3369)

ElementA_ ElementA

Definition: device/gemm_batched.h:523

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::LayoutA

typename layout::LayoutTranspose< LayoutB >::type LayoutA

Definition: device/gemm_batched.h:217

cutlass::gemm::kernel::GemmBatched::Params::problem_size

cutlass::gemm::GemmCoord problem_size

Definition: kernel/gemm_batched.h:62

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ElementB

ElementA ElementB

Definition: device/gemm_batched.h:219

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3687659e826ba7f38bb060ad6020a739)

static size_t get_workspace_size(Arguments const &args)

Gets the workspace size.

Definition: device/gemm_batched.h:660

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ElementA

ElementB ElementA

Definition: device/gemm_batched.h:216

cutlass::gemm::device::GemmBatched::Arguments::ref_C

TensorRef< ElementC const, LayoutC > ref_C

Definition: device/gemm_batched.h:278

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a2129a4dccbd73f8c0f26b08ce5a5cb28)

CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, int64_t stride_A_, TensorRef< ElementB const, LayoutB > ref_B_, int64_t stride_B_, TensorRef< ElementC const, LayoutC > ref_C_, int64_t stride_C_, TensorRef< ElementC, LayoutC > ref_D_, int64_t stride_D_, typename EpilogueOutputOp::Params epilogue_, int batch_count_)

Constructs an Arguments structure.

Definition: device/gemm_batched.h:601

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ae86daa985279c77e57e682b64a68d330)

CUTLASS_HOST_DEVICE Arguments()

Default ctor.

Definition: device/gemm_batched.h:597

cutlass::TensorRef_aligned

bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)

Definition: tensor_ref.h:382

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_B](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ad7d2b82b83d7503b9f920ce3bdcdffa5)

TensorRef< ElementB const, LayoutB > ref_B

Definition: device/gemm_batched.h:582

cutlass::Status::kSuccess

Operation was successful.

cutlass::gemm::kernel::GemmBatched::Params::ref_A

Mma::IteratorA::TensorRef ref_A

Definition: kernel/gemm_batched.h:65

cutlass::gemm::device::GemmBatched::run

Status run(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_batched.h:407

threadblock_swizzle.h

Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...

arch.h

Defines tags for architecture-specific configurations.

cutlass::gemm::kernel::GemmBatched

Definition: kernel/gemm_batched.h:49

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ElementAccumulator

ElementAccumulator ElementAccumulator

Definition: device/gemm_batched.h:226

cutlass::gemm::device::GemmBatched::Arguments::stride_B

int64_t stride_B

Definition: device/gemm_batched.h:277

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::WarpShape

WarpShape WarpShape

Definition: device/gemm_batched.h:230

cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ArchTag

ArchTag ArchTag

Definition: device/gemm_batched.h:228

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::LayoutA](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af623ca54d9554cdfafc09af7a22cdd62)

LayoutA_ LayoutA

Definition: device/gemm_batched.h:524

cutlass::gemm::device::GemmBatched::operator()

Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: device/gemm_batched.h:448

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_C](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#aa9e30e41627595590421d8b53941b2b2)

TensorRef< ElementC const, LayoutC > ref_C

Definition: device/gemm_batched.h:584

cutlass::gemm::device::GemmBatched::Arguments::epilogue

EpilogueOutputOp::Params epilogue

Definition: device/gemm_batched.h:282

cutlass.h

Basic include for CUTLASS.

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::batch_count](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#adb66f3083f56c15578b139b7935452b5)

int batch_count

Definition: device/gemm_batched.h:589

cutlass::gemm::device::GemmBatched::GemmKernel

kernel::GemmBatched< typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle > GemmKernel

Definition: device/gemm_batched.h:264

cutlass::gemm::device::GemmBatched

Definition: device/gemm_batched.h:213

cutlass::Status

Status

Status code returned by CUTLASS operations.

Definition: cutlass.h:39

cutlass::gemm::device::GemmBatched::LayoutC

LayoutC_ LayoutC

Definition: device/gemm_batched.h:223

cutlass::gemm::device::GemmBatched::Arguments::ref_A

TensorRef< ElementA const, LayoutA > ref_A

Definition: device/gemm_batched.h:274

[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_A](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ac8830c9ed0e0a8bd7aa2aa4382550a2f)

int64_t stride_A

Definition: device/gemm_batched.h:581


Generated by 1.8.11