Back to Cutlass

CUTLASS: mma_simt_tile_iterator.h Source File

docs/mma__simt__tile__iterator_8h_source.html

4.4.2160.7 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

mma_simt_tile_iterator.h

[Go to the documentation of this file.](mma simt tile__iterator_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

30 #pragma once

31

32 #include "cutlass/cutlass.h"

33 #include "cutlass/array.h"

34 #include "cutlass/tensor_ref.h"

35 #include "cutlass/matrix_shape.h"

36 #include "cutlass/layout/matrix.h"

37

38 #include "cutlass/gemm/gemm.h"

39 #include "[cutlass/gemm/warp/mma_simt_policy.h](mma simt policy_8h.html)"

40

42

43 namespace cutlass {

44 namespace gemm {

45 namespace warp {

46

48

53 template <

55typename Shape_,

57Operand Operand,

59typename Element_,

61typename Layout_,

63typename Policy_,

65int PartitionsK = 1,

67int PartitionGroupSize = 1

68 >

69 class MmaSimtTileIterator;

70

72

77 template <

79typename Shape_,

81typename Element_,

83typename Policy_,

85int PartitionsK,

87int PartitionGroupSize

88 >

89 class MmaSimtTileIterator<Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize> {

90 public:

91

93using Shape = Shape_;

94

96static Operand const kOperand = Operand::kA;

97

99using Element = Element_;

100

102using Layout = layout::ColumnMajor;

103

105using Policy = Policy_;

106

108using TensorRef = TensorRef<Element, Layout>;

109

111using Index = typename TensorRef::Index;

112

114using LongIndex = typename TensorRef::LongIndex;

115

117using TensorCoord = typename TensorRef::TensorCoord;

118

119//

120// Derived quantities

121//

122

123static_assert(!(Shape::kRow % Policy::WarpShape::kRow),

124"The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");

125

126static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");

127static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");

128static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");

129static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");

130

132using ThreadShape = MatrixShape<

133 Shape::kRow / Policy::WarpShape::kRow,

134 Shape::kColumn

135 >;

136

137static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM),

138"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");

139

141using Iterations = MatrixShape<

142 ThreadShape::kRow / Policy::LaneMmaShape::kM,

143 ThreadShape::kColumn

144 >;

145

147using Fragment = Array<Element, ThreadShape::kCount>;

148

149 private:

150

152cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kM>, layout::ColumnMajor> ref_;

153

154 public:

155

157CUTLASS_HOST_DEVICE

158MmaSimtTileIterator() { }

159

161CUTLASS_HOST_DEVICE

162MmaSimtTileIterator(

163TensorRef ref,

164int lane_id

165 ) {

166

167// compute offset based on thread ID and lane layout

168typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();

169

170MatrixCoord lane_offset = lane_layout.inverse(lane_id) *

171MatrixCoord(Policy::LaneMmaShape::kM, 0);

172

173 ref.add_coord_offset(lane_offset);

174

175 ref_.reset(

176reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(ref.data()),

177 ref.stride(0) / Policy::LaneMmaShape::kM);

178 }

179

180

182CUTLASS_HOST_DEVICE

183MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {

184 ref_.add_pointer_offset(offset);

185return *this;

186 }

187

189CUTLASS_HOST_DEVICE

190MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {

191

192 ref_.add_coord_offset({

193 coord.row() * Shape::kRow / Policy::LaneMmaShape::kM,

194 coord.column() * Shape::kColumn});

195

196return *this;

197 }

198

200CUTLASS_HOST_DEVICE

201MmaSimtTileIterator & operator++() {

202

203 ref_.add_coord_offset({0, Shape::kColumn});

204

205return *this;

206 }

207

209CUTLASS_HOST_DEVICE

210MmaSimtTileIterator & operator--() {

211

212 ref_.add_coord_offset({0, -Shape::kColumn});

213

214return *this;

215 }

216

218CUTLASS_HOST_DEVICE

219void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {

220 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr =

221reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag);

222

223CUTLASS_PRAGMA_UNROLL

224for (int k = 0; k < Iterations::kColumn; ++k) {

225CUTLASS_PRAGMA_UNROLL

226for (int m = 0; m < Iterations::kRow; ++m) {

227 dst_ptr[m + k * Iterations::kRow] =

228 *(ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM);

229 }

230 }

231 }

233CUTLASS_HOST_DEVICE

234void load(Fragment &frag) const {

235 load_with_pointer_offset(frag, 0);

236 }

237

239CUTLASS_HOST_DEVICE

240void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {

241

242 Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =

243reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag);

244

245CUTLASS_PRAGMA_UNROLL

246for (int k = 0; k < Iterations::kN; ++k) {

247CUTLASS_PRAGMA_UNROLL

248for (int m = 0; m < Iterations::kM; ++m) {

249 *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) =

250 src_ptr[m + k * Iterations::kM];

251 }

252 }

253 }

254

256CUTLASS_HOST_DEVICE

257void store(Fragment const &frag) const {

258 store_with_pointer_offset(frag, 0);

259 }

260

268 CUTLASS_DEVICE

269void set_kgroup_index(int k_group) {

270// no operation here

271 }

272 };

273

275

280 template <

282typename Shape_,

284typename Element_,

286typename Policy_,

288int PartitionsK,

290int PartitionGroupSize

291 >

