Back to Cutlass

CUTLASS: gemm_complex.h Source File

docs/include_2cutlass_2gemm_2device_2gemm__complex_8h_source.html

4.4.274.9 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

include/cutlass/gemm/device/gemm_complex.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/numeric_types.h"

33 #include "cutlass/arch/arch.h"

34 #include "cutlass/device_kernel.h"

35

36 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"

37 #include "cutlass/gemm/kernel/gemm.h"

38

39 #include "cutlass/gemm/kernel/default_gemm_complex.h"

40 #include "[cutlass/gemm/device/default_gemm_configuration.h](default gemm configuration_8h.html)"

41

43

44 namespace cutlass {

45 namespace gemm {

46 namespace device {

47

49

113

116

119

122

125

128

131

134

137

140

143

146

149

152

155

159 template <

161typename ElementA_,

163typename LayoutA_,

165typename ElementB_,

167typename LayoutB_,

169typename ElementC_,

171typename LayoutC_,

173typename ElementAccumulator_ = ElementC_,

175typename OperatorClass_ = arch::OpClassSimt,

177typename ArchTag_ = arch::Sm70,

179typename ThreadblockShape_ = typename DefaultGemmConfiguration<

180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

181 ElementAccumulator_>::ThreadblockShape,

183typename WarpShape_ = typename DefaultGemmConfiguration<

184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

185 ElementAccumulator_>::WarpShape,

187typename InstructionShape_ = typename DefaultGemmConfiguration<

188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

189 ElementAccumulator_>::InstructionShape,

191typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<

192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,

193 ElementAccumulator_>::EpilogueOutputOp,

195typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,

197int Stages =

198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,

199 ElementC_, ElementAccumulator_>::kStages,

201ComplexTransform TransformA = ComplexTransform::kNone,

203ComplexTransform TransformB = ComplexTransform::kNone,

205bool SplitKSerial = false

206 >

207 class GemmComplex {

208public:

209

210using ElementA = ElementA_;

211using LayoutA = LayoutA_;

212using TensorRefA = TensorRef<ElementA const, LayoutA>;

213using ElementB = ElementB_;

214using LayoutB = LayoutB_;

215using TensorRefB = TensorRef<ElementB const, LayoutB>;

216using ElementC = ElementC_;

217using LayoutC = LayoutC_;

218using TensorRefC = TensorRef<ElementC const, LayoutC>;

219using TensorRefD = TensorRef<ElementC, LayoutC>;

220using ElementAccumulator = ElementAccumulator_;

221using OperatorClass = OperatorClass_;

222using ArchTag = ArchTag_;

223using ThreadblockShape = ThreadblockShape_;

224using WarpShape = WarpShape_;

225using InstructionShape = InstructionShape_;

226using EpilogueOutputOp = EpilogueOutputOp_;

227using ThreadblockSwizzle = ThreadblockSwizzle_;

228static int const kStages = Stages;

229static ComplexTransform const kTransformA = TransformA;

230static ComplexTransform const kTransformB = TransformB;

231static bool const kSplitKSerial = SplitKSerial;

232

234using GemmKernel = typename kernel::DefaultGemmComplex<

235ElementA,

236LayoutA,

237ElementB,

238LayoutB,

239ElementC,

240LayoutC,

241ElementAccumulator,

242OperatorClass,

243ArchTag,

244ThreadblockShape,

245WarpShape,

246InstructionShape,

247EpilogueOutputOp,

248ThreadblockSwizzle,

249kStages,

250kTransformA,

251kTransformB,

252 kSplitKSerial

253 >::GemmKernel;

254

256struct Arguments {

257

258//

259// Data members

260//

261

262GemmCoord problem_size;

263TensorRef<ElementA const, LayoutA> ref_A;

264TensorRef<ElementB const, LayoutB> ref_B;

265TensorRef<ElementC const, LayoutC> ref_C;

266TensorRef<ElementC, LayoutC> ref_D;

267typename EpilogueOutputOp::Params epilogue;

268int split_k_slices;

269

270//

271// Methods

272//

273

275CUTLASS_HOST_DEVICE

276Arguments(): problem_size(0, 0, 0), split_k_slices(1) {

277

278 }

279

281CUTLASS_HOST_DEVICE

282Arguments(

283GemmCoord problem_size_,

284TensorRef<ElementA const, LayoutA> ref_A_,

285TensorRef<ElementB const, LayoutB> ref_B_,

286TensorRef<ElementC const, LayoutC> ref_C_,

287TensorRef<ElementC, LayoutC> ref_D_,

288typename EpilogueOutputOp::Params epilogue_ =

289typename EpilogueOutputOp::Params(),

290int split_k_slices = 1

291 ):

292 problem_size(problem_size_),

293 ref_A(ref_A_),

294 ref_B(ref_B_),

295 ref_C(ref_C_),

296 ref_D(ref_D_),

297 epilogue(epilogue_),

298 split_k_slices(split_k_slices) {

299

300 }

301 };

302

303 private:

304

306typename GemmKernel::Params params_;

307

308 public:

309

311GemmComplex() { }

312

314static Status can_implement(Arguments const &args) {

315

316if (!kSplitKSerial && args.split_k_slices > 1) {

317return Status::kErrorInvalidProblem;

318 }

319

320return Status::kSuccess;

321 }

322

324static size_t get_workspace_size(Arguments const &args) {

325

326if (kSplitKSerial && args.split_k_slices > 1) {

327

328// Determine grid shape

329 ThreadblockSwizzle threadblock_swizzle;

330

331cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(

332 args.problem_size,

333 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},

334 args.split_k_slices);

335

336return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());

337 }

338

339return 0;

340 }

341

343Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

344

345// Determine grid shape

346 ThreadblockSwizzle threadblock_swizzle;

347

348cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(

349 args.problem_size,

350 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},

351 args.split_k_slices);

352

353if (kSplitKSerial) {

354if (args.split_k_slices > 1) {

355if (!workspace) {

356return Status::kErrorWorkspaceNull;

357 }

358

359size_t bytes = get_workspace_size(args);

360

361 cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);

362

363if (result != cudaSuccess) {

364return Status::kErrorInternal;

365 }

366 }

367 }

