Back to Cutlass

CUTLASS: mma_sm75.h Source File

docs/mma__sm75_8h_source.html

4.4.2175.3 KB
Original Source
<!-- do not remove this div, it is closed by doxygen! -->

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

<!-- end header part --><!-- Generated by Doxygen 1.8.11 --> <input type="text" id="MSearchField" value="Search" accesskey="S" onfocus="searchBox.OnSearchFieldFocus(true)" onblur="searchBox.OnSearchFieldFocus(false)" onkeyup="searchBox.OnSearchFieldChange(event)"> <!-- window showing the filter options --> <!-- iframe showing the search results (closed by default) --> <!-- top -->

mma_sm75.h

<!--header-->

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include <assert.h>

32

33 #include "cutlass/arch/wmma.h"

34

35 #if defined(CUTLASS_ARCH_WMMA_ENABLED)

36 // CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply.

37 #include <mma.h>

38 #include "cutlass/wmma_array.h"

39 #endif

40

41 // CUTLASS includes

42 #include "cutlass/arch/mma.h"

43 #include "cutlass/layout/matrix.h"

44 #include "cutlass/numeric_types.h"

45

47

48 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))

49

50 #define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1

51

52 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))

53 #define CUTLASS_ARCH_MMA_SM75_ENABLED

54 #endif

55 #endif

56

58

59 namespace cutlass {

60 namespace arch {

61

63 //

64 // Matrix Multiply 1688 - FP16 accumulation

65 //

67

69 template <>

70 struct Mma<

71 gemm::GemmShape<16, 8, 8>,

72 32,

73half_t,

74layout::RowMajor,

75half_t,

76layout::ColumnMajor,

77half_t,

78layout::RowMajor,

79 OpMultiplyAdd> {

80

81using Shape = gemm::GemmShape<16, 8, 8>;

82

83using ElementA = half_t;

84using LayoutA = layout::RowMajor;

85using FragmentA = Array<half_t, 4>;

86

87using ElementB = half_t;

88using LayoutB = layout::ColumnMajor;

89using FragmentB = Array<half_t, 2>;

90

91using ElementC = half_t;

92using LayoutC = layout::RowMajor;

93using FragmentC = Array<half_t, 4>;

94

95using Operator = OpMultiplyAdd;

96

97CUTLASS_HOST_DEVICE

98void operator()(

99FragmentC &d,

100FragmentA const &a,

101FragmentB const &b,

102FragmentC const &c

103 ) const {

104

105 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

106

107unsigned const *A = reinterpret_cast<unsigned const *>(&a);

108unsigned const *B = reinterpret_cast<unsigned const *>(&b);

109unsigned const *C = reinterpret_cast<unsigned const *>(&c);

110unsigned *D = reinterpret_cast<unsigned *>(&d);

111

112asm volatile(

113"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"

114 : "=r"(D[0]), "=r"(D[1])

115 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));

116

117 #else

118 assert(0);

119 #endif

120 }

121 };

122

124 //

125 // Matrix Multiply 1688 - FP32 accumulation

126 //

128

130 template <>

131 struct Mma<

132 gemm::GemmShape<16, 8, 8>,

133 32,

134half_t,

135layout::RowMajor,

136half_t,

137layout::ColumnMajor,

138 float,

139layout::RowMajor,

140 OpMultiplyAdd> {

141

142using Shape = gemm::GemmShape<16, 8, 8>;

143

144using ElementA = half_t;

145using LayoutA = layout::RowMajor;

146using FragmentA = Array<half_t, 4>;

147

148using ElementB = half_t;

149using LayoutB = layout::ColumnMajor;

150using FragmentB = Array<half_t, 2>;

151

152using ElementC = float;

153using LayoutC = layout::RowMajor;

154using FragmentC = Array<float, 4>;

155

156using Operator = OpMultiplyAdd;

157

159CUTLASS_HOST_DEVICE

160void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,

161FragmentC const &c) const {

162

163 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

164

165unsigned const *A = reinterpret_cast<unsigned const *>(&a);

166unsigned const *B = reinterpret_cast<unsigned const *>(&b);

167float const *C = reinterpret_cast<float const *>(&c);

168float *D = reinterpret_cast<float *>(&d);

169

170asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"

171 : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])

172 :

173"r"(A[0]), "r"(A[1]),

174"r"(B[0]),

175"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])

176 );

177

178 #else

179 assert(0);

180 #endif

181 }

182 };

183

185 //

186 // Integer matrix multiply .8816 (8b)

187 //

189

191 template <>

192 struct Mma<

193 gemm::GemmShape<8, 8, 16>,

194 32,

195 int8_t,

196layout::RowMajor,

197 int8_t,

198layout::ColumnMajor,

199 int,

200layout::RowMajor,

