docs/device_2gemm__splitk__parallel_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
device/gemm_splitk_parallel.h
[Go to the documentation of this file.](device_2gemm splitk parallel_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/device_kernel.h"
35
36 #include "cutlass/gemm/threadblock/threadblock_swizzle.h"
37 #include "cutlass/gemm/kernel/gemm.h"
38
39 #include "[cutlass/gemm/kernel/default_gemm_splitk_parallel.h](default gemm splitk__parallel_8h.html)"
40 #include "[cutlass/gemm/device/default_gemm_configuration.h](default gemm configuration_8h.html)"
41
42 #include "cutlass/epilogue/thread/conversion_op.h"
43 #include "[cutlass/reduction/kernel/reduce_split_k.h](reduce split k_8h.html)"
44 #include "cutlass/reduction/thread/reduction_operators.h"
45
47
48 namespace cutlass {
49 namespace gemm {
50 namespace device {
51
53
58 template <
60typename ElementA_,
62typename LayoutA_,
64typename ElementB_,
66typename LayoutB_,
68typename ElementC_,
70typename LayoutC_,
72typename ElementAccumulator_ = ElementC_,
74typename OperatorClass_ = arch::OpClassSimt,
76typename ArchTag_ = arch::Sm70,
78typename ThreadblockShape_ = typename DefaultGemmConfiguration<
79 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
80 ElementAccumulator_>::ThreadblockShape,
82typename WarpShape_ = typename DefaultGemmConfiguration<
83 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
84 ElementAccumulator_>::WarpShape,
86typename InstructionShape_ = typename DefaultGemmConfiguration<
87 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
88 ElementAccumulator_>::InstructionShape,
90typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
91 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
92 ElementAccumulator_>::EpilogueOutputOp,
94typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert<
95 ElementAccumulator_,
96 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
97 ElementAccumulator_,
98 ElementAccumulator_>::EpilogueOutputOp::kCount,
99 ElementAccumulator_>,
101typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd<
102 ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator,
103 EpilogueOutputOp_::kCount>,
105typename ThreadblockSwizzle_ =
106 threadblock::GemmSplitKHorizontalThreadblockSwizzle,
108int Stages =
109 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
110 ElementC_, ElementAccumulator_>::kStages,
112int kAlignmentA =
113 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
114 ElementC_, ElementAccumulator_>::kAlignmentA,
116int kAlignmentB =
117 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
118 ElementC_, ElementAccumulator_>::kAlignmentB,
120typename Operator_ = typename DefaultGemmConfiguration<
121 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
122 ElementAccumulator_>::Operator>
123 class GemmSplitKParallel {
124public:
125
126using ElementA = ElementA_;
128using ElementB = ElementB_;
130using ElementC = ElementC_;
132using ElementAccumulator = ElementAccumulator_;
133using OperatorClass = OperatorClass_;
135using ThreadblockShape = ThreadblockShape_;
136using WarpShape = WarpShape_;
137using InstructionShape = InstructionShape_;
138using ConvertScaledOp = ConvertScaledOp_;
139using EpilogueOutputOp = EpilogueOutputOp_;
140using ReductionOp = ReductionOp_;
141using ThreadblockSwizzle = ThreadblockSwizzle_;
142using Operator = Operator_;
143static int const kStages = Stages;
144
146using GemmKernel = typename kernel::DefaultGemmSplitKParallel<
147ElementA,
148LayoutA,
149 kAlignmentA,
150ElementB,
151LayoutB,
152 kAlignmentB,
154LayoutC,
156OperatorClass,
157ArchTag,
158ThreadblockShape,
159WarpShape,
160InstructionShape,
161ConvertScaledOp,
163kStages,
164Operator
166
168using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
169cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
170EpilogueOutputOp,
171ReductionOp
172 >;
173
174//
175//
176//
177
180
181//
182// Data members
183//
184
186TensorRef<ElementA const, LayoutA> ref_A;
187TensorRef<ElementB const, LayoutB> ref_B;
188TensorRef<ElementC const, LayoutC> ref_C;
189TensorRef<ElementC, LayoutC> ref_D;
190typename EpilogueOutputOp::Params epilogue;
191int split_k_slices;
192typename ConvertScaledOp::Params convert;
193typename ReductionOp::Params reduction;
194
195//
196// Methods
197//
198
202
206GemmCoord problem_size_,
207TensorRef<ElementA const, LayoutA> ref_A_,
208TensorRef<ElementB const, LayoutB> ref_B_,
209TensorRef<ElementC const, LayoutC> ref_C_,
210TensorRef<ElementC, LayoutC> ref_D_,
211typename EpilogueOutputOp::Params epilogue_ =
212typename EpilogueOutputOp::Params(),
213int split_k_slices = 1,
214typename ConvertScaledOp::Params convert_ =
215typename ConvertScaledOp::Params(),
216typename ReductionOp::Params reduction_ =
217typename ReductionOp::Params()
218 ):
219 problem_size(problem_size_),
220 ref_A(ref_A_),
221 ref_B(ref_B_),
222 ref_C(ref_C_),
223 ref_D(ref_D_),
224 epilogue(epilogue_),
225 split_k_slices(split_k_slices),
226 convert(convert_),
227 reduction(reduction_) { }
228 };
229
230 private:
231
233typename GemmKernel::Params gemm_params_;
234
236typename ReductionKernel::Params reduction_params_;
237
238 public:
239
241GemmSplitKParallel() { }
242
244static Status can_implement(Arguments const &args) {
245
246// TODO
247
248return Status::kSuccess;
249 }
250
252static size_t get_workspace_size(Arguments const &args) {
253
254// Determine grid shape
255 ThreadblockSwizzle threadblock_swizzle;
256
257cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
258 args.problem_size,
259 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
260 args.split_k_slices);
261
262return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) * size_t(args.problem_size.n()) * grid_shape.k();
263 }
264
266Status initialize(Arguments const &args, void *workspace) {
267
268// Determine grid shape
269 ThreadblockSwizzle threadblock_swizzle;
270
271cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
272 args.problem_size,
273 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
274 args.split_k_slices);
275
276// Define a reference to the workspace - this is an aligned region in device memory.
277if (!workspace) {
278return Status::kErrorWorkspaceNull;
279 }
280
281TensorRef<ElementAccumulator_, layout::RowMajor> ref_workspace(
282 static_cast<ElementAccumulator_ *>(workspace),
283 args.problem_size.n());
284
285 int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n());
286
287// Initialize the Params structure
288 gemm_params_ = typename GemmKernel::Params{
289 args.problem_size,
290 grid_shape,
291 args.ref_A.non_const_ref(),
292 args.ref_B.non_const_ref(),
293 ref_workspace,
294 args.convert,
295 partition_stride
296 };
297
298 reduction_params_ = typename ReductionKernel::Params(
299 args.problem_size.mn(),
300 grid_shape.k(),
301 partition_stride,
302 ref_workspace,
303 args.ref_D,
304 args.ref_C.non_const_ref(),
305 args.epilogue
306 );
307
308return Status::kSuccess;
309 }
310
312Status update(Arguments const &args, void *workspace = nullptr) {
313
314if (!workspace) {
315return Status::kErrorWorkspaceNull;
316 }
317
318 gemm_params_.ref_A.reset(args.ref_A.data());
319 gemm_params_.ref_B.reset(args.ref_B.data());
320 gemm_params_.ref_D.reset(workspace);
321
322 reduction_params_.ref_D.reset(args.ref_D.data());
323 reduction_params_.ref_C.reset(args.ref_C.data());
324
325return Status::kSuccess;
326 }
327
329Status run(cudaStream_t stream = nullptr) {
330
331//
332// Launch GEMM kernel
333//
334
335 ThreadblockSwizzle threadblock_swizzle;
336
337 dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape);
338 dim3 block(GemmKernel::kThreadCount, 1, 1);
339
340 cudaError_t result;
341
342int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
343if (smem_size >= (48 << 10)) {
344
345 result = cudaFuncSetAttribute(
346 Kernel<GemmKernel>,
347 cudaFuncAttributeMaxDynamicSharedMemorySize,
348 smem_size);
349
350if (result != cudaSuccess) {
351return Status::kErrorInternal;
352 }
353
354 result = cudaFuncSetAttribute(
355 Kernel<GemmKernel>,
356 cudaFuncAttributePreferredSharedMemoryCarveout, 100);
357
358if (result != cudaSuccess) {
359return Status::kErrorInternal;
360 }
361 }
362
363 Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
364
365 result = cudaGetLastError();
366
367if (result != cudaSuccess) {
368return Status::kErrorInternal;
369 }
370
371//
372// Launch reduction kernel
373//
374
375 block = ReductionKernel::block_shape();
376 grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn());
377
378 Kernel<ReductionKernel><<< grid, block, 0, stream >>>(reduction_params_);
379
380 result = cudaGetLastError();
381
382if (result != cudaSuccess) {
383return Status::kErrorInternal;
384 }
385
386return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
387 }
388
390Status operator()(cudaStream_t stream = nullptr) {
391return run(stream);
392 }
393
396Arguments const &args,
397void *workspace = nullptr,
398 cudaStream_t stream = nullptr) {
399
400Status status = initialize(args, workspace);
401
402if (status == Status::kSuccess) {
403 status = run(stream);
404 }
405
406return status;
407 }
408 };
409
411
413 template <
415typename ElementA_,
417typename LayoutA_,
419typename ElementB_,
421typename LayoutB_,
423typename ElementC_,
425typename ElementAccumulator_,
427typename OperatorClass_,
429typename ArchTag_,
431typename ThreadblockShape_,
433typename WarpShape_,
435typename InstructionShape_,
437typename EpilogueOutputOp_,
439typename ConvertScaledOp_,
441typename ReductionOp_,
443typename ThreadblockSwizzle_,
445int Stages, int kAlignmentA, int kAlignmentB,
447typename Operator_>
[448](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html) class GemmSplitKParallel<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
449 layout::ColumnMajor, ElementAccumulator_,
450 OperatorClass_, ArchTag_, ThreadblockShape_,
451 WarpShape_, InstructionShape_, EpilogueOutputOp_,
452 ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_,
453 Stages, kAlignmentA, kAlignmentB, Operator_> {
454public:
455
[456](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a05fc35f2f2fc3c329eccb6af24981caf)using ElementA = ElementA_;
[457](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ab0f19b729484a5d7e384af1a310f3f8c)using LayoutA = LayoutA_;
[458](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#af5f036e046e05c2a19cfd99673f9835c)using ElementB = ElementB_;
[459](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad3783855d4101f59892e1af5024288ff)using LayoutB = LayoutB_;
[460](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#adbe8a410fe634ab05b8cf69356b79b26)using [ElementC](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#adbe8a410fe634ab05b8cf69356b79b26) = ElementC_;
[461](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a33d738b2e304c974a9b77be0b176fb59)using LayoutC = layout::ColumnMajor;
[462](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a37df0372c002340106a6f1651348084e)using ElementAccumulator = ElementAccumulator_;
[463](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a498ec1c01a7bfd6f2e401450991ed8be)using OperatorClass = OperatorClass_;
[464](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ab83912e2e116c176d3f733ccdee06a1b)using ArchTag = ArchTag_;
[465](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#add39b0bee00309be7dfca383dbda0cab)using ThreadblockShape = ThreadblockShape_;
[466](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a93db32fb628949381ff8d18b2a765624)using WarpShape = WarpShape_;
[467](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a8c6c83e045a18b7a3c004e039509576e)using InstructionShape = InstructionShape_;
[468](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aa69d9364cc5247ea353608d5c0600fe7)using ConvertScaledOp = ConvertScaledOp_;
[469](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#af5a360d190ca3e8a9df879eaf8e65dd9)using EpilogueOutputOp = EpilogueOutputOp_;
[470](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a2d8d3a504dd8807ed09e25f37a658783)using [ReductionOp](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a2d8d3a504dd8807ed09e25f37a658783) = ReductionOp_;
[471](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a1f04e5294e4238442cb23666564db958)using ThreadblockSwizzle = ThreadblockSwizzle_;
[472](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a696cd49441ddb490d32a374135731c68)using [Operator](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a696cd49441ddb490d32a374135731c68) = Operator_;
[473](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a87961d33bf1aff6a6cbb5a6bc022493e)static int const kStages = Stages;
474
475using UnderlyingOperator = GemmSplitKParallel<
476ElementB,
477typename layout::LayoutTranspose<LayoutB>::type,
478ElementA,
479typename layout::LayoutTranspose<LayoutA>::type,
480ElementC,
481layout::RowMajor,
483OperatorClass,
484ArchTag,
485ThreadblockShape,
486WarpShape,
487InstructionShape,
488EpilogueOutputOp,
489ConvertScaledOp,
490ReductionOp,
492 Stages,
493 kAlignmentA,
494 kAlignmentB,
495[Operator](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a696cd49441ddb490d32a374135731c68)
[496](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#adb3cad6256057addcd5cf96f469fd679) >;
497
[498](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d)using [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d) = typename UnderlyingOperator::Arguments;
[499](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a949dbf8f84e6350649a171bf3b45478a)using [GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a949dbf8f84e6350649a171bf3b45478a) = typename UnderlyingOperator::GemmKernel;
[500](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aa69c465611c07990cdc79605c16b04ff)using [ReductionKernel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aa69c465611c07990cdc79605c16b04ff) = typename UnderlyingOperator::ReductionKernel;
501
[503](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html)struct Arguments {
504
505//
506// Data members
507//
508
[509](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#adee4f1a66aa6b6cb0400f6159ec52eb9)GemmCoord [problem_size](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#adee4f1a66aa6b6cb0400f6159ec52eb9);
[510](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a8b10e75e5d6cd348dacc085f5264ee95)TensorRef<ElementA const, LayoutA> [ref_A](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a8b10e75e5d6cd348dacc085f5264ee95);
[511](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a9a22df7c4d515a48e03fd6f16e074217)TensorRef<ElementB const, LayoutB> [ref_B](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a9a22df7c4d515a48e03fd6f16e074217);
[512](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a2aabb13f196a087b77245c67c8664b7b)TensorRef<ElementC const, LayoutC> [ref_C](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a2aabb13f196a087b77245c67c8664b7b);
[513](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a850da307d8741296e515add0f716eaf9)TensorRef<ElementC, LayoutC> [ref_D](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a850da307d8741296e515add0f716eaf9);
[514](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a04818e67f94c5440ac6c367798e17fc2)typename EpilogueOutputOp::Params [epilogue](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a04818e67f94c5440ac6c367798e17fc2);
[515](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#aff78ac3c99bb15cf8a7d7a1ece736cd1)int [split_k_slices](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#aff78ac3c99bb15cf8a7d7a1ece736cd1);
[516](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a48ced96adaf371f03c1c9a50db9f50f2)typename ConvertScaledOp::Params [convert](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a48ced96adaf371f03c1c9a50db9f50f2);
[517](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a63048fa3419753d96a60eaee28f6cfe4)typename ReductionOp::Params [reduction](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a63048fa3419753d96a60eaee28f6cfe4);
518
519//
520// Methods
521//
522
[525](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#acf6c5b216c0c82f0c7797627d651743f)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#acf6c5b216c0c82f0c7797627d651743f)() { }
526
[529](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a37c45d8dc800de6a631b8a096704559a)[Arguments](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a37c45d8dc800de6a631b8a096704559a)(
530GemmCoord problem_size_,
531TensorRef<ElementA const, LayoutA> ref_A_,
532TensorRef<ElementB const, LayoutB> ref_B_,
533TensorRef<ElementC const, LayoutC> ref_C_,
534TensorRef<ElementC, LayoutC> ref_D_,
535typename EpilogueOutputOp::Params epilogue_ =
536typename EpilogueOutputOp::Params(),
537int split_k_slices = 1,
538typename ConvertScaledOp::Params convert_ =
539typename ConvertScaledOp::Params(),
540typename ReductionOp::Params reduction_ =
541typename ReductionOp::Params()
542 ):
543 problem_size(problem_size_),
544 ref_A(ref_A_),
545 ref_B(ref_B_),
546 ref_C(ref_C_),
547 ref_D(ref_D_),
548 epilogue(epilogue_),
549 split_k_slices(split_k_slices),
550 convert(convert_),
551 reduction(reduction_) { }
552 };
553
554 private:
555
557UnderlyingOperator underlying_operator_;
558
559 public:
560
[562](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad0c614a548bcade989eb25633b45bb0f)[GemmSplitKParallel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad0c614a548bcade989eb25633b45bb0f)() { }
563
[565](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a535863339cab9879474e31f2fd543804)static [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d) [to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a535863339cab9879474e31f2fd543804)(Arguments const &args) {
566return [UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d)(
567 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
568 {args.ref_B.data(), args.ref_B.stride(0)},
569 {args.ref_A.data(), args.ref_A.stride(0)},
570 {args.ref_C.data(), args.ref_C.stride(0)},
571 {args.ref_D.data(), args.ref_D.stride(0)},
572 args.epilogue,
573 args.split_k_slices,
574 args.convert,
575 args.reduction
576 );
577 }
578
[580](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a465591fbfde2a9aa6330d9adcbf82bd6)static Status [can_implement](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a465591fbfde2a9aa6330d9adcbf82bd6)(Arguments const &args) {
581
582return UnderlyingOperator::can_implement(to_underlying_arguments(args));
583 }
584
[586](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9c70e23eef0a15d849b5b0ebadfcdd)static size_t [get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9c70e23eef0a15d849b5b0ebadfcdd)(Arguments const &args) {
587
588return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
589 }
590
[592](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9f086305f76d7f885bf032f3d2c7c9)Status [initialize](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9f086305f76d7f885bf032f3d2c7c9)(Arguments const &args, void *workspace) {
593
594return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
595 }
596
[598](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a44facae3996ed3da5fdb4398e469b773)Status [update](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a44facae3996ed3da5fdb4398e469b773)(Arguments const &args, void *workspace = nullptr) {
599
600return underlying_operator_.update(to_underlying_arguments(args), workspace);
601 }
602
[604](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a1de7cf5d8bad27b3ff6c803dbc572077)Status [run](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a1de7cf5d8bad27b3ff6c803dbc572077)(cudaStream_t stream = nullptr) {
605
606return underlying_operator_.run(stream);
607 }
608
[610](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a72f5de19ad97e08241157d5106f2f66a)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a72f5de19ad97e08241157d5106f2f66a)(cudaStream_t stream = nullptr) {
611return run(stream);
612 }
613
[615](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad6d811ca346ce6467a291497edc85623)Status [operator()](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad6d811ca346ce6467a291497edc85623)(
616Arguments const &args,
617void *workspace = nullptr,
618 cudaStream_t stream = nullptr) {
619
620Status status = initialize(args, workspace);
621
622if (status == Status::kSuccess) {
623 status = run(stream);
624 }
625
626return status;
627 }
628 };
629
631
632 } // namespace device
633 } // namespace gemm
634 } // namespace cutlass
635
cutlass::epilogue::thread::Convert
Definition: conversion_op.h:53
WarpShape WarpShape
Definition: device/gemm_splitk_parallel.h:136
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::GemmKernel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a949dbf8f84e6350649a171bf3b45478a)
typename UnderlyingOperator::GemmKernel GemmKernel
Definition: device/gemm_splitk_parallel.h:499
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_D](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a850da307d8741296e515add0f716eaf9)
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_splitk_parallel.h:513
typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: device/gemm_splitk_parallel.h:129
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_C](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a2aabb13f196a087b77245c67c8664b7b)
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_splitk_parallel.h:512
Operator Operator
Definition: device/gemm_splitk_parallel.h:142
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ReductionKernel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aa69c465611c07990cdc79605c16b04ff)
typename UnderlyingOperator::ReductionKernel ReductionKernel
Definition: device/gemm_splitk_parallel.h:500
Definition: aligned_buffer.h:35
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::problem_size](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#adee4f1a66aa6b6cb0400f6159ec52eb9)
GemmCoord problem_size
Definition: device/gemm_splitk_parallel.h:509
cutlass::gemm::kernel::DefaultGemmSplitKParallel
Definition: default_gemm_splitk_parallel.h:88
cutlass::gemm::device::GemmSplitKParallel::operator()
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:395
cutlass::gemm::device::GemmSplitKParallel::kStages
static int const kStages
Definition: device/gemm_splitk_parallel.h:143
cutlass::reduction::kernel::ReduceSplitK::block_shape
static CUTLASS_HOST_DEVICE dim3 block_shape()
Determines the threadblock shape.
Definition: reduce_split_k.h:138
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ElementC](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#adbe8a410fe634ab05b8cf69356b79b26)
ElementC_ ElementC
Definition: device/gemm_splitk_parallel.h:460
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_A](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a8b10e75e5d6cd348dacc085f5264ee95)
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_splitk_parallel.h:510
Kernel performing a reduction over densely packed tensors in global memory.
Definition: include/cutlass/gemm/gemm.h:94
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::epilogue](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a04818e67f94c5440ac6c367798e17fc2)
EpilogueOutputOp::Params epilogue
Definition: device/gemm_splitk_parallel.h:514
Functor performing conversion operations used by epilogues.
cutlass::gemm::device::GemmSplitKParallel::Arguments::split_k_slices
int split_k_slices
Definition: device/gemm_splitk_parallel.h:191
cutlass::gemm::device::GemmSplitKParallel::Arguments::reduction
ReductionOp::Params reduction
Definition: device/gemm_splitk_parallel.h:193
cutlass::reduction::thread::ReduceAdd
Mixed-precision reduction.
Definition: reduction_operators.h:50
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::run](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a1de7cf5d8bad27b3ff6c803dbc572077)
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:604
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::update](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a44facae3996ed3da5fdb4398e469b773)
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_splitk_parallel.h:598
cutlass::reduction::kernel::ReduceSplitK::Params
Params structure.
Definition: reduce_split_k.h:80
InstructionShape InstructionShape
Definition: device/gemm_splitk_parallel.h:137
cutlass::gemm::device::GemmSplitKParallel::Arguments::Arguments
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_splitk_parallel.h:201
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
cutlass::gemm::device::GemmSplitKParallel::Arguments::convert
ConvertScaledOp::Params convert
Definition: device/gemm_splitk_parallel.h:192
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::GemmSplitKParallel](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad0c614a548bcade989eb25633b45bb0f)
GemmSplitKParallel()
Constructs the GEMM.
Definition: device/gemm_splitk_parallel.h:562
ElementC ElementC
Definition: device/gemm_splitk_parallel.h:130
ThreadblockShape ThreadblockShape
Definition: device/gemm_splitk_parallel.h:135
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_B](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a9a22df7c4d515a48e03fd6f16e074217)
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_splitk_parallel.h:511
EpilogueOutputOp EpilogueOutputOp
Definition: device/gemm_splitk_parallel.h:139
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::get_workspace_size](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9c70e23eef0a15d849b5b0ebadfcdd)
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_splitk_parallel.h:586
ArchTag ArchTag
Definition: device/gemm_splitk_parallel.h:134
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#ad6d811ca346ce6467a291497edc85623)
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:615
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::operator()](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a72f5de19ad97e08241157d5106f2f66a)
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:610
cutlass::gemm::device::GemmSplitKParallel::can_implement
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_splitk_parallel.h:244
cutlass::gemm::device::GemmSplitKParallel::Arguments::ref_D
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_splitk_parallel.h:189
cutlass::gemm::device::GemmSplitKParallel::GemmSplitKParallel
GemmSplitKParallel()
Constructs the GEMM.
Definition: device/gemm_splitk_parallel.h:241
cutlass::layout::LayoutTranspose
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
cutlass::gemm::device::GemmSplitKParallel::Arguments::problem_size
GemmCoord problem_size
Definition: device/gemm_splitk_parallel.h:185
cutlass::TensorRef< ElementA const, LayoutA >
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ReductionOp](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a2d8d3a504dd8807ed09e25f37a658783)
ReductionOp_ ReductionOp
Definition: device/gemm_splitk_parallel.h:470
typename kernel::DefaultGemmSplitKParallel< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, ConvertScaledOp, ThreadblockSwizzle, kStages, Operator >::GemmKernel GemmKernel
GEMM kernel.
Definition: device/gemm_splitk_parallel.h:165
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::reduction](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a63048fa3419753d96a60eaee28f6cfe4)
ReductionOp::Params reduction
Definition: device/gemm_splitk_parallel.h:517
typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: device/gemm_splitk_parallel.h:127
cutlass::gemm::device::GemmSplitKParallel::operator()
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:390
ReductionOp ReductionOp
Definition: device/gemm_splitk_parallel.h:140
ElementA ElementB
Definition: device/gemm_splitk_parallel.h:128
cutlass::Status::kErrorInternal
An error within CUTLASS occurred.
Template for generic CUTLASS kernel.
[reduce_split_k.h](reduce split k_8h.html)
Kernel performing a reduction over densely packed tensors in global memory.
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::UnderlyingArguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#aaaf935dfdc267b089de7b304be979d4d)
typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: device/gemm_splitk_parallel.h:498
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Operator](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a696cd49441ddb490d32a374135731c68)
Operator_ Operator
Definition: device/gemm_splitk_parallel.h:472
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::convert](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a48ced96adaf371f03c1c9a50db9f50f2)
ConvertScaledOp::Params convert
Definition: device/gemm_splitk_parallel.h:516
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
cutlass::reduction::kernel::ReduceSplitK
Definition: reduce_split_k.h:55
cutlass::reduction::kernel::ReduceSplitK::grid_shape
static CUTLASS_HOST_DEVICE dim3 grid_shape(cutlass::MatrixCoord problem_size)
Computes the grid size given a chosen threadblock shape.
Definition: reduce_split_k.h:128
[default_gemm_configuration.h](default gemm configuration_8h.html)
Definitions for GEMM structures.
cutlass::TensorRef::non_const_ref
CUTLASS_HOST_DEVICE NonConstTensorRef non_const_ref() const
Definition: tensor_ref.h:229
cutlass::gemm::device::GemmSplitKParallel::get_workspace_size
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_splitk_parallel.h:252
cutlass::gemm::device::GemmSplitKParallel
Definition: device/gemm_splitk_parallel.h:123
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
cutlass::gemm::device::GemmSplitKParallel::Arguments::ref_C
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_splitk_parallel.h:188
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::initialize](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a4d9f086305f76d7f885bf032f3d2c7c9)
Status initialize(Arguments const &args, void *workspace)
Initializes GEMM state from arguments.
Definition: device/gemm_splitk_parallel.h:592
ConvertScaledOp ConvertScaledOp
Definition: device/gemm_splitk_parallel.h:138
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::can_implement](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a465591fbfde2a9aa6330d9adcbf82bd6)
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_splitk_parallel.h:580
cutlass::gemm::device::GemmSplitKParallel::Arguments::Arguments
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Constructs an Arguments structure.
Definition: device/gemm_splitk_parallel.h:205
ElementAccumulator ElementAccumulator
Definition: device/gemm_splitk_parallel.h:132
cutlass::Status::kErrorWorkspaceNull
The given workspace is null when it is required to be non-null.
Operation was successful.
cutlass::gemm::device::GemmSplitKParallel::Arguments::ref_B
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_splitk_parallel.h:187
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
ThreadblockSwizzle ThreadblockSwizzle
Definition: device/gemm_splitk_parallel.h:141
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::split_k_slices](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#aff78ac3c99bb15cf8a7d7a1ece736cd1)
int split_k_slices
Definition: device/gemm_splitk_parallel.h:515
cutlass::gemm::device::GemmSplitKParallel::update
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_splitk_parallel.h:312
cutlass::gemm::device::GemmSplitKParallel::run
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:329
cutlass::gemm::device::GemmSplitKParallel::Arguments::epilogue
EpilogueOutputOp::Params epilogue
Definition: device/gemm_splitk_parallel.h:190
[default_gemm_splitk_parallel.h](default gemm splitk__parallel_8h.html)
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
cutlass::gemm::device::GemmSplitKParallel::Arguments
Argument structure.
Definition: device/gemm_splitk_parallel.h:179
cutlass::gemm::device::GemmSplitKParallel::LayoutC
LayoutC_ LayoutC
Definition: device/gemm_splitk_parallel.h:131
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#a37c45d8dc800de6a631b8a096704559a)
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Constructs an Arguments structure.
Definition: device/gemm_splitk_parallel.h:529
cutlass::gemm::device::GemmSplitKParallel::initialize
Status initialize(Arguments const &args, void *workspace)
Initializes GEMM state from arguments.
Definition: device/gemm_splitk_parallel.h:266
cutlass::gemm::device::GemmSplitKParallel::Arguments::ref_A
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_splitk_parallel.h:186
Basic include for CUTLASS.
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::Arguments](structcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01Elementafcb1aeaf2035a7ac769d7acc233423b.html#acf6c5b216c0c82f0c7797627d651743f)
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_splitk_parallel.h:525
[cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::to_underlying_arguments](classcutlass_1_1gemm_1_1device_1_1GemmSplitKParallel_3_01ElementA 00_01LayoutA 00_01ElementBbe7c1f7154ad5b5bf9d4d28301e2b457.html#a535863339cab9879474e31f2fd543804)
static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator.
Definition: device/gemm_splitk_parallel.h:565
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
ElementB ElementA
Definition: device/gemm_splitk_parallel.h:126
OperatorClass OperatorClass
Definition: device/gemm_splitk_parallel.h:133
Generated by 1.8.11