docs/include_2cutlass_2gemm_2device_2gemm__complex_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
include/cutlass/gemm/device/gemm_complex.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/device_kernel.h"
35
36 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
37 #include "cutlass/gemm/kernel/gemm.h"
38
39 #include "cutlass/gemm/kernel/default_gemm_complex.h"
40 #include "[cutlass/gemm/device/default_gemm_configuration.h](default gemm configuration_8h.html)"
41
43
44 namespace cutlass {
45 namespace gemm {
46 namespace device {
47
49
113
116
119
122
125
128
131
134
137
140
143
146
149
152
155
159 template <
161typename ElementA_,
163typename LayoutA_,
165typename ElementB_,
167typename LayoutB_,
169typename ElementC_,
171typename LayoutC_,
173typename ElementAccumulator_ = ElementC_,
175typename OperatorClass_ = arch::OpClassSimt,
177typename ArchTag_ = arch::Sm70,
179typename ThreadblockShape_ = typename DefaultGemmConfiguration<
180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181 ElementAccumulator_>::ThreadblockShape,
183typename WarpShape_ = typename DefaultGemmConfiguration<
184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185 ElementAccumulator_>::WarpShape,
187typename InstructionShape_ = typename DefaultGemmConfiguration<
188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189 ElementAccumulator_>::InstructionShape,
191typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193 ElementAccumulator_>::EpilogueOutputOp,
195typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
197int Stages =
198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199 ElementC_, ElementAccumulator_>::kStages,
201ComplexTransform TransformA = ComplexTransform::kNone,
203ComplexTransform TransformB = ComplexTransform::kNone,
205bool SplitKSerial = false
206 >
207 class GemmComplex {
208public:
209
210using ElementA = ElementA_;
212using TensorRefA = TensorRef<ElementA const, LayoutA>;
213using ElementB = ElementB_;
215using TensorRefB = TensorRef<ElementB const, LayoutB>;
216using ElementC = ElementC_;
218using TensorRefC = TensorRef<ElementC const, LayoutC>;
219using TensorRefD = TensorRef<ElementC, LayoutC>;
220using ElementAccumulator = ElementAccumulator_;
221using OperatorClass = OperatorClass_;
223using ThreadblockShape = ThreadblockShape_;
224using WarpShape = WarpShape_;
225using InstructionShape = InstructionShape_;
226using EpilogueOutputOp = EpilogueOutputOp_;
227using ThreadblockSwizzle = ThreadblockSwizzle_;
228static int const kStages = Stages;
229static ComplexTransform const kTransformA = TransformA;
230static ComplexTransform const kTransformB = TransformB;
231static bool const kSplitKSerial = SplitKSerial;
232
234using GemmKernel = typename kernel::DefaultGemmComplex<
235ElementA,
236LayoutA,
237ElementB,
238LayoutB,
239ElementC,
240LayoutC,
242OperatorClass,
243ArchTag,
244ThreadblockShape,
245WarpShape,
246InstructionShape,
247EpilogueOutputOp,
249kStages,
250kTransformA,
251kTransformB,
252 kSplitKSerial
254
257
258//
259// Data members
260//
261
263TensorRef<ElementA const, LayoutA> ref_A;
264TensorRef<ElementB const, LayoutB> ref_B;
265TensorRef<ElementC const, LayoutC> ref_C;
266TensorRef<ElementC, LayoutC> ref_D;
267typename EpilogueOutputOp::Params epilogue;
268int split_k_slices;
269
270//
271// Methods
272//
273
276Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
277
278 }
279
283GemmCoord problem_size_,
284TensorRef<ElementA const, LayoutA> ref_A_,
285TensorRef<ElementB const, LayoutB> ref_B_,
286TensorRef<ElementC const, LayoutC> ref_C_,
287TensorRef<ElementC, LayoutC> ref_D_,
288typename EpilogueOutputOp::Params epilogue_ =
289typename EpilogueOutputOp::Params(),
290int split_k_slices = 1
291 ):
292 problem_size(problem_size_),
293 ref_A(ref_A_),
294 ref_B(ref_B_),
295 ref_C(ref_C_),
296 ref_D(ref_D_),
297 epilogue(epilogue_),
298 split_k_slices(split_k_slices) {
299
300 }
301 };
302
303 private:
304
306typename GemmKernel::Params params_;
307
308 public:
309
311GemmComplex() { }
312
314static Status can_implement(Arguments const &args) {
315
316if (!kSplitKSerial && args.split_k_slices > 1) {
317return Status::kErrorInvalidProblem;
318 }
319
320return Status::kSuccess;
321 }
322
324static size_t get_workspace_size(Arguments const &args) {
325
326if (kSplitKSerial && args.split_k_slices > 1) {
327
328// Determine grid shape
329 ThreadblockSwizzle threadblock_swizzle;
330
331cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
332 args.problem_size,
333 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
334 args.split_k_slices);
335
336return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
337 }
338
339return 0;
340 }
341
343Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
344
345// Determine grid shape
346 ThreadblockSwizzle threadblock_swizzle;
347
348cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
349 args.problem_size,
350 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
351 args.split_k_slices);
352
353if (kSplitKSerial) {
354if (args.split_k_slices > 1) {
355if (!workspace) {
356return Status::kErrorWorkspaceNull;
357 }
358
359size_t bytes = get_workspace_size(args);
360
361 cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
362
363if (result != cudaSuccess) {
364return Status::kErrorInternal;
365 }
366 }
367 }
368else {
369
370if (args.split_k_slices > 1) {
371return Status::kErrorInvalidProblem;
372 }
373 }
374
375// Initialize the Params structure
376 params_ = typename GemmKernel::Params{
377 args.problem_size,
378 grid_shape,
379 args.ref_A.non_const_ref(),
380 args.ref_B.non_const_ref(),
381 args.ref_C.non_const_ref(),
382 args.ref_D,
383 args.epilogue,
384static_cast<int *>(workspace)
385 };
386
387return Status::kSuccess;
388 }
389
391Status update(Arguments const &args, void *workspace = nullptr) {
392
393if (kSplitKSerial && args.split_k_slices > 1) {
394if (!workspace) {
395return Status::kErrorWorkspaceNull;
396 }
397 }
398
399 params_.ref_A.reset(args.ref_A.non_const_ref().data());
400 params_.ref_B.reset(args.ref_B.non_const_ref().data());
401 params_.ref_C.reset(args.ref_C.non_const_ref().data());
402 params_.ref_D.reset(args.ref_D.data());
403 params_.semaphore = static_cast<int *>(workspace);
404
405return Status::kSuccess;
406 }
407
409Status run(cudaStream_t stream = nullptr) {
410
411 ThreadblockSwizzle threadblock_swizzle;
412
413 dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
414 dim3 block(GemmKernel::kThreadCount, 1, 1);
415
416 cudaError_t result;
417
418int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
419if (smem_size >= (48 << 10)) {
420 result = cudaFuncSetAttribute(Kernel<GemmKernel>,
421 cudaFuncAttributeMaxDynamicSharedMemorySize,
422 smem_size);
423
424if (result != cudaSuccess) {
425return Status::kErrorInternal;
426 }
427
428 result = cudaFuncSetAttribute(
429 Kernel<GemmKernel>,
430 cudaFuncAttributePreferredSharedMemoryCarveout, 100);
431
432if (result != cudaSuccess) {
433return Status::kErrorInternal;
434 }
435 }
436
437 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
438
439 result = cudaGetLastError();
440
441return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
442 }
443
445Status operator()(cudaStream_t stream = nullptr) {
446return run(stream);
447 }
448
451Arguments const &args,
452void *workspace = nullptr,
453 cudaStream_t stream = nullptr) {
454
455Status status = initialize(args, workspace);
456
457if (status == Status::kSuccess) {
458 status = run(stream);
459 }
460
461return status;
462 }
463 };
464
466
468 template <
470typename ElementA_,
472typename LayoutA_,
474typename ElementB_,
476typename LayoutB_,
478typename ElementC_,
480typename ElementAccumulator_,
482typename OperatorClass_,
484typename ArchTag_,
486typename ThreadblockShape_,
488typename WarpShape_,
490typename InstructionShape_,
492typename EpilogueOutputOp_,
494typename ThreadblockSwizzle_,
496int Stages,
498ComplexTransform TransformA,
500ComplexTransform TransformB,
502bool SplitKSerial
503 >
[504](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html) class GemmComplex<
505 ElementA_,
506 LayoutA_,
507 ElementB_,
508 LayoutB_,
509 ElementC_,
510 layout::ColumnMajor, // partially specialized on LayoutC
511 ElementAccumulator_,
512 OperatorClass_,
513 ArchTag_,
514 ThreadblockShape_,
515 WarpShape_,
516 InstructionShape_,
517 EpilogueOutputOp_,
518 ThreadblockSwizzle_,
519 Stages,
520 TransformA,
521 TransformB,
522 SplitKSerial
523 > {
524 public:
525
[526](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#aa6621903fd434110b57220f2b2fb97cb)using ElementA = ElementA_;
[527](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a9b48f3a933f3b37814f9b70503b7684a)using LayoutA = LayoutA_;
[528](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#ae8995862cf7e42bc086f5941c1aa5d35)using TensorRefA = TensorRef<ElementA const, LayoutA>;
[529](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3570f3ed978cba7f66d1310ce66a56b3)using ElementB = ElementB_;
[530](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a90e18e93d96cd07f03201134d3c1b5a0)using LayoutB = LayoutB_;
[531](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3c170badab35f7754939a4cd9d8258fe)using TensorRefB = TensorRef<ElementB const, LayoutB>;
[532](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a77d1c52156347656311764de09456670)using ElementC = ElementC_;
[533](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#af2b903fa011363e7049d5f0807b77731)using LayoutC = layout::ColumnMajor;
[534](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#ae77f478fd7cff440628fb38e230f2609)using TensorRefC = TensorRef<ElementC const, LayoutC>;
[535](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a5b68c920af70250817a7791d91ab77f5)using TensorRefD = TensorRef<ElementC, LayoutC>;
[536](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a233bda6413e491449aa29b3222c60904)using ElementAccumulator = ElementAccumulator_;
[537](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a4d170f269f81dafe07770197c3864a6b)using OperatorClass = OperatorClass_;
[538](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a15200a21650efa7f582747dbbad044ca)using ArchTag = ArchTag_;
[539](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#ae09f224faa27d9735ab77899d36dbc96)using ThreadblockShape = ThreadblockShape_;
[540](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a09a291542ae92f0fb97a4c0a5ee25db4)using WarpShape = WarpShape_;
[541](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#af691b4304896f601c34eaedf78493ed5)using InstructionShape = InstructionShape_;
[542](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#ab2b9c1976d62f70a32d93f55f79a2401)using EpilogueOutputOp = EpilogueOutputOp_;
[543](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a02473fb6e60eed4bf79b510d7096b4c5)using ThreadblockSwizzle = ThreadblockSwizzle_;
[544](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a689afffc991cf4e6aab7d6e4f5fe4d46)static int const kStages = Stages;
[545](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#afe14a91a30bea2204d4351591df7b5cc)static bool const kSplitKSerial = SplitKSerial;
546
547using UnderlyingOperator = GemmComplex<
548ElementB,
549typename layout::LayoutTranspose<LayoutB>::type,
550ElementA,
551typename layout::LayoutTranspose<LayoutA>::type,
552ElementC,
553layout::RowMajor,
555OperatorClass,
556ArchTag,
557ThreadblockShape,
558WarpShape,
559InstructionShape,
560EpilogueOutputOp,
562 Stages,
563 TransformA,
564 TransformB,
565 SplitKSerial
[566](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#aa55140ff232b12c3a4bf1e5093282354) >;
567
[568](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b)using [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b) = typename UnderlyingOperator::Arguments;
[569](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#abe65836275404d572a7e1e2108c72982)using [GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#abe65836275404d572a7e1e2108c72982) = typename UnderlyingOperator::GemmKernel;
570
[572](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html)struct Arguments {
573
574//
575// Data members
576//
577
[578](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a29159f430d4a733ec3fac550d0458e18)GemmCoord [problem_size](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a29159f430d4a733ec3fac550d0458e18);
[579](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ac8e9298e3786e9391d740faa4d0566f2)TensorRef<ElementA const, LayoutA> [ref_A](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#ac8e9298e3786e9391d740faa4d0566f2);
[580](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ab706387c660af35ae2b9579165eec85d)TensorRef<ElementB const, LayoutB> [ref_B](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#ab706387c660af35ae2b9579165eec85d);
[581](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a3a59aa793429bc57d796b40fa4fab622)TensorRef<ElementC const, LayoutC> [ref_C](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a3a59aa793429bc57d796b40fa4fab622);
[582](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a2904e3ad7a47b3d85ea60d94eeebe84b)TensorRef<ElementC, LayoutC> [ref_D](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a2904e3ad7a47b3d85ea60d94eeebe84b);
[583](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ad1f435bf8b7003afad9b803adf9fcb89)typename EpilogueOutputOp::Params [epilogue](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#ad1f435bf8b7003afad9b803adf9fcb89);
[584](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#aec118721190212e7e61c7d17d4c93d1c)int [split_k_slices](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#aec118721190212e7e61c7d17d4c93d1c);
585
586//
587// Methods
588//
589
[592](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a710950fddfc99fc79302cbfe959bb201)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a710950fddfc99fc79302cbfe959bb201)() { }
593
[596](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a8886db2fcca9a63381861662d318ad12)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_a3923967cafb5cb9774c320dc24baa77.html#a8886db2fcca9a63381861662d318ad12)(
597GemmCoord problem_size_,
598TensorRef<ElementA const, LayoutA> ref_A_,
599TensorRef<ElementB const, LayoutB> ref_B_,
600TensorRef<ElementC const, LayoutC> ref_C_,
601TensorRef<ElementC, LayoutC> ref_D_,
602typename EpilogueOutputOp::Params epilogue_ =
603typename EpilogueOutputOp::Params(),
604int split_k_slices = 1
605 ):
606 problem_size(problem_size_),
607 ref_A(ref_A_),
608 ref_B(ref_B_),
609 ref_C(ref_C_),
610 ref_D(ref_D_),
611 epilogue(epilogue_),
612 split_k_slices(split_k_slices) { }
613 };
614
615 private:
616
617UnderlyingOperator underlying_operator_;
618
619 public:
620
[622](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a9ce748bfc112dd4bb942c5e7c95845df)[GemmComplex](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a9ce748bfc112dd4bb942c5e7c95845df)() { }
623
[625](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3dd09eeeae6c4faeddc4abc8bb57b177)static [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b) [to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3dd09eeeae6c4faeddc4abc8bb57b177)(Arguments const &args) {
626return [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b)(
627 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
628 {args.ref_B.data(), args.ref_B.stride(0)},
629 {args.ref_A.data(), args.ref_A.stride(0)},
630 {args.ref_C.data(), args.ref_C.stride(0)},
631 {args.ref_D.data(), args.ref_D.stride(0)},
632 args.epilogue,
633 args.split_k_slices
634 );
635 }
636
[638](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#adb94d2e6dd70b46bea6b5b433e14fea9)static Status [can_implement](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#adb94d2e6dd70b46bea6b5b433e14fea9)(Arguments const &args) {
639
640return UnderlyingOperator::can_implement(to_underlying_arguments(args));
641 }
642
[644](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a75342fc4122c07d1382b31ee5f188210)static size_t [get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a75342fc4122c07d1382b31ee5f188210)(Arguments const &args) {
645
646return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
647 }
648
[650](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a5c3286631f254746c9eb788b780cdca3)Status [initialize](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a5c3286631f254746c9eb788b780cdca3)(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
651
652return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
653 }
654
[656](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a08446a157a60f7f1e23315c1ece09bce)Status [update](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a08446a157a60f7f1e23315c1ece09bce)(Arguments const &args, void *workspace = nullptr) {
657
658return underlying_operator_.update(to_underlying_arguments(args), workspace);
659 }
660
[662](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a4111bba1e9d2000fcc9bba2f114ee801)Status [run](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a4111bba1e9d2000fcc9bba2f114ee801)(cudaStream_t stream = nullptr) {
663
664return underlying_operator_.run(stream);
665 }
666
[668](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a375220f643161478c1fb5bcd24f8b5cd)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a375220f643161478c1fb5bcd24f8b5cd)(cudaStream_t stream = nullptr) {
669return run(stream);
670 }
671
[673](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a50ff89a3c0b3735b669cf4e3b755918a)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA 00_01ElementB 00_07c56401b4df75709ae636675d9980a9a.html#a50ff89a3c0b3735b669cf4e3b755918a)(
674Arguments const &args,
675void *workspace = nullptr,
676 cudaStream_t stream = nullptr) {
677
678Status status = initialize(args, workspace);
679
680if (status == Status::kSuccess) {
681 status = run(stream);
682 }
683
684return status;
685 }
686 };
687
689
690 } // namespace device
691 } // namespace gemm
692 } // namespace cutlass
693
cutlass::gemm::device::GemmComplex::operator()
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:450
cutlass::gemm::device::GemmComplex::kTransformA
static ComplexTransform const kTransformA
Definition: include/cutlass/gemm/device/gemm_complex.h:229
cutlass::gemm::device::GemmComplex::Arguments::ref_A
TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm_complex.h:263
cutlass::gemm::device::GemmComplex
Definition: include/cutlass/gemm/device/gemm_complex.h:207
Definition: aligned_buffer.h:35
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: complex.h:43
ElementA ElementB
Definition: include/cutlass/gemm/device/gemm_complex.h:213
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_C](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a3a59aa793429bc57d796b40fa4fab622)
TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm_complex.h:581
cutlass::gemm::device::GemmComplex::run
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:409
cutlass::Status::kErrorInvalidProblem
Specified problem size is not supported by operator.
cutlass::gemm::device::GemmComplex::Arguments::problem_size
GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm_complex.h:262
cutlass::gemm::device::GemmComplex::can_implement
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm_complex.h:314
EpilogueOutputOp EpilogueOutputOp
Definition: include/cutlass/gemm/device/gemm_complex.h:226
cutlass::gemm::device::GemmComplex::kStages
static int const kStages
Definition: include/cutlass/gemm/device/gemm_complex.h:228
ArchTag ArchTag
Definition: include/cutlass/gemm/device/gemm_complex.h:222
cutlass::gemm::device::GemmComplex::Arguments::ref_B
TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm_complex.h:264
Definition: include/cutlass/gemm/gemm.h:94
cutlass::gemm::device::GemmComplex::Arguments::epilogue
EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm_complex.h:267
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a50ff89a3c0b3735b669cf4e3b755918a)
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:673
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::device::GemmComplex::initialize
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:343
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::split_k_slices](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#aec118721190212e7e61c7d17d4c93d1c)
int split_k_slices
Definition: include/cutlass/gemm/device/gemm_complex.h:584
cutlass::ComplexTransform::kNone
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a3dd09eeeae6c4faeddc4abc8bb57b177)
static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator.
Definition: include/cutlass/gemm/device/gemm_complex.h:625
typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: include/cutlass/gemm/device/gemm_complex.h:211
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
cutlass::gemm::device::GemmComplex::Arguments
Argument structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:256
cutlass::gemm::device::GemmComplex::get_workspace_size
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm_complex.h:324
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::initialize](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a5c3286631f254746c9eb788b780cdca3)
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:650
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::run](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a4111bba1e9d2000fcc9bba2f114ee801)
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:662
cutlass::gemm::device::GemmComplex::LayoutC
LayoutC_ LayoutC
Definition: include/cutlass/gemm/device/gemm_complex.h:217
cutlass::gemm::device::GemmComplex::Arguments::Arguments
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm_complex.h:276
cutlass::layout::LayoutTranspose
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
ElementC ElementC
Definition: include/cutlass/gemm/device/gemm_complex.h:216
cutlass::TensorRef< ElementA const, LayoutA >
ElementAccumulator ElementAccumulator
Definition: include/cutlass/gemm/device/gemm_complex.h:220
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a75342fc4122c07d1382b31ee5f188210)
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm_complex.h:644
cutlass::Status::kErrorInternal
An error within CUTLASS occurred.
cutlass::gemm::device::GemmComplex::operator()
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:445
Template for generic CUTLASS kernel.
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::GemmComplex](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a9ce748bfc112dd4bb942c5e7c95845df)
GemmComplex()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm_complex.h:622
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a375220f643161478c1fb5bcd24f8b5cd)
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:668
ThreadblockSwizzle ThreadblockSwizzle
Definition: include/cutlass/gemm/device/gemm_complex.h:227
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
cutlass::gemm::device::GemmComplex::GemmComplex
GemmComplex()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm_complex.h:311
[default_gemm_configuration.h](default gemm configuration_8h.html)
Definitions for GEMM structures.
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#abe65836275404d572a7e1e2108c72982)
typename UnderlyingOperator::GemmKernel GemmKernel
Definition: include/cutlass/gemm/device/gemm_complex.h:569
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::problem_size](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a29159f430d4a733ec3fac550d0458e18)
GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm_complex.h:578
typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: include/cutlass/gemm/device/gemm_complex.h:214
cutlass::gemm::device::GemmComplex::Arguments::ref_D
TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm_complex.h:266
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
cutlass::gemm::device::GemmComplex::Arguments::ref_C
TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm_complex.h:265
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a171a1b3ddd40fb9f318e51fa28029f4b)
typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: include/cutlass/gemm/device/gemm_complex.h:568
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_B](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ab706387c660af35ae2b9579165eec85d)
TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm_complex.h:580
WarpShape WarpShape
Definition: include/cutlass/gemm/device/gemm_complex.h:224
cutlass::gemm::device::GemmComplex::update
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:391
cutlass::gemm::device::GemmComplex::kSplitKSerial
static bool const kSplitKSerial
Definition: include/cutlass/gemm/device/gemm_complex.h:231
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::update](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#a08446a157a60f7f1e23315c1ece09bce)
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:656
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a710950fddfc99fc79302cbfe959bb201)
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm_complex.h:592
cutlass::Status::kErrorWorkspaceNull
The given workspace is null when it is required to be non-null.
Operation was successful.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::can_implement](classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html#adb94d2e6dd70b46bea6b5b433e14fea9)
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm_complex.h:638
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_D](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a2904e3ad7a47b3d85ea60d94eeebe84b)
TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm_complex.h:582
cutlass::gemm::device::GemmComplex::Arguments::split_k_slices
int split_k_slices
Definition: include/cutlass/gemm/device/gemm_complex.h:268
InstructionShape InstructionShape
Definition: include/cutlass/gemm/device/gemm_complex.h:225
OperatorClass OperatorClass
Definition: include/cutlass/gemm/device/gemm_complex.h:221
cutlass::gemm::device::GemmComplex::Arguments::Arguments
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:282
typename kernel::DefaultGemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kTransformA, kTransformB, kSplitKSerial >::GemmKernel GemmKernel
Define the kernel.
Definition: include/cutlass/gemm/device/gemm_complex.h:253
ThreadblockShape ThreadblockShape
Definition: include/cutlass/gemm/device/gemm_complex.h:223
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::epilogue](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ad1f435bf8b7003afad9b803adf9fcb89)
EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm_complex.h:583
ElementB ElementA
Definition: include/cutlass/gemm/device/gemm_complex.h:210
cutlass::gemm::device::GemmComplex::kTransformB
static ComplexTransform const kTransformB
Definition: include/cutlass/gemm/device/gemm_complex.h:230
Basic include for CUTLASS.
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#a8886db2fcca9a63381861662d318ad12)
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:596
[cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_A](structcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA 00_01LayoutA 00_01ElementB___00_a3923967cafb5cb9774c320dc24baa77.html#ac8e9298e3786e9391d740faa4d0566f2)
TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm_complex.h:579
Generated by 1.8.11