368else {

369

370if (args.split_k_slices > 1) {

371return Status::kErrorInvalidProblem;

372 }

373 }

374

375// Initialize the Params structure

376 params_ = typename GemmKernel::Params{

377 args.problem_size,

378 grid_shape,

379 args.ref_A.non_const_ref(),

380 args.ref_B.non_const_ref(),

381 args.ref_C.non_const_ref(),

382 args.ref_D,

383 args.epilogue,

384static_cast<int *>(workspace)

385 };

386

387return Status::kSuccess;

388 }

389

391Status update(Arguments const &args, void *workspace = nullptr) {

392

393if (kSplitKSerial && args.split_k_slices > 1) {

394if (!workspace) {

395return Status::kErrorWorkspaceNull;

396 }

397 }

398

399 params_.ref_A.reset(args.ref_A.non_const_ref().data());

400 params_.ref_B.reset(args.ref_B.non_const_ref().data());

401 params_.ref_C.reset(args.ref_C.non_const_ref().data());

402 params_.ref_D.reset(args.ref_D.data());

403 params_.semaphore = static_cast<int *>(workspace);

404

405return Status::kSuccess;

406 }

407

409Status run(cudaStream_t stream = nullptr) {

410

411 ThreadblockSwizzle threadblock_swizzle;

412

413 dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);

414 dim3 block(GemmKernel::kThreadCount, 1, 1);

415

416 cudaError_t result;

417

418int smem_size = int(sizeof(typename GemmKernel::SharedStorage));

419if (smem_size >= (48 << 10)) {

420 result = cudaFuncSetAttribute(Kernel<GemmKernel>,

421 cudaFuncAttributeMaxDynamicSharedMemorySize,

422 smem_size);

423

424if (result != cudaSuccess) {

425return Status::kErrorInternal;

426 }

427

428 result = cudaFuncSetAttribute(

429 Kernel<GemmKernel>,

430 cudaFuncAttributePreferredSharedMemoryCarveout, 100);

431

432if (result != cudaSuccess) {

433return Status::kErrorInternal;

434 }

435 }

436

437 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);

438

439 result = cudaGetLastError();

440

441return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;

442 }

443