292 class MmaSimtTileIterator<Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize> {

293 public:

294

296using Shape = Shape_;

297

299static Operand const kOperand = Operand::kB;

300

302using Element = Element_;

303

305using Layout = layout::RowMajor;

306

308using Policy = Policy_;

309

311using TensorRef = TensorRef<Element, Layout>;

312

314using Index = typename TensorRef::Index;

315

317using LongIndex = typename TensorRef::LongIndex;

318

320using TensorCoord = typename TensorRef::TensorCoord;

321

322//

323// Derived quantities

324//

325

326static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn),

327"The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");

328

329static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");

330static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");

331static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");

332static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");

333

335using ThreadShape = MatrixShape<

336 Shape::kRow,

337 Shape::kColumn / Policy::WarpShape::kColumn

338 >;

339

340static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN),

341"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");

342

344using Iterations = MatrixShape<

345 ThreadShape::kRow,

346 ThreadShape::kColumn / Policy::LaneMmaShape::kN

347 >;

348

350using Fragment = Array<Element, ThreadShape::kCount>;

351

352 private:

353

355cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kN>, layout::RowMajor> ref_;

356

357

358 public:

359

361CUTLASS_HOST_DEVICE

362MmaSimtTileIterator() { }

363

365CUTLASS_HOST_DEVICE

366MmaSimtTileIterator(

367TensorRef ref,

368int lane_id

369 ) {

370

371// compute offset based on thread ID and lane layout

372typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();

373

374MatrixCoord lane_offset = lane_layout.inverse(lane_id) *

375MatrixCoord(0, Policy::LaneMmaShape::kN);

376

377 ref.add_coord_offset(lane_offset);

378

379 ref_.reset(

380reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(ref.data()),

381 ref.stride(0) / Policy::LaneMmaShape::kN);

382 }

383

385CUTLASS_HOST_DEVICE

386MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {

387 ref_.add_pointer_offset(offset);

388return *this;

389 }

390

392CUTLASS_HOST_DEVICE

393MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {

394

395 ref_.add_coord_offset({

396 coord.row() * Shape::kRow,

397 coord.column() * Shape::kColumn / Policy::LaneMmaShape::kN});

398

399return *this;

400 }

401

403CUTLASS_HOST_DEVICE

404MmaSimtTileIterator & operator++() {

405

406 ref_.add_coord_offset({Shape::kRow, 0});

407

408return *this;

409 }

410

412CUTLASS_HOST_DEVICE

413MmaSimtTileIterator & operator--() {

414

415 ref_.add_coord_offset({-Shape::kRow, 0});

416

417return *this;

418 }

419

421CUTLASS_HOST_DEVICE

422void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {

423

424 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =

425reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);

426

427CUTLASS_PRAGMA_UNROLL

428for (int k = 0; k < Iterations::kRow; ++k) {

429CUTLASS_PRAGMA_UNROLL

430for (int n = 0; n < Iterations::kColumn; ++n) {

431 dst_ptr[n + k * Iterations::kColumn] =

432 *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN);

433 }

434 }

435 }

436

438CUTLASS_HOST_DEVICE

439void load(Fragment &frag) const {

440 load_with_pointer_offset(frag, 0);

441 }

442

444CUTLASS_HOST_DEVICE

445void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {

446

447 Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =

448reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);

449

450CUTLASS_PRAGMA_UNROLL

451for (int k = 0; k < Iterations::kM; ++k) {

452CUTLASS_PRAGMA_UNROLL

453for (int n = 0; n < Iterations::kN; ++n) {

454 *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) =

455 src_ptr[n + k * Iterations::kN];

456 }

457 }

458 }

459

461CUTLASS_HOST_DEVICE

462void store(Fragment const &frag, Index pointer_offset) const {

463 store_with_pointer_offset(frag, 0);

464 }

465

473 CUTLASS_DEVICE

474void set_kgroup_index(int k_group) {

475// no operation here

476 }

477 };

478

480

485 template <

487typename Shape_,

489typename Element_,

491typename Policy_

492 >

493 class MmaSimtTileIterator<Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_> {

494 public:

495

497using Shape = Shape_;

498

500static Operand const kOperand = Operand::kC;

501

503using Element = Element_;

504

506using Layout = layout::ColumnMajor;

507

509using Policy = Policy_;

510

512using TensorRef = TensorRef<Element, Layout>;

513

515using Index = typename TensorRef::Index;

516

518using LongIndex = typename TensorRef::LongIndex;

519

521using TensorCoord = typename TensorRef::TensorCoord;

522

523//

524// Derived quantities

525//

526

527static_assert(

528 (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),

529"Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");

530

531static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");

532static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");

533static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");

534static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");

535static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");

536static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");

537

539using ThreadShape = MatrixShape<

540 Shape::kRow / Policy::WarpShape::kRow,

541 Shape::kColumn / Policy::WarpShape::kColumn

542 >;

543

544static_assert(

545 (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),

546"Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");

547

549using Iterations = MatrixShape<

550 ThreadShape::kRow / Policy::LaneMmaShape::kM,

551 ThreadShape::kColumn / Policy::LaneMmaShape::kN

552 >;

553

554using Delta = MatrixShape<

555 Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,

556 Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN

557 >;

558

560using Fragment = Array<Element, ThreadShape::kCount>;

561

562 private:

563

564TensorRef ref_;

565

566 public:

567

569CUTLASS_HOST_DEVICE

570MmaSimtTileIterator() { }

571

573CUTLASS_HOST_DEVICE

574MmaSimtTileIterator(

575TensorRef const &ref,

576int lane_id

577 ):

578 ref_(ref) {

579

580// compute offset based on thread ID and lane layout

581typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();

582

583MatrixCoord lane_offset = lane_layout.inverse(lane_id) *

584MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);

585

586 ref_.add_coord_offset(lane_offset);

587 }

