docs/include_2cutlass_2gemm_2device_2gemm_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
include/cutlass/gemm/device/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/device_kernel.h"
35
36 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
37 #include "cutlass/gemm/kernel/gemm.h"
38
39 #include "cutlass/gemm/kernel/default_gemm.h"
40 #include "[cutlass/gemm/device/default_gemm_configuration.h](default gemm configuration_8h.html)"
41
43
44 namespace cutlass {
45 namespace gemm {
46 namespace device {
47
49
113
116
119
122
125
128
131
134
137
140
143
146
149
152
155
159 template <
161typename ElementA_,
163typename LayoutA_,
165typename ElementB_,
167typename LayoutB_,
169typename ElementC_,
171typename LayoutC_,
173typename ElementAccumulator_ = ElementC_,
175typename OperatorClass_ = arch::OpClassSimt,
177typename ArchTag_ = arch::Sm70,
179typename ThreadblockShape_ = typename DefaultGemmConfiguration<
180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181 ElementAccumulator_>::ThreadblockShape,
183typename WarpShape_ = typename DefaultGemmConfiguration<
184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185 ElementAccumulator_>::WarpShape,
187typename InstructionShape_ = typename DefaultGemmConfiguration<
188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189 ElementAccumulator_>::InstructionShape,
191typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193 ElementAccumulator_>::EpilogueOutputOp,
195typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
197int Stages =
198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199 ElementC_, ElementAccumulator_>::kStages,
201int AlignmentA =
202 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
203 ElementC_, ElementAccumulator_>::kAlignmentA,
205int AlignmentB =
206 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
207 ElementC_, ElementAccumulator_>::kAlignmentB,
209bool SplitKSerial = false,
211typename Operator_ = typename DefaultGemmConfiguration<
212 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
213 ElementAccumulator_>::Operator,
215bool IsBetaZero = false>
217public:
218
219using ElementA = ElementA_;
221using TensorRefA = TensorRef<ElementA const, LayoutA>;
222using ElementB = ElementB_;
224using TensorRefB = TensorRef<ElementB const, LayoutB>;
225using ElementC = ElementC_;
227using TensorRefC = TensorRef<ElementC const, LayoutC>;
228using TensorRefD = TensorRef<ElementC, LayoutC>;
229using ElementAccumulator = ElementAccumulator_;
230using OperatorClass = OperatorClass_;
232using ThreadblockShape = ThreadblockShape_;
233using WarpShape = WarpShape_;
234using InstructionShape = InstructionShape_;
235using EpilogueOutputOp = EpilogueOutputOp_;
236using ThreadblockSwizzle = ThreadblockSwizzle_;
237using Operator = Operator_;
238static int const kStages = Stages;
239static int const kAlignmentA = AlignmentA;
240static int const kAlignmentB = AlignmentB;
241static int const kAlignmentC = EpilogueOutputOp::kCount;
242static bool const kSplitKSerial = SplitKSerial;
243static bool const kIsBetaZero = IsBetaZero;
244
246using GemmKernel = typename kernel::DefaultGemm<
247ElementA,
248LayoutA,
249kAlignmentA,
250ElementB,
251LayoutB,
252kAlignmentB,
253ElementC,
254LayoutC,
256OperatorClass,
257ArchTag,
258ThreadblockShape,
259WarpShape,
260InstructionShape,
261EpilogueOutputOp,
263kStages,
264kSplitKSerial,
265Operator,
266 kIsBetaZero
268
271
272//
273// Data members
274//
275
277TensorRef<ElementA const, LayoutA> ref_A;
278TensorRef<ElementB const, LayoutB> ref_B;
279TensorRef<ElementC const, LayoutC> ref_C;
280TensorRef<ElementC, LayoutC> ref_D;
281typename EpilogueOutputOp::Params epilogue;
282int split_k_slices;
283
284//
285// Methods
286//
287
290Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
291
292 }
293
297GemmCoord problem_size_,
298TensorRef<ElementA const, LayoutA> ref_A_,
299TensorRef<ElementB const, LayoutB> ref_B_,
300TensorRef<ElementC const, LayoutC> ref_C_,
301TensorRef<ElementC, LayoutC> ref_D_,
302typename EpilogueOutputOp::Params epilogue_ =
303typename EpilogueOutputOp::Params(),
304int split_k_slices = 1
305 ):
306 problem_size(problem_size_),
307 ref_A(ref_A_),
308 ref_B(ref_B_),
309 ref_C(ref_C_),
310 ref_D(ref_D_),
311 epilogue(epilogue_),
312 split_k_slices(split_k_slices) {
313
314 }
315 };
316
317 private:
318
320typename GemmKernel::Params params_;
321
322 public:
323
326
328static Status can_implement(Arguments const &args) {
329
330if (!kSplitKSerial && args.split_k_slices > 1) {
331return Status::kErrorInvalidProblem;
332 }
333
334Status status = GemmKernel::can_implement(
335 args.problem_size,
336 args.ref_A.non_const_ref(),
337 args.ref_B.non_const_ref(),
338 args.ref_C.non_const_ref(),
339 args.ref_D
340 );
341
342if (status != Status::kSuccess) {
343return status;
344 }
345
346return Status::kSuccess;
347 }
348
350static size_t get_workspace_size(Arguments const &args) {
351
352if (kSplitKSerial && args.split_k_slices > 1) {
353
354// Determine grid shape
355 ThreadblockSwizzle threadblock_swizzle;
356
357cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
358 args.problem_size,
359 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
360 args.split_k_slices);
361
362return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
363 }
364
365return 0;
366 }
367
369Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
370
371// Determine grid shape
372 ThreadblockSwizzle threadblock_swizzle;
373
374cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
375 args.problem_size,
376 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
377 args.split_k_slices);
378
379if (kSplitKSerial) {
380if (args.split_k_slices > 1) {
381if (!workspace) {
382return Status::kErrorWorkspaceNull;
383 }
384
385size_t bytes = get_workspace_size(args);
386
387 cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
388
389if (result != cudaSuccess) {
390return Status::kErrorInternal;
391 }
392 }
393 }
394else {
395
396if (args.split_k_slices > 1) {
397return Status::kErrorInvalidProblem;
398 }
399 }
400
401// Initialize the Params structure
402 params_ = typename GemmKernel::Params{
403 args.problem_size,
404 grid_shape,
405 args.ref_A.non_const_ref(),
406 args.ref_B.non_const_ref(),
407 args.ref_C.non_const_ref(),
408 args.ref_D,
409 args.epilogue,
410static_cast<int *>(workspace)
411 };
412
413return Status::kSuccess;
414 }
415
417Status update(Arguments const &args, void *workspace = nullptr) {
418
419if (kSplitKSerial && args.split_k_slices > 1) {
420if (!workspace) {
421return Status::kErrorWorkspaceNull;
422 }
423 }
424
425 params_.ref_A.reset(args.ref_A.non_const_ref().data());
426 params_.ref_B.reset(args.ref_B.non_const_ref().data());
427 params_.ref_C.reset(args.ref_C.non_const_ref().data());
428 params_.ref_D.reset(args.ref_D.data());
429 params_.semaphore = static_cast<int *>(workspace);
430
431return Status::kSuccess;
432 }
433
435Status run(cudaStream_t stream = nullptr) {
436
437 ThreadblockSwizzle threadblock_swizzle;
438
439 dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
440 dim3 block(GemmKernel::kThreadCount, 1, 1);
441
442 cudaError_t result;
443
444int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
445if (smem_size >= (48 << 10)) {
446 result = cudaFuncSetAttribute(Kernel<GemmKernel>,
447 cudaFuncAttributeMaxDynamicSharedMemorySize,
448 smem_size);
449
450if (result != cudaSuccess) {
451return Status::kErrorInternal;
452 }
453
454 result = cudaFuncSetAttribute(
455 Kernel<GemmKernel>,
456 cudaFuncAttributePreferredSharedMemoryCarveout, 100);
457
458if (result != cudaSuccess) {
459return Status::kErrorInternal;
460 }
461 }
462
463 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
464
465 result = cudaGetLastError();
466
467return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
468 }
469
471Status operator()(cudaStream_t stream = nullptr) {
472return run(stream);
473 }
474
477Arguments const &args,
478void *workspace = nullptr,
479 cudaStream_t stream = nullptr) {
480
481Status status = initialize(args, workspace);
482
483if (status == Status::kSuccess) {
484 status = run(stream);
485 }
486
487return status;
488 }
489 };
490
492
494 template <
496typename ElementA_,
498typename LayoutA_,
500typename ElementB_,
502typename LayoutB_,
504typename ElementC_,
506typename ElementAccumulator_,
508typename OperatorClass_,
510typename ArchTag_,
512typename ThreadblockShape_,
514typename WarpShape_,
516typename InstructionShape_,
518typename EpilogueOutputOp_,
520typename ThreadblockSwizzle_,
522int Stages,
524int AlignmentA,
526int AlignmentB,
528bool SplitKSerial,
530typename Operator_,
532bool IsBetaZero>
[533](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html) class Gemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
534 layout::ColumnMajor, // partially specialized on LayoutC
535 ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
536 WarpShape_, InstructionShape_, EpilogueOutputOp_,
537 ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial,
538 Operator_, IsBetaZero> {
539public:
540
[541](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a09db4f8f255d272e7350394d568f4a01)using ElementA = ElementA_;
[542](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5212eb5b3af32e5bc43cc4179bb346ef)using LayoutA = LayoutA_;
[543](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a82da12cbdf6f75499d315ee530f5330e)using TensorRefA = TensorRef<ElementA const, LayoutA>;
[544](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a0f904b72a3ff91f7ae6ad1a91e915b6d)using ElementB = ElementB_;
[545](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1841e0e97e59862c7a92fc8d2ab7c9bc)using LayoutB = LayoutB_;
[546](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ab4bd8a2bb7be0fa2b583cf34b63b62eb)using TensorRefB = TensorRef<ElementB const, LayoutB>;
[547](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ad58a37fecfeb982d20fc209a0df4c1fa)using ElementC = ElementC_;
[548](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#afe4685fea6a4603a7459bbe9923c9cb3)using LayoutC = layout::ColumnMajor;
[549](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aff1ad6d93937a9e4b261eb69322449e7)using TensorRefC = TensorRef<ElementC const, LayoutC>;
[550](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abaa02d78437ae0f42260848d722c134f)using TensorRefD = TensorRef<ElementC, LayoutC>;
[551](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#acbc61142b95f4d33bc0b8518857ab7be)using ElementAccumulator = ElementAccumulator_;
[552](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa816137b589b8bf2204ace73d49b7ded)using OperatorClass = OperatorClass_;
[553](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1b502a4097e745c12d0d628d080ba447)using ArchTag = ArchTag_;
[554](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ab798f409ba80eab4a0140fdf43e768ee)using ThreadblockShape = ThreadblockShape_;
[555](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abdc15293a8b083372e5395049440d01c)using WarpShape = WarpShape_;
[556](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a71b84f983b94b50a48bd0890f1e0ed59)using InstructionShape = InstructionShape_;
[557](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ad1d60f7381ae03803a078a26604bd8be)using EpilogueOutputOp = EpilogueOutputOp_;
[558](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#af55d56acaa01ce303c22d6e9e0b0f895)using ThreadblockSwizzle = ThreadblockSwizzle_;
[559](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a029f48f17ec3fb98067bfacd7e06f3d2)using Operator = Operator_;
[560](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5de0cfa9c3831daebbdc8326c239dd33)static int const kStages = Stages;
[561](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a78660ed036162b8455546ff5718968d0)static int const kAlignmentA = AlignmentA;
[562](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa39221ab9fa4248c613b7222f764072e)static int const kAlignmentB = AlignmentB;
[563](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a0b609010f97cb53cf4d8f1ecb4bb0b79)static bool const kSplitKSerial = SplitKSerial;
[564](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a98d1d07f32f29b29e883775fcd276833)static bool const kIsBetaZero = IsBetaZero;
565
566using UnderlyingOperator = Gemm<
567ElementB,
568typename layout::LayoutTranspose<LayoutB>::type,
569ElementA,
570typename layout::LayoutTranspose<LayoutA>::type,
571ElementC,
572layout::RowMajor,
574OperatorClass,
575ArchTag,
576ThreadblockShape,
577WarpShape,
578InstructionShape,
579EpilogueOutputOp,
581 Stages,
582kAlignmentB,
583kAlignmentA,
584 SplitKSerial,
585Operator,
586 kIsBetaZero
[587](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa48a0b3645f2ef103cb1b2d41218d865) >;
588
[589](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a)using [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a) = typename UnderlyingOperator::Arguments;
[590](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2bdbc5e737f9bfd1e09a7cfb30e60e29)using [GemmKernel](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2bdbc5e737f9bfd1e09a7cfb30e60e29) = typename UnderlyingOperator::GemmKernel;
[591](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#ad74a049e26f4b9224362b4d1c93ca14b)static int const kAlignmentC = UnderlyingOperator::kAlignmentC;
592
[594](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html)struct Arguments {
595
596//
597// Data members
598//
599
[600](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#acd02e86dfff866eade08415e0043ccc3)GemmCoord [problem_size](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#acd02e86dfff866eade08415e0043ccc3);
[601](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a9bdaf3563983efcca649460be169b334)TensorRef<ElementA const, LayoutA> [ref_A](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a9bdaf3563983efcca649460be169b334);
[602](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab77204c1010b17c6643d26a89f41c3d0)TensorRef<ElementB const, LayoutB> [ref_B](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab77204c1010b17c6643d26a89f41c3d0);
[603](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a590b8da88ae9350042838451e3e37a22)TensorRef<ElementC const, LayoutC> [ref_C](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a590b8da88ae9350042838451e3e37a22);
[604](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab1d4d5865786a415f87db1def1b029e7)TensorRef<ElementC, LayoutC> [ref_D](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab1d4d5865786a415f87db1def1b029e7);
[605](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a426f402c08be99849a4477a07f010a5e)typename EpilogueOutputOp::Params [epilogue](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a426f402c08be99849a4477a07f010a5e);
[606](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#aaef8450711318fa1a53fe3cb72b59263)int [split_k_slices](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#aaef8450711318fa1a53fe3cb72b59263);
607
608//
609// Methods
610//
611
[614](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ac6c397a181a52c0dbb39bf3710ee4658)[Arguments](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ac6c397a181a52c0dbb39bf3710ee4658)() { }
615
[618](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a331de1adfdcbea6d0137afe64a4f6f4c)[Arguments](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a331de1adfdcbea6d0137afe64a4f6f4c)(
619GemmCoord problem_size_,
620TensorRef<ElementA const, LayoutA> ref_A_,
621TensorRef<ElementB const, LayoutB> ref_B_,
622TensorRef<ElementC const, LayoutC> ref_C_,
623TensorRef<ElementC, LayoutC> ref_D_,
624typename EpilogueOutputOp::Params epilogue_ =
625typename EpilogueOutputOp::Params(),
626int split_k_slices = 1
627 ):
628 problem_size(problem_size_),
629 ref_A(ref_A_),
630 ref_B(ref_B_),
631 ref_C(ref_C_),
632 ref_D(ref_D_),
633 epilogue(epilogue_),
634 split_k_slices(split_k_slices) { }
635 };
636
637 private:
638
639UnderlyingOperator underlying_operator_;
640
641 public:
642
[644](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abcacf502806db50eb17a6d925aee16d5)[Gemm](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abcacf502806db50eb17a6d925aee16d5)() { }
645
[647](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa9313915a6129f0c43b43ef3698b3ee4)static [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a) [to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa9313915a6129f0c43b43ef3698b3ee4)(Arguments const &args) {
648return [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a)(
649 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
650 {args.ref_B.data(), args.ref_B.stride(0)},
651 {args.ref_A.data(), args.ref_A.stride(0)},
652 {args.ref_C.data(), args.ref_C.stride(0)},
653 {args.ref_D.data(), args.ref_D.stride(0)},
654 args.epilogue,
655 args.split_k_slices
656 );
657 }
658
[660](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a662bcbcb6164c803ab490c86e69b9ee1)static Status [can_implement](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a662bcbcb6164c803ab490c86e69b9ee1)(Arguments const &args) {
661
662return UnderlyingOperator::can_implement(to_underlying_arguments(args));
663 }
664
[666](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1469133c30fde6b28296e3ff6951e7a4)static size_t [get_workspace_size](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1469133c30fde6b28296e3ff6951e7a4)(Arguments const &args) {
667
668return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
669 }
670
[672](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7a14474e4238d2fac92ad71c6de087d8)Status [initialize](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7a14474e4238d2fac92ad71c6de087d8)(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
673
674return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
675 }
676
[678](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2b6c5275c173d73cffe8e6b6b1ccf2c1)Status [update](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2b6c5275c173d73cffe8e6b6b1ccf2c1)(Arguments const &args, void *workspace = nullptr) {
679
680return underlying_operator_.update(to_underlying_arguments(args), workspace);
681 }
682
[684](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5f4f93ca97b358b4410f3d0b1e0a6387)Status [run](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5f4f93ca97b358b4410f3d0b1e0a6387)(cudaStream_t stream = nullptr) {
685
686return underlying_operator_.run(stream);
687 }
688
[690](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a384db4125183e504fafc5a946b7ba757)Status [operator()](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a384db4125183e504fafc5a946b7ba757)(cudaStream_t stream = nullptr) {
691return run(stream);
692 }
693
[695](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a6115aa957b3ba8ad9e54b7efeefaacd1)Status [operator()](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA 00_01ElementB 00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a6115aa957b3ba8ad9e54b7efeefaacd1)(
696Arguments const &args,
697void *workspace = nullptr,
698 cudaStream_t stream = nullptr) {
699
700Status status = initialize(args, workspace);
701
702if (status == Status::kSuccess) {
703 status = run(stream);
704 }
705
706return status;
707 }
708 };
709
711
712 } // namespace device
713 } // namespace gemm
714 } // namespace cutlass
715
cutlass::gemm::kernel::DefaultGemm
Definition: default_gemm.h:116
cutlass::gemm::device::Gemm::kStages
static int const kStages
Definition: include/cutlass/gemm/device/gemm.h:238
cutlass::gemm::device::Gemm::Arguments::problem_size
GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm.h:276
Definition: aligned_buffer.h:35
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::split_k_slices](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#aaef8450711318fa1a53fe3cb72b59263)
int split_k_slices
Definition: include/cutlass/gemm/device/gemm.h:606
cutlass::Status::kErrorInvalidProblem
Specified problem size is not supported by operator.
ElementB ElementA
Definition: include/cutlass/gemm/device/gemm.h:219
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_A](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a9bdaf3563983efcca649460be169b334)
TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm.h:601
cutlass::gemm::device::Gemm::get_workspace_size
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm.h:350
ThreadblockSwizzle ThreadblockSwizzle
Definition: include/cutlass/gemm/device/gemm.h:236
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ac6c397a181a52c0dbb39bf3710ee4658)
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm.h:614
cutlass::gemm::device::Gemm::can_implement
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm.h:328
cutlass::gemm::device::Gemm::Arguments::Arguments
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm.h:290
cutlass::gemm::device::Gemm::Arguments::Arguments
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm.h:296
Definition: include/cutlass/gemm/gemm.h:94
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::get_workspace_size](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a1469133c30fde6b28296e3ff6951e7a4)
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm.h:666
Definition: include/cutlass/gemm/device/gemm.h:216
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::device::Gemm::kSplitKSerial
static bool const kSplitKSerial
Definition: include/cutlass/gemm/device/gemm.h:242
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::update](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2b6c5275c173d73cffe8e6b6b1ccf2c1)
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm.h:678
cutlass::gemm::device::Gemm::Arguments::epilogue
EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm.h:281
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::epilogue](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a426f402c08be99849a4477a07f010a5e)
EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm.h:605
cutlass::gemm::device::Gemm::Arguments::ref_A
TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm.h:277
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Gemm](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#abcacf502806db50eb17a6d925aee16d5)
Gemm()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm.h:644
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::can_implement](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a662bcbcb6164c803ab490c86e69b9ee1)
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm.h:660
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::operator()](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a384db4125183e504fafc5a946b7ba757)
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:690
InstructionShape InstructionShape
Definition: include/cutlass/gemm/device/gemm.h:234
Operator Operator
Definition: include/cutlass/gemm/device/gemm.h:237
cutlass::gemm::device::Gemm::update
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm.h:417
cutlass::gemm::device::Gemm::Arguments::split_k_slices
int split_k_slices
Definition: include/cutlass/gemm/device/gemm.h:282
ElementC ElementC
Definition: include/cutlass/gemm/device/gemm.h:225
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::run](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a5f4f93ca97b358b4410f3d0b1e0a6387)
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:684
cutlass::gemm::device::Gemm::operator()
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:471
OperatorClass OperatorClass
Definition: include/cutlass/gemm/device/gemm.h:230
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7615ad046304360243729c29c65e878a)
typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: include/cutlass/gemm/device/gemm.h:589
ElementAccumulator ElementAccumulator
Definition: include/cutlass/gemm/device/gemm.h:229
typename kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kSplitKSerial, Operator, kIsBetaZero >::GemmKernel GemmKernel
Define the kernel.
Definition: include/cutlass/gemm/device/gemm.h:267
cutlass::gemm::device::Gemm::kAlignmentB
static int const kAlignmentB
Definition: include/cutlass/gemm/device/gemm.h:240
cutlass::layout::LayoutTranspose
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
cutlass::gemm::device::Gemm::kAlignmentA
static int const kAlignmentA
Definition: include/cutlass/gemm/device/gemm.h:239
ThreadblockShape ThreadblockShape
Definition: include/cutlass/gemm/device/gemm.h:232
cutlass::gemm::device::Gemm::Arguments::ref_C
TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm.h:279
cutlass::TensorRef< ElementA const, LayoutA >
cutlass::gemm::device::Gemm::Arguments::ref_B
TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm.h:278
cutlass::Status::kErrorInternal
An error within CUTLASS occurred.
Template for generic CUTLASS kernel.
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::problem_size](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#acd02e86dfff866eade08415e0043ccc3)
GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm.h:600
cutlass::gemm::device::Gemm::operator()
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:476
cutlass::gemm::device::Gemm::Gemm
Gemm()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm.h:325
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a331de1adfdcbea6d0137afe64a4f6f4c)
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm.h:618
cutlass::gemm::device::Gemm::LayoutC
LayoutC_ LayoutC
Definition: include/cutlass/gemm/device/gemm.h:226
typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: include/cutlass/gemm/device/gemm.h:223
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_C](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#a590b8da88ae9350042838451e3e37a22)
TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm.h:603
cutlass::gemm::device::Gemm::Arguments
Argument structure.
Definition: include/cutlass/gemm/device/gemm.h:270
[default_gemm_configuration.h](default gemm configuration_8h.html)
Definitions for GEMM structures.
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::initialize](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a7a14474e4238d2fac92ad71c6de087d8)
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm.h:672
cutlass::gemm::device::Gemm::run
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:435
EpilogueOutputOp EpilogueOutputOp
Definition: include/cutlass/gemm/device/gemm.h:235
cutlass::gemm::device::Gemm::initialize
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm.h:369
cutlass::Status::kErrorWorkspaceNull
The given workspace is null when it is required to be non-null.
WarpShape WarpShape
Definition: include/cutlass/gemm/device/gemm.h:233
Operation was successful.
cutlass::gemm::device::Gemm::kAlignmentC
static int const kAlignmentC
Definition: include/cutlass/gemm/device/gemm.h:241
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#aa9313915a6129f0c43b43ef3698b3ee4)
static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator.
Definition: include/cutlass/gemm/device/gemm.h:647
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
ElementA ElementB
Definition: include/cutlass/gemm/device/gemm.h:222
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
cutlass::gemm::device::Gemm::Arguments::ref_D
TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm.h:280
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_D](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab1d4d5865786a415f87db1def1b029e7)
TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm.h:604
cutlass::gemm::device::Gemm::kIsBetaZero
static bool const kIsBetaZero
Definition: include/cutlass/gemm/device/gemm.h:243
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::GemmKernel](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a2bdbc5e737f9bfd1e09a7cfb30e60e29)
typename UnderlyingOperator::GemmKernel GemmKernel
Definition: include/cutlass/gemm/device/gemm.h:590
ArchTag ArchTag
Definition: include/cutlass/gemm/device/gemm.h:231
Basic include for CUTLASS.
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::operator()](classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html#a6115aa957b3ba8ad9e54b7efeefaacd1)
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:695
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
[cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_B](structcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA 00_01LayoutA 00_01ElementB___00_01Layou1b211cc9c97c022d8fe10f2dd32c8709.html#ab77204c1010b17c6643d26a89f41c3d0)
TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm.h:602
typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: include/cutlass/gemm/device/gemm.h:220
Generated by 1.8.11