445Status operator()(cudaStream_t stream = nullptr) {

446return run(stream);

447 }

448

450Status operator()(

451Arguments const &args,

452void *workspace = nullptr,

453 cudaStream_t stream = nullptr) {

454

455Status status = initialize(args, workspace);

456

457if (status == Status::kSuccess) {

458 status = run(stream);

459 }

460

461return status;

462 }

463 };

464

466

468 template <

470typename ElementA_,

472typename LayoutA_,

474typename ElementB_,

476typename LayoutB_,

478typename ElementC_,

480typename ElementAccumulator_,

482typename OperatorClass_,

484typename ArchTag_,

486typename ThreadblockShape_,

488typename WarpShape_,

490typename InstructionShape_,

492typename EpilogueOutputOp_,

494typename ThreadblockSwizzle_,

496int Stages,

498ComplexTransform TransformA,

500ComplexTransform TransformB,

502bool SplitKSerial

503 >

[504](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html) class GemmComplex<

505 ElementA_,

506 LayoutA_,

507 ElementB_,

508 LayoutB_,

509 ElementC_,

510 layout::ColumnMajor, // partially specialized on LayoutC

511 ElementAccumulator_,

512 OperatorClass_,

513 ArchTag_,

514 ThreadblockShape_,

515 WarpShape_,

516 InstructionShape_,

517 EpilogueOutputOp_,

518 ThreadblockSwizzle_,

519 Stages,

520 TransformA,

521 TransformB,

522 SplitKSerial

