Back to Cutlass

CUTLASS: mma_sm60.h Source File

docs/gemm_2thread_2mma__sm60_8h_source.html

4.4.292.7 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

gemm/thread/mma_sm60.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/tensor_ref.h"

33 #include "cutlass/layout/matrix.h"

34 #include "cutlass/gemm/gemm.h"

35 #include "cutlass/gemm/thread/mma.h"

36 #include "cutlass/functional.h"

37 #include "cutlass/reduction/thread/reduce.h"

38

40

41 namespace cutlass {

42 namespace gemm {

43 namespace thread {

44

46

47 namespace detail {

48

50 template <

52typename Shape,

53

55typename LayoutA,

56

58typename LayoutB,

59

61typename LayoutC,

62

64bool

65 >

66 struct Mma_HFMA2;

67

68

70 // Specialization for NNN //

72

73 template <typename Shape>

74 struct Mma_HFMA2 <

75 Shape,

76 layout::ColumnMajor,

77layout::ColumnMajor,

78layout::ColumnMajor,

79 true

80 > {

81

82static_assert(

83 !(Shape::kM % 2),

84"Mma_HFMA2 requires the M dimension to be divisible by 2."

85 );

86

88using FragmentA = Array<half_t, Shape::kMK>;

89

91using FragmentB = Array<half_t, Shape::kKN>;

92

94using FragmentC = Array<half_t, Shape::kMN>;

95

96//

97// Methods

98//

99

101CUTLASS_HOST_DEVICE

102void operator()(

103FragmentC & D,

104FragmentA const & A,

105FragmentB const & B,

106FragmentC const & C) {

107

109 D = C;

110

112using Mma = arch::Mma<

113gemm::GemmShape<2,1,1>,

114 1,

115half_t,

116layout::ColumnMajor,

117 half_t,

118 layout::ColumnMajor,

119 half_t,

120 layout::ColumnMajor,

121 arch::OpMultiplyAdd>;

122

123 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);

124 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);

125 Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);

126

127Mma mma;

128

129CUTLASS_PRAGMA_UNROLL

130for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){

131

132CUTLASS_PRAGMA_UNROLL

133for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){

134

135CUTLASS_PRAGMA_UNROLL

136for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){

137

138 Array<half_t, 2> tmp;

139 Array<half_t, 2> *ptr_tmp = &tmp;

140 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];

141

142 mma(

143 tmp,

144 ptr_A[k*Shape::kM/2 + m],

145 ptr_B[n*Shape::kK + k],

146 tmp);

147

148 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];

149 }

150 }

151 }

152 }

153 };

154

156 // Specialization for NNT //

158

159 template <typename Shape>

160 struct Mma_HFMA2<

161 Shape,

162 layout::ColumnMajor,

163layout::ColumnMajor,

164layout::RowMajor,

165 true

166 > {

167

168static_assert(

169 !(Shape::kN % 2),

170"Mma_HFMA2 requires the N dimension to be divisible by 2."

171 );

172

174using FragmentA = Array<half_t, Shape::kMK>;

175

177using FragmentB = Array<half_t, Shape::kKN>;

178

180using FragmentC = Array<half_t, Shape::kMN>;

181

182//

183// Methods

184//

185

187CUTLASS_HOST_DEVICE

188void operator()(

189FragmentC & D,

190FragmentA const & A,

191FragmentB const & B,

192FragmentC const & C) {

193

195 D = C;

196

198using Mma = arch::Mma<

199gemm::GemmShape<1,2,1>,

200 1,

201half_t,

202layout::ColumnMajor,

203 half_t,

204 layout::ColumnMajor,

205 half_t,

206layout::RowMajor,

207 arch::OpMultiplyAdd>;

208

209 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);

210 Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);

211 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);

212

213Mma mma;

214

215CUTLASS_PRAGMA_UNROLL

216for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){

217

218CUTLASS_PRAGMA_UNROLL

219for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){

220

221CUTLASS_PRAGMA_UNROLL

222for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){

223

224 Array<half_t, 2> tmp;

225 Array<half_t, 2> *ptr_tmp = &tmp;

226 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];

227

228 Array<half_t, 2> tmp_B;

229 tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);

230 tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);

231