201 OpMultiplyAdd> {

202

203using Shape = gemm::GemmShape<8, 8, 16>;

204

205using ElementA = int8_t;

206using LayoutA = layout::RowMajor;

207using FragmentA = Array<int8_t, 4>;

208

209using ElementB = int8_t;

210using LayoutB = layout::ColumnMajor;

211using FragmentB = Array<int8_t, 4>;

212

213using ElementC = int;

214using LayoutC = layout::RowMajor;

215using FragmentC = Array<int, 2>;

216

217using Operator = OpMultiplyAdd;

218

220CUTLASS_HOST_DEVICE

221void operator()(

222FragmentC &d,

223FragmentA const &a,

224FragmentB const &b,

225FragmentC const &c

226 ) const {

227

228 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

229

230unsigned const & A = reinterpret_cast<unsigned const &>(a);

231unsigned const & B = reinterpret_cast<unsigned const &>(b);

232

233int const *C = reinterpret_cast<int const *>(&c);

234int *D = reinterpret_cast<int *>(&d);

235

236asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

237 : "=r"(D[0]), "=r"(D[1])

238 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

239

240 #else

241 assert(0);

242 #endif

243 }

244 };

245

247 template <>

248 struct Mma<

249 gemm::GemmShape<8, 8, 16>,

250 32,

251 uint8_t,

252layout::RowMajor,

253 int8_t,

254layout::ColumnMajor,

255 int,

256layout::RowMajor,

257 OpMultiplyAdd> {

258

259using Shape = gemm::GemmShape<8, 8, 16>;

260

261using ElementA = uint8_t;

262using LayoutA = layout::RowMajor;

263using FragmentA = Array<uint8_t, 4>;

264

265using ElementB = int8_t;

266using LayoutB = layout::ColumnMajor;

267using FragmentB = Array<int8_t, 4>;

268

269using ElementC = int;

270using LayoutC = layout::RowMajor;

271using FragmentC = Array<int, 2>;

272

273using Operator = OpMultiplyAdd;

274

276CUTLASS_HOST_DEVICE

277void operator()(

278FragmentC &d,

279FragmentA const &a,

280FragmentB const &b,

281FragmentC const &c

282 ) const {

283

284 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

285

286unsigned const & A = reinterpret_cast<unsigned const &>(a);

287unsigned const & B = reinterpret_cast<unsigned const &>(b);

288

289int const *C = reinterpret_cast<int const *>(&c);

290int *D = reinterpret_cast<int *>(&d);

291

292asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

293 : "=r"(D[0]), "=r"(D[1])

294 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

295

296 #else

297 assert(0);

298 #endif

299 }

300 };

301

303 template <>

304 struct Mma<

305 gemm::GemmShape<8, 8, 16>,

306 32,

307 int8_t,

308layout::RowMajor,

309 uint8_t,

310layout::ColumnMajor,

311 int,

312layout::RowMajor,

313 OpMultiplyAdd> {

314

315using Shape = gemm::GemmShape<8, 8, 16>;

316

317using ElementA = int8_t;

318using LayoutA = layout::RowMajor;

319using FragmentA = Array<int8_t, 4>;

320

321using ElementB = uint8_t;

322using LayoutB = layout::ColumnMajor;

323using FragmentB = Array<uint8_t, 4>;

324

325using ElementC = int;

326using LayoutC = layout::RowMajor;

327using FragmentC = Array<int, 2>;

328

329using Operator = OpMultiplyAdd;

330

332CUTLASS_HOST_DEVICE

333void operator()(

334FragmentC &d,

335FragmentA const &a,

336FragmentB const &b,

337FragmentC const &c

338 ) const {

339

340 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

341

342unsigned const & A = reinterpret_cast<unsigned const &>(a);

343unsigned const & B = reinterpret_cast<unsigned const &>(b);

344

345int const *C = reinterpret_cast<int const *>(&c);

346int *D = reinterpret_cast<int *>(&d);

347

348asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

349 : "=r"(D[0]), "=r"(D[1])

350 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

351

352

353 #else

354 assert(0);

355 #endif

356 }

357 };

358

360 template <>

361 struct Mma<

362 gemm::GemmShape<8, 8, 16>,

363 32,

364 uint8_t,

365layout::RowMajor,

366 uint8_t,

367layout::ColumnMajor,

368 int,

369layout::RowMajor,

370 OpMultiplyAdd> {

371

372using Shape = gemm::GemmShape<8, 8, 16>;

373

374using ElementA = uint8_t;

375using LayoutA = layout::RowMajor;

376using FragmentA = Array<uint8_t, 4>;

377

378using ElementB = uint8_t;

379using LayoutB = layout::ColumnMajor;

380using FragmentB = Array<uint8_t, 4>;

381

382using ElementC = int;

383using LayoutC = layout::RowMajor;

384using FragmentC = Array<int, 2>;

385

386using Operator = OpMultiplyAdd;

387

389CUTLASS_HOST_DEVICE

390void operator()(

391FragmentC &d,

392FragmentA const &a,

393FragmentB const &b,

394FragmentC const &c

395 ) const {

396

397 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

398

399unsigned const & A = reinterpret_cast<unsigned const &>(a);

400unsigned const & B = reinterpret_cast<unsigned const &>(b);

401

402int const *C = reinterpret_cast<int const *>(&c);

403int *D = reinterpret_cast<int *>(&d);

404

405asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

406 : "=r"(D[0]), "=r"(D[1])

407 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

408

409 #else

410 assert(0);

411 #endif

412 }

413 };