588

590CUTLASS_HOST_DEVICE

591MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {

592 ref_.add_pointer_offset(offset);

593return *this;

594 }

595

597CUTLASS_HOST_DEVICE

598MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {

599

600 ref_.add_coord_offset({

601 coord.row() * Shape::kRow,

602 coord.column() * Shape::kColumn});

603

604return *this;

605 }

606

608CUTLASS_HOST_DEVICE

609MmaSimtTileIterator & operator++() {

610

611 ref_.add_coord_offset({Shape::kRow, 0});

612

613return *this;

614 }

615

617CUTLASS_HOST_DEVICE

618MmaSimtTileIterator & operator--() {

619

620 ref_.add_coord_offset({-Shape::kRow, 0});

621

622return *this;

623 }

624

626CUTLASS_HOST_DEVICE

627void load_with_pointer_offset(

628Fragment &frag,

629Index pointer_offset) const {

630

631CUTLASS_PRAGMA_UNROLL

632for (int mma_n = 0; mma_n < Iterations::kN; ++mma_n) {

633CUTLASS_PRAGMA_UNROLL

634for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {

635

636 Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =

637reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> const *>(

638 ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kN + n}));

639

640CUTLASS_PRAGMA_UNROLL

641for (int mma_m = 0; mma_m < Iterations::kM; ++mma_m) {

642

643 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr =

644reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag) +

645 mma_m + Iterations::kM * (n + mma_n * Policy::LaneMmaShape::kN);

646

647 *dst_ptr = src_ptr[mma_m * Policy::WarpShape::kM];

648 }

649 }

650 }

651 }

652

654CUTLASS_HOST_DEVICE

655void load(Fragment &frag) const {

656 load_with_pointer_offset(frag, 0);

657 }

658

660CUTLASS_HOST_DEVICE

661void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {

662

663CUTLASS_PRAGMA_UNROLL

664for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {

665CUTLASS_PRAGMA_UNROLL

666for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {

667

668 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr=

669reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(

670 ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kColumn + n}));

671

672CUTLASS_PRAGMA_UNROLL

673for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {

674

675 Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =

676reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> const *>(&frag) +

677 mma_m + Iterations::kRow * (n + mma_n * Policy::LaneMmaShape::kN);

678

679 dst_ptr[mma_m * Policy::WarpShape::kRow] = *src_ptr;

680 }

681 }

682 }

683 }

685CUTLASS_HOST_DEVICE

686void store(Fragment const &frag) const {

687 store_with_pointer_offset(frag, 0);

688 }

689 };

690

692

697 template <

699typename Shape_,

701typename Element_,

703typename Policy_

704 >

705 class MmaSimtTileIterator<Shape_, Operand::kC, Element_, layout::RowMajor, Policy_> {

706 public:

707

709using Shape = Shape_;

710

712static Operand const kOperand = Operand::kC;

713

715using Element = Element_;

716

718using Layout = layout::RowMajor;

719

721using Policy = Policy_;

722

724using TensorRef = TensorRef<Element, Layout>;

725

727using Index = typename TensorRef::Index;

728

730using LongIndex = typename TensorRef::LongIndex;

731

733using TensorCoord = typename TensorRef::TensorCoord;

734

735//

736// Derived quantities

737//

738

739static_assert(

740 (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),

741"Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");

742

743static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");

744static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");

745static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");

746static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");

747static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");

748static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");

749

751using ThreadShape = MatrixShape<

752 Shape::kRow / Policy::WarpShape::kRow,

753 Shape::kColumn / Policy::WarpShape::kColumn

754 >;

755

756static_assert(

757 (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),

758"Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");

759

761using Iterations = MatrixShape<

762 ThreadShape::kRow / Policy::LaneMmaShape::kM,

763 ThreadShape::kColumn / Policy::LaneMmaShape::kN

764 >;

765

766using Delta = MatrixShape<

767 Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,

768 Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN

769 >;

770

772using Fragment = Array<Element, ThreadShape::kCount>;

773

774 private:

775

776TensorRef ref_;

777

778 public:

779

781CUTLASS_HOST_DEVICE

782MmaSimtTileIterator() { }

783

785CUTLASS_HOST_DEVICE

786MmaSimtTileIterator(

787TensorRef const &ref,

788int lane_id

789 ):

790 ref_(ref) {

791

792// compute offset based on thread ID and lane layout

793typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();

794

795MatrixCoord lane_offset = lane_layout.inverse(lane_id) *

796MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);

797

798 ref_.add_coord_offset(lane_offset);

799 }

800

802CUTLASS_HOST_DEVICE

803MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {

804 ref_.add_pointer_offset(offset);

805return *this;

806 }

807

809CUTLASS_HOST_DEVICE

810MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {

811

812 ref_.add_coord_offset({

813 coord.row() * Shape::kRow,

814 coord.column() * Shape::kColumn});

815

816return *this;

817 }

818

820CUTLASS_HOST_DEVICE

821MmaSimtTileIterator & operator++() {

822

823 ref_.add_coord_offset({Shape::kRow, 0});

824

825return *this;

826 }

827

829CUTLASS_HOST_DEVICE