232 mma(

233 tmp,

234 ptr_A[k*Shape::kM + m],

235 tmp_B,

236 tmp);

237

238 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];

239 }

240 }

241 }

242 }

243 };

244

245

247 // Specialization for NTN //

249

250 template <typename Shape>

251 struct Mma_HFMA2 <

252 Shape,

253 layout::ColumnMajor,

254layout::RowMajor,

255layout::ColumnMajor,

256 true

257 > {

258

259static_assert(

260 !(Shape::kM % 2),

261"Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."

262 );

263

265using FragmentA = Array<half_t, Shape::kMK>;

266

268using FragmentB = Array<half_t, Shape::kKN>;

269

271using FragmentC = Array<half_t, Shape::kMN>;

272

273//

274// Methods

275//

276

278CUTLASS_HOST_DEVICE

279void operator()(

280FragmentC & D,

281FragmentA const & A,

282FragmentB const & B,

283FragmentC const & C) {

284

286 D = C;

287

288using Mma = arch::Mma<

289gemm::GemmShape<2,1,1>,

290 1,

291half_t,

292layout::ColumnMajor,

293 half_t,

294layout::RowMajor,

295 half_t,

296 layout::ColumnMajor,

297 arch::OpMultiplyAdd>;

298

299 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);

300 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);

301 Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);

302

303Mma mma;

304

305CUTLASS_PRAGMA_UNROLL

306for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {

307

308CUTLASS_PRAGMA_UNROLL

309for (int m = 0; m < Shape::kM / Mma::Shape::kM; ++m) {

310

311CUTLASS_PRAGMA_UNROLL

312for (int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) {

313

314 Array<half_t, 2> tmp;

315 Array<half_t, 2> *ptr_tmp = &tmp;

316

317 ptr_tmp[0] = ptr_D[m + n * Shape::kM/2];

318

319 mma(

320 tmp,

321 ptr_A[m + k * Shape::kM/2],

322 ptr_B[k * Shape::kN + n],

323 tmp);

324

325 ptr_D[m + n * Shape::kM/2] = ptr_tmp[0];

326 }

327 }

328 }

329 }

330 };

331

333 // Specialization for NTT //

335

336 template <typename Shape>

337 struct Mma_HFMA2<

338 Shape,

339 layout::ColumnMajor,

340layout::RowMajor,

341layout::RowMajor,

342 true

343 > {

344

345static_assert(

346 !(Shape::kN % 2),

347"Mma_HFMA2 requires the N dimension to be divisible by 2."

348 );

349

351using FragmentA = Array<half_t, Shape::kMK>;

352

354using FragmentB = Array<half_t, Shape::kKN>;

355

357using FragmentC = Array<half_t, Shape::kMN>;

358

359//

360// Methods

361//

362

364CUTLASS_HOST_DEVICE

365void operator()(

366FragmentC & D,

367FragmentA const & A,

368FragmentB const & B,

369FragmentC const & C) {

370

372 D = C;

373

375using Mma = arch::Mma<

376gemm::GemmShape<1,2,1>,

377 1,

378half_t,

379layout::ColumnMajor,

380 half_t,

381layout::RowMajor,

382 half_t,

383 layout::RowMajor,

384 arch::OpMultiplyAdd>;

385

386 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);

387 Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);

388 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);

389

390Mma mma;

391

392CUTLASS_PRAGMA_UNROLL

393for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){

394

395CUTLASS_PRAGMA_UNROLL

396for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){

397

398CUTLASS_PRAGMA_UNROLL

399for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){

400

401 Array<half_t, 2> tmp;

402 Array<half_t, 2> *ptr_tmp = &tmp;

403 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];

404

405 mma(

406 tmp,

407 ptr_A[k*Shape::kM + m],

408 ptr_B[k*Shape::kN/2 + n],

409 tmp);

410

411 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];

412 }

413 }

414 }

415 }

416 };

417

418

420 // Specialization for TNN //

422

423 template <typename Shape>

424 struct Mma_HFMA2 <

425 Shape,

426 layout::RowMajor,

427layout::ColumnMajor,

428layout::ColumnMajor,

429 true