414

416 //

417 // Integer matrix multiply (8b) with SATURATE

418 //

420

422 template <>

423 struct Mma<

424 gemm::GemmShape<8,8,16>,

425 32,

426 int8_t,

427layout::RowMajor,

428 int8_t,

429layout::ColumnMajor,

430 int,

431layout::RowMajor,

432 OpMultiplyAddSaturate> {

433

434using Shape = gemm::GemmShape<8,8,16>;

435

436using ElementA = int8_t;

437using LayoutA = layout::RowMajor;

438using FragmentA = Array<int8_t, 4>;

439

440using ElementB = int8_t;

441using LayoutB = layout::ColumnMajor;

442using FragmentB = Array<int8_t, 4>;

443

444using ElementC = int;

445using LayoutC = layout::RowMajor;

446using FragmentC = Array<int, 2>;

447

448using Operator = OpMultiplyAddSaturate;

449

451CUTLASS_HOST_DEVICE

452void operator()(

453FragmentC &d,

454FragmentA const &a,

455FragmentB const &b,

456FragmentC const &c

457 ) const {

458

459 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

460

461unsigned const & A = reinterpret_cast<unsigned const &>(a);

462unsigned const & B = reinterpret_cast<unsigned const &>(b);

463

464int const *C = reinterpret_cast<int const *>(&c);

465int *D = reinterpret_cast<int *>(&d);

466

467asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

468 : "=r"(D[0]), "=r"(D[1])

469 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

470

471 #else

472 assert(0);

473 #endif

474 }

475 };

476

478 template <>

479 struct Mma<

480 gemm::GemmShape<8,8,16>,

481 32,

482 uint8_t,

483layout::RowMajor,

484 int8_t,

485layout::ColumnMajor,

486 int,

487layout::RowMajor,

488 OpMultiplyAddSaturate> {

489

490using Shape = gemm::GemmShape<8,8,16>;

491

492using ElementA = uint8_t;

493using LayoutA = layout::RowMajor;

494using FragmentA = Array<uint8_t, 4>;

495

496using ElementB = int8_t;

497using LayoutB = layout::ColumnMajor;

498using FragmentB = Array<int8_t, 4>;

499

500using ElementC = int;

501using LayoutC = layout::RowMajor;

502using FragmentC = Array<int, 2>;

503

504using Operator = OpMultiplyAddSaturate;

505

507CUTLASS_HOST_DEVICE

508void operator()(

509FragmentC &d,

510FragmentA const &a,

511FragmentB const &b,

512FragmentC const &c

513 ) const {

514

515 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

516

517unsigned const & A = reinterpret_cast<unsigned const &>(a);

518unsigned const & B = reinterpret_cast<unsigned const &>(b);

519

520int const *C = reinterpret_cast<int const *>(&c);

521int *D = reinterpret_cast<int *>(&d);

522

523asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

524 : "=r"(D[0]), "=r"(D[1])

525 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

526

527 #else

528 assert(0);

529 #endif

530 }

531 };

532

534 template <>

535 struct Mma<

536 gemm::GemmShape<8,8,16>,

537 32,

538 int8_t,

539layout::RowMajor,

540 uint8_t,

541layout::ColumnMajor,

542 int,

543layout::RowMajor,

544 OpMultiplyAddSaturate> {

545

546using Shape = gemm::GemmShape<8,8,16>;

547

548using ElementA = int8_t;

549using LayoutA = layout::RowMajor;

550using FragmentA = Array<int8_t, 4>;

551

552using ElementB = uint8_t;

553using LayoutB = layout::ColumnMajor;

554using FragmentB = Array<uint8_t, 4>;

555

556using ElementC = int;

557using LayoutC = layout::RowMajor;

558using FragmentC = Array<int, 2>;

559

560using Operator = OpMultiplyAddSaturate;

561

563CUTLASS_HOST_DEVICE

564void operator()(

565FragmentC &d,

566FragmentA const &a,

567FragmentB const &b,

568FragmentC const &c

569 ) const {

570

571 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

572

573unsigned const & A = reinterpret_cast<unsigned const &>(a);

574unsigned const & B = reinterpret_cast<unsigned const &>(b);

575

576int const *C = reinterpret_cast<int const *>(&c);

577int *D = reinterpret_cast<int *>(&d);

578

579asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

580 : "=r"(D[0]), "=r"(D[1])

581 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

582

583 #else

584 assert(0);

585 #endif

586 }

587 };

588

590 template <>

591 struct Mma<

592 gemm::GemmShape<8,8,16>,

593 32,

594 uint8_t,

595layout::RowMajor,

596 uint8_t,

597layout::ColumnMajor,

598 int,

599layout::RowMajor,