830MmaSimtTileIterator & operator--() {

831

832 ref_.add_coord_offset({-Shape::kRow, 0});

833

834return *this;

835 }

836

838CUTLASS_HOST_DEVICE

839void load_with_pointer_offset(

840Fragment &frag,

841Index pointer_offset) const {

842

843CUTLASS_PRAGMA_UNROLL

844for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {

845CUTLASS_PRAGMA_UNROLL

846for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {

847

848 Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =

849reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> const *>(

850 ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));

851

852CUTLASS_PRAGMA_UNROLL

853for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {

854

855 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =

856reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag) +

857 mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);

858

859 *dst_ptr = src_ptr[mma_n * Policy::WarpShape::kColumn];

860 }

861 }

862 }

863 }

864

866CUTLASS_HOST_DEVICE

867void load(Fragment &frag) const {

868 load_with_pointer_offset(frag, 0);

869 }

870

872CUTLASS_HOST_DEVICE

873void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {

874

875CUTLASS_PRAGMA_UNROLL

876for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {

877CUTLASS_PRAGMA_UNROLL

878for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {

879

880 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =

881reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(

882 ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));

883

884CUTLASS_PRAGMA_UNROLL

885for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {

886

887 Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =

888reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> const *>(&frag) +

889 mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);

890

891 dst_ptr[mma_n * Policy::WarpShape::kColumn] = *src_ptr;

892 }

893 }

894 }

895 }

896

898CUTLASS_HOST_DEVICE

899void store(Fragment const &frag) const {

900 store_with_pointer_offset(frag, 0);

901 }

902 };

903

905

907

912 template <

914typename Shape_,

916typename Element_,

918typename Policy_,

920int PartitionsK,

922int PartitionGroupSize

923 >

924 class MmaSimtTileIterator<Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved<4>, Policy_, PartitionsK, PartitionGroupSize> {

925 public:

926

928using Shape = Shape_;

929

931static Operand const kOperand = Operand::kA;

932

934using Element = Element_;

935

937using Layout = layout::ColumnMajorInterleaved<4> ;

938

940using Policy = Policy_;

941

943using TensorRef = TensorRef<Element, Layout>;

944

946using Index = typename TensorRef::Index;

947

949using LongIndex = typename TensorRef::LongIndex;

950

952using TensorCoord = typename TensorRef::TensorCoord;

953

955static const int kInterleave = 4;

956

958static const int kPartitionsK = PartitionsK;

959

961static const int kGroupPerTile = PartitionGroupSize / Shape::kColumn;

962

963//

964// Derived quantities

965//

966

967static_assert(!(Shape::kRow % Policy::WarpShape::kRow),

968"The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");

969

970static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");

971static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");

972static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");

973static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");

974

976using ThreadShape = MatrixShape<

977 Shape::kRow / Policy::WarpShape::kRow,

978 Shape::kColumn

979 >;

980

981static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM) && !(ThreadShape::kColumn % Policy::LaneMmaShape::kK),

982"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");

983

985using Iterations = MatrixShape<

986 ThreadShape::kRow / Policy::LaneMmaShape::kM,

987 ThreadShape::kColumn / Policy::LaneMmaShape::kK

988 >;

989

991using Fragment = Array<Element, ThreadShape::kCount>;

992

993 private:

994

996cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kMK>, layout::ColumnMajorInterleaved<4>> ref_;

997

999int k_group_idx_;

1000

1001 public:

1002CUTLASS_HOST_DEVICE

1003MmaSimtTileIterator() { }

1004

1006CUTLASS_HOST_DEVICE

1007MmaSimtTileIterator(

1008TensorRef ref,

1009int lane_id

1010 ) {

1011

1012// compute offset based on thread ID and lane layout

1013typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();

1014

1015MatrixCoord lane_offset = lane_layout.inverse(lane_id) *

1016MatrixCoord(Policy::LaneMmaShape::kM, 0);

1017

1018 ref.add_coord_offset(lane_offset);

1019

1020 k_group_idx_ = 0;

1021 ref_.reset(reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *>(ref.data()), ref.stride(0)/Policy::LaneMmaShape::kMK);

1022 }

1023

1024

1026CUTLASS_HOST_DEVICE

1027MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {

1028 ref_.add_pointer_offset(offset);

1029return *this;

1030 }

1031

1033CUTLASS_HOST_DEVICE

1034MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {

1035

1036 ref_.add_coord_offset({

1037 coord.row() * Shape::kRow / Policy::LaneMmaShape::kMK,

1038 coord.column() * Shape::kColumn});

1039

1040return *this;

1041 }

1042

1044CUTLASS_HOST_DEVICE

1045MmaSimtTileIterator & operator++() {

1046

1047 add_tile_offset({0, 1});

1048

1049if (kPartitionsK > 1) {

1050 ++k_group_idx_;

1051// Jump to next stage

1052if (k_group_idx_ == kGroupPerTile) {

1053 k_group_idx_ = 0;

1054 add_tile_offset({0, kGroupPerTile * (kPartitionsK-1)});

1055 }

1056 }

1057

1058return *this;

1059 }

1060

1062CUTLASS_HOST_DEVICE

1063MmaSimtTileIterator & operator--() {

1064

1065 ref_.add_coord_offset({0, -Shape::kColumn});

1066

1067return *this;

1068 }

1069

1071CUTLASS_HOST_DEVICE