430 > {

431

432static_assert(

433 !(Shape::kM % 2),

434"Mma_HFMA2 requires the M dimension to be divisible by 2."

435 );

436

438using FragmentA = Array<half_t, Shape::kMK>;

439

441using FragmentB = Array<half_t, Shape::kKN>;

442

444using FragmentC = Array<half_t, Shape::kMN>;

445

446//

447// Methods

448//

449

451CUTLASS_HOST_DEVICE

452void operator()(

453FragmentC & D,

454FragmentA const & A,

455FragmentB const & B,

456FragmentC const & C) {

457

459 D = C;

460

462using Mma = arch::Mma<

463gemm::GemmShape<2,1,1>,

464 1,

465half_t,

466layout::RowMajor,

467 half_t,

468layout::ColumnMajor,

469 half_t,

470 layout::ColumnMajor,

471 arch::OpMultiplyAdd>;

472

473 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);

474 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);

475 Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);

476

477Mma mma;

478

479CUTLASS_PRAGMA_UNROLL

480for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){

481

482CUTLASS_PRAGMA_UNROLL

483for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){

484

485CUTLASS_PRAGMA_UNROLL

486for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){

487

488 Array<half_t, 2> tmp;

489 Array<half_t, 2> *ptr_tmp = &tmp;

490 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];

491

492 Array<half_t, 2> tmp_A;

493 tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);

494 tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);

495

496 mma(

497 tmp,

498 tmp_A,

499 ptr_B[n*Shape::kK + k],

500 tmp);

501

502 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];

503 }

504 }

505 }

506 }

507 };

508

510 // Specialization for TNT //

512

513 template <typename Shape>

514 struct Mma_HFMA2 <

515 Shape,

516 layout::RowMajor,

517layout::ColumnMajor,

518layout::RowMajor,

519 true

520 > {

521

522static_assert(

523 !(Shape::kN % 2),

524"Mma_HFMA2 requires the N dimension to be divisible by 2."

525 );

526

528using FragmentA = Array<half_t, Shape::kMK>;

529

531using FragmentB = Array<half_t, Shape::kKN>;

532

534using FragmentC = Array<half_t, Shape::kMN>;

535

536//

537// Methods

538//

539

541CUTLASS_HOST_DEVICE

542void operator()(

543FragmentC & D,

544FragmentA const & A,

545FragmentB const & B,

546FragmentC const & C) {

547

549 D = C;

550

552using Mma = arch::Mma<

553gemm::GemmShape<1,2,1>,

554 1,

555half_t,

556layout::RowMajor,

557 half_t,

558layout::ColumnMajor,

559 half_t,

560 layout::RowMajor,

561 arch::OpMultiplyAdd>;

562

563 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);

564 Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);

565 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);

566

567Mma mma;

568

569CUTLASS_PRAGMA_UNROLL

570for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){

571

572CUTLASS_PRAGMA_UNROLL

573for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){

574

575CUTLASS_PRAGMA_UNROLL

576for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){

577

578 Array<half_t, 2> tmp;

579 Array<half_t, 2> *ptr_tmp = &tmp;

580 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];

581

582 Array<half_t, 2> tmp_B;

583 tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);

584 tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);

585

586 mma(

587 tmp,

588 ptr_A[m*Shape::kK + k],

589 tmp_B,

590 tmp);

591

592 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];

593 }

594 }

595 }

596 }

597 };

598

600 // Specialization for TTN //

602

603 template <typename Shape>

604 struct Mma_HFMA2 <

605 Shape,

606 layout::RowMajor,

607layout::RowMajor,

608layout::ColumnMajor,

609 true

610 > {

611

612static_assert(

613 !(Shape::kM % 2),

614"Mma_HFMA2 requires the M dimension to be divisible by 2."

615 );

616

618using FragmentA = Array<half_t, Shape::kMK>;

619

621using FragmentB = Array<half_t, Shape::kKN>;

622

624using FragmentC = Array<half_t, Shape::kMN>;

625

626//

627// Methods

628//

629

631CUTLASS_HOST_DEVICE

632void operator()(

633FragmentC & D,

634FragmentA const & A,

635FragmentB const & B,

636FragmentC const & C) {

637

639 D = C;

640

642using Mma = arch::Mma<

643gemm::GemmShape<2,1,1>,

644 1,

645half_t,

646layout::RowMajor,

647 half_t,

648 layout::RowMajor,

649 half_t,

650layout::ColumnMajor,

651 arch::OpMultiplyAdd>;

652

653 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);

654 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);