600 OpMultiplyAddSaturate> {

601

602using Shape = gemm::GemmShape<8,8,16>;

603

604using ElementA = uint8_t;

605using LayoutA = layout::RowMajor;

606using FragmentA = Array<uint8_t, 4>;

607

608using ElementB = uint8_t;

609using LayoutB = layout::ColumnMajor;

610using FragmentB = Array<uint8_t, 4>;

611

612using ElementC = int;

613using LayoutC = layout::RowMajor;

614using FragmentC = Array<int, 2>;

615

616using Operator = OpMultiplyAddSaturate;

617

619CUTLASS_HOST_DEVICE

620void operator()(

621FragmentC &d,

622FragmentA const &a,

623FragmentB const &b,

624FragmentC const &c

625 ) const {

626

627 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

628

629unsigned const & A = reinterpret_cast<unsigned const &>(a);

630unsigned const & B = reinterpret_cast<unsigned const &>(b);

631

632int const *C = reinterpret_cast<int const *>(&c);

633int *D = reinterpret_cast<int *>(&d);

634

635asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

636 : "=r"(D[0]), "=r"(D[1])

637 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

638

639 #else

640 assert(0);

641 #endif

642 }

643 };

644

646 //

647 // Integer matrix multiply (4b)

648 //

650

652 template <>

653 struct Mma<

654 gemm::GemmShape<8,8,32>,

655 32,

656int4b_t,

657layout::RowMajor,

658int4b_t,

659layout::ColumnMajor,

660 int,

661layout::RowMajor,

662 OpMultiplyAdd> {

663

664using Shape = gemm::GemmShape<8,8,32>;

665

666using ElementA = int4b_t;

667using LayoutA = layout::RowMajor;

668using FragmentA = Array<int4b_t, 8>;

669

670using ElementB = int4b_t;

671using LayoutB = layout::ColumnMajor;

672using FragmentB = Array<int4b_t, 8>;

673

674using ElementC = int;

675using LayoutC = layout::RowMajor;

676using FragmentC = Array<int, 2>;

677

678using Operator = OpMultiplyAdd;

679

681CUTLASS_HOST_DEVICE

682void operator()(

683FragmentC &d,

684FragmentA const &a,

685FragmentB const &b,

686FragmentC const &c

687 ) const {

688

689 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

690

691unsigned const & A = reinterpret_cast<unsigned const &>(a);

692unsigned const & B = reinterpret_cast<unsigned const &>(b);

693

694int const *C = reinterpret_cast<int const *>(&c);

695int *D = reinterpret_cast<int *>(&d);

696

697asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

698 : "=r"(D[0]), "=r"(D[1])

699 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

700

701 #else

702 assert(0);

703 #endif

704 }

705 };

706

708 template <>

709 struct Mma<

710 gemm::GemmShape<8,8,32>,

711 32,

712uint4b_t,

713layout::RowMajor,

714int4b_t,

715layout::ColumnMajor,

716 int,

717layout::RowMajor,

718 OpMultiplyAdd> {

719

720using Shape = gemm::GemmShape<8,8,32>;

721

722using ElementA = uint4b_t;

723using LayoutA = layout::RowMajor;

724using FragmentA = Array<uint4b_t, 8>;

725

726using ElementB = int4b_t;

727using LayoutB = layout::ColumnMajor;

728using FragmentB = Array<int4b_t, 8>;

729

730using ElementC = int;

731using LayoutC = layout::RowMajor;

732using FragmentC = Array<int, 2>;

733

734using Operator = OpMultiplyAdd;

735

737CUTLASS_HOST_DEVICE

738void operator()(

739FragmentC &d,

740FragmentA const &a,

741FragmentB const &b,

742FragmentC const &c

743 ) const {

744

745 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

746

747unsigned const & A = reinterpret_cast<unsigned const &>(a);

748unsigned const & B = reinterpret_cast<unsigned const &>(b);

749

750int const *C = reinterpret_cast<int const *>(&c);

751int *D = reinterpret_cast<int *>(&d);

752

753asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

754 : "=r"(D[0]), "=r"(D[1])

755 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

756

757 #else

758 assert(0);

759 #endif

760 }

761 };

762

764 template <>

765 struct Mma<

766 gemm::GemmShape<8,8,32>,

767 32,

768int4b_t,

769layout::RowMajor,

770uint4b_t,

771layout::ColumnMajor,

772 int,

773layout::RowMajor,