523 > {

524 public:

525

[526](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#aa6621903fd434110b57220f2b2fb97cb)using ElementA = ElementA_;

[527](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a9b48f3a933f3b37814f9b70503b7684a)using LayoutA = LayoutA_;

[528](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#ae8995862cf7e42bc086f5941c1aa5d35)using TensorRefA = TensorRef<ElementA const, LayoutA>;

[529](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3570f3ed978cba7f66d1310ce66a56b3)using ElementB = ElementB_;

[530](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a90e18e93d96cd07f03201134d3c1b5a0)using LayoutB = LayoutB_;

[531](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3c170badab35f7754939a4cd9d8258fe)using TensorRefB = TensorRef<ElementB const, LayoutB>;

[532](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a77d1c52156347656311764de09456670)using ElementC = ElementC_;

[533](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#af2b903fa011363e7049d5f0807b77731)using LayoutC = layout::ColumnMajor;

[534](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#ae77f478fd7cff440628fb38e230f2609)using TensorRefC = TensorRef<ElementC const, LayoutC>;

[535](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a5b68c920af70250817a7791d91ab77f5)using TensorRefD = TensorRef<ElementC, LayoutC>;

[536](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a233bda6413e491449aa29b3222c60904)using ElementAccumulator = ElementAccumulator_;

[537](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a4d170f269f81dafe07770197c3864a6b)using OperatorClass = OperatorClass_;

[538](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a15200a21650efa7f582747dbbad044ca)using ArchTag = ArchTag_;

[539](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#ae09f224faa27d9735ab77899d36dbc96)using ThreadblockShape = ThreadblockShape_;

[540](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a09a291542ae92f0fb97a4c0a5ee25db4)using WarpShape = WarpShape_;

[541](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#af691b4304896f601c34eaedf78493ed5)using InstructionShape = InstructionShape_;

[542](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#ab2b9c1976d62f70a32d93f55f79a2401)using EpilogueOutputOp = EpilogueOutputOp_;

[543](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a02473fb6e60eed4bf79b510d7096b4c5)using ThreadblockSwizzle = ThreadblockSwizzle_;

[544](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a689afffc991cf4e6aab7d6e4f5fe4d46)static int const kStages = Stages;

[545](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#afe14a91a30bea2204d4351591df7b5cc)static bool const kSplitKSerial = SplitKSerial;

546

547using UnderlyingOperator = GemmComplex<

548ElementB,

549typename layout::LayoutTranspose<LayoutB>::type,

550ElementA,

551typename layout::LayoutTranspose<LayoutA>::type,

552ElementC,

553layout::RowMajor,

554ElementAccumulator,

555OperatorClass,

556ArchTag,

557ThreadblockShape,

558WarpShape,

559InstructionShape,

560EpilogueOutputOp,

561ThreadblockSwizzle,

562 Stages,

563 TransformA,

564 TransformB,

565 SplitKSerial

[566](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#aa55140ff232b12c3a4bf1e5093282354) >;

567

[568](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b)using [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b) = typename UnderlyingOperator::Arguments;

[569](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#abe65836275404d572a7e1e2108c72982)using [GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#abe65836275404d572a7e1e2108c72982) = typename UnderlyingOperator::GemmKernel;

570

[572](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html)struct Arguments {

573

574//

575// Data members

576//

577

[578](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a29159f430d4a733ec3fac550d0458e18)GemmCoord [problem_size](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a29159f430d4a733ec3fac550d0458e18);

[579](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ac8e9298e3786e9391d740faa4d0566f2)TensorRef<ElementA const, LayoutA> [ref_A](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#ac8e9298e3786e9391d740faa4d0566f2);

[580](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ab706387c660af35ae2b9579165eec85d)TensorRef<ElementB const, LayoutB> [ref_B](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#ab706387c660af35ae2b9579165eec85d);

[581](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a3a59aa793429bc57d796b40fa4fab622)TensorRef<ElementC const, LayoutC> [ref_C](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a3a59aa793429bc57d796b40fa4fab622);

[582](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a2904e3ad7a47b3d85ea60d94eeebe84b)TensorRef<ElementC, LayoutC> [ref_D](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a2904e3ad7a47b3d85ea60d94eeebe84b);

[583](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ad1f435bf8b7003afad9b803adf9fcb89)typename EpilogueOutputOp::Params [epilogue](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#ad1f435bf8b7003afad9b803adf9fcb89);

[584](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#aec118721190212e7e61c7d17d4c93d1c)int [split_k_slices](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#aec118721190212e7e61c7d17d4c93d1c);

585

586//

587// Methods

588//

589

591CUTLASS_HOST_DEVICE

[592](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a710950fddfc99fc79302cbfe959bb201)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a710950fddfc99fc79302cbfe959bb201)() { }

593

595CUTLASS_HOST_DEVICE

[596](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a8886db2fcca9a63381861662d318ad12)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a8886db2fcca9a63381861662d318ad12)(

597GemmCoord problem_size_,

598TensorRef<ElementA const, LayoutA> ref_A_,

599TensorRef<ElementB const, LayoutB> ref_B_,

600TensorRef<ElementC const, LayoutC> ref_C_,

601TensorRef<ElementC, LayoutC> ref_D_,

602typename EpilogueOutputOp::Params epilogue_ =

603typename EpilogueOutputOp::Params(),

604int split_k_slices = 1

605 ):

606 problem_size(problem_size_),

607 ref_A(ref_A_),

608 ref_B(ref_B_),

609 ref_C(ref_C_),

610 ref_D(ref_D_),

611 epilogue(epilogue_),

612 split_k_slices(split_k_slices) { }

613 };

614

615 private:

616

617UnderlyingOperator underlying_operator_;

618

619 public:

620

[622](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a9ce748bfc112dd4bb942c5e7c95845df)[GemmComplex](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a9ce748bfc112dd4bb942c5e7c95845df)() { }

623

[625](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3dd09eeeae6c4faeddc4abc8bb57b177)static [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b) [to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3dd09eeeae6c4faeddc4abc8bb57b177)(Arguments const &args) {

626return [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b)(

627 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},

628 {args.ref_B.data(), args.ref_B.stride(0)},

629 {args.ref_A.data(), args.ref_A.stride(0)},

630 {args.ref_C.data(), args.ref_C.stride(0)},

631 {args.ref_D.data(), args.ref_D.stride(0)},

632 args.epilogue,

633 args.split_k_slices

634 );

635 }

636

[638](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#adb94d2e6dd70b46bea6b5b433e14fea9)static Status [can_implement](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#adb94d2e6dd70b46bea6b5b433e14fea9)(Arguments const &args) {

639

640return UnderlyingOperator::can_implement(to_underlying_arguments(args));

641 }

642

[644](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a75342fc4122c07d1382b31ee5f188210)static size_t [get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a75342fc4122c07d1382b31ee5f188210)(Arguments const &args) {

645

646return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));

647 }

648

[650](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a5c3286631f254746c9eb788b780cdca3)Status [initialize](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a5c3286631f254746c9eb788b780cdca3)(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {

651

652return underlying_operator_.initialize(to_underlying_arguments(args), workspace);

653 }

654

[656](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a08446a157a60f7f1e23315c1ece09bce)Status [update](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a08446a157a60f7f1e23315c1ece09bce)(Arguments const &args, void *workspace = nullptr) {

657

658return underlying_operator_.update(to_underlying_arguments(args), workspace);

659 }

660

[662](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a4111bba1e9d2000fcc9bba2f114ee801)Status [run](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a4111bba1e9d2000fcc9bba2f114ee801)(cudaStream_t stream = nullptr) {

663

664return underlying_operator_.run(stream);

665 }

666

[668](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a375220f643161478c1fb5bcd24f8b5cd)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a375220f643161478c1fb5bcd24f8b5cd)(cudaStream_t stream = nullptr) {

669return run(stream);

670 }

671

[673](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a50ff89a3c0b3735b669cf4e3b755918a)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a50ff89a3c0b3735b669cf4e3b755918a)(

674Arguments const &args,

675void *workspace = nullptr,

676 cudaStream_t stream = nullptr) {

677

678Status status = initialize(args, workspace);

679

680if (status == Status::kSuccess) {

681 status = run(stream);

682 }

683

684return status;

685 }

686 };

687

689

690 } // namespace device

691 } // namespace gemm

692 } // namespace cutlass

693

cutlass::gemm::device::GemmComplex::operator()

Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm_complex.h:450

cutlass::gemm::device::GemmComplex::kTransformA

static ComplexTransform const kTransformA

Definition: include/cutlass/gemm/device/gemm_complex.h:229

cutlass::gemm::device::GemmComplex::Arguments::ref_A

TensorRef< ElementA const, LayoutA > ref_A

Definition: include/cutlass/gemm/device/gemm_complex.h:263

cutlass::gemm::device::GemmComplex

Definition: include/cutlass/gemm/device/gemm_complex.h:207

cutlass

Definition: aligned_buffer.h:35

cutlass::ComplexTransform

ComplexTransform

Enumeraed type describing a transformation on a complex value.

Definition: complex.h:43

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ElementB

ElementA ElementB

Definition: include/cutlass/gemm/device/gemm_complex.h:213

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_C](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a3a59aa793429bc57d796b40fa4fab622)

TensorRef< ElementC const, LayoutC > ref_C

Definition: include/cutlass/gemm/device/gemm_complex.h:581

cutlass::gemm::device::GemmComplex::run

Status run(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm_complex.h:409

cutlass::Status::kErrorInvalidProblem

Specified problem size is not supported by operator.

cutlass::gemm::device::GemmComplex::Arguments::problem_size

GemmCoord problem_size

Definition: include/cutlass/gemm/device/gemm_complex.h:262

cutlass::gemm::device::GemmComplex::can_implement

static Status can_implement(Arguments const &args)

Determines whether the GEMM can execute the given problem.

Definition: include/cutlass/gemm/device/gemm_complex.h:314

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::EpilogueOutputOp

EpilogueOutputOp EpilogueOutputOp

Definition: include/cutlass/gemm/device/gemm_complex.h:226

cutlass::gemm::device::GemmComplex::kStages

static int const kStages

Definition: include/cutlass/gemm/device/gemm_complex.h:228

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ArchTag

ArchTag ArchTag

Definition: include/cutlass/gemm/device/gemm_complex.h:222

cutlass::gemm::device::GemmComplex::Arguments::ref_B

TensorRef< ElementB const, LayoutB > ref_B

Definition: include/cutlass/gemm/device/gemm_complex.h:264

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::gemm::device::GemmComplex::Arguments::epilogue

EpilogueOutputOp::Params epilogue

Definition: include/cutlass/gemm/device/gemm_complex.h:267

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a50ff89a3c0b3735b669cf4e3b755918a)

Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm_complex.h:673

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::gemm::device::GemmComplex::initialize

Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Initializes GEMM state from arguments.

Definition: include/cutlass/gemm/device/gemm_complex.h:343

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::split_k_slices](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#aec118721190212e7e61c7d17d4c93d1c)

int split_k_slices

Definition: include/cutlass/gemm/device/gemm_complex.h:584

cutlass::ComplexTransform::kNone

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3dd09eeeae6c4faeddc4abc8bb57b177)

static UnderlyingArguments to_underlying_arguments(Arguments const &args)

Helper to construct a transposed equivalent for the underying GEMM operator.

Definition: include/cutlass/gemm/device/gemm_complex.h:625

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::LayoutA

typename layout::LayoutTranspose< LayoutB >::type LayoutA

Definition: include/cutlass/gemm/device/gemm_complex.h:211

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

cutlass::gemm::device::GemmComplex::Arguments

Argument structure.

Definition: include/cutlass/gemm/device/gemm_complex.h:256

cutlass::gemm::device::GemmComplex::get_workspace_size

static size_t get_workspace_size(Arguments const &args)

Gets the workspace size.

Definition: include/cutlass/gemm/device/gemm_complex.h:324

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::initialize](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a5c3286631f254746c9eb788b780cdca3)

Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)

Initializes GEMM state from arguments.

Definition: include/cutlass/gemm/device/gemm_complex.h:650

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::run](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a4111bba1e9d2000fcc9bba2f114ee801)