655 Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);

656

657Mma mma;

658

659CUTLASS_PRAGMA_UNROLL

660for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){

661

662CUTLASS_PRAGMA_UNROLL

663for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){

664

665CUTLASS_PRAGMA_UNROLL

666for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){

667

668 Array<half_t, 2> tmp;

669 Array<half_t, 2> *ptr_tmp = &tmp;

670 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];

671

672 Array<half_t, 2> tmp_A;

673 tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);

674 tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);

675

676 mma(

677 tmp,

678 tmp_A,

679 ptr_B[k*Shape::kN + n],

680 tmp);

681

682 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];

683 }

684 }

685 }

686 }

687 };

688

689

691 // Specialization for TTT //

693

694 template <typename Shape>

695 struct Mma_HFMA2<

696 Shape,

697 layout::RowMajor,

698layout::RowMajor,

699layout::RowMajor,

700 true

701 > {

702

703static_assert(

704 !(Shape::kN % 2),

705"Mma_HFMA2 requires the N dimension to be divisible by 2."

706 );

707

709using FragmentA = Array<half_t, Shape::kMK>;

710

712using FragmentB = Array<half_t, Shape::kKN>;

713

715using FragmentC = Array<half_t, Shape::kMN>;

716

717//

718// Methods

719//

720

722CUTLASS_HOST_DEVICE

723void operator()(

724FragmentC & D,

725FragmentA const & A,

726FragmentB const & B,

727FragmentC const & C) {

728

730 D = C;

731

733using Mma = arch::Mma<

734gemm::GemmShape<1,2,1>,

735 1,

736half_t,

737layout::RowMajor,

738 half_t,

739 layout::RowMajor,

740 half_t,

741 layout::RowMajor,

742 arch::OpMultiplyAdd>;

743

744 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);

745 Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);

746 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);

747

748Mma mma;

749

750CUTLASS_PRAGMA_UNROLL

751for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){

752

753CUTLASS_PRAGMA_UNROLL

754for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){

755

756CUTLASS_PRAGMA_UNROLL

757for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){

758

759 Array<half_t, 2> tmp;

760 Array<half_t, 2> *ptr_tmp = &tmp;

761 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];

762

763 mma(

764 tmp,

765 ptr_A[m*Shape::kK + k],

766 ptr_B[k*Shape::kN/2 + n],

767 tmp);

768

769 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];

770 }

771 }

772 }

773 }

774 };

775

777 // Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T //

779

780 template <typename Shape, typename LayoutA, typename LayoutB>

781 struct Mma_HFMA2<

782 Shape,

783 LayoutA,

784 LayoutB,

785 layout::RowMajor,

786 false

787 > {

788

789static_assert(

790 !(Shape::kK % 2),

791"Mma_HFMA2 requires the K dimension to be divisible by 2."

792 );

793

795using FragmentA = Array<half_t, Shape::kMK>;

796

798using FragmentB = Array<half_t, Shape::kKN>;

799

801using FragmentC = Array<half_t, Shape::kMN>;

802

803//

804// Methods

805//

806

808CUTLASS_HOST_DEVICE

809void operator()(

810FragmentC & D,

811FragmentA const & A,

812FragmentB const & B,

813FragmentC const & C) {

814

816 D = C;

817

819using GemmShape = gemm::GemmShape<1,1,2>;

820

821 Array<half_t, 1> *ptr_D = reinterpret_cast<Array<half_t, 1> *>(&D);

822 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);

823 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);

824

825// Inner product is calculated using MACs, followed by final reduction

826multiply_add<Array<half_t, 2>> mac;

827cutlass::reduction::thread::Reduce< plus<half_t>, Array<half_t, 2> > reduce;

828

829CUTLASS_PRAGMA_UNROLL