774 OpMultiplyAdd> {

775

776using Shape = gemm::GemmShape<8,8,32>;

777

778using ElementA = int4b_t;

779using LayoutA = layout::RowMajor;

780using FragmentA = Array<int4b_t, 8>;

781

782using ElementB = uint4b_t;

783using LayoutB = layout::ColumnMajor;

784using FragmentB = Array<uint4b_t, 8>;

785

786using ElementC = int;

787using LayoutC = layout::RowMajor;

788using FragmentC = Array<int, 2>;

789

790using Operator = OpMultiplyAdd;

791

793CUTLASS_HOST_DEVICE

794void operator()(

795FragmentC &d,

796FragmentA const &a,

797FragmentB const &b,

798FragmentC const &c

799 ) const {

800

801 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

802

803unsigned const & A = reinterpret_cast<unsigned const &>(a);

804unsigned const & B = reinterpret_cast<unsigned const &>(b);

805

806int const *C = reinterpret_cast<int const *>(&c);

807int *D = reinterpret_cast<int *>(&d);

808

809asm volatile("_mma.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

810 : "=r"(D[0]), "=r"(D[1])

811 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

812

813 #else

814 assert(0);

815 #endif

816 }

817 };

818

820 template <>

821 struct Mma<

822 gemm::GemmShape<8,8,32>,

823 32,

824uint4b_t,

825layout::RowMajor,

826uint4b_t,

827layout::ColumnMajor,

828 int,

829layout::RowMajor,

830 OpMultiplyAdd> {

831

832using Shape = gemm::GemmShape<8,8,32>;

833

834using ElementA = uint4b_t;

835using LayoutA = layout::RowMajor;

836using FragmentA = Array<uint4b_t, 8>;

837

838using ElementB = uint4b_t;

839using LayoutB = layout::ColumnMajor;

840using FragmentB = Array<uint4b_t, 8>;

841

842using ElementC = int;

843using LayoutC = layout::RowMajor;

844using FragmentC = Array<int, 2>;

845

846using Operator = OpMultiplyAdd;

847

849CUTLASS_HOST_DEVICE

850void operator()(

851FragmentC &d,

852FragmentA const &a,

853FragmentB const &b,

854FragmentC const &c

855 ) const {

856

857 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

858

859unsigned const & A = reinterpret_cast<unsigned const &>(a);

860unsigned const & B = reinterpret_cast<unsigned const &>(b);

861

862int const *C = reinterpret_cast<int const *>(&c);

863int *D = reinterpret_cast<int *>(&d);

864

865asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

866 : "=r"(D[0]), "=r"(D[1])

867 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

868

869 #else

870 assert(0);

871 #endif

872 }

873 };

874

876 //

877 // Integer matrix multiply (4b) - SATURATE

878 //

880

882 template <>

883 struct Mma<

884 gemm::GemmShape<8,8,32>,

885 32,

886int4b_t,

887layout::RowMajor,

888int4b_t,

889layout::ColumnMajor,

890 int,

891layout::RowMajor,

892 OpMultiplyAddSaturate> {

893

894using Shape = gemm::GemmShape<8,8,32>;

895

896using ElementA = int4b_t;

897using LayoutA = layout::RowMajor;

898using FragmentA = Array<int4b_t, 8>;

899

900using ElementB = int4b_t;

901using LayoutB = layout::ColumnMajor;

902using FragmentB = Array<int4b_t, 8>;

903

904using ElementC = int;

905using LayoutC = layout::RowMajor;

906using FragmentC = Array<int, 2>;

907

908using Operator = OpMultiplyAddSaturate;

909

911CUTLASS_HOST_DEVICE

912void operator()(

913FragmentC &d,

914FragmentA const &a,

915FragmentB const &b,

916FragmentC const &c

917 ) const {

918

919 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

920

921unsigned const & A = reinterpret_cast<unsigned const &>(a);

922unsigned const & B = reinterpret_cast<unsigned const &>(b);

923

924int const *C = reinterpret_cast<int const *>(&c);

925int *D = reinterpret_cast<int *>(&d);

926

927asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

928 : "=r"(D[0]), "=r"(D[1])

929 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

930

931 #else

932 assert(0);

933 #endif

934 }

935 };

936

938 template <>

939 struct Mma<

940 gemm::GemmShape<8,8,32>,

941 32,

942uint4b_t,

943layout::RowMajor,

944int4b_t,

945layout::ColumnMajor,

946 int,

947layout::RowMajor,

948 OpMultiplyAddSaturate> {

949

950using Shape = gemm::GemmShape<8,8,32>;

951

952using ElementA = uint4b_t;

953using LayoutA = layout::RowMajor;

954using FragmentA = Array<uint4b_t, 8>;

955

956using ElementB = int4b_t;

957using LayoutB = layout::ColumnMajor;

958using FragmentB = Array<int4b_t, 8>;

959

960using ElementC = int;

961using LayoutC = layout::RowMajor;

962using FragmentC = Array<int, 2>;

963

964using Operator = OpMultiplyAddSaturate;

965

967CUTLASS_HOST_DEVICE

968void operator()(

969FragmentC &d,

970FragmentA const &a,

971FragmentB const &b,

972FragmentC const &c

973 ) const {

974

975 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

976

977unsigned const & A = reinterpret_cast<unsigned const &>(a);

978unsigned const & B = reinterpret_cast<unsigned const &>(b);

979

980int const *C = reinterpret_cast<int const *>(&c);

981int *D = reinterpret_cast<int *>(&d);

982

983asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

984 : "=r"(D[0]), "=r"(D[1])

985 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

986

987 #else

988 assert(0);

989 #endif

990 }

991 };

992

994 template <>

995 struct Mma<

996 gemm::GemmShape<8,8,32>,

997 32,

998int4b_t,

999layout::RowMajor,

1000uint4b_t,

1001layout::ColumnMajor,

1002 int,

1003layout::RowMajor,

1004 OpMultiplyAddSaturate> {

1005

1006using Shape = gemm::GemmShape<8,8,32>;

1007

1008using ElementA = int4b_t;

1009using LayoutA = layout::RowMajor;

1010using FragmentA = Array<int4b_t, 8>;

1011

1012using ElementB = uint4b_t;

1013using LayoutB = layout::ColumnMajor;

1014using FragmentB = Array<uint4b_t, 8>;

1015

1016using ElementC = int;

1017using LayoutC = layout::RowMajor;

1018using FragmentC = Array<int, 2>;

1019

1020using Operator = OpMultiplyAddSaturate;

1021

1023CUTLASS_HOST_DEVICE

1024void operator()(

1025FragmentC &d,

1026FragmentA const &a,

1027FragmentB const &b,

1028FragmentC const &c

1029 ) const {

1030

1031 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

1032

1033unsigned const & A = reinterpret_cast<unsigned const &>(a);

1034unsigned const & B = reinterpret_cast<unsigned const &>(b);

1035

1036int const *C = reinterpret_cast<int const *>(&c);

1037int *D = reinterpret_cast<int *>(&d);

1038

1039asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

1040 : "=r"(D[0]), "=r"(D[1])

1041 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

1042

1043 #else

1044 assert(0);

1045 #endif

1046 }

1047 };

1048

1050 template <>

1051 struct Mma<

1052 gemm::GemmShape<8,8,32>,

1053 32,

1054uint4b_t,

1055layout::RowMajor,

1056uint4b_t,

1057layout::ColumnMajor,

1058 int,

1059layout::RowMajor,

1060 OpMultiplyAddSaturate> {

1061

1062using Shape = gemm::GemmShape<8,8,32>;

1063

1064using ElementA = uint4b_t;

1065using LayoutA = layout::RowMajor;

1066using FragmentA = Array<uint4b_t, 8>;

1067

1068using ElementB = uint4b_t;

1069using LayoutB = layout::ColumnMajor;

1070using FragmentB = Array<uint4b_t, 8>;

1071

1072using ElementC = int;

1073using LayoutC = layout::RowMajor;

1074using FragmentC = Array<int, 2>;

1075

1076using Operator = OpMultiplyAddSaturate;

1077

1079CUTLASS_HOST_DEVICE

1080void operator()(

1081FragmentC &d,

1082FragmentA const &a,

1083FragmentB const &b,

1084FragmentC const &c

1085 ) const {

1086

1087 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

1088

1089unsigned const & A = reinterpret_cast<unsigned const &>(a);

1090unsigned const & B = reinterpret_cast<unsigned const &>(b);

1091

1092int const *C = reinterpret_cast<int const *>(&c);

1093int *D = reinterpret_cast<int *>(&d);

1094

1095asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"

1096 : "=r"(D[0]), "=r"(D[1])

1097 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));

1098

1099 #else

1100 assert(0);

1101 #endif

1102 }

