Back to Cutlass

CUTLASS: library.h Source File

docs/library_8h_source.html

4.4.283.4 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

library.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

40 #pragma once

41

43

44 #include <vector>

45 #include <string>

46 #include <cstdint>

47 #include <cuda_runtime.h>

48

49 #include "cutlass/cutlass.h"

50 #include "cutlass/matrix_coord.h"

51 #include "cutlass/tensor_coord.h"

52 #include "cutlass/layout/tensor.h"

53

54 #include "cutlass/gemm/gemm.h"

56

57 namespace cutlass {

58 namespace library {

59

61

63 enum class LayoutTypeID {

64kUnknown,

65kColumnMajor,

66kRowMajor,

67kColumnMajorInterleavedK4,

68kRowMajorInterleavedK4,

69kColumnMajorInterleavedK16,

70kRowMajorInterleavedK16,

71kTensorNCHW,

72kTensorNHWC,

73kInvalid

74 };

75

77 enum class NumericTypeID {

78kUnknown,

79kVoid,

80kB1,

81kU4,

82kU8,

83kU16,

84kU32,

85kU64,

86kS4,

87kS8,

88kS16,

89kS32,

90kS64,

91kF16,

92kF32,

93kF64,

94kCF16,

95kCF32,

96kCF64,

97kCS4,

98kCS8,

99kCS16,

100kCS32,

101kCS64,

102kCU4,

103kCU8,

104kCU16,

105kCU32,

106kCU64,

107 kInvalid

108 };

109

111 enum class ComplexTransform {

112kNone,

113kConjugate

114 };

115

117 enum class OperationKind {

118kGemm,

119 kInvalid

120 };

121

123 enum class ScalarPointerMode {

124kHost,

125kDevice,

126 kInvalid

127 };

128

130 enum class SplitKMode {

131kNone,

132kSerial,

133kParallel,

134kParallelSerial,

135 kInvalid

136 };

137

139 enum class OpcodeClassID {

140kSimt,

141kTensorOp,

142kWmmaTensorOp,

143 kInvalid

144 };

145

147

149 enum class GemmKind {

150kGemm,

151kBatched,

152kArray,

153kPlanarComplex,

154kPlanarComplexBatched,

155 kInvalid

156 };

157

159

161 template <typename T> T from_string(std::string const &);

162

164 char const *to_string(OperationKind type, bool pretty = false);

165

167 template <> OperationKind from_string<OperationKind>(std::string const &str);

168

170 char const *to_string(NumericTypeID type, bool pretty = false);

171

173 template <> NumericTypeID from_string<NumericTypeID>(std::string const &str);

174

176 int sizeof_bits(NumericTypeID type);

177

179 bool is_complex_type(NumericTypeID type);

180

182 NumericTypeID get_real_type(NumericTypeID type);

183

185 bool is_integer_type(NumericTypeID type);

186

188 bool is_signed_type(NumericTypeID type);

189

191 bool is_signed_integer(NumericTypeID type);

192

194 bool is_unsigned_integer(NumericTypeID type);

195

197 bool is_float_type(NumericTypeID type);

198

200 char const *to_string(Status status, bool pretty = false);

201

203 char const *to_string(LayoutTypeID layout, bool pretty = false);

204

206 template <> LayoutTypeID from_string<LayoutTypeID>(std::string const &str);

207

209 int get_layout_stride_rank(LayoutTypeID layout_id);

210

212 char const *to_string(OpcodeClassID type, bool pretty = false);

213

215 template <>

216 OpcodeClassID from_string<OpcodeClassID>(std::string const &str);

217

219 std::string lexical_cast(int64_t int_value);

220

222 bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str);

223

225 std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type);

226

228 bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src);

229

231 bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src);

232

234 bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);

235

237

238 struct MathInstructionDescription {

239

241cutlass::gemm::GemmCoord instruction_shape;

242

244NumericTypeID element_accumulator;

245

247OpcodeClassID opcode_class;

248

249//

250// Methods

251//

252

253MathInstructionDescription(

254cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(),

255NumericTypeID element_accumulator = NumericTypeID::kInvalid,

256OpcodeClassID opcode_class = OpcodeClassID::kInvalid

257 ):

258 instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {}

259

260 };