830for(auto n=0; n < Shape::kN / GemmShape::kN; n++){

831

832CUTLASS_PRAGMA_UNROLL

833for(auto m=0; m < Shape::kM / GemmShape::kM; m++){

834

835 Array<half_t, 2> tmp_C;

836 tmp_C.clear();

837 Array<half_t, 1> *ptr_tmp_C = reinterpret_cast<Array<half_t, 1> *>(&tmp_C);

838 ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];

839

840CUTLASS_PRAGMA_UNROLL

841for(auto k=0; k < Shape::kK / GemmShape::kK; k++){

842 tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);

843 }

844

845 Array<half_t, 1> res;

846 Array<half_t, 1> *ptr_res = &res;

847 res = reduce(tmp_C);

848

849 ptr_D[m*Shape::kN + n] = ptr_res[0];

850 }

851 }

852 }

853 };

854

856 // Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N //

858

859 template <typename Shape, typename LayoutA, typename LayoutB>

860 struct Mma_HFMA2<

861 Shape,

862 LayoutA,

863 LayoutB,

864 layout::ColumnMajor,

865 false

866 > {

867

868static_assert(

869 !(Shape::kK % 2),

870"Mma_HFMA2 requires the K dimension to be divisible by 2."

871 );

872

874using FragmentA = Array<half_t, Shape::kMK>;

875

877using FragmentB = Array<half_t, Shape::kKN>;

878

880using FragmentC = Array<half_t, Shape::kMN>;

881

882//

883// Methods

884//

885

887CUTLASS_HOST_DEVICE

888void operator()(

889FragmentC & D,

890FragmentA const & A,

891FragmentB const & B,

892FragmentC const & C) {

893

895 D = C;

896

898using GemmShape= gemm::GemmShape<1,1,2>;

899

900 Array<half_t, 1> *ptr_D = reinterpret_cast<Array<half_t, 1> *>(&D);

901 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);

902 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);

903

904// Inner product is calculated using MACs, followed by final reduction

905multiply_add<Array<half_t, 2>> mac;

906cutlass::reduction::thread::Reduce< plus<half_t>, Array<half_t, 2> > reduce;

907

908CUTLASS_PRAGMA_UNROLL

909for(auto n=0; n < Shape::kN / GemmShape::kN; n++){

910

911CUTLASS_PRAGMA_UNROLL

912for(auto m=0; m < Shape::kM / GemmShape::kM; m++){

913

914 Array<half_t, 2> tmp_C;

915 tmp_C.clear();

916 Array<half_t, 1> *ptr_tmp_C = reinterpret_cast<Array<half_t, 1> *>(&tmp_C);

917 ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];

918

919CUTLASS_PRAGMA_UNROLL

920for(auto k=0; k < Shape::kK / GemmShape::kK; k++){

921

922 tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);

923

924 }

925

926 Array<half_t, 1> res;

927 Array<half_t, 1> *ptr_res = &res;

928 res = reduce(tmp_C);

929

930 ptr_D[n*Shape::kM + m] = ptr_res[0];

931 }

932 }

933 }

934 };

935

936 } // namespace detail

937

939

941 template <

943typename Shape_, typename LayoutA, typename LayoutB, typename LayoutC

944 >

[945](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html) struct Mma<

946 Shape_,

947half_t,

948 LayoutA,

949half_t,

950 LayoutB,

951half_t,

952 LayoutC,

953 arch::OpMultiplyAdd