Status run(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm_complex.h:662

cutlass::gemm::device::GemmComplex::LayoutC

LayoutC_ LayoutC

Definition: include/cutlass/gemm/device/gemm_complex.h:217

cutlass::gemm::device::GemmComplex::Arguments::Arguments

CUTLASS_HOST_DEVICE Arguments()

Default ctor.

Definition: include/cutlass/gemm/device/gemm_complex.h:276

cutlass::layout::LayoutTranspose

Defines transposes of matrix layouts.

Definition: layout/matrix.h:921

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ElementC

ElementC ElementC

Definition: include/cutlass/gemm/device/gemm_complex.h:216

cutlass::TensorRef< ElementA const, LayoutA >

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ElementAccumulator

ElementAccumulator ElementAccumulator

Definition: include/cutlass/gemm/device/gemm_complex.h:220

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a75342fc4122c07d1382b31ee5f188210)

static size_t get_workspace_size(Arguments const &args)

Gets the workspace size.

Definition: include/cutlass/gemm/device/gemm_complex.h:644

cutlass::Status::kErrorInternal

An error within CUTLASS occurred.

cutlass::gemm::device::GemmComplex::operator()

Status operator()(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm_complex.h:445

device_kernel.h

Template for generic CUTLASS kernel.

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::GemmComplex](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a9ce748bfc112dd4bb942c5e7c95845df)