261

263 struct TileDescription {

264

266cutlass::gemm::GemmCoord threadblock_shape;

267

269int threadblock_stages;

270

272cutlass::gemm::GemmCoord warp_count;

273

275MathInstructionDescription math_instruction;

276

278int minimum_compute_capability;

279

281int maximum_compute_capability;

282

283//

284// Methods

285//

286

287TileDescription(

288cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(),

289int threadblock_stages = 0,

290cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(),

291MathInstructionDescription math_instruction = MathInstructionDescription(),

292int minimum_compute_capability = 0,

293int maximum_compute_capability = 0

294 ):

295 threadblock_shape(threadblock_shape),

296 threadblock_stages(threadblock_stages),

297 warp_count(warp_count),

298 math_instruction(math_instruction),

299 minimum_compute_capability(minimum_compute_capability),

300 maximum_compute_capability(maximum_compute_capability) { }

301 };

302

304 struct OperationDescription {

305

307char const * name;

308

310OperationKind kind;

311

313TileDescription tile_description;

314

315//

316// Methods

317//

318OperationDescription(

319char const * name = "unknown",

320OperationKind kind = OperationKind::kInvalid,

321TileDescription const & tile_description = TileDescription()

322 ):

323 name(name), kind(kind), tile_description(tile_description) { }

324 };

325

327 struct TensorDescription {

328

330NumericTypeID element;

331

333LayoutTypeID layout;

334

336int alignment;

337

339int log_extent_range;

340

342int log_stride_range;

343

344//

345// Methods

346//

347TensorDescription(

348NumericTypeID element = NumericTypeID::kInvalid,

349LayoutTypeID layout = LayoutTypeID::kInvalid,

350int alignment = 1,

351int log_extent_range = 24,

352int log_stride_range = 24

353 ):

354 element(element),

355 layout(layout),

356 alignment(alignment),

357 log_extent_range(log_extent_range),

358 log_stride_range(log_stride_range) { }

359 };

360

362

364 struct GemmDescription : public OperationDescription {

365

367GemmKind gemm_kind;

368

370TensorDescription A;

371

373TensorDescription B;

374

376TensorDescription C;

377

379NumericTypeID element_epilogue;

380

382SplitKMode split_k_mode;

383

385ComplexTransform transform_A;

386

388ComplexTransform transform_B;

389

390//

391// Methods

392//

393

394GemmDescription(

395GemmKind gemm_kind = GemmKind::kGemm,

396TensorDescription const &A = TensorDescription(),

397TensorDescription const &B = TensorDescription(),

398TensorDescription const &C = TensorDescription(),

399NumericTypeID element_epilogue = NumericTypeID::kInvalid,

400SplitKMode split_k_mode = SplitKMode::kNone,

401ComplexTransform transform_A = ComplexTransform::kNone,

402ComplexTransform transform_B = ComplexTransform::kNone

403 ):

404 gemm_kind(gemm_kind),

405 A(A),

406 B(B),

407 C(C),

408 element_epilogue(element_epilogue),

409 split_k_mode(split_k_mode),

410 transform_A(transform_A),

411 transform_B(transform_B) {}

412 };

413

416

418 class Operation {

419 public:

420

421virtual ~Operation() { }

422

423virtual OperationDescription const & description() const = 0;

424

425virtual Status can_implement(

426void const *configuration,

427void const *arguments) const = 0;

428

429virtual uint64_t get_host_workspace_size(

430void const *configuration) const = 0;

431

432virtual uint64_t get_device_workspace_size(

433void const *configuration) const = 0;

434

435virtual Status initialize(

436void const *configuration,

437void *host_workspace,

438void *device_workspace,

439 cudaStream_t stream = nullptr) const = 0;

440

441virtual Status run(

442void const *arguments,

443void *host_workspace,

444void *device_workspace = nullptr,

445 cudaStream_t stream = nullptr) const = 0;

446 };

447

449

451 //