954 > {

955

[957](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a041bfce41e4c95a7a67dc4156173e1f4)using [Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a041bfce41e4c95a7a67dc4156173e1f4) = Shape_;

958

[960](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#abc237ebaf010ac6a3e91a93830772707)using ElementA = half_t;

961

[963](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a4b52c217fcddfa6f6ec603ed0caff3f0)using ElementB = half_t;

964

[966](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a836cdbd43f3a01a930049af70f8009bd)using ElementC = half_t;

967

[969](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a62084aaf63a7538ba29de4c60d64d133)using [Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a62084aaf63a7538ba29de4c60d64d133) = arch::OpMultiplyAdd;

970

[972](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a64b2cf33786247c4acd872fb8856abd5)using [FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a64b2cf33786247c4acd872fb8856abd5) = Array<ElementA, Shape::kMK>;

973

[975](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a5a00c6305fd345f12f9469b790e99f12)using [FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a5a00c6305fd345f12f9469b790e99f12) = Array<ElementB, Shape::kKN>;

976

[978](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7)using [FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7) = Array<ElementC, Shape::kMN>;

979

980//

981// Methods

982//

983

985CUTLASS_HOST_DEVICE

[986](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a7eb69f25c0b516fda203957a230df3ee)void [operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a7eb69f25c0b516fda203957a230df3ee)(

987[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7) & D,

988[FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a64b2cf33786247c4acd872fb8856abd5) const & A,

989[FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a5a00c6305fd345f12f9469b790e99f12) const & B,

990[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7) const & C) {

991

992constexpr bool a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value;

993constexpr bool b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value;

994constexpr bool c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value;

995constexpr bool c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value;

996

997constexpr bool m_mod2 = !(Shape::kM % 2);

998constexpr bool n_mod2 = !(Shape::kN % 2);

999constexpr bool k_mod2 = !(Shape::kK % 2);

1000

1001// HFMA based MMA optimizations are of 2 types :

1002// 1. Inner product

1003// 2. Outer product

1004// It is chosen based on LayoutC (for outer product gemm) or

1005// Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms)

1006// If all fails, we choose the generic MMA

1007constexpr bool use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);

1008constexpr bool use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);

1009constexpr bool use_optimized = (use_outer_prod || use_inner_prod);

1010

1011typename platform::conditional< use_optimized,

1012detail::Mma_HFMA2<Shape, LayoutA, LayoutB, LayoutC, use_outer_prod>,

1013MmaGeneric <Shape, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator>

1014 >::type mma;

1015

1016 mma(D, A, B, C);

1017

1018 }

1019 };

1020

1022