GemmComplex()

Constructs the GEMM.

Definition: include/cutlass/gemm/device/gemm_complex.h:622

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a375220f643161478c1fb5bcd24f8b5cd)

Status operator()(cudaStream_t stream=nullptr)

Runs the kernel using initialized state.

Definition: include/cutlass/gemm/device/gemm_complex.h:668

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ThreadblockSwizzle

ThreadblockSwizzle ThreadblockSwizzle

Definition: include/cutlass/gemm/device/gemm_complex.h:227

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::device::GemmComplex::GemmComplex

GemmComplex()

Constructs the GEMM.

Definition: include/cutlass/gemm/device/gemm_complex.h:311

[default_gemm_configuration.h](default gemm configuration_8h.html)

Definitions for GEMM structures.

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#abe65836275404d572a7e1e2108c72982)

typename UnderlyingOperator::GemmKernel GemmKernel

Definition: include/cutlass/gemm/device/gemm_complex.h:569

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::problem_size](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a29159f430d4a733ec3fac550d0458e18)

GemmCoord problem_size

Definition: include/cutlass/gemm/device/gemm_complex.h:578

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::LayoutB

typename layout::LayoutTranspose< LayoutA >::type LayoutB

Definition: include/cutlass/gemm/device/gemm_complex.h:214

cutlass::gemm::device::GemmComplex::Arguments::ref_D

TensorRef< ElementC, LayoutC > ref_D

Definition: include/cutlass/gemm/device/gemm_complex.h:266

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::gemm::device::GemmComplex::Arguments::ref_C