452 // OperationKind: Gemm

453 // GemmKind: Gemm

454 //

455 struct GemmConfiguration {

456

458gemm::GemmCoord problem_size;

459

461 int64_t lda;

462

464 int64_t ldb;

465

467 int64_t ldc;

468

470 int64_t ldd;

471

473int split_k_slices;

474 };

475

477 struct GemmArguments {

478

480void const *A;

481

483void const *B;

484

486void const *C;

487

489void *D;

490

492void const *alpha;

493

495void const *beta;

496

498ScalarPointerMode pointer_mode;

499 };

500

502

504 //

505 // OperationKind: Gemm

506 // GemmKind: Batched

507

508 struct GemmBatchedConfiguration {

509

511gemm::GemmCoord problem_size;

512

514 int64_t lda;

515

517 int64_t ldb;

518

520 int64_t ldc;

521

523 int64_t ldd;

524

526 int64_t batch_stride_A;

527

529 int64_t batch_stride_B;

530

532 int64_t batch_stride_C;

533

535 int64_t batch_stride_D;

536

538int batch_count;

539 };

540

542 using GemmBatchedArguments = GemmArguments;

543

545

547 //

548 // OperationKind: Gemm

549 // GemmKind: Array

550

551 struct GemmArrayConfiguration {

552

553gemm::GemmCoord problem_size;

554

555 int64_t const *lda;

556 int64_t const *ldb;

557 int64_t const *ldc;

558 int64_t const *ldd;

559

560int batch_count;

561 };

562

564 struct GemmArrayArguments {

565void const * const *A;

566void const * const *B;

567void const * const *C;

568void * const *D;

569void const *alpha;

570void const *beta;

571ScalarPointerMode pointer_mode;

572 };

573

575

577 //

578 // OperationKind: Gemm

579 // GemmKind: Planar complex

580

581 struct GemmPlanarComplexConfiguration {

582

583gemm::GemmCoord problem_size;

584

585 int64_t lda;

586 int64_t ldb;

587 int64_t ldc;

588 int64_t ldd;

589

590 int64_t imag_stride_A;

591 int64_t imag_stride_B;

592 int64_t imag_stride_C;

593 int64_t imag_stride_D;

594 };

595

596 using GemmPlanarComplexArgments = GemmArguments;

597

599

601 //

602 // OperationKind: Gemm

603 // GemmKind: Planar complex batched

604 //

605 struct GemmPlanarComplexBatchedConfiguration {

606

607gemm::GemmCoord problem_size;

608

609 int64_t lda;

610 int64_t ldb;

611 int64_t ldc;

612 int64_t ldd;

613

614 int64_t imag_stride_A;

615 int64_t imag_stride_B;

616 int64_t imag_stride_C;

617 int64_t imag_stride_D;

618

619 int64_t batched_stride_A;

620 int64_t batched_stride_B;

621 int64_t batched_stride_C;

622 int64_t batched_stride_D;

623 };

624

625 using GemmPlanarComplexBatchedArguments = GemmArguments;

626

628

629 } // namespace library

630 } // namespace cutlass

631

cutlass::library::GemmPlanarComplexBatchedConfiguration::lda

int64_t lda

Definition: library.h:609

cutlass::library::NumericTypeID::kCS8

cutlass::library::TensorDescription::alignment

int alignment

Alignment restriction on pointers, strides, and extents.

Definition: library.h:336

cutlass::library::NumericTypeID::kCU32

cutlass::library::GemmArrayArguments::A

void const *const * A

Definition: library.h:565

cutlass::library::GemmKind::kPlanarComplexBatched

cutlass::library::NumericTypeID::kCS16

cutlass::library::OpcodeClassID::kWmmaTensorOp

cutlass::library::Operation::~Operation

virtual ~Operation()

Definition: library.h:421

cutlass::library::OperationDescription

High-level description of an operation.

Definition: library.h:304

cutlass::library::GemmKind::kPlanarComplex

cutlass::library::LayoutTypeID::kColumnMajor

cutlass::library::to_string

char const * to_string(OperationKind type, bool pretty=false)