1072void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {

1073

1074 Array<Element, Policy::LaneMmaShape::kMK > *dst_ptr =

1075reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *>(&frag);

1076

1077CUTLASS_PRAGMA_UNROLL

1078for (int k = 0; k < Iterations::kColumn; ++k) {

1079

1080CUTLASS_PRAGMA_UNROLL

1081for (int m = 0; m < Iterations::kRow; ++m) {

1082

1083 dst_ptr[m + k * Iterations::kRow] =

1084 *((ref_.data() + ref_.offset({m * Policy::WarpShape::kRow / kInterleave,

1085 k*Policy::LaneMmaShape::kK}) + pointer_offset / Policy::LaneMmaShape::kM));

1086 }

1087 }

1088 }

1089

1091CUTLASS_HOST_DEVICE

1092void load(Fragment &frag) const {

1093 load_with_pointer_offset(frag, 0);

1094 }

1095

1097CUTLASS_HOST_DEVICE

1098void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {

1099

1100 Array<Element, Policy::LaneMmaShape::kMK> const *src_ptr =

1101reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK > *>(&frag);

1102

1103CUTLASS_PRAGMA_UNROLL

1104for (int k = 0; k < Iterations::kN; ++k) {

1105CUTLASS_PRAGMA_UNROLL

1106for (int m = 0; m < Iterations::kM; ++m) {

1107 *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) =

1108 src_ptr[m + k * Iterations::kM];

1109 }

1110 }

1111 }

1112

1114CUTLASS_HOST_DEVICE

1115void store(Fragment const &frag) const {

1116 store_with_pointer_offset(frag, 0);

1117 }

1118

1126 CUTLASS_DEVICE

1127void set_kgroup_index(int k_group) {

1128// no operation here

1129 }

1130 };

1131

1133

1138 template <

1140typename Shape_,

1142typename Element_,

1144typename Policy_,

1146int PartitionsK,

1148int PartitionGroupSize

1149 >

1150 class MmaSimtTileIterator<Shape_, Operand::kB, Element_, layout::RowMajorInterleaved<4>, Policy_, PartitionsK, PartitionGroupSize> {

1151 public:

1152

1154using Shape = Shape_;

1155

1157static Operand const kOperand = Operand::kB;

1158

1160using Element = Element_;

1161

1163using Layout = layout::RowMajorInterleaved<4>;

1164

1166using Policy = Policy_;

1167

1169using TensorRef = TensorRef<Element, Layout>;

1170

1172using Index = typename TensorRef::Index;

1173

1175using LongIndex = typename TensorRef::LongIndex;

1176

1178using TensorCoord = typename TensorRef::TensorCoord;

1179

1181static const int kInterleave = 4;

1182

1184static const int kPartitionsK = PartitionsK;

1185

1187static const int kGroupPerTile = PartitionGroupSize / Shape::kRow;

1188

1189//

1190// Derived quantities

1191//

1192

1193static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn),

1194"The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");

1195

1196static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");

1197static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");

1198static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");

1199static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");

1200

1202using ThreadShape = MatrixShape<

1203 Shape::kRow,

1204 Shape::kColumn / Policy::WarpShape::kColumn

1205 >;

1206

1207static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN) && !(ThreadShape::kRow % Policy::LaneMmaShape::kK),

1208"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");

1209

1211using Iterations = MatrixShape<

1212 ThreadShape::kRow / Policy::LaneMmaShape::kK,

1213 ThreadShape::kColumn / Policy::LaneMmaShape::kN

1214 >;

1215

1217using Fragment = Array<Element, ThreadShape::kCount>;

1218

1219

1220 private:

1221

1223cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kKN>, layout::RowMajorInterleaved<4>> ref_;

1224

1226int k_group_idx_;

1227

1228 public:

1229

1231CUTLASS_HOST_DEVICE

1232MmaSimtTileIterator() { }

1233

1235CUTLASS_HOST_DEVICE

1236MmaSimtTileIterator(

1237TensorRef ref,

1238int lane_id

1239 ) {

1240

1241// compute offset based on thread ID and lane layout

1242typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();

1243

1244MatrixCoord lane_offset = lane_layout.inverse(lane_id) *

1245MatrixCoord(0, Policy::LaneMmaShape::kN);

1246

1247 ref.add_coord_offset(lane_offset);

1248

1249 k_group_idx_ = 0;

1250

1251 ref_.reset(

1252reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *>(ref.data()),

1253 ref.stride(0) / Policy::LaneMmaShape::kKN);

1254 }

1255

1257CUTLASS_HOST_DEVICE

1258MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {

1259 ref_.add_pointer_offset(offset);

1260return *this;

1261 }

1262

1264CUTLASS_HOST_DEVICE

1265MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {

1266

1267 ref_.add_coord_offset({

1268 coord.row() * Shape::kRow,

1269 coord.column() * Shape::kColumn / Policy::LaneMmaShape::kKN});

1270

1271return *this;

1272 }

1273

1275CUTLASS_HOST_DEVICE

1276MmaSimtTileIterator & operator++() {

1277

1278 add_tile_offset({1, 0});

1279

1280if (kPartitionsK > 1) {

1281 ++k_group_idx_;

1282// Jump to next stage

1283if (k_group_idx_ == kGroupPerTile) {

1284 k_group_idx_ = 0;

1285 add_tile_offset({kGroupPerTile * (kPartitionsK-1), 0});

1286 }

1287 }

1288

1289return *this;

1290 }