1023 namespace detail {

1024

1026template <

1027typename LayoutA,

1029typename LayoutB>

[1030](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html)struct [EnableMma_Crow_SM60](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html) {

1031

[1032](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html#a8ec734b2126bd5147abafee8a3b7be70)static bool const kIsConventionalLayout =

1033 (platform::is_same<LayoutA, layout::RowMajor>::value ||

1034platform::is_same<LayoutA, layout::ColumnMajor>::value) &&

1035 (platform::is_same<LayoutB, layout::RowMajor>::value ||

1036platform::is_same<LayoutB, layout::ColumnMajor>::value);

1037

[1038](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html#a2efb4c6abab3bfc29c0d58df8ccc0fd3)static bool const value = kIsConventionalLayout;

1039 };

1040 };

1041

1043

1045 template <

1047typename Shape_,

1048typename LayoutA_,

1049typename LayoutB_

1050 >

[1051](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html) struct Mma<

1052 Shape_,

1053half_t,

1054 LayoutA_,

1055half_t,

1056 LayoutB_,

1057half_t,

1058 layout::RowMajor,

1059 arch::OpMultiplyAdd,

1060 typename platform::enable_if<detail::EnableMma_Crow_SM60<

1061 LayoutA_,

1062 LayoutB_

1063 >::value>::type>{

1064

[1065](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a951f25ff3bb7a76bac1f867ee21c657f)using [Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a951f25ff3bb7a76bac1f867ee21c657f) = Shape_;

[1066](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a7ffe7f427ffce1c269587417e4fed240)using ElementA = half_t;

[1067](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0975c18cc4a9d376011858c6dbf740d0)using LayoutA = LayoutA_;

[1068](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a962acba07bc680b70ee1b08732d2516f)using ElementB = half_t;

[1069](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a28b637c1f311310a27b39c44e89e698e)using LayoutB = LayoutB_;

[1070](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#aff9afb3fc630bd0bdb35de1b402c65fa)using ElementC = half_t;

[1071](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a397dfb5a622d1ebe47177825194a03a9)using LayoutC = layout::RowMajor;

[1072](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#af96ae215c5f273447ed44baa1315ffcf)using [Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#af96ae215c5f273447ed44baa1315ffcf) = arch::OpMultiplyAdd;

1073

1074using TransposeMma = Mma<

1075GemmShapeTranspose<Shape>,

1076half_t,

1077typename layout::LayoutTranspose<LayoutB>::type,

1078 half_t,

1079typename layout::LayoutTranspose<LayoutA>::type,

1080 half_t,

1081layout::ColumnMajor,

1082 arch::OpMultiplyAdd,

[1083](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a2acc2e5fb14c4e62ea997d80402730c5)bool>;

1084

[1085](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#adbd6a51a9e477d917f5739230a023524)using [FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#adbd6a51a9e477d917f5739230a023524) = Array<ElementA, Shape::kMK>;

[1086](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a687f0bd7056ea8ff518bfed26f027e4f)using [FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a687f0bd7056ea8ff518bfed26f027e4f) = Array<ElementB, Shape::kKN>;

[1087](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f)using [FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f) = Array<ElementC, Shape::kMN>;

1088

1089CUTLASS_HOST_DEVICE

[1090](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a72fad6edd8b029407aad12fb22937358)void [operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a72fad6edd8b029407aad12fb22937358)(

1091[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f) & D,

1092[FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#adbd6a51a9e477d917f5739230a023524) const & A,

1093[FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a687f0bd7056ea8ff518bfed26f027e4f) const & B,

1094[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f) const & C) {

1095

1096TransposeMma mma;

1097

1098 mma(D, B, A, C);

1099 }

1100 };

1101

1103

1104 } // namespace thread

1105 } // namespace gemm

1106 } // namespace cutlass

1107

cutlass::multiply_add

Fused multiply-add.

Definition: functional.h:92

[cutlass::gemm::thread::detail::EnableMma_Crow_SM60](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html)

Determines whether to enable thread::Gemm<> specializations compatible with SM50. ...

Definition: gemm/thread/mma_sm60.h:1030

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, LayoutA, LayoutB, layout::RowMajor, false >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:801

cutlass::gemm::GemmShape::kM

static int const kM

Definition: include/cutlass/gemm/gemm.h:58

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::RowMajor, layout::ColumnMajor, true >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:271

cutlass

Definition: aligned_buffer.h:35

constexpr

#define constexpr

Definition: platform.h:137

tensor_ref.h

Defines a structure containing strides, bounds, and a pointer to tensor data.

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::ColumnMajor, layout::ColumnMajor, true >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:94

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::ColumnMajor, layout::RowMajor, true >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:528

cutlass::platform::is_same

std::is_same (false specialization)

Definition: platform.h:394

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, LayoutA, LayoutB, layout::RowMajor, false >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:809

cutlass::gemm::thread::detail::Mma_HFMA2

Structure to compute the matrix product for HFMA.

Definition: gemm/thread/mma_sm60.h:66

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f)

Array< ElementC, Shape::kMN > FragmentC

Definition: gemm/thread/mma_sm60.h:1087

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::ColumnMajor, layout::ColumnMajor, true >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:441

cutlass::half_t

IEEE half-precision floating-point type.

Definition: half.h:126

gemm.h

Defines common types used for all GEMM-like operators.

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a72fad6edd8b029407aad12fb22937358)

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Definition: gemm/thread/mma_sm60.h:1090

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::ColumnMajor, layout::ColumnMajor, true >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:444

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::ColumnMajor, layout::ColumnMajor, true >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:438

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::RowMajor, layout::RowMajor, true >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:357

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::ColumnMajor, layout::ColumnMajor, true >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:102

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::RowMajor, layout::RowMajor, true >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:723

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::RowMajor, layout::ColumnMajor, true >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:632

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::ColumnMajor, layout::RowMajor, true >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:174

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::RowMajor, layout::RowMajor, true >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:712

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::RowMajor, layout::ColumnMajor, true >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:624

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#af96ae215c5f273447ed44baa1315ffcf)

arch::OpMultiplyAdd Operator

Definition: gemm/thread/mma_sm60.h:1072

cutlass::gemm::GemmShape::kK

static int const kK

Definition: include/cutlass/gemm/gemm.h:60

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::ColumnMajor, layout::RowMajor, true >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:177

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a5a00c6305fd345f12f9469b790e99f12)

Array< ElementB, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:975

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::ColumnMajor, layout::ColumnMajor, true >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:91

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a687f0bd7056ea8ff518bfed26f027e4f)

Array< ElementB, Shape::kKN > FragmentB

Definition: gemm/thread/mma_sm60.h:1086

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::RowMajor, layout::RowMajor, true >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:365

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a62084aaf63a7538ba29de4c60d64d133)

