docs/library_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
library.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
40 #pragma once
41
43
44 #include <vector>
45 #include <string>
46 #include <cstdint>
47 #include <cuda_runtime.h>
48
49 #include "cutlass/cutlass.h"
50 #include "cutlass/matrix_coord.h"
51 #include "cutlass/tensor_coord.h"
52 #include "cutlass/layout/tensor.h"
53
54 #include "cutlass/gemm/gemm.h"
56
57 namespace cutlass {
58 namespace library {
59
61
63 enum class LayoutTypeID {
64kUnknown,
65kColumnMajor,
66kRowMajor,
71kTensorNCHW,
72kTensorNHWC,
73kInvalid
74 };
75
77 enum class NumericTypeID {
78kUnknown,
79kVoid,
80kB1,
81kU4,
82kU8,
83kU16,
84kU32,
85kU64,
86kS4,
87kS8,
88kS16,
89kS32,
90kS64,
91kF16,
92kF32,
93kF64,
94kCF16,
95kCF32,
96kCF64,
97kCS4,
98kCS8,
99kCS16,
100kCS32,
101kCS64,
102kCU4,
103kCU8,
104kCU16,
105kCU32,
106kCU64,
107 kInvalid
108 };
109
111 enum class ComplexTransform {
112kNone,
113kConjugate
114 };
115
117 enum class OperationKind {
118kGemm,
119 kInvalid
120 };
121
123 enum class ScalarPointerMode {
124kHost,
125kDevice,
126 kInvalid
127 };
128
130 enum class SplitKMode {
131kNone,
132kSerial,
133kParallel,
134kParallelSerial,
135 kInvalid
136 };
137
139 enum class OpcodeClassID {
140kSimt,
141kTensorOp,
142kWmmaTensorOp,
143 kInvalid
144 };
145
147
150kGemm,
151kBatched,
152kArray,
153kPlanarComplex,
155 kInvalid
156 };
157
159
161 template <typename T> T from_string(std::string const &);
162
164 char const *to_string(OperationKind type, bool pretty = false);
165
167 template <> OperationKind from_string<OperationKind>(std::string const &str);
168
170 char const *to_string(NumericTypeID type, bool pretty = false);
171
173 template <> NumericTypeID from_string<NumericTypeID>(std::string const &str);
174
176 int sizeof_bits(NumericTypeID type);
177
179 bool is_complex_type(NumericTypeID type);
180
182 NumericTypeID get_real_type(NumericTypeID type);
183
185 bool is_integer_type(NumericTypeID type);
186
188 bool is_signed_type(NumericTypeID type);
189
191 bool is_signed_integer(NumericTypeID type);
192
194 bool is_unsigned_integer(NumericTypeID type);
195
197 bool is_float_type(NumericTypeID type);
198
200 char const *to_string(Status status, bool pretty = false);
201
203 char const *to_string(LayoutTypeID layout, bool pretty = false);
204
206 template <> LayoutTypeID from_string<LayoutTypeID>(std::string const &str);
207
209 int get_layout_stride_rank(LayoutTypeID layout_id);
210
212 char const *to_string(OpcodeClassID type, bool pretty = false);
213
215 template <>
216 OpcodeClassID from_string<OpcodeClassID>(std::string const &str);
217
219 std::string lexical_cast(int64_t int_value);
220
222 bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str);
223
225 std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type);
226
228 bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src);
229
231 bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src);
232
234 bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);
235
237
238 struct MathInstructionDescription {
239
241cutlass::gemm::GemmCoord instruction_shape;
242
244NumericTypeID element_accumulator;
245
247OpcodeClassID opcode_class;
248
249//
250// Methods
251//
252
253MathInstructionDescription(
254cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(),
255NumericTypeID element_accumulator = NumericTypeID::kInvalid,
256OpcodeClassID opcode_class = OpcodeClassID::kInvalid
257 ):
258 instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {}
259
260 };
261
263 struct TileDescription {
264
266cutlass::gemm::GemmCoord threadblock_shape;
267
270
272cutlass::gemm::GemmCoord warp_count;
273
275MathInstructionDescription math_instruction;
276
278int minimum_compute_capability;
279
281int maximum_compute_capability;
282
283//
284// Methods
285//
286
288cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(),
289int threadblock_stages = 0,
290cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(),
291MathInstructionDescription math_instruction = MathInstructionDescription(),
292int minimum_compute_capability = 0,
293int maximum_compute_capability = 0
294 ):
295 threadblock_shape(threadblock_shape),
296 threadblock_stages(threadblock_stages),
297 warp_count(warp_count),
298 math_instruction(math_instruction),
299 minimum_compute_capability(minimum_compute_capability),
300 maximum_compute_capability(maximum_compute_capability) { }
301 };
302
304 struct OperationDescription {
305
308
311
313TileDescription tile_description;
314
315//
316// Methods
317//
319char const * name = "unknown",
320OperationKind kind = OperationKind::kInvalid,
321TileDescription const & tile_description = TileDescription()
322 ):
323 name(name), kind(kind), tile_description(tile_description) { }
324 };
325
327 struct TensorDescription {
328
331
334
337
339int log_extent_range;
340
342int log_stride_range;
343
344//
345// Methods
346//
348NumericTypeID element = NumericTypeID::kInvalid,
349LayoutTypeID layout = LayoutTypeID::kInvalid,
350int alignment = 1,
351int log_extent_range = 24,
352int log_stride_range = 24
353 ):
354 element(element),
355 layout(layout),
356 alignment(alignment),
357 log_extent_range(log_extent_range),
358 log_stride_range(log_stride_range) { }
359 };
360
362
364 struct GemmDescription : public OperationDescription {
365
368
371
374
377
379NumericTypeID element_epilogue;
380
383
385ComplexTransform transform_A;
386
388ComplexTransform transform_B;
389
390//
391// Methods
392//
393
395GemmKind gemm_kind = GemmKind::kGemm,
396TensorDescription const &A = TensorDescription(),
397TensorDescription const &B = TensorDescription(),
398TensorDescription const &C = TensorDescription(),
399NumericTypeID element_epilogue = NumericTypeID::kInvalid,
400SplitKMode split_k_mode = SplitKMode::kNone,
401ComplexTransform transform_A = ComplexTransform::kNone,
402ComplexTransform transform_B = ComplexTransform::kNone
403 ):
404 gemm_kind(gemm_kind),
405 A(A),
406 B(B),
407 C(C),
408 element_epilogue(element_epilogue),
409 split_k_mode(split_k_mode),
410 transform_A(transform_A),
411 transform_B(transform_B) {}
412 };
413
416
419 public:
420
421virtual ~Operation() { }
422
423virtual OperationDescription const & description() const = 0;
424
425virtual Status can_implement(
426void const *configuration,
427void const *arguments) const = 0;
428
429virtual uint64_t get_host_workspace_size(
430void const *configuration) const = 0;
431
432virtual uint64_t get_device_workspace_size(
433void const *configuration) const = 0;
434
435virtual Status initialize(
436void const *configuration,
437void *host_workspace,
438void *device_workspace,
439 cudaStream_t stream = nullptr) const = 0;
440
441virtual Status run(
442void const *arguments,
443void *host_workspace,
444void *device_workspace = nullptr,
445 cudaStream_t stream = nullptr) const = 0;
446 };
447
449
451 //
452 // OperationKind: Gemm
453 // GemmKind: Gemm
454 //
455 struct GemmConfiguration {
456
458gemm::GemmCoord problem_size;
459
462
465
468
471
473int split_k_slices;
474 };
475
477 struct GemmArguments {
478
481
484
487
490
493
496
498ScalarPointerMode pointer_mode;
499 };
500
502
504 //
505 // OperationKind: Gemm
506 // GemmKind: Batched
507
508 struct GemmBatchedConfiguration {
509
511gemm::GemmCoord problem_size;
512
515
518
521
524
526 int64_t batch_stride_A;
527
529 int64_t batch_stride_B;
530
532 int64_t batch_stride_C;
533
535 int64_t batch_stride_D;
536
538int batch_count;
539 };
540
542 using GemmBatchedArguments = GemmArguments;
543
545
547 //
548 // OperationKind: Gemm
549 // GemmKind: Array
550
551 struct GemmArrayConfiguration {
552
553gemm::GemmCoord problem_size;
554
559
560int batch_count;
561 };
562
564 struct GemmArrayArguments {
571ScalarPointerMode pointer_mode;
572 };
573
575
577 //
578 // OperationKind: Gemm
579 // GemmKind: Planar complex
580
581 struct GemmPlanarComplexConfiguration {
582
583gemm::GemmCoord problem_size;
584
589
590 int64_t imag_stride_A;
591 int64_t imag_stride_B;
592 int64_t imag_stride_C;
593 int64_t imag_stride_D;
594 };
595
596 using GemmPlanarComplexArgments = GemmArguments;
597
599
601 //
602 // OperationKind: Gemm
603 // GemmKind: Planar complex batched
604 //
605 struct GemmPlanarComplexBatchedConfiguration {
606
607gemm::GemmCoord problem_size;
608
613
614 int64_t imag_stride_A;
615 int64_t imag_stride_B;
616 int64_t imag_stride_C;
617 int64_t imag_stride_D;
618
619 int64_t batched_stride_A;
620 int64_t batched_stride_B;
621 int64_t batched_stride_C;
622 int64_t batched_stride_D;
623 };
624
625 using GemmPlanarComplexBatchedArguments = GemmArguments;
626
628
629 } // namespace library
630 } // namespace cutlass
631
cutlass::library::GemmPlanarComplexBatchedConfiguration::lda
int64_t lda
Definition: library.h:609
cutlass::library::NumericTypeID::kCS8
cutlass::library::TensorDescription::alignment
int alignment
Alignment restriction on pointers, strides, and extents.
Definition: library.h:336
cutlass::library::NumericTypeID::kCU32
cutlass::library::GemmArrayArguments::A
void const *const * A
Definition: library.h:565
cutlass::library::GemmKind::kPlanarComplexBatched
cutlass::library::NumericTypeID::kCS16
cutlass::library::OpcodeClassID::kWmmaTensorOp
cutlass::library::Operation::~Operation
virtual ~Operation()
Definition: library.h:421
cutlass::library::OperationDescription
High-level description of an operation.
Definition: library.h:304
cutlass::library::GemmKind::kPlanarComplex
cutlass::library::LayoutTypeID::kColumnMajor
char const * to_string(OperationKind type, bool pretty=false)
Converts a NumericType enumerant to a string.
Definition: aligned_buffer.h:35
cutlass::library::is_complex_type
bool is_complex_type(NumericTypeID type)
Returns true if the numeric type is a complex data type or false if real-valued.
cutlass::library::GemmArrayArguments::D
void *const * D
Definition: library.h:568
cutlass::library::TensorDescription::layout
LayoutTypeID layout
Enumerant identifying the layout function for the tensor.
Definition: library.h:333
cutlass::library::GemmDescription::gemm_kind
GemmKind gemm_kind
Indicates the kind of GEMM performed.
Definition: library.h:367
cutlass::library::GemmPlanarComplexConfiguration::ldc
int64_t ldc
Definition: library.h:587
cutlass::library::GemmArguments
Arguments for GEMM.
Definition: library.h:477
cutlass::library::GemmKind::kArray
cutlass::library::GemmArrayConfiguration::batch_count
int batch_count
Definition: library.h:560
cutlass::library::ComplexTransform
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: library.h:111
cutlass::library::LayoutTypeID::kColumnMajorInterleavedK16
cutlass::library::GemmArrayArguments::C
void const *const * C
Definition: library.h:567
cutlass::library::GemmPlanarComplexConfiguration::problem_size
gemm::GemmCoord problem_size
Definition: library.h:583
cutlass::library::GemmArrayConfiguration
Configuration for batched GEMM in which multiple matrix products are computed.
Definition: library.h:551
cutlass::library::is_signed_integer
bool is_signed_integer(NumericTypeID type)
Returns true if numeric type is a signed integer.
GemmKind
Enumeration indicating what kind of GEMM operation to perform.
Definition: library.h:149
cutlass::library::NumericTypeID::kCF64
cutlass::library::NumericTypeID::kCU64
cutlass::library::get_real_type
NumericTypeID get_real_type(NumericTypeID type)
Returns the real-valued type underlying a type (only different from 'type' if complex) ...
cutlass::library::from_string< OperationKind >
OperationKind from_string< OperationKind >(std::string const &str)
Parses a NumericType enumerant from a string.
Definition: include/cutlass/gemm/gemm.h:94
cutlass::library::get_layout_stride_rank
int get_layout_stride_rank(LayoutTypeID layout_id)
Returns the rank of a layout's stride base on the LayoutTypeID.
cutlass::library::GemmBatchedConfiguration::ldb
int64_t ldb
Leading dimension of B matrix.
Definition: library.h:517
cutlass::library::NumericTypeID::kCF32
cutlass::library::NumericTypeID::kCU8
cutlass::library::ComplexTransform::kNone
cutlass::library::GemmArrayConfiguration::ldc
int64_t const * ldc
Definition: library.h:557
cutlass::library::OpcodeClassID::kTensorOp
cutlass::library::NumericTypeID::kS64
cutlass::library::GemmPlanarComplexBatchedConfiguration::batched_stride_B
int64_t batched_stride_B
Definition: library.h:620
cutlass::library::GemmPlanarComplexConfiguration
Complex valued GEMM in which real and imaginary parts are separated by a stride.
Definition: library.h:581
cutlass::library::TensorDescription::log_stride_range
int log_stride_range
log2() of the maximum value each relevant stride may have
Definition: library.h:342
cutlass::library::ComplexTransform::kConjugate
cutlass::library::NumericTypeID::kB1
Defines common types used for all GEMM-like operators.
cutlass::library::GemmDescription::transform_A
ComplexTransform transform_A
Transformation on A operand.
Definition: library.h:385
cutlass::library::NumericTypeID::kS16
cutlass::library::SplitKMode::kSerial
cutlass::library::GemmPlanarComplexBatchedConfiguration::imag_stride_B
int64_t imag_stride_B
Definition: library.h:615
cutlass::library::GemmDescription::GemmDescription
GemmDescription(GemmKind gemm_kind=GemmKind::kGemm, TensorDescription const &A=TensorDescription(), TensorDescription const &B=TensorDescription(), TensorDescription const &C=TensorDescription(), NumericTypeID element_epilogue=NumericTypeID::kInvalid, SplitKMode split_k_mode=SplitKMode::kNone, ComplexTransform transform_A=ComplexTransform::kNone, ComplexTransform transform_B=ComplexTransform::kNone)
Definition: library.h:394
int sizeof_bits(NumericTypeID type)
Returns the size of a data type in bits.
Base class for all device-wide operations.
Definition: library.h:418
cutlass::library::GemmPlanarComplexConfiguration::imag_stride_A
int64_t imag_stride_A
Definition: library.h:590
cutlass::library::NumericTypeID::kCS64
cutlass::library::from_string< NumericTypeID >
NumericTypeID from_string< NumericTypeID >(std::string const &str)
Parses a NumericType enumerant from a string.
cutlass::library::LayoutTypeID
LayoutTypeID
Layout type identifier.
Definition: library.h:63
cutlass::library::OpcodeClassID
OpcodeClassID
Indicates the classificaition of the math instruction.
Definition: library.h:139
cutlass::library::GemmPlanarComplexBatchedConfiguration::ldc
int64_t ldc
Definition: library.h:611
cutlass::library::lexical_cast
std::string lexical_cast(int64_t int_value)
Lexical cast from int64_t to string.
cutlass::library::GemmArguments::pointer_mode
ScalarPointerMode pointer_mode
Enumerant indicating whether alpha/beta point to host or device memory.
Definition: library.h:498
cutlass::library::GemmArrayConfiguration::ldb
int64_t const * ldb
Definition: library.h:556
cutlass::library::GemmArrayConfiguration::problem_size
gemm::GemmCoord problem_size
Definition: library.h:553
cutlass::library::GemmPlanarComplexBatchedConfiguration::batched_stride_C
int64_t batched_stride_C
Definition: library.h:621
cutlass::library::OperationDescription::OperationDescription
OperationDescription(char const *name="unknown", OperationKind kind=OperationKind::kInvalid, TileDescription const &tile_description=TileDescription())
Definition: library.h:318
cutlass::library::TileDescription::maximum_compute_capability
int maximum_compute_capability
Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation.
Definition: library.h:281
cutlass::library::NumericTypeID::kS8
cutlass::library::GemmConfiguration
Configuration for basic GEMM operations.
Definition: library.h:455
cutlass::library::GemmPlanarComplexBatchedConfiguration::imag_stride_D
int64_t imag_stride_D
Definition: library.h:617
cutlass::library::MathInstructionDescription
Definition: library.h:238
cutlass::library::GemmArguments::B
void const * B
Pointer to B matrix.
Definition: library.h:483
cutlass::library::GemmPlanarComplexBatchedConfiguration::imag_stride_A
int64_t imag_stride_A
Definition: library.h:614
cutlass::library::NumericTypeID::kU16
cutlass::library::LayoutTypeID::kRowMajorInterleavedK16
cutlass::library::GemmDescription::A
TensorDescription A
Describes the A operand.
Definition: library.h:370
cutlass::library::TileDescription
Structure describing the tiled structure of a GEMM-like computation.
Definition: library.h:263
cutlass::library::GemmConfiguration::split_k_slices
int split_k_slices
Number of partitions of K dimension.
Definition: library.h:473
cutlass::library::GemmPlanarComplexConfiguration::imag_stride_B
int64_t imag_stride_B
Definition: library.h:591
cutlass::library::from_string< OpcodeClassID >
OpcodeClassID from_string< OpcodeClassID >(std::string const &str)
Converts a OpcodeClassID enumerant from a string.
cutlass::library::NumericTypeID::kU8
cutlass::library::GemmArguments::A
void const * A
Pointer to A matrix.
Definition: library.h:480
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
cutlass::library::GemmConfiguration::ldd
int64_t ldd
Leading dimension of D matrix.
Definition: library.h:470
cutlass::library::GemmDescription::transform_B
ComplexTransform transform_B
Transformation on B operand.
Definition: library.h:388
cutlass::library::GemmPlanarComplexBatchedConfiguration::ldb
int64_t ldb
Definition: library.h:610
cutlass::library::is_signed_type
bool is_signed_type(NumericTypeID type)
Returns true if numeric type is signed.
cutlass::library::NumericTypeID::kCU4
cutlass::library::GemmDescription::element_epilogue
NumericTypeID element_epilogue
Describes the data type of the scalars passed to the epilogue.
Definition: library.h:379
cutlass::library::GemmArrayConfiguration::lda
int64_t const * lda
Definition: library.h:555
cutlass::library::SplitKMode::kParallel
cutlass::library::GemmArrayConfiguration::ldd
int64_t const * ldd
Definition: library.h:558
cutlass::library::TileDescription::minimum_compute_capability
int minimum_compute_capability
Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation.
Definition: library.h:278
cutlass::library::GemmPlanarComplexConfiguration::ldd
int64_t ldd
Definition: library.h:588
cutlass::library::GemmBatchedConfiguration::batch_stride_C
int64_t batch_stride_C
Stride between instances of the C matrix in memory.
Definition: library.h:532
cutlass::library::GemmArrayArguments::B
void const *const * B
Definition: library.h:566
cutlass::library::NumericTypeID
NumericTypeID
Numeric data type.
Definition: library.h:77
cutlass::library::GemmPlanarComplexConfiguration::lda
int64_t lda
Definition: library.h:585
cutlass::library::TileDescription::warp_count
cutlass::gemm::GemmCoord warp_count
Number of warps in each logical dimension.
Definition: library.h:272
cutlass::library::GemmConfiguration::lda
int64_t lda
Leading dimension of A matrix.
Definition: library.h:461
cutlass::library::is_float_type
bool is_float_type(NumericTypeID type)
Returns true if numeric type is floating-point type.
cutlass::library::LayoutTypeID::kColumnMajorInterleavedK4
cutlass::library::TileDescription::TileDescription
TileDescription(cutlass::gemm::GemmCoord threadblock_shape=cutlass::gemm::GemmCoord(), int threadblock_stages=0, cutlass::gemm::GemmCoord warp_count=cutlass::gemm::GemmCoord(), MathInstructionDescription math_instruction=MathInstructionDescription(), int minimum_compute_capability=0, int maximum_compute_capability=0)
Definition: library.h:287
cutlass::library::cast_from_double
bool cast_from_double(std::vector< uint8_t > &bytes, NumericTypeID type, double src)
Casts from a real value represented as a double to the destination type. Returns true if successful...
cutlass::library::NumericTypeID::kU32
cutlass::library::MathInstructionDescription::element_accumulator
NumericTypeID element_accumulator
Describes the data type of the internal accumulator.
Definition: library.h:244
Defines a canonical coordinate for rank=4 tensors offering named indices.
cutlass::library::GemmDescription::B
TensorDescription B
Describes the B operand.
Definition: library.h:373
cutlass::library::LayoutTypeID::kInvalid
cutlass::library::GemmArrayArguments::alpha
void const * alpha
Definition: library.h:569
cutlass::library::NumericTypeID::kU4
cutlass::library::GemmArguments::beta
void const * beta
Host or device pointer to beta scalar.
Definition: library.h:495
cutlass::library::LayoutTypeID::kTensorNCHW
cutlass::library::GemmBatchedConfiguration::ldd
int64_t ldd
Leading dimension of D matrix.
Definition: library.h:523
cutlass::library::MathInstructionDescription::opcode_class
OpcodeClassID opcode_class
Classification of math instruction.
Definition: library.h:247
cutlass::library::NumericTypeID::kCF16
cutlass::library::GemmBatchedConfiguration::problem_size
gemm::GemmCoord problem_size
GEMM problem size.
Definition: library.h:511
cutlass::library::GemmConfiguration::ldc
int64_t ldc
Leading dimension of C matrix.
Definition: library.h:467
cutlass::library::LayoutTypeID::kTensorNHWC
cutlass::library::ScalarPointerMode::kHost
cutlass::library::GemmArguments::D
void * D
Pointer to D matrix.
Definition: library.h:489
cutlass::library::cast_from_uint64
bool cast_from_uint64(std::vector< uint8_t > &bytes, NumericTypeID type, uint64_t src)
Casts from an unsigned int64 to the destination type. Returns true if successful. ...
cutlass::library::GemmDescription::C
TensorDescription C
Describes the source and destination matrices.
Definition: library.h:376
cutlass::library::GemmPlanarComplexConfiguration::ldb
int64_t ldb
Definition: library.h:586
cutlass::library::GemmPlanarComplexBatchedConfiguration::imag_stride_C
int64_t imag_stride_C
Definition: library.h:616
cutlass::library::GemmPlanarComplexBatchedConfiguration
Batched complex valued GEMM in which real and imaginary parts are separated by a stride.
Definition: library.h:605
cutlass::library::OperationKind::kGemm
cutlass::library::GemmPlanarComplexBatchedConfiguration::batched_stride_D
int64_t batched_stride_D
Definition: library.h:622
cutlass::library::GemmBatchedConfiguration
Configuration for batched GEMM in which multiple matrix products are computed.
Definition: library.h:508
cutlass::library::GemmKind::kBatched
cutlass::library::LayoutTypeID::kRowMajor
cutlass::library::GemmBatchedConfiguration::batch_stride_A
int64_t batch_stride_A
Stride between instances of the A matrix in memory.
Definition: library.h:526
cutlass::library::is_integer_type
bool is_integer_type(NumericTypeID type)
Returns true if numeric type is integer.
cutlass::library::ScalarPointerMode
ScalarPointerMode
Enumeration indicating whether scalars are in host or device memory.
Definition: library.h:123
cutlass::library::TensorDescription::element
NumericTypeID element
Numeric type of an individual element.
Definition: library.h:330
cutlass::library::NumericTypeID::kCU16
cutlass::library::GemmBatchedConfiguration::batch_count
int batch_count
Number of GEMMs in batch.
Definition: library.h:538
cutlass::library::GemmArguments::C
void const * C
Pointer to C matrix.
Definition: library.h:486
T from_string(std::string const &)
Lexical cast from string.
cutlass::library::TileDescription::threadblock_stages
int threadblock_stages
Describes the number of pipeline stages in the threadblock-scoped mainloop.
Definition: library.h:269
Defines a canonical coordinate for rank=2 matrices offering named indices.
cutlass::library::GemmPlanarComplexConfiguration::imag_stride_D
int64_t imag_stride_D
Definition: library.h:593
cutlass::library::NumericTypeID::kS4
cutlass::library::NumericTypeID::kS32
cutlass::library::from_string< LayoutTypeID >
LayoutTypeID from_string< LayoutTypeID >(std::string const &str)
Parses a LayoutType enumerant from a string.
cutlass::library::GemmArrayArguments::pointer_mode
ScalarPointerMode pointer_mode
Definition: library.h:571
cutlass::library::MathInstructionDescription::MathInstructionDescription
MathInstructionDescription(cutlass::gemm::GemmCoord instruction_shape=cutlass::gemm::GemmCoord(), NumericTypeID element_accumulator=NumericTypeID::kInvalid, OpcodeClassID opcode_class=OpcodeClassID::kInvalid)
Definition: library.h:253
cutlass::library::GemmPlanarComplexBatchedConfiguration::batched_stride_A
int64_t batched_stride_A
Definition: library.h:619
cutlass::library::GemmDescription
Description of all GEMM computations.
Definition: library.h:364
cutlass::library::GemmBatchedConfiguration::lda
int64_t lda
Leading dimension of A matrix.
Definition: library.h:514
cutlass::library::GemmConfiguration::problem_size
gemm::GemmCoord problem_size
GEMM problem size.
Definition: library.h:458
cutlass::library::GemmDescription::split_k_mode
SplitKMode split_k_mode
Describes the structure of parallel reductions.
Definition: library.h:382
cutlass::library::cast_from_int64
bool cast_from_int64(std::vector< uint8_t > &bytes, NumericTypeID type, int64_t src)
Casts from a signed int64 to the destination type. Returns true if successful.
cutlass::library::TensorDescription::log_extent_range
int log_extent_range
log2() of the maximum extent of each dimension
Definition: library.h:339
cutlass::library::OperationDescription::name
char const * name
Unique identifier describing the operation.
Definition: library.h:307
cutlass::library::GemmBatchedConfiguration::batch_stride_B
int64_t batch_stride_B
Stride between instances of the B matrix in memory.
Definition: library.h:529
cutlass::library::MathInstructionDescription::instruction_shape
cutlass::gemm::GemmCoord instruction_shape
Shape of the target math instruction.
Definition: library.h:241
cutlass::library::GemmArrayArguments::beta
void const * beta
Definition: library.h:570
cutlass::library::OperationDescription::tile_description
TileDescription tile_description
Describes the tiled structure of a GEMM-like computation.
Definition: library.h:313
cutlass::library::NumericTypeID::kF16
cutlass::library::NumericTypeID::kF32
cutlass::library::GemmBatchedConfiguration::ldc
int64_t ldc
Leading dimension of C matrix.
Definition: library.h:520
cutlass::library::TensorDescription
Structure describing the properties of a tensor.
Definition: library.h:327
cutlass::library::GemmConfiguration::ldb
int64_t ldb
Leading dimension of B matrix.
Definition: library.h:464
cutlass::library::GemmPlanarComplexBatchedConfiguration::problem_size
gemm::GemmCoord problem_size
Definition: library.h:607
cutlass::library::is_unsigned_integer
bool is_unsigned_integer(NumericTypeID type)
returns true if numeric type is an unsigned integer
cutlass::library::GemmArrayArguments
Arguments for GEMM - used by all the GEMM operations.
Definition: library.h:564
cutlass::library::NumericTypeID::kCS32
cutlass::library::LayoutTypeID::kUnknown
cutlass::library::OpcodeClassID::kSimt
cutlass::library::OperationKind
OperationKind
Enumeration indicating the kind of operation.
Definition: library.h:117
cutlass::library::GemmArguments::alpha
void const * alpha
Host or device pointer to alpha scalar.
Definition: library.h:492
cutlass::library::SplitKMode::kParallelSerial
cutlass::library::OperationDescription::kind
OperationKind kind
Kind of operation.
Definition: library.h:310
cutlass::library::TileDescription::threadblock_shape
cutlass::gemm::GemmCoord threadblock_shape
Describes the shape of a threadblock (in elements)
Definition: library.h:266
cutlass::library::GemmPlanarComplexBatchedConfiguration::ldd
int64_t ldd
Definition: library.h:612
cutlass::library::NumericTypeID::kVoid
cutlass::library::ScalarPointerMode::kDevice
cutlass::library::NumericTypeID::kF64
cutlass::library::GemmPlanarComplexConfiguration::imag_stride_C
int64_t imag_stride_C
Definition: library.h:592
cutlass::library::TileDescription::math_instruction
MathInstructionDescription math_instruction
Core math instruction.
Definition: library.h:275
Basic include for CUTLASS.
SplitKMode
Describes how reductions are performed across threadblocks.
Definition: library.h:130
cutlass::library::NumericTypeID::kU64
cutlass::library::NumericTypeID::kCS4
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
cutlass::library::GemmBatchedConfiguration::batch_stride_D
int64_t batch_stride_D
Stride between instances of the D matrix in memory.
Definition: library.h:535
cutlass::library::LayoutTypeID::kRowMajorInterleavedK4
cutlass::library::TensorDescription::TensorDescription
TensorDescription(NumericTypeID element=NumericTypeID::kInvalid, LayoutTypeID layout=LayoutTypeID::kInvalid, int alignment=1, int log_extent_range=24, int log_stride_range=24)
Definition: library.h:347
Generated by 1.8.11