1291

1293CUTLASS_HOST_DEVICE

1294MmaSimtTileIterator & operator--() {

1295

1296 ref_.add_coord_offset({-Shape::kRow, 0});

1297

1298return *this;

1299 }

1300

1302CUTLASS_HOST_DEVICE

1303void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {

1304

1305 Array<Element, Policy::LaneMmaShape::kKN> *dst_ptr =

1306reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *>(&frag);

1307

1308CUTLASS_PRAGMA_UNROLL

1309for (int k = 0; k < Iterations::kRow; ++k) {

1310CUTLASS_PRAGMA_UNROLL

1311for (int n = 0; n < Iterations::kColumn; ++n) {

1312 dst_ptr[n + k * Iterations::kColumn] =

1313 *(ref_.data() + ref_.offset({k * Policy::LaneMmaShape::kK,

1314 n * Policy::WarpShape::kColumn / kInterleave}) + pointer_offset / Policy::LaneMmaShape::kN);

1315 }

1316 }

1317 }

1318

1320CUTLASS_HOST_DEVICE

1321void load(Fragment &frag) const {

1322 load_with_pointer_offset(frag, 0);

1323 }

1324

1326CUTLASS_HOST_DEVICE

1327void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {

1328

1329 Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =

1330reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);

1331

1332CUTLASS_PRAGMA_UNROLL

1333for (int k = 0; k < Iterations::kM; ++k) {

1334CUTLASS_PRAGMA_UNROLL

1335for (int n = 0; n < Iterations::kN; ++n) {

1336 *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) =

1337 src_ptr[n + k * Iterations::kN];

1338 }

1339 }

1340 }

1341

1343CUTLASS_HOST_DEVICE

1344void store(Fragment const &frag, Index pointer_offset) const {

1345 store_with_pointer_offset(frag, 0);

1346 }

1347

1355 CUTLASS_DEVICE

1356void set_kgroup_index(int k_group) {

1357// no operation here

1358 }

1359 };

1360

1362

1363 } // namespace warp

1364 } // namespace gemm

1365 } // namespace cutlass

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::store

CUTLASS_HOST_DEVICE void store(Fragment const &frag) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:257

[mma_simt_policy.h](mma simt policy_8h.html)

Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Fragment

Array< Element, ThreadShape::kCount > Fragment

Fragment object holding a thread's part of a tile.

Definition: mma_simt_tile_iterator.h:991

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::operator++

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:404

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::store

CUTLASS_HOST_DEVICE void store(Fragment const &frag) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:686

tensor_ref.h

Defines a structure containing strides, bounds, and a pointer to tensor data.

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)

Constructor from TensorRef.

Definition: mma_simt_tile_iterator.h:1007

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::Policy

Policy_ Policy

Decomposition of elements among threads.

Definition: mma_simt_tile_iterator.h:308

cutlass::TensorRef::data

CUTLASS_HOST_DEVICE Element * data() const

Returns the pointer to referenced data.

Definition: tensor_ref.h:254

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)

Constructor from TensorRef.

Definition: mma_simt_tile_iterator.h:786

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)

Constructor from TensorRef.

Definition: mma_simt_tile_iterator.h:574

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::set_kgroup_index

CUTLASS_DEVICE void set_kgroup_index(int k_group)

Definition: mma_simt_tile_iterator.h:1356

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator()

Definition: mma_simt_tile_iterator.h:1003

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::LongIndex

typename TensorRef::LongIndex LongIndex

Long Index type.

Definition: mma_simt_tile_iterator.h:730

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::set_kgroup_index

CUTLASS_DEVICE void set_kgroup_index(int k_group)

Definition: mma_simt_tile_iterator.h:1127

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::TensorCoord

typename TensorRef::TensorCoord TensorCoord

Coordinate for an element in the tensor.

Definition: mma_simt_tile_iterator.h:521

cutlass::gemm::Operand

Operand

GEMM operand enumeration: D = A * B + C.

Definition: include/cutlass/gemm/gemm.h:39

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Index

typename TensorRef::Index Index

Index type.

Definition: mma_simt_tile_iterator.h:946

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::Policy

Policy_ Policy

Decomposition of elements among threads.

Definition: mma_simt_tile_iterator.h:105

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::LongIndex

typename TensorRef::LongIndex LongIndex

Long Index type.

Definition: mma_simt_tile_iterator.h:317

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Fragment

Array< Element, ThreadShape::kCount > Fragment

Fragment object holding a thread's part of a tile.

Definition: mma_simt_tile_iterator.h:1217

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::load_with_pointer_offset

CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:1072

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::Fragment

Array< Element, ThreadShape::kCount > Fragment

Fragment object holding a thread's part of a tile.

Definition: mma_simt_tile_iterator.h:772

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::operator++

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:201

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)

Constructor from TensorRef.

Definition: mma_simt_tile_iterator.h:1236

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::store_with_pointer_offset

CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:445

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator()

Default ctor constructs null iterator.

Definition: mma_simt_tile_iterator.h:362

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::load

CUTLASS_HOST_DEVICE void load(Fragment &frag) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:655

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::add_pointer_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)

Adds a pointer offset to internal pointer(s) to advance through memory.

Definition: mma_simt_tile_iterator.h:386

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::Element

Element_ Element

Element type.

Definition: mma_simt_tile_iterator.h:302

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::load

CUTLASS_HOST_DEVICE void load(Fragment &frag) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:234

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::add_tile_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)