1103 };

1104

1106 //

1107 // b1 ^ b1 + s32 => s32

1108 //

1110

1112 template <>

1113 struct Mma<

1114 gemm::GemmShape<8,8,128>,

1115 32,

1116uint1b_t,

1117layout::RowMajor,

1118uint1b_t,

1119layout::ColumnMajor,

1120 int,

1121layout::RowMajor,

1122 OpXorPopc> {

1123

1124using Shape = gemm::GemmShape<8,8,128>;

1125

1126using ElementA = uint1b_t;

1127using LayoutA = layout::RowMajor;

1128using FragmentA = Array<uint1b_t, 32>;

1129

1130using ElementB = uint1b_t;

1131using LayoutB = layout::ColumnMajor;

1132using FragmentB = Array<uint1b_t, 32>;

1133

1134using ElementC = int;

1135using LayoutC = layout::RowMajor;

1136using FragmentC = Array<int, 2>;

1137

1138using Operator = OpXorPopc;

1139

1141CUTLASS_HOST_DEVICE

1142void operator()(

1143FragmentC &d,

1144FragmentA const &a,

1145FragmentB const &b,

1146FragmentC const &c

1147 ) const {

1148

1149 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)

1150

1151 #if defined(CUTLASS_ARCH_WMMA_ENABLED)

1152using WmmaFragmentA = nvcuda::wmma::fragment<

1153 nvcuda::wmma::matrix_a,

1154 Shape::kM,

1155 Shape::kN,

1156 Shape::kK,

1157 nvcuda::wmma::experimental::precision::b1,

1158 nvcuda::wmma::row_major>;

1159

1160using WmmaFragmentB = nvcuda::wmma::fragment<

1161 nvcuda::wmma::matrix_b,

1162 Shape::kM,

1163 Shape::kN,

1164 Shape::kK,

1165 nvcuda::wmma::experimental::precision::b1,

1166 nvcuda::wmma::col_major>;

1167

1168using WmmaFragmentC = nvcuda::wmma::fragment<

1169 nvcuda::wmma::accumulator,

1170 Shape::kM,

1171 Shape::kN,

1172 Shape::kK,

1173int>;

1174

1175 WmmaFragmentA const & A = reinterpret_cast<WmmaFragmentA const &>(a);

1176 WmmaFragmentB const & B = reinterpret_cast<WmmaFragmentB const &>(b);

1177

1178 WmmaFragmentC const & C = reinterpret_cast<WmmaFragmentC const &>(c);

1179 WmmaFragmentC & D = reinterpret_cast<WmmaFragmentC &>(d);

1180

1181 nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,

1182 nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);