Converts a NumericType enumerant to a string.

cutlass

Definition: aligned_buffer.h:35

cutlass::library::is_complex_type

bool is_complex_type(NumericTypeID type)

Returns true if the numeric type is a complex data type or false if real-valued.

cutlass::library::GemmArrayArguments::D

void *const * D

Definition: library.h:568

cutlass::library::TensorDescription::layout

LayoutTypeID layout

Enumerant identifying the layout function for the tensor.

Definition: library.h:333

cutlass::library::GemmDescription::gemm_kind

GemmKind gemm_kind

Indicates the kind of GEMM performed.

Definition: library.h:367

cutlass::library::GemmPlanarComplexConfiguration::ldc

int64_t ldc

Definition: library.h:587

cutlass::library::GemmArguments

Arguments for GEMM.

Definition: library.h:477

cutlass::library::GemmKind::kArray

cutlass::library::GemmArrayConfiguration::batch_count

int batch_count

Definition: library.h:560

cutlass::library::ComplexTransform

ComplexTransform

Enumeraed type describing a transformation on a complex value.

Definition: library.h:111

cutlass::library::LayoutTypeID::kColumnMajorInterleavedK16

cutlass::library::GemmArrayArguments::C

void const *const * C

Definition: library.h:567

cutlass::library::GemmPlanarComplexConfiguration::problem_size

gemm::GemmCoord problem_size

Definition: library.h:583

cutlass::library::GemmArrayConfiguration

Configuration for batched GEMM in which multiple matrix products are computed.

Definition: library.h:551

cutlass::library::is_signed_integer

bool is_signed_integer(NumericTypeID type)

Returns true if numeric type is a signed integer.

cutlass::library::GemmKind

GemmKind

Enumeration indicating what kind of GEMM operation to perform.

Definition: library.h:149

cutlass::library::NumericTypeID::kCF64

cutlass::library::NumericTypeID::kCU64

cutlass::library::get_real_type

NumericTypeID get_real_type(NumericTypeID type)

Returns the real-valued type underlying a type (only different from 'type' if complex) ...

cutlass::library::from_string< OperationKind >

OperationKind from_string< OperationKind >(std::string const &str)

Parses a NumericType enumerant from a string.

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::library::get_layout_stride_rank

int get_layout_stride_rank(LayoutTypeID layout_id)

Returns the rank of a layout's stride base on the LayoutTypeID.

cutlass::library::GemmBatchedConfiguration::ldb

int64_t ldb

Leading dimension of B matrix.

Definition: library.h:517

cutlass::library::NumericTypeID::kCF32

cutlass::library::NumericTypeID::kCU8

cutlass::library::ComplexTransform::kNone

cutlass::library::GemmArrayConfiguration::ldc

int64_t const * ldc

Definition: library.h:557

cutlass::library::OpcodeClassID::kTensorOp

cutlass::library::NumericTypeID::kS64

cutlass::library::GemmPlanarComplexBatchedConfiguration::batched_stride_B

int64_t batched_stride_B

Definition: library.h:620

cutlass::library::GemmPlanarComplexConfiguration

Complex valued GEMM in which real and imaginary parts are separated by a stride.

Definition: library.h:581

cutlass::library::TensorDescription::log_stride_range

int log_stride_range

log2() of the maximum value each relevant stride may have

Definition: library.h:342

cutlass::library::ComplexTransform::kConjugate

cutlass::library::NumericTypeID::kB1

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::library::GemmDescription::transform_A

ComplexTransform transform_A

Transformation on A operand.

Definition: library.h:385

cutlass::library::NumericTypeID::kS16

cutlass::library::SplitKMode::kSerial

cutlass::library::GemmPlanarComplexBatchedConfiguration::imag_stride_B

int64_t imag_stride_B

Definition: library.h:615

cutlass::library::GemmDescription::GemmDescription