Advances an iterator along logical dimensions of matrix in units of whole tiles.

Definition: mma_simt_tile_iterator.h:810

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator()

Default ctor constructs null iterator.

Definition: mma_simt_tile_iterator.h:782

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)

Constructor from TensorRef.

Definition: mma_simt_tile_iterator.h:162

cutlass::TensorRef::add_coord_offset

CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)

Adds an offset to each pointer.

Definition: tensor_ref.h:326

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::Shape

Shape_ Shape

Shape of tile to load (concept: MatrixShape)

Definition: mma_simt_tile_iterator.h:497

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::load_with_pointer_offset

CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:1303

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

cutlass::gemm::Operand::kC

B multiplicand.

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::store_with_pointer_offset

CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:1098

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

cutlass::gemm::warp::MmaSimtTileIterator

Definition: mma_simt_tile_iterator.h:69

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::load

CUTLASS_HOST_DEVICE void load(Fragment &frag) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:439

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::operator--

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:618

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::operator--

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:210

cutlass::gemm::Operand::kA

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::TensorRef

TensorRef< Element, Layout > TensorRef

TensorRef type for loading element from a tensor.

Definition: mma_simt_tile_iterator.h:724

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Element

Element_ Element

Element type.

Definition: mma_simt_tile_iterator.h:1160

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator()

Default ctor constructs null iterator.

Definition: mma_simt_tile_iterator.h:1232

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::add_pointer_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)

Adds a pointer offset to internal pointer(s) to advance through memory.

Definition: mma_simt_tile_iterator.h:591

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Policy

Policy_ Policy

Decomposition of elements among threads.

Definition: mma_simt_tile_iterator.h:940

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::load_with_pointer_offset

CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:219

cutlass::TensorRef::stride

CUTLASS_HOST_DEVICE Stride stride() const

Returns the layout object's stride vector.

Definition: tensor_ref.h:277

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::LongIndex

typename TensorRef::LongIndex LongIndex

Long Index type.

Definition: mma_simt_tile_iterator.h:114

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::operator++

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:1276

cutlass::TensorRef::TensorCoord

typename Layout::TensorCoord TensorCoord

Coordinate in logical tensor space.

Definition: tensor_ref.h:171

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::add_tile_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)

Advances an iterator along logical dimensions of matrix in units of whole tiles.

Definition: mma_simt_tile_iterator.h:393

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::Fragment

Array< Element, ThreadShape::kCount > Fragment

Fragment object holding a thread's part of a tile.

Definition: mma_simt_tile_iterator.h:560

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::set_kgroup_index

CUTLASS_DEVICE void set_kgroup_index(int k_group)

Definition: mma_simt_tile_iterator.h:474

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::load

CUTLASS_HOST_DEVICE void load(Fragment &frag) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:1092

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::add_pointer_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)

Adds a pointer offset to internal pointer(s) to advance through memory.

Definition: mma_simt_tile_iterator.h:183

cutlass::TensorRef::reset

CUTLASS_HOST_DEVICE void reset(Element *ptr=nullptr)

Updates only the pointer.

Definition: tensor_ref.h:235

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::Fragment

Array< Element, ThreadShape::kCount > Fragment

Fragment object holding a thread's part of a tile.

Definition: mma_simt_tile_iterator.h:350

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::Index

typename TensorRef::Index Index

Index type.

Definition: mma_simt_tile_iterator.h:515

cutlass::TensorRef< Element, Layout >

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::add_pointer_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)

Adds a pointer offset to internal pointer(s) to advance through memory.

Definition: mma_simt_tile_iterator.h:1027

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::TensorCoord

typename TensorRef::TensorCoord TensorCoord

Coordinate for an element in the tensor.

Definition: mma_simt_tile_iterator.h:952

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Element

Element_ Element

Element type.

Definition: mma_simt_tile_iterator.h:934

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::LongIndex

typename TensorRef::LongIndex LongIndex

Long Index type.

Definition: mma_simt_tile_iterator.h:518

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::TensorCoord

typename TensorRef::TensorCoord TensorCoord

Coordinate for an element in the tensor.

Definition: mma_simt_tile_iterator.h:320

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::TensorRef::offset

CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const

Computes the offset of an index from the origin of the tensor.

Definition: tensor_ref.h:301

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::set_kgroup_index

CUTLASS_DEVICE void set_kgroup_index(int k_group)

Definition: mma_simt_tile_iterator.h:269

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)

Constructor from TensorRef.

Definition: mma_simt_tile_iterator.h:366

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::store

CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:462

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::operator--

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:1063

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::store_with_pointer_offset

CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:240

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Shape

Shape_ Shape

Shape of tile to load (concept: MatrixShape)

Definition: mma_simt_tile_iterator.h:928

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::store_with_pointer_offset

CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:661

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::Element

Element_ Element

Element type.

Definition: mma_simt_tile_iterator.h:99

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::Policy

Policy_ Policy

Decomposition of elements among threads.

Definition: mma_simt_tile_iterator.h:509

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::load_with_pointer_offset

CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const

Loads a fragment from memory with additional logical offset.

Definition: mma_simt_tile_iterator.h:627

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::add_tile_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)

Advances an iterator along logical dimensions of matrix in units of whole tiles.

Definition: mma_simt_tile_iterator.h:1034

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::operator++

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:1045

cutlass::TensorRef::Index

typename Layout::Index Index

Index type.

Definition: tensor_ref.h:165

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::store_with_pointer_offset

CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:873

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::add_pointer_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)

Adds a pointer offset to internal pointer(s) to advance through memory.

Definition: mma_simt_tile_iterator.h:1258

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::load_with_pointer_offset

CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:422

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Policy

Policy_ Policy

Decomposition of elements among threads.

Definition: mma_simt_tile_iterator.h:1166

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::load

CUTLASS_HOST_DEVICE void load(Fragment &frag) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:867

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Index

typename TensorRef::Index Index

Index type.

Definition: mma_simt_tile_iterator.h:1172

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::Index

typename TensorRef::Index Index

Index type.

Definition: mma_simt_tile_iterator.h:111

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::TensorCoord

typename TensorRef::TensorCoord TensorCoord

Coordinate for an element in the tensor.

Definition: mma_simt_tile_iterator.h:733

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::LongIndex

typename TensorRef::LongIndex LongIndex

Long Index type.

Definition: mma_simt_tile_iterator.h:1175

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::store

CUTLASS_HOST_DEVICE void store(Fragment const &frag) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:899

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::load_with_pointer_offset

CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const

Loads a fragment from memory with additional logical offset.

Definition: mma_simt_tile_iterator.h:839

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::Index

typename TensorRef::Index Index

Index type.

Definition: mma_simt_tile_iterator.h:314

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::operator--

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:413

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize >::Shape

Shape_ Shape

Shape of tile to load (concept: MatrixShape)

Definition: mma_simt_tile_iterator.h:296

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::operator++

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:821

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator()

Default ctor constructs null iterator.

Definition: mma_simt_tile_iterator.h:570

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::Fragment

Array< Element, ThreadShape::kCount > Fragment

Fragment object holding a thread's part of a tile.

Definition: mma_simt_tile_iterator.h:147

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::add_tile_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)

Advances an iterator along logical dimensions of matrix in units of whole tiles.

Definition: mma_simt_tile_iterator.h:190

cutlass::layout::ColumnMajorInterleaved

Definition: layout/matrix.h:343

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::load

CUTLASS_HOST_DEVICE void load(Fragment &frag) const

Loads a fragment from memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:1321

cutlass::TensorRef::add_pointer_offset

CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)

Adds an offset to each pointer.

Definition: tensor_ref.h:319

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::operator--

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:1294

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::add_pointer_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)

Adds a pointer offset to internal pointer(s) to advance through memory.

Definition: mma_simt_tile_iterator.h:803

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::MmaSimtTileIterator

CUTLASS_HOST_DEVICE MmaSimtTileIterator()

Default ctor constructs null iterator.

Definition: mma_simt_tile_iterator.h:158

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::operator--

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:830

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::add_tile_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)

Advances an iterator along logical dimensions of matrix in units of whole tiles.

Definition: mma_simt_tile_iterator.h:598

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::TensorCoord

typename TensorRef::TensorCoord TensorCoord

Coordinate for an element in the tensor.

Definition: mma_simt_tile_iterator.h:1178

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::store

CUTLASS_HOST_DEVICE void store(Fragment const &frag) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:1115

cutlass::gemm::Operand::kB

A multiplicand.

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::Index

typename TensorRef::Index Index

Index type.

Definition: mma_simt_tile_iterator.h:727

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::Shape

Shape_ Shape

Shape of tile to load (concept: MatrixShape)

Definition: mma_simt_tile_iterator.h:1154

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::TensorCoord

typename TensorRef::TensorCoord TensorCoord

Coordinate for an element in the tensor.

Definition: mma_simt_tile_iterator.h:117

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::operator++

CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()

Advances the iterator along the advance dimension.

Definition: mma_simt_tile_iterator.h:609

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize >::Shape

Shape_ Shape

Shape of tile to load (concept: MatrixShape)

Definition: mma_simt_tile_iterator.h:93

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::add_tile_offset

CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)

Advances an iterator along logical dimensions of matrix in units of whole tiles.

Definition: mma_simt_tile_iterator.h:1265

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::store

CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:1344

cutlass.h

Basic include for CUTLASS.

cutlass::MatrixCoord

Definition: matrix_coord.h:39

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kB, Element_, layout::RowMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::store_with_pointer_offset

CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const

Stores a fragment to memory at the location pointed to by the iterator.

Definition: mma_simt_tile_iterator.h:1327

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved< 4 >, Policy_, PartitionsK, PartitionGroupSize >::LongIndex

typename TensorRef::LongIndex LongIndex

Long Index type.

Definition: mma_simt_tile_iterator.h:949

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::Element

Element_ Element

Element type.

Definition: mma_simt_tile_iterator.h:715

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_ >::Element

Element_ Element

Element type.

Definition: mma_simt_tile_iterator.h:503

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::Policy

Policy_ Policy

Decomposition of elements among threads.

Definition: mma_simt_tile_iterator.h:721

cutlass::TensorRef::LongIndex

typename Layout::LongIndex LongIndex

Long index used for pointer offsets.

Definition: tensor_ref.h:168

cutlass::gemm::warp::MmaSimtTileIterator< Shape_, Operand::kC, Element_, layout::RowMajor, Policy_ >::Shape

Shape_ Shape

Shape of tile to load (concept: MatrixShape)

Definition: mma_simt_tile_iterator.h:709

cutlass::layout::RowMajorInterleaved

Definition: layout/matrix.h:237


Generated by 1.8.11