TensorRef< ElementC const, LayoutC > ref_C

Definition: include/cutlass/gemm/device/gemm_complex.h:265

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b)

typename UnderlyingOperator::Arguments UnderlyingArguments

Definition: include/cutlass/gemm/device/gemm_complex.h:568

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_B](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ab706387c660af35ae2b9579165eec85d)

TensorRef< ElementB const, LayoutB > ref_B

Definition: include/cutlass/gemm/device/gemm_complex.h:580

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::WarpShape

WarpShape WarpShape

Definition: include/cutlass/gemm/device/gemm_complex.h:224

cutlass::gemm::device::GemmComplex::update

Status update(Arguments const &args, void *workspace=nullptr)

Lightweight update given a subset of arguments.

Definition: include/cutlass/gemm/device/gemm_complex.h:391

cutlass::gemm::device::GemmComplex::kSplitKSerial

static bool const kSplitKSerial

Definition: include/cutlass/gemm/device/gemm_complex.h:231

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::update](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a08446a157a60f7f1e23315c1ece09bce)

Status update(Arguments const &args, void *workspace=nullptr)

Lightweight update given a subset of arguments.

Definition: include/cutlass/gemm/device/gemm_complex.h:656

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a710950fddfc99fc79302cbfe959bb201)

CUTLASS_HOST_DEVICE Arguments()

Default ctor.

Definition: include/cutlass/gemm/device/gemm_complex.h:592

cutlass::Status::kErrorWorkspaceNull

The given workspace is null when it is required to be non-null.

cutlass::Status::kSuccess

Operation was successful.

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::can_implement](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#adb94d2e6dd70b46bea6b5b433e14fea9)

static Status can_implement(Arguments const &args)

Determines whether the GEMM can execute the given problem.

Definition: include/cutlass/gemm/device/gemm_complex.h:638

threadblock_swizzle.h

Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...

arch.h

Defines tags for architecture-specific configurations.

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_D](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a2904e3ad7a47b3d85ea60d94eeebe84b)

TensorRef< ElementC, LayoutC > ref_D

Definition: include/cutlass/gemm/device/gemm_complex.h:582

cutlass::gemm::device::GemmComplex::Arguments::split_k_slices

int split_k_slices

Definition: include/cutlass/gemm/device/gemm_complex.h:268

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::InstructionShape

InstructionShape InstructionShape

Definition: include/cutlass/gemm/device/gemm_complex.h:225

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::OperatorClass

OperatorClass OperatorClass

Definition: include/cutlass/gemm/device/gemm_complex.h:221

cutlass::gemm::device::GemmComplex::Arguments::Arguments

CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)

Constructs an Arguments structure.

Definition: include/cutlass/gemm/device/gemm_complex.h:282

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::GemmKernel

typename kernel::DefaultGemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kTransformA, kTransformB, kSplitKSerial >::GemmKernel GemmKernel

Define the kernel.

Definition: include/cutlass/gemm/device/gemm_complex.h:253

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ThreadblockShape

ThreadblockShape ThreadblockShape

Definition: include/cutlass/gemm/device/gemm_complex.h:223

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::epilogue](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ad1f435bf8b7003afad9b803adf9fcb89)

EpilogueOutputOp::Params epilogue

Definition: include/cutlass/gemm/device/gemm_complex.h:583

cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ElementA

ElementB ElementA

Definition: include/cutlass/gemm/device/gemm_complex.h:210

cutlass::gemm::device::GemmComplex::kTransformB

static ComplexTransform const kTransformB

Definition: include/cutlass/gemm/device/gemm_complex.h:230

cutlass.h

Basic include for CUTLASS.

cutlass::Status

Status

Status code returned by CUTLASS operations.

Definition: cutlass.h:39

gemm.h

Template for a pipelined GEMM kernel. Does not compute batching or support split-K.

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a8886db2fcca9a63381861662d318ad12)

CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)

Constructs an Arguments structure.

Definition: include/cutlass/gemm/device/gemm_complex.h:596

[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_A](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ac8e9298e3786e9391d740faa4d0566f2)

TensorRef< ElementA const, LayoutA > ref_A

Definition: include/cutlass/gemm/device/gemm_complex.h:579


Generated by 1.8.11