1183 #else

1184

1185 assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions.

1186

1187 #endif // defined(CUTLASS_ARCH_WMMA_ENABLED)

1188

1189 #else

1190 assert(0);

1191 #endif

1192

1193 }

1194 };

1195

1197

1198 } // namespace arch

1199 } // namespace cutlass

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:794

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:217

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementA

uint8_t ElementA

Definition: mma_sm75.h:492

cutlass::uint4b_t

integer_subbyte< 4, false > uint4b_t

4-bit Unsigned integer type

Definition: integer_subbyte.h:158

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:734

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentB

Array< uint4b_t, 8 > FragmentB

Definition: mma_sm75.h:1070

wmma_array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< int8_t, 4 > FragmentB

Definition: mma_sm75.h:211

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentB

Array< uint8_t, 4 > FragmentB

Definition: mma_sm75.h:610

cutlass

Definition: aligned_buffer.h:35

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementC

int ElementC

Definition: mma_sm75.h:269

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:846

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::Operator

OpMultiplyAddSaturate Operator

Definition: mma_sm75.h:616

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementC

int ElementC

Definition: mma_sm75.h:500

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:277

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:329

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:676

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< uint8_t, 4 > FragmentB

Definition: mma_sm75.h:323

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentB

Array< uint8_t, 4 > FragmentB

Definition: mma_sm75.h:554

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementC

int ElementC

Definition: mma_sm75.h:382

cutlass::uint1b_t

integer_subbyte< 1, false > uint1b_t

1-bit Unsigned integer type

Definition: integer_subbyte.h:152

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementC

int ElementC

Definition: mma_sm75.h:1016

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:446

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementB

uint8_t ElementB

Definition: mma_sm75.h:378

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementA

int8_t ElementA

Definition: mma_sm75.h:205

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementA

int8_t ElementA

Definition: mma_sm75.h:548

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< int8_t, 4 > FragmentB

Definition: mma_sm75.h:267

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:968

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:1024

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< int8_t, 4 > FragmentA

Definition: mma_sm75.h:207

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 2 > FragmentB

Definition: mma_sm75.h:150

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< int8_t, 4 > FragmentA

Definition: mma_sm75.h:319

cutlass::integer_subbyte

4-bit signed integer type

Definition: integer_subbyte.h:42

cutlass::half_t

IEEE half-precision floating-point type.

Definition: half.h:126

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementC

int ElementC

Definition: mma_sm75.h:730

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementC

int ElementC

Definition: mma_sm75.h:444

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:564

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementC

int ElementC

Definition: mma_sm75.h:904

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 128 >, 32, uint1b_t, layout::RowMajor, uint1b_t, layout::ColumnMajor, int, layout::RowMajor, OpXorPopc >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:1142

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< int4b_t, 8 > FragmentA

Definition: mma_sm75.h:780

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 128 >, 32, uint1b_t, layout::RowMajor, uint1b_t, layout::ColumnMajor, int, layout::RowMajor, OpXorPopc >::ElementC

int ElementC

Definition: mma_sm75.h:1134

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementC

int ElementC

Definition: mma_sm75.h:842

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::Operator

OpMultiplyAddSaturate Operator

Definition: mma_sm75.h:964

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< uint8_t, 4 > FragmentB

Definition: mma_sm75.h:380

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::Operator

OpMultiplyAddSaturate Operator

Definition: mma_sm75.h:560

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementB

int8_t ElementB

Definition: mma_sm75.h:496

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:912

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementB

uint8_t ElementB

Definition: mma_sm75.h:608

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:1018

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentB

Array< int4b_t, 8 > FragmentB

Definition: mma_sm75.h:958

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementC

int ElementC

Definition: mma_sm75.h:612

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementB

int8_t ElementB

Definition: mma_sm75.h:209

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:333

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementA

uint8_t ElementA

Definition: mma_sm75.h:261

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentB

Array< int8_t, 4 > FragmentB

Definition: mma_sm75.h:498

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC

float ElementC

Definition: mma_sm75.h:152

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementB

int8_t ElementB

Definition: mma_sm75.h:440

mma.h

Templates exposing architecture support for multiply-add operations.

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentA

Array< int8_t, 4 > FragmentA

Definition: mma_sm75.h:550

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementB

uint8_t ElementB

Definition: mma_sm75.h:321

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementA

uint8_t ElementA

Definition: mma_sm75.h:374

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:558

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentB

Array< uint4b_t, 8 > FragmentB

Definition: mma_sm75.h:1014

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentA

Array< int4b_t, 8 > FragmentA

Definition: mma_sm75.h:1010

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:156

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:95

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< uint4b_t, 8 > FragmentB