GemmDescription(GemmKind gemm_kind=GemmKind::kGemm, TensorDescription const &A=TensorDescription(), TensorDescription const &B=TensorDescription(), TensorDescription const &C=TensorDescription(), NumericTypeID element_epilogue=NumericTypeID::kInvalid, SplitKMode split_k_mode=SplitKMode::kNone, ComplexTransform transform_A=ComplexTransform::kNone, ComplexTransform transform_B=ComplexTransform::kNone)

Definition: library.h:394

cutlass::library::sizeof_bits

int sizeof_bits(NumericTypeID type)

Returns the size of a data type in bits.

cutlass::library::Operation

Base class for all device-wide operations.

Definition: library.h:418

cutlass::library::GemmPlanarComplexConfiguration::imag_stride_A

int64_t imag_stride_A

Definition: library.h:590

cutlass::library::NumericTypeID::kCS64

cutlass::library::from_string< NumericTypeID >

NumericTypeID from_string< NumericTypeID >(std::string const &str)

Parses a NumericType enumerant from a string.

cutlass::library::LayoutTypeID

LayoutTypeID

Layout type identifier.

Definition: library.h:63

cutlass::library::OpcodeClassID

OpcodeClassID

Indicates the classificaition of the math instruction.

Definition: library.h:139

cutlass::library::GemmPlanarComplexBatchedConfiguration::ldc

int64_t ldc

Definition: library.h:611

cutlass::library::lexical_cast

std::string lexical_cast(int64_t int_value)

Lexical cast from int64_t to string.

cutlass::library::GemmArguments::pointer_mode

ScalarPointerMode pointer_mode

Enumerant indicating whether alpha/beta point to host or device memory.

Definition: library.h:498

cutlass::library::GemmArrayConfiguration::ldb

int64_t const * ldb

Definition: library.h:556

cutlass::library::GemmArrayConfiguration::problem_size

gemm::GemmCoord problem_size

Definition: library.h:553

cutlass::library::GemmPlanarComplexBatchedConfiguration::batched_stride_C

int64_t batched_stride_C

Definition: library.h:621

cutlass::library::OperationDescription::OperationDescription

OperationDescription(char const *name="unknown", OperationKind kind=OperationKind::kInvalid, TileDescription const &tile_description=TileDescription())

Definition: library.h:318

cutlass::library::TileDescription::maximum_compute_capability

int maximum_compute_capability

Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation.

Definition: library.h:281

cutlass::library::NumericTypeID::kS8

cutlass::library::GemmConfiguration

Configuration for basic GEMM operations.

Definition: library.h:455

cutlass::library::GemmPlanarComplexBatchedConfiguration::imag_stride_D

int64_t imag_stride_D

Definition: library.h:617

cutlass::library::MathInstructionDescription

Definition: library.h:238

cutlass::library::GemmArguments::B

void const * B

Pointer to B matrix.

Definition: library.h:483

cutlass::library::GemmPlanarComplexBatchedConfiguration::imag_stride_A

int64_t imag_stride_A

Definition: library.h:614

cutlass::library::NumericTypeID::kU16

cutlass::library::LayoutTypeID::kRowMajorInterleavedK16

cutlass::library::GemmDescription::A

TensorDescription A

Describes the A operand.

Definition: library.h:370

cutlass::library::TileDescription

Structure describing the tiled structure of a GEMM-like computation.

Definition: library.h:263

cutlass::library::GemmConfiguration::split_k_slices

int split_k_slices

Number of partitions of K dimension.

Definition: library.h:473

cutlass::library::GemmPlanarComplexConfiguration::imag_stride_B

int64_t imag_stride_B

Definition: library.h:591

cutlass::library::from_string< OpcodeClassID >

OpcodeClassID from_string< OpcodeClassID >(std::string const &str)

Converts a OpcodeClassID enumerant from a string.

cutlass::library::NumericTypeID::kU8

cutlass::library::GemmArguments::A

void const * A

Pointer to A matrix.

Definition: library.h:480

tensor.h

Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...

cutlass::library::GemmConfiguration::ldd

int64_t ldd

Leading dimension of D matrix.

Definition: library.h:470

cutlass::library::GemmDescription::transform_B

ComplexTransform transform_B

Transformation on B operand.

Definition: library.h:388