arch::OpMultiplyAdd Operator

Underlying mathematical operator.

Definition: gemm/thread/mma_sm60.h:969

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::RowMajor, layout::RowMajor, true >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:709

cutlass::layout::LayoutTranspose

Defines transposes of matrix layouts.

Definition: layout/matrix.h:921

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::ColumnMajor, layout::RowMajor, true >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:531

cutlass::gemm::thread::MmaGeneric

Gemplate that handles all packed matrix layouts.

Definition: gemm/thread/mma_sm50.h:65

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, LayoutA, LayoutB, layout::ColumnMajor, false >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:888

reduce.h

Defines basic thread level reduction with specializations for Array<T, N>.

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7)

Array< ElementC, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:978

cutlass::platform::enable_if

std::enable_if (true specialization)

Definition: platform.h:315

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::ColumnMajor, layout::RowMajor, true >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:188

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

mma.h

Templates exposing architecture support for warp-level multiply-add operations.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::platform::conditional

std::conditional (true specialization)

Definition: platform.h:325

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::RowMajor, layout::ColumnMajor, true >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:265

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::ColumnMajor, layout::ColumnMajor, true >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:88

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, LayoutA, LayoutB, layout::ColumnMajor, false >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:880

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a7eb69f25c0b516fda203957a230df3ee)

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:986

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, LayoutA, LayoutB, layout::RowMajor, false >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:795

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::RowMajor, layout::RowMajor, true >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:351

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::ColumnMajor, layout::ColumnMajor, true >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:452

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::RowMajor, layout::ColumnMajor, true >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:621

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a041bfce41e4c95a7a67dc4156173e1f4)

Shape_ Shape

Size of the Gemm problem - concept: gemm::GemmShape<>

Definition: gemm/thread/mma_sm60.h:957

cutlass::gemm::thread::Mma

Structure to compute the matrix product.

Definition: gemm/thread/mma.h:66

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, LayoutA, LayoutB, layout::ColumnMajor, false >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:877

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::RowMajor, layout::RowMajor, true >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:715

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::ColumnMajor, layout::RowMajor, true >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:534

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::ColumnMajor, layout::RowMajor, true >::FragmentC

Array< half_t, Shape::kMN > FragmentC

C operand storage.

Definition: gemm/thread/mma_sm60.h:180

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, LayoutA, LayoutB, layout::RowMajor, false >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:798

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::RowMajor, layout::ColumnMajor, true >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:618

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::RowMajor, layout::ColumnMajor, true >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:268

cutlass::arch::Mma

Matrix multiply-add operation.

Definition: arch/mma.h:92

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a64b2cf33786247c4acd872fb8856abd5)

Array< ElementA, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:972

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::RowMajor, layout::RowMajor, true >::FragmentB

Array< half_t, Shape::kKN > FragmentB

B operand storage.

Definition: gemm/thread/mma_sm60.h:354

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::RowMajor, layout::ColumnMajor, layout::RowMajor, true >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:542

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, layout::ColumnMajor, layout::RowMajor, layout::ColumnMajor, true >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)

Computes a matrix product D = A * B + C.

Definition: gemm/thread/mma_sm60.h:279

functional.h

Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#adbd6a51a9e477d917f5739230a023524)

Array< ElementA, Shape::kMK > FragmentA

Definition: gemm/thread/mma_sm60.h:1085

cutlass::reduction::thread::Reduce

Structure to compute the thread level reduction.

Definition: reduce.h:43

cutlass::arch::mac

CUTLASS_HOST_DEVICE Array< T, N > mac(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c)

Definition: simd.h:84

[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a951f25ff3bb7a76bac1f867ee21c657f)

Shape_ Shape

Definition: gemm/thread/mma_sm60.h:1065

cutlass::gemm::thread::detail::Mma_HFMA2< Shape, LayoutA, LayoutB, layout::ColumnMajor, false >::FragmentA

Array< half_t, Shape::kMK > FragmentA

A operand storage.

Definition: gemm/thread/mma_sm60.h:874

cutlass::gemm::GemmShape::kN

static int const kN

Definition: include/cutlass/gemm/gemm.h:59


Generated by 1.8.11