Definition: mma_sm75.h:840

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:508

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementC

int ElementC

Definition: mma_sm75.h:556

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementC

int ElementC

Definition: mma_sm75.h:674

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:273

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:327

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:384

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 128 >, 32, uint1b_t, layout::RowMajor, uint1b_t, layout::ColumnMajor, int, layout::RowMajor, OpXorPopc >::Operator

OpXorPopc Operator

Definition: mma_sm75.h:1138

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< float, 4 > FragmentC

Definition: mma_sm75.h:154

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentA

Array< uint8_t, 4 > FragmentA

Definition: mma_sm75.h:606

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:682

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementA

int8_t ElementA

Definition: mma_sm75.h:436

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:732

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< half_t, 4 > FragmentC

Definition: mma_sm75.h:93

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:215

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:221

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementB

uint8_t ElementB

Definition: mma_sm75.h:552

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentA

Array< int8_t, 4 > FragmentA

Definition: mma_sm75.h:438

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::Operator

OpMultiplyAddSaturate Operator

Definition: mma_sm75.h:1020

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::Operator

OpMultiplyAddSaturate Operator

Definition: mma_sm75.h:1076

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm75.h:146

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:1080

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< int4b_t, 8 > FragmentB

Definition: mma_sm75.h:728

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementB

int8_t ElementB

Definition: mma_sm75.h:265

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:502

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementA

int8_t ElementA

Definition: mma_sm75.h:317

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentA

Array< uint4b_t, 8 > FragmentA

Definition: mma_sm75.h:954

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 128 >, 32, uint1b_t, layout::RowMajor, uint1b_t, layout::ColumnMajor, int, layout::RowMajor, OpXorPopc >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:1136

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< int4b_t, 8 > FragmentA

Definition: mma_sm75.h:668

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementC

int ElementC

Definition: mma_sm75.h:1072

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::Operator

OpMultiplyAddSaturate Operator

Definition: mma_sm75.h:908

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< uint4b_t, 8 > FragmentA

Definition: mma_sm75.h:724

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 128 >, 32, uint1b_t, layout::RowMajor, uint1b_t, layout::ColumnMajor, int, layout::RowMajor, OpXorPopc >::FragmentA

Array< uint1b_t, 32 > FragmentA

Definition: mma_sm75.h:1128

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentA

Array< uint4b_t, 8 > FragmentA

Definition: mma_sm75.h:1066

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:962

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< uint8_t, 4 > FragmentA

Definition: mma_sm75.h:376

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< uint4b_t, 8 > FragmentA

Definition: mma_sm75.h:836

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:844

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:906

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:678

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:620

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:738

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:160

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementA

uint8_t ElementA

Definition: mma_sm75.h:604

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementC

int ElementC

Definition: mma_sm75.h:325

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:850

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:452

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::Operator

OpMultiplyAddSaturate Operator

Definition: mma_sm75.h:448

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentB

Array< int8_t, 4 > FragmentB

Definition: mma_sm75.h:442

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Computes multiply-add.

Definition: mma_sm75.h:390

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< uint8_t, 4 > FragmentA

Definition: mma_sm75.h:263

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const

Definition: mma_sm75.h:98

cutlass::arch::Mma

Matrix multiply-add operation.

Definition: arch/mma.h:92

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentA

Array< int4b_t, 8 > FragmentA

Definition: mma_sm75.h:898

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentB

Array< int4b_t, 8 > FragmentB

Definition: mma_sm75.h:902

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentA

Array< uint8_t, 4 > FragmentA

Definition: mma_sm75.h:494

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:386

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 128 >, 32, uint1b_t, layout::RowMajor, uint1b_t, layout::ColumnMajor, int, layout::RowMajor, OpXorPopc >::FragmentB

Array< uint1b_t, 32 > FragmentB

Definition: mma_sm75.h:1132

wmma.h

Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 2 > FragmentB

Definition: mma_sm75.h:89

cutlass::arch::Mma< gemm::GemmShape< 16, 8, 8 >, 32, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm75.h:85

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:1074

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, uint8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:614

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< uint4b_t, 8 > FragmentB

Definition: mma_sm75.h:784

cutlass::int4b_t

integer_subbyte< 4, true > int4b_t

4-bit Integer type

Definition: integer_subbyte.h:155

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::Operator

OpMultiplyAddSaturate Operator

Definition: mma_sm75.h:504

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< int4b_t, 8 > FragmentB

Definition: mma_sm75.h:672

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, uint8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:271

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementC

int ElementC

Definition: mma_sm75.h:786

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< int, 2 > FragmentC

Definition: mma_sm75.h:788

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, int4b_t, layout::RowMajor, uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm75.h:790

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 16 >, 32, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAdd >::ElementC

int ElementC

Definition: mma_sm75.h:213

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 32 >, 32, uint4b_t, layout::RowMajor, int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate >::ElementC

int ElementC

Definition: mma_sm75.h:960

<!-- fragment --> <!-- contents --><!-- start footer part -->
<address class="footer"><small> Generated by 1.8.11 </small></address>