cutlass::library::GemmPlanarComplexBatchedConfiguration::ldb

int64_t ldb

Definition: library.h:610

cutlass::library::is_signed_type

bool is_signed_type(NumericTypeID type)

Returns true if numeric type is signed.

cutlass::library::NumericTypeID::kCU4

cutlass::library::GemmDescription::element_epilogue

NumericTypeID element_epilogue

Describes the data type of the scalars passed to the epilogue.

Definition: library.h:379

cutlass::library::GemmArrayConfiguration::lda

int64_t const * lda

Definition: library.h:555

cutlass::library::SplitKMode::kParallel

cutlass::library::GemmArrayConfiguration::ldd

int64_t const * ldd

Definition: library.h:558

cutlass::library::TileDescription::minimum_compute_capability

int minimum_compute_capability

Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation.

Definition: library.h:278

cutlass::library::GemmPlanarComplexConfiguration::ldd

int64_t ldd

Definition: library.h:588

cutlass::library::GemmBatchedConfiguration::batch_stride_C

int64_t batch_stride_C

Stride between instances of the C matrix in memory.

Definition: library.h:532

cutlass::library::GemmArrayArguments::B

void const *const * B

Definition: library.h:566

cutlass::library::NumericTypeID

NumericTypeID

Numeric data type.

Definition: library.h:77

cutlass::library::GemmPlanarComplexConfiguration::lda

int64_t lda

Definition: library.h:585

cutlass::library::TileDescription::warp_count

cutlass::gemm::GemmCoord warp_count

Number of warps in each logical dimension.

Definition: library.h:272

cutlass::library::GemmConfiguration::lda

int64_t lda

Leading dimension of A matrix.

Definition: library.h:461

cutlass::library::is_float_type

bool is_float_type(NumericTypeID type)

Returns true if numeric type is floating-point type.

cutlass::library::LayoutTypeID::kColumnMajorInterleavedK4

cutlass::library::TileDescription::TileDescription

TileDescription(cutlass::gemm::GemmCoord threadblock_shape=cutlass::gemm::GemmCoord(), int threadblock_stages=0, cutlass::gemm::GemmCoord warp_count=cutlass::gemm::GemmCoord(), MathInstructionDescription math_instruction=MathInstructionDescription(), int minimum_compute_capability=0, int maximum_compute_capability=0)

Definition: library.h:287

cutlass::library::cast_from_double

bool cast_from_double(std::vector< uint8_t > &bytes, NumericTypeID type, double src)

Casts from a real value represented as a double to the destination type. Returns true if successful...

cutlass::library::NumericTypeID::kU32

cutlass::library::MathInstructionDescription::element_accumulator

NumericTypeID element_accumulator

Describes the data type of the internal accumulator.

Definition: library.h:244

tensor_coord.h

Defines a canonical coordinate for rank=4 tensors offering named indices.

cutlass::library::GemmDescription::B

TensorDescription B

Describes the B operand.

Definition: library.h:373

cutlass::library::LayoutTypeID::kInvalid

cutlass::library::GemmArrayArguments::alpha

void const * alpha

Definition: library.h:569

cutlass::library::NumericTypeID::kU4

cutlass::library::GemmArguments::beta

void const * beta

Host or device pointer to beta scalar.

Definition: library.h:495

cutlass::library::LayoutTypeID::kTensorNCHW

cutlass::library::GemmBatchedConfiguration::ldd

int64_t ldd

Leading dimension of D matrix.

Definition: library.h:523

cutlass::library::MathInstructionDescription::opcode_class

OpcodeClassID opcode_class

Classification of math instruction.

Definition: library.h:247

cutlass::library::NumericTypeID::kCF16

cutlass::library::GemmBatchedConfiguration::problem_size

gemm::GemmCoord problem_size

GEMM problem size.

Definition: library.h:511

cutlass::library::GemmConfiguration::ldc

int64_t ldc

Leading dimension of C matrix.

Definition: library.h:467

cutlass::library::LayoutTypeID::kTensorNHWC

cutlass::library::ScalarPointerMode::kHost

cutlass::library::GemmArguments::D

void * D

Pointer to D matrix.

Definition: library.h:489

cutlass::library::cast_from_uint64

bool cast_from_uint64(std::vector< uint8_t > &bytes, NumericTypeID type, uint64_t src)

Casts from an unsigned int64 to the destination type. Returns true if successful. ...

cutlass::library::GemmDescription::C

TensorDescription C

Describes the source and destination matrices.

Definition: library.h:376

cutlass::library::GemmPlanarComplexConfiguration::ldb

int64_t ldb

Definition: library.h:586

cutlass::library::GemmPlanarComplexBatchedConfiguration::imag_stride_C

int64_t imag_stride_C

Definition: library.h:616

cutlass::library::GemmPlanarComplexBatchedConfiguration

Batched complex valued GEMM in which real and imaginary parts are separated by a stride.

Definition: library.h:605

cutlass::library::OperationKind::kGemm

cutlass::library::GemmPlanarComplexBatchedConfiguration::batched_stride_D

int64_t batched_stride_D

Definition: library.h:622

cutlass::library::GemmBatchedConfiguration

Configuration for batched GEMM in which multiple matrix products are computed.

Definition: library.h:508

cutlass::library::GemmKind::kBatched

cutlass::library::LayoutTypeID::kRowMajor

cutlass::library::GemmBatchedConfiguration::batch_stride_A

int64_t batch_stride_A

Stride between instances of the A matrix in memory.

Definition: library.h:526

cutlass::library::is_integer_type

bool is_integer_type(NumericTypeID type)

Returns true if numeric type is integer.

cutlass::library::ScalarPointerMode

ScalarPointerMode

Enumeration indicating whether scalars are in host or device memory.

Definition: library.h:123

cutlass::library::TensorDescription::element

NumericTypeID element

Numeric type of an individual element.

Definition: library.h:330

cutlass::library::NumericTypeID::kCU16

cutlass::library::GemmBatchedConfiguration::batch_count

int batch_count

Number of GEMMs in batch.

Definition: library.h:538

cutlass::library::GemmArguments::C

void const * C

Pointer to C matrix.

Definition: library.h:486

cutlass::library::from_string

T from_string(std::string const &)

Lexical cast from string.

cutlass::library::TileDescription::threadblock_stages

int threadblock_stages

Describes the number of pipeline stages in the threadblock-scoped mainloop.

Definition: library.h:269

matrix_coord.h

Defines a canonical coordinate for rank=2 matrices offering named indices.

cutlass::library::GemmPlanarComplexConfiguration::imag_stride_D

int64_t imag_stride_D

Definition: library.h:593

cutlass::library::NumericTypeID::kS4

cutlass::library::NumericTypeID::kS32

cutlass::library::from_string< LayoutTypeID >

LayoutTypeID from_string< LayoutTypeID >(std::string const &str)

Parses a LayoutType enumerant from a string.

cutlass::library::GemmArrayArguments::pointer_mode

ScalarPointerMode pointer_mode

Definition: library.h:571

cutlass::library::MathInstructionDescription::MathInstructionDescription

MathInstructionDescription(cutlass::gemm::GemmCoord instruction_shape=cutlass::gemm::GemmCoord(), NumericTypeID element_accumulator=NumericTypeID::kInvalid, OpcodeClassID opcode_class=OpcodeClassID::kInvalid)

Definition: library.h:253

cutlass::library::GemmPlanarComplexBatchedConfiguration::batched_stride_A

int64_t batched_stride_A

Definition: library.h:619

cutlass::library::GemmDescription

Description of all GEMM computations.

Definition: library.h:364

cutlass::library::GemmBatchedConfiguration::lda

int64_t lda

Leading dimension of A matrix.

Definition: library.h:514

cutlass::library::GemmConfiguration::problem_size

gemm::GemmCoord problem_size

GEMM problem size.

Definition: library.h:458

cutlass::library::GemmDescription::split_k_mode

SplitKMode split_k_mode

Describes the structure of parallel reductions.

Definition: library.h:382

cutlass::library::cast_from_int64

bool cast_from_int64(std::vector< uint8_t > &bytes, NumericTypeID type, int64_t src)

Casts from a signed int64 to the destination type. Returns true if successful.

cutlass::library::TensorDescription::log_extent_range

int log_extent_range

log2() of the maximum extent of each dimension

Definition: library.h:339

cutlass::library::OperationDescription::name

char const * name

Unique identifier describing the operation.

Definition: library.h:307

cutlass::library::GemmBatchedConfiguration::batch_stride_B

int64_t batch_stride_B

Stride between instances of the B matrix in memory.

Definition: library.h:529

cutlass::library::MathInstructionDescription::instruction_shape

cutlass::gemm::GemmCoord instruction_shape

Shape of the target math instruction.

Definition: library.h:241

cutlass::library::GemmArrayArguments::beta

void const * beta

Definition: library.h:570

cutlass::library::OperationDescription::tile_description

TileDescription tile_description

Describes the tiled structure of a GEMM-like computation.

Definition: library.h:313

cutlass::library::NumericTypeID::kF16

cutlass::library::NumericTypeID::kF32

cutlass::library::GemmBatchedConfiguration::ldc

int64_t ldc

Leading dimension of C matrix.

Definition: library.h:520

cutlass::library::TensorDescription

Structure describing the properties of a tensor.

Definition: library.h:327

cutlass::library::GemmConfiguration::ldb

int64_t ldb

Leading dimension of B matrix.

Definition: library.h:464

cutlass::library::GemmPlanarComplexBatchedConfiguration::problem_size

gemm::GemmCoord problem_size

Definition: library.h:607

cutlass::library::is_unsigned_integer

bool is_unsigned_integer(NumericTypeID type)

returns true if numeric type is an unsigned integer

cutlass::library::GemmArrayArguments

Arguments for GEMM - used by all the GEMM operations.

Definition: library.h:564

cutlass::library::NumericTypeID::kCS32

cutlass::library::LayoutTypeID::kUnknown

cutlass::library::OpcodeClassID::kSimt

cutlass::library::OperationKind

OperationKind

Enumeration indicating the kind of operation.

Definition: library.h:117

cutlass::library::GemmArguments::alpha

void const * alpha

Host or device pointer to alpha scalar.

Definition: library.h:492

cutlass::library::SplitKMode::kParallelSerial

cutlass::library::OperationDescription::kind

OperationKind kind

Kind of operation.

Definition: library.h:310

cutlass::library::TileDescription::threadblock_shape

cutlass::gemm::GemmCoord threadblock_shape

Describes the shape of a threadblock (in elements)

Definition: library.h:266

cutlass::library::GemmPlanarComplexBatchedConfiguration::ldd

int64_t ldd

Definition: library.h:612

cutlass::library::NumericTypeID::kVoid

cutlass::library::ScalarPointerMode::kDevice

cutlass::library::NumericTypeID::kF64

cutlass::library::GemmPlanarComplexConfiguration::imag_stride_C

int64_t imag_stride_C

Definition: library.h:592

cutlass::library::TileDescription::math_instruction

MathInstructionDescription math_instruction

Core math instruction.

Definition: library.h:275

cutlass.h

Basic include for CUTLASS.

cutlass::library::SplitKMode

SplitKMode

Describes how reductions are performed across threadblocks.

Definition: library.h:130

cutlass::library::NumericTypeID::kU64

cutlass::library::NumericTypeID::kCS4

cutlass::Status

Status

Status code returned by CUTLASS operations.

Definition: cutlass.h:39

cutlass::library::GemmBatchedConfiguration::batch_stride_D

int64_t batch_stride_D

Stride between instances of the D matrix in memory.

Definition: library.h:535

cutlass::library::LayoutTypeID::kRowMajorInterleavedK4

cutlass::library::TensorDescription::TensorDescription

TensorDescription(NumericTypeID element=NumericTypeID::kInvalid, LayoutTypeID layout=LayoutTypeID::kInvalid, int alignment=1, int log_extent_range=24, int log_stride_range=24)

Definition: library.h:347


Generated by 1.8.11