docs/device_2gemm__batched_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
device/gemm_batched.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/device_kernel.h"
35
36 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
37 #include "cutlass/gemm/kernel/gemm_batched.h"
38
39 #include "cutlass/gemm/kernel/default_gemm.h"
40 #include "[cutlass/gemm/device/default_gemm_configuration.h](default gemm configuration_8h.html)"
41
43
44 namespace cutlass {
45 namespace gemm {
46 namespace device {
47
49
113
116
119
122
125
128
131
134
137
140
143
146
149
152
155
159 template <
161typename ElementA_,
163typename LayoutA_,
165typename ElementB_,
167typename LayoutB_,
169typename ElementC_,
171typename LayoutC_,
173typename ElementAccumulator_ = ElementC_,
175typename OperatorClass_ = arch::OpClassSimt,
177typename ArchTag_ = arch::Sm70,
179typename ThreadblockShape_ = typename DefaultGemmConfiguration<
180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181 ElementAccumulator_>::ThreadblockShape,
183typename WarpShape_ = typename DefaultGemmConfiguration<
184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185 ElementAccumulator_>::WarpShape,
187typename InstructionShape_ = typename DefaultGemmConfiguration<
188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189 ElementAccumulator_>::InstructionShape,
191typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193 ElementAccumulator_>::EpilogueOutputOp,
195typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle,
197int Stages =
198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199 ElementC_, ElementAccumulator_>::kStages,
201int AlignmentA =
202 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
203 ElementC_, ElementAccumulator_>::kAlignmentA,
205int AlignmentB =
206 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
207 ElementC_, ElementAccumulator_>::kAlignmentB,
209typename Operator_ = typename DefaultGemmConfiguration<
210 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
211 ElementAccumulator_>::Operator
212 >
213 class GemmBatched {
214public:
215
216using ElementA = ElementA_;
218using TensorRefA = TensorRef<ElementA const, LayoutA>;
219using ElementB = ElementB_;
221using TensorRefB = TensorRef<ElementB const, LayoutB>;
222using ElementC = ElementC_;
224using TensorRefC = TensorRef<ElementC const, LayoutC>;
225using TensorRefD = TensorRef<ElementC, LayoutC>;
226using ElementAccumulator = ElementAccumulator_;
227using OperatorClass = OperatorClass_;
229using ThreadblockShape = ThreadblockShape_;
230using WarpShape = WarpShape_;
231using InstructionShape = InstructionShape_;
232using EpilogueOutputOp = EpilogueOutputOp_;
233using ThreadblockSwizzle = ThreadblockSwizzle_;
234static int const kStages = Stages;
235static int const kAlignmentA = AlignmentA;
236static int const kAlignmentB = AlignmentB;
237static int const kAlignmentC = EpilogueOutputOp::kCount;
238using Operator = Operator_;
239
241using DefaultGemmKernel = typename kernel::DefaultGemm<
242ElementA,
243LayoutA,
244kAlignmentA,
245ElementB,
246LayoutB,
247kAlignmentB,
248ElementC,
249LayoutC,
251OperatorClass,
252ArchTag,
253ThreadblockShape,
254WarpShape,
255InstructionShape,
256EpilogueOutputOp,
258kStages,
259false,
260Operator,
261false
263
264using GemmKernel = kernel::GemmBatched<typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle>;
265
268
269//
270// Data members
271//
272
274TensorRef<ElementA const, LayoutA> ref_A;
276TensorRef<ElementB const, LayoutB> ref_B;
278TensorRef<ElementC const, LayoutC> ref_C;
280TensorRef<ElementC, LayoutC> ref_D;
282typename EpilogueOutputOp::Params epilogue;
283int batch_count;
284
285//
286// Methods
287//
288
292
296GemmCoord problem_size_,
297TensorRef<ElementA const, LayoutA> ref_A_,
298 int64_t stride_A_,
299TensorRef<ElementB const, LayoutB> ref_B_,
300 int64_t stride_B_,
301TensorRef<ElementC const, LayoutC> ref_C_,
302 int64_t stride_C_,
303TensorRef<ElementC, LayoutC> ref_D_,
304 int64_t stride_D_,
305typename EpilogueOutputOp::Params epilogue_,
306int batch_count_
307 ):
308 problem_size(problem_size_),
309 ref_A(ref_A_),
310 stride_A(stride_A_),
311 ref_B(ref_B_),
312 stride_B(stride_B_),
313 ref_C(ref_C_),
314 stride_C(stride_C_),
315 ref_D(ref_D_),
316 stride_D(stride_D_),
317 epilogue(epilogue_),
318 batch_count(batch_count_) { }
319 };
320
321 private:
322
324typename GemmKernel::Params params_;
325
326 public:
327
329GemmBatched() { }
330
332static Status can_implement(Arguments const &args) {
333
334if (! TensorRef_aligned(args.ref_A, kAlignmentA) || (args.stride_A % kAlignmentA)) {
335return Status::kErrorMisalignedOperand;
336 }
337
338if (! TensorRef_aligned(args.ref_B, kAlignmentB) || (args.stride_B % kAlignmentB)) {
339return Status::kErrorMisalignedOperand;
340 }
341
342if (! TensorRef_aligned(args.ref_C, kAlignmentC) || (args.stride_C % kAlignmentC)) {
343return Status::kErrorMisalignedOperand;
344 }
345
346if (! TensorRef_aligned(args.ref_D, kAlignmentC) || (args.stride_D % kAlignmentC)) {
347return Status::kErrorMisalignedOperand;
348 }
349
350if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||
351 (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||
352 (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {
353
354return Status::kErrorMisalignedOperand;
355 }
356
357return Status::kSuccess;
358 }
359
361static size_t get_workspace_size(Arguments const &args) {
362return 0;
363 }
364
366Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
367
368// Determine grid shape
369 ThreadblockSwizzle threadblock_swizzle;
370
371cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
372 args.problem_size,
373 args.batch_count,
374 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
375
376// Initialize the Params structure
377 params_ = typename GemmKernel::Params{
378 args.problem_size,
379 grid_shape,
380 args.ref_A.non_const_ref(),
381 args.stride_A,
382 args.ref_B.non_const_ref(),
383 args.stride_B,
384 args.ref_C.non_const_ref(),
385 args.stride_C,
386 args.ref_D,
387 args.stride_D,
388 args.epilogue,
389 args.batch_count
390 };
391
392return Status::kSuccess;
393 }
394
396Status update(Arguments const &args, void *workspace = nullptr) {
397
398 params_.ref_A.reset(args.ref_A.non_const_ref().data());
399 params_.ref_B.reset(args.ref_B.non_const_ref().data());
400 params_.ref_C.reset(args.ref_C.non_const_ref().data());
401 params_.ref_D.reset(args.ref_D.data());
402
403return Status::kSuccess;
404 }
405
407Status run(cudaStream_t stream = nullptr) {
408
409 ThreadblockSwizzle threadblock_swizzle;
410
411 dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
412 dim3 block(GemmKernel::kThreadCount, 1, 1);
413
414 cudaError_t result;
415
416int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
417if (smem_size >= (48 << 10)) {
418 result = cudaFuncSetAttribute(Kernel<GemmKernel>,
419 cudaFuncAttributeMaxDynamicSharedMemorySize,
420 smem_size);
421
422if (result != cudaSuccess) {
423return Status::kErrorInternal;
424 }
425
426 result = cudaFuncSetAttribute(
427 Kernel<GemmKernel>,
428 cudaFuncAttributePreferredSharedMemoryCarveout, 100);
429
430if (result != cudaSuccess) {
431return Status::kErrorInternal;
432 }
433 }
434
435 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
436
437 result = cudaGetLastError();
438
439return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
440 }
441
443Status operator()(cudaStream_t stream = nullptr) {
444return run(stream);
445 }
446
449Arguments const &args,
450void *workspace = nullptr,
451 cudaStream_t stream = nullptr) {
452
453Status status = initialize(args, workspace);
454
455if (status == Status::kSuccess) {
456 status = run(stream);
457 }
458
459return status;
460 }
461 };
462
464
466 template <
468typename ElementA_,
470typename LayoutA_,
472typename ElementB_,
474typename LayoutB_,
476typename ElementC_,
478typename ElementAccumulator_,
480typename OperatorClass_,
482typename ArchTag_,
484typename ThreadblockShape_,
486typename WarpShape_,
488typename InstructionShape_,
490typename EpilogueOutputOp_,
492typename ThreadblockSwizzle_,
494int Stages,
496int AlignmentA,
498int AlignmentB,
499typename Operator_
500 >
[501](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html) class GemmBatched<
502 ElementA_,
503 LayoutA_,
504 ElementB_,
505 LayoutB_,
506 ElementC_,
507 layout::ColumnMajor,
508 ElementAccumulator_,
509 OperatorClass_,
510 ArchTag_,
511 ThreadblockShape_,
512 WarpShape_,
513 InstructionShape_,
514 EpilogueOutputOp_,
515 ThreadblockSwizzle_,
516 Stages,
517 AlignmentA,
518 AlignmentB,
519 Operator_
520 > {
521 public:
522
[523](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a52b9261576b5633e901719f7c21d3369)using [ElementA](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a52b9261576b5633e901719f7c21d3369) = ElementA_;
[524](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af623ca54d9554cdfafc09af7a22cdd62)using [LayoutA](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af623ca54d9554cdfafc09af7a22cdd62) = LayoutA_;
[525](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a18266ad32200d3a72aba6e17a6297a3a)using TensorRefA = TensorRef<ElementA const, LayoutA>;
[526](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3fd5c64783f88a7533801fef7d1375ad)using [ElementB](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3fd5c64783f88a7533801fef7d1375ad) = ElementB_;
[527](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a4aaaa6ca0e4b9f983fe37b4105fd058f)using [LayoutB](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a4aaaa6ca0e4b9f983fe37b4105fd058f) = LayoutB_;
[528](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a5595a5e74a0fb536794edf94cd5c7b7f)using TensorRefB = TensorRef<ElementB const, LayoutB>;
[529](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#aef19ab5158e41856723852b3e307cc5d)using [ElementC](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#aef19ab5158e41856723852b3e307cc5d) = ElementC_;
[530](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#aed31a68c08cbfe9bf32d788be3f41679)using LayoutC = layout::ColumnMajor;
[531](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a04e1ec5b0634d45b9ae6811c0ea9f528)using TensorRefC = TensorRef<ElementC const, LayoutC>;
[532](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#acd52c5c939493b3446af9682a2f7793c)using TensorRefD = TensorRef<ElementC, LayoutC>;
[533](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae7f006ea8bc324d31de9dfbebc1b9327)using [ElementAccumulator](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae7f006ea8bc324d31de9dfbebc1b9327) = ElementAccumulator_;
[534](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a37600c0bf3570bc4b21c26b2b64fc54a)using [OperatorClass](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a37600c0bf3570bc4b21c26b2b64fc54a) = OperatorClass_;
[535](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a681b145a9701109f9d72059bb874895b)using [ArchTag](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a681b145a9701109f9d72059bb874895b) = ArchTag_;
[536](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a657e50fb03ea4d16f7b904920d9aa000)using [ThreadblockShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a657e50fb03ea4d16f7b904920d9aa000) = ThreadblockShape_;
[537](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3760f803bd2b31b3fdf47741caa950fa)using [WarpShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3760f803bd2b31b3fdf47741caa950fa) = WarpShape_;
[538](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae073edad6dd4447d7f99c94f4cd0c1c8)using [InstructionShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae073edad6dd4447d7f99c94f4cd0c1c8) = InstructionShape_;
[539](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a494be150d3b809a4ecf66df682481905)using [EpilogueOutputOp](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a494be150d3b809a4ecf66df682481905) = EpilogueOutputOp_;
[540](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af8b282788223086b80fbb097b22459ec)using [ThreadblockSwizzle](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af8b282788223086b80fbb097b22459ec) = ThreadblockSwizzle_;
[541](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab7f6a87909a3c2d45de71367a0d6eae3)static int const kStages = Stages;
542
[543](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a4b924723475dcef72e0130ce1bb43956)static int const kAlignmentA = AlignmentA;
[544](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a8f5d41976058b08562aa1819687d79a2)static int const kAlignmentB = AlignmentB;
[545](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a79d27ed8dc23cc975f287ec0f041ddf9)static int const kAlignmentC = EpilogueOutputOp::kCount;
[546](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a5a77d26d895197ff5224dac759e05766)static bool const kSplitKSerial = false;
547
548//
549using UnderlyingOperator = GemmBatched<
550ElementB,
551typename layout::LayoutTranspose<LayoutB>::type,
552ElementA,
553typename layout::LayoutTranspose<LayoutA>::type,
554ElementC,
555layout::RowMajor,
557OperatorClass,
558ArchTag,
559ThreadblockShape,
560WarpShape,
561InstructionShape,
562EpilogueOutputOp,
564 Stages,
565kAlignmentB,
566 kAlignmentA
[567](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a55141da9e85b0c3556e531a2a6c19126) >;
568
[569](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992)using [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992) = typename UnderlyingOperator::Arguments;
[570](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3947c9b192bec2fad631334f31632353)using [GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3947c9b192bec2fad631334f31632353) = typename UnderlyingOperator::GemmKernel;
571
[573](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html)struct Arguments {
574
575//
576// Data members
577//
578
[579](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ad0469cc3e961d21e212d026bccf6fe1a)GemmCoord [problem_size](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ad0469cc3e961d21e212d026bccf6fe1a);
[580](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a1727630fc0525724df28a75ccf2580b9)TensorRef<ElementA const, LayoutA> [ref_A](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a1727630fc0525724df28a75ccf2580b9);
[581](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ac8830c9ed0e0a8bd7aa2aa4382550a2f) int64_t [stride_A](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ac8830c9ed0e0a8bd7aa2aa4382550a2f);
[582](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ad7d2b82b83d7503b9f920ce3bdcdffa5)TensorRef<ElementB const, LayoutB> [ref_B](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ad7d2b82b83d7503b9f920ce3bdcdffa5);
[583](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a302101a4e5c00c843b3c525ddb94c117) int64_t [stride_B](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a302101a4e5c00c843b3c525ddb94c117);
[584](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#aa9e30e41627595590421d8b53941b2b2)TensorRef<ElementC const, LayoutC> [ref_C](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#aa9e30e41627595590421d8b53941b2b2);
[585](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a9f8a044d7b7439192dfe2bf488558ed3) int64_t [stride_C](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a9f8a044d7b7439192dfe2bf488558ed3);
[586](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a17c4e381e91229a8ef15b18ee5ec073d)TensorRef<ElementC, LayoutC> [ref_D](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a17c4e381e91229a8ef15b18ee5ec073d);
[587](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ac181dba327e605b6cde9de5c7f176e7c) int64_t [stride_D](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ac181dba327e605b6cde9de5c7f176e7c);
[588](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#af9c2fa1e0cc0456197c2cc0840c89982)typename EpilogueOutputOp::Params [epilogue](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#af9c2fa1e0cc0456197c2cc0840c89982);
[589](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#adb66f3083f56c15578b139b7935452b5)int [batch_count](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#adb66f3083f56c15578b139b7935452b5);
590
591//
592// Methods
593//
594
[597](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ae86daa985279c77e57e682b64a68d330)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#ae86daa985279c77e57e682b64a68d330)() { }
598
[601](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a2129a4dccbd73f8c0f26b08ce5a5cb28)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_213d78696663f4231cd52c6a277c60e5.html#a2129a4dccbd73f8c0f26b08ce5a5cb28)(
602GemmCoord problem_size_,
603TensorRef<ElementA const, LayoutA> ref_A_,
604 int64_t stride_A_,
605TensorRef<ElementB const, LayoutB> ref_B_,
606 int64_t stride_B_,
607TensorRef<ElementC const, LayoutC> ref_C_,
608 int64_t stride_C_,
609TensorRef<ElementC, LayoutC> ref_D_,
610 int64_t stride_D_,
611typename EpilogueOutputOp::Params epilogue_,
612int batch_count_
613 ):
614 problem_size(problem_size_),
615 ref_A(ref_A_),
616 stride_A(stride_A_),
617 ref_B(ref_B_),
618 stride_B(stride_B_),
619 ref_C(ref_C_),
620 stride_C(stride_C_),
621 ref_D(ref_D_),
622 stride_D(stride_D_),
623 epilogue(epilogue_),
624 batch_count(batch_count_) { }
625 };
626
627 private:
628
629UnderlyingOperator underlying_operator_;
630
631 public:
632
[634](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a75922fd7bcd77fbc714cd87681f692bf)[GemmBatched](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a75922fd7bcd77fbc714cd87681f692bf)() { }
635
[637](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ac4ef1ac1e0876aaee5bff50dc09fe8a9)static [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992) [to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ac4ef1ac1e0876aaee5bff50dc09fe8a9)(Arguments const &args) {
638return [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992)(
639 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
640 {args.ref_B.data(), args.ref_B.stride(0)},
641 args.stride_B,
642 {args.ref_A.data(), args.ref_A.stride(0)},
643 args.stride_A,
644 {args.ref_C.data(), args.ref_C.stride(0)},
645 args.stride_C,
646 {args.ref_D.data(), args.ref_D.stride(0)},
647 args.stride_D,
648 args.epilogue,
649 args.batch_count
650 );
651 }
652
[654](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abbd82c0f989a9d07e5e222db96386701)static Status [can_implement](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abbd82c0f989a9d07e5e222db96386701)(Arguments const &args) {
655
656return UnderlyingOperator::can_implement(to_underlying_arguments(args));
657 }
658
[660](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3687659e826ba7f38bb060ad6020a739)static size_t [get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3687659e826ba7f38bb060ad6020a739)(Arguments const &args) {
661
662return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
663 }
664
[666](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a428d8b1c4ac36040145a59d8e4cff3d2)Status [initialize](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a428d8b1c4ac36040145a59d8e4cff3d2)(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
667
668return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
669 }
670
[672](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a9f0c7054068175c1891e4820857603c3)Status [update](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a9f0c7054068175c1891e4820857603c3)(Arguments const &args, void *workspace = nullptr) {
673
674return underlying_operator_.update(to_underlying_arguments(args), workspace);
675 }
676
[678](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abcae3d15f1ec2ee7ae93690c82fbee8a)Status [run](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abcae3d15f1ec2ee7ae93690c82fbee8a)(cudaStream_t stream = nullptr) {
679
680return underlying_operator_.run(stream);
681 }
682
[684](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a00805989734182945f982cab23a5dca8)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a00805989734182945f982cab23a5dca8)(cudaStream_t stream = nullptr) {
685return run(stream);
686 }
687
[689](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a53ca4db66d0d2c96d9036d8eb7c6072b)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA 00_01ElementB 00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a53ca4db66d0d2c96d9036d8eb7c6072b)(
690Arguments const &args,
691void *workspace = nullptr,
692 cudaStream_t stream = nullptr) {
693
694Status status = initialize(args, workspace);
695
696if (status == Status::kSuccess) {
697 status = run(stream);
698 }
699
700return status;
701 }
702
703 };
704
706
707 } // namespace device
708 } // namespace gemm
709 } // namespace cutlass
710
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementC](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#aef19ab5158e41856723852b3e307cc5d)
ElementC_ ElementC
Definition: device/gemm_batched.h:529
cutlass::gemm::kernel::DefaultGemm
Definition: default_gemm.h:116
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ThreadblockSwizzle](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af8b282788223086b80fbb097b22459ec)
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: device/gemm_batched.h:540
cutlass::gemm::device::GemmBatched::kAlignmentB
static int const kAlignmentB
Definition: device/gemm_batched.h:236
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_A](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a1727630fc0525724df28a75ccf2580b9)
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_batched.h:580
cutlass::gemm::device::GemmBatched::Arguments::ref_D
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_batched.h:280
cutlass::gemm::device::GemmBatched::Arguments::problem_size
GemmCoord problem_size
Definition: device/gemm_batched.h:273
Definition: aligned_buffer.h:35
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementB](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3fd5c64783f88a7533801fef7d1375ad)
ElementB_ ElementB
Definition: device/gemm_batched.h:526
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a53ca4db66d0d2c96d9036d8eb7c6072b)
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:689
cutlass::gemm::device::GemmBatched::Arguments::stride_D
int64_t stride_D
Definition: device/gemm_batched.h:281
cutlass::gemm::kernel::GemmBatched::Params::ref_D
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_batched.h:74
cutlass::gemm::device::GemmBatched::Arguments::Arguments
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, int64_t stride_A_, TensorRef< ElementB const, LayoutB > ref_B_, int64_t stride_B_, TensorRef< ElementC const, LayoutC > ref_C_, int64_t stride_C_, TensorRef< ElementC, LayoutC > ref_D_, int64_t stride_D_, typename EpilogueOutputOp::Params epilogue_, int batch_count_)
Constructs an Arguments structure.
Definition: device/gemm_batched.h:295
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::LayoutB](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a4aaaa6ca0e4b9f983fe37b4105fd058f)
LayoutB_ LayoutB
Definition: device/gemm_batched.h:527
typename kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, false, Operator, false >::GemmKernel DefaultGemmKernel
Define the kernel.
Definition: device/gemm_batched.h:262
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::can_implement](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abbd82c0f989a9d07e5e222db96386701)
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_batched.h:654
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ThreadblockShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a657e50fb03ea4d16f7b904920d9aa000)
ThreadblockShape_ ThreadblockShape
Definition: device/gemm_batched.h:536
cutlass::gemm::device::GemmBatched::operator()
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:443
Definition: include/cutlass/gemm/gemm.h:94
cutlass::gemm::device::GemmBatched::Arguments
Argument structure.
Definition: device/gemm_batched.h:267
typename DefaultGemmConfiguration< OperatorClass, ArchTag, ElementB, ElementA, ElementC,ElementAccumulator >::Operator Operator
Definition: device/gemm_batched.h:238
cutlass::gemm::kernel::GemmBatched::Params::ref_B
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_batched.h:68
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_C](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a9f8a044d7b7439192dfe2bf488558ed3)
int64_t stride_C
Definition: device/gemm_batched.h:585
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::OperatorClass](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a37600c0bf3570bc4b21c26b2b64fc54a)
OperatorClass_ OperatorClass
Definition: device/gemm_batched.h:534
typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: device/gemm_batched.h:220
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ArchTag](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a681b145a9701109f9d72059bb874895b)
ArchTag_ ArchTag
Definition: device/gemm_batched.h:535
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3947c9b192bec2fad631334f31632353)
typename UnderlyingOperator::GemmKernel GemmKernel
Definition: device/gemm_batched.h:570
cutlass::gemm::device::GemmBatched::get_workspace_size
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_batched.h:361
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_D](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ac181dba327e605b6cde9de5c7f176e7c)
int64_t stride_D
Definition: device/gemm_batched.h:587
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::WarpShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3760f803bd2b31b3fdf47741caa950fa)
WarpShape_ WarpShape
Definition: device/gemm_batched.h:537
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::epilogue](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#af9c2fa1e0cc0456197c2cc0840c89982)
EpilogueOutputOp::Params epilogue
Definition: device/gemm_batched.h:588
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::update](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a9f0c7054068175c1891e4820857603c3)
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_batched.h:672
cutlass::gemm::device::GemmBatched::Arguments::stride_A
int64_t stride_A
Definition: device/gemm_batched.h:275
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::InstructionShape](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae073edad6dd4447d7f99c94f4cd0c1c8)
InstructionShape_ InstructionShape
Definition: device/gemm_batched.h:538
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::initialize](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a428d8b1c4ac36040145a59d8e4cff3d2)
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: device/gemm_batched.h:666
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::run](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#abcae3d15f1ec2ee7ae93690c82fbee8a)
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:678
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementAccumulator](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ae7f006ea8bc324d31de9dfbebc1b9327)
ElementAccumulator_ ElementAccumulator
Definition: device/gemm_batched.h:533
cutlass::gemm::kernel::GemmBatched::Params::ref_C
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: kernel/gemm_batched.h:71
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
OperatorClass OperatorClass
Definition: device/gemm_batched.h:227
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_B](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a302101a4e5c00c843b3c525ddb94c117)
int64_t stride_B
Definition: device/gemm_batched.h:583
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
cutlass::gemm::device::GemmBatched::kAlignmentC
static int const kAlignmentC
Definition: device/gemm_batched.h:237
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_D](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a17c4e381e91229a8ef15b18ee5ec073d)
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_batched.h:586
ThreadblockShape ThreadblockShape
Definition: device/gemm_batched.h:229
cutlass::gemm::kernel::GemmBatched::SharedStorage
Shared memory storage structure.
Definition: kernel/gemm_batched.h:124
cutlass::gemm::device::GemmBatched::Arguments::Arguments
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_batched.h:291
cutlass::gemm::kernel::GemmBatched::Params::grid_tiled_shape
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_batched.h:63
InstructionShape InstructionShape
Definition: device/gemm_batched.h:231
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::GemmBatched](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a75922fd7bcd77fbc714cd87681f692bf)
GemmBatched()
Constructs the GEMM.
Definition: device/gemm_batched.h:634
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::problem_size](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ad0469cc3e961d21e212d026bccf6fe1a)
GemmCoord problem_size
Definition: device/gemm_batched.h:579
cutlass::gemm::device::GemmBatched::kStages
static int const kStages
Definition: device/gemm_batched.h:234
cutlass::gemm::device::GemmBatched::update
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_batched.h:396
cutlass::gemm::kernel::GemmBatched::kThreadCount
static int const kThreadCount
Definition: kernel/gemm_batched.h:58
cutlass::gemm::device::GemmBatched::can_implement
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_batched.h:332
cutlass::gemm::device::GemmBatched::Arguments::stride_C
int64_t stride_C
Definition: device/gemm_batched.h:279
cutlass::gemm::kernel::GemmBatched::Params
Parameters structure.
Definition: kernel/gemm_batched.h:61
cutlass::layout::LayoutTranspose
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
cutlass::Status::kErrorMisalignedOperand
operands fail alignment requirements.
cutlass::TensorRef< ElementA const, LayoutA >
ElementC ElementC
Definition: device/gemm_batched.h:222
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ab5f57fac13e42a08d351ac48c2cc9992)
typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: device/gemm_batched.h:569
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#ac4ef1ac1e0876aaee5bff50dc09fe8a9)
static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator.
Definition: device/gemm_batched.h:637
cutlass::Status::kErrorInternal
An error within CUTLASS occurred.
cutlass::gemm::device::GemmBatched::Arguments::ref_B
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_batched.h:276
cutlass::gemm::device::GemmBatched::kAlignmentA
static int const kAlignmentA
Definition: device/gemm_batched.h:235
Template for generic CUTLASS kernel.
cutlass::gemm::device::GemmBatched::initialize
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: device/gemm_batched.h:366
ThreadblockSwizzle ThreadblockSwizzle
Definition: device/gemm_batched.h:233
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::EpilogueOutputOp](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a494be150d3b809a4ecf66df682481905)
EpilogueOutputOp_ EpilogueOutputOp
Definition: device/gemm_batched.h:539
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::gemm::device::GemmBatched::GemmBatched
GemmBatched()
Constructs the GEMM.
Definition: device/gemm_batched.h:329
Top-level include for all CUTLASS numeric types.
EpilogueOutputOp EpilogueOutputOp
Definition: device/gemm_batched.h:232
cutlass::gemm::device::GemmBatched::Arguments::batch_count
int batch_count
Definition: device/gemm_batched.h:283
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a00805989734182945f982cab23a5dca8)
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:684
[default_gemm_configuration.h](default gemm configuration_8h.html)
Definitions for GEMM structures.
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementA](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a52b9261576b5633e901719f7c21d3369)
ElementA_ ElementA
Definition: device/gemm_batched.h:523
typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: device/gemm_batched.h:217
cutlass::gemm::kernel::GemmBatched::Params::problem_size
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_batched.h:62
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
ElementA ElementB
Definition: device/gemm_batched.h:219
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#a3687659e826ba7f38bb060ad6020a739)
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_batched.h:660
ElementB ElementA
Definition: device/gemm_batched.h:216
cutlass::gemm::device::GemmBatched::Arguments::ref_C
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_batched.h:278
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#a2129a4dccbd73f8c0f26b08ce5a5cb28)
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, int64_t stride_A_, TensorRef< ElementB const, LayoutB > ref_B_, int64_t stride_B_, TensorRef< ElementC const, LayoutC > ref_C_, int64_t stride_C_, TensorRef< ElementC, LayoutC > ref_D_, int64_t stride_D_, typename EpilogueOutputOp::Params epilogue_, int batch_count_)
Constructs an Arguments structure.
Definition: device/gemm_batched.h:601
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ae86daa985279c77e57e682b64a68d330)
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_batched.h:597
bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)
Definition: tensor_ref.h:382
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_B](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ad7d2b82b83d7503b9f920ce3bdcdffa5)
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_batched.h:582
Operation was successful.
cutlass::gemm::kernel::GemmBatched::Params::ref_A
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_batched.h:65
cutlass::gemm::device::GemmBatched::run
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:407
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
cutlass::gemm::kernel::GemmBatched
Definition: kernel/gemm_batched.h:49
ElementAccumulator ElementAccumulator
Definition: device/gemm_batched.h:226
cutlass::gemm::device::GemmBatched::Arguments::stride_B
int64_t stride_B
Definition: device/gemm_batched.h:277
WarpShape WarpShape
Definition: device/gemm_batched.h:230
ArchTag ArchTag
Definition: device/gemm_batched.h:228
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::LayoutA](classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html#af623ca54d9554cdfafc09af7a22cdd62)
LayoutA_ LayoutA
Definition: device/gemm_batched.h:524
cutlass::gemm::device::GemmBatched::operator()
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:448
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_C](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#aa9e30e41627595590421d8b53941b2b2)
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_batched.h:584
cutlass::gemm::device::GemmBatched::Arguments::epilogue
EpilogueOutputOp::Params epilogue
Definition: device/gemm_batched.h:282
Basic include for CUTLASS.
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::batch_count](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#adb66f3083f56c15578b139b7935452b5)
int batch_count
Definition: device/gemm_batched.h:589
cutlass::gemm::device::GemmBatched::GemmKernel
kernel::GemmBatched< typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle > GemmKernel
Definition: device/gemm_batched.h:264
cutlass::gemm::device::GemmBatched
Definition: device/gemm_batched.h:213
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
cutlass::gemm::device::GemmBatched::LayoutC
LayoutC_ LayoutC
Definition: device/gemm_batched.h:223
cutlass::gemm::device::GemmBatched::Arguments::ref_A
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_batched.h:274
[cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_A](structcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA 00_01LayoutA 00_01ElementB___00_213d78696663f4231cd52c6a277c60e5.html#ac8830c9ed0e0a8bd7aa2aa4382550a2f)
int64_t stride_A
Definition: device/gemm_batched.h:581
Generated by 1.8.11