docs/mma__simt__tile__iterator_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_simt_tile_iterator.h
[Go to the documentation of this file.](mma simt tile__iterator_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
30 #pragma once
31
32 #include "cutlass/cutlass.h"
33 #include "cutlass/array.h"
34 #include "cutlass/tensor_ref.h"
35 #include "cutlass/matrix_shape.h"
36 #include "cutlass/layout/matrix.h"
37
38 #include "cutlass/gemm/gemm.h"
39 #include "[cutlass/gemm/warp/mma_simt_policy.h](mma simt policy_8h.html)"
40
42
43 namespace cutlass {
44 namespace gemm {
45 namespace warp {
46
48
53 template <
55typename Shape_,
59typename Element_,
61typename Layout_,
63typename Policy_,
65int PartitionsK = 1,
67int PartitionGroupSize = 1
68 >
69 class MmaSimtTileIterator;
70
72
77 template <
79typename Shape_,
81typename Element_,
83typename Policy_,
85int PartitionsK,
87int PartitionGroupSize
88 >
89 class MmaSimtTileIterator<Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize> {
90 public:
91
94
96static Operand const kOperand = Operand::kA;
97
100
102using Layout = layout::ColumnMajor;
103
106
108using TensorRef = TensorRef<Element, Layout>;
109
111using Index = typename TensorRef::Index;
112
114using LongIndex = typename TensorRef::LongIndex;
115
117using TensorCoord = typename TensorRef::TensorCoord;
118
119//
120// Derived quantities
121//
122
123static_assert(!(Shape::kRow % Policy::WarpShape::kRow),
124"The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
125
126static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
127static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
128static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
129static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
130
132using ThreadShape = MatrixShape<
133 Shape::kRow / Policy::WarpShape::kRow,
134 Shape::kColumn
135 >;
136
137static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM),
138"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
139
141using Iterations = MatrixShape<
142 ThreadShape::kRow / Policy::LaneMmaShape::kM,
143 ThreadShape::kColumn
144 >;
145
147using Fragment = Array<Element, ThreadShape::kCount>;
148
149 private:
150
152cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kM>, layout::ColumnMajor> ref_;
153
154 public:
155
158MmaSimtTileIterator() { }
159
163TensorRef ref,
164int lane_id
165 ) {
166
167// compute offset based on thread ID and lane layout
168typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
169
170MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
171MatrixCoord(Policy::LaneMmaShape::kM, 0);
172
173 ref.add_coord_offset(lane_offset);
174
175 ref_.reset(
176reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(ref.data()),
177 ref.stride(0) / Policy::LaneMmaShape::kM);
178 }
179
180
183MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {
184 ref_.add_pointer_offset(offset);
185return *this;
186 }
187
190MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
191
192 ref_.add_coord_offset({
193 coord.row() * Shape::kRow / Policy::LaneMmaShape::kM,
194 coord.column() * Shape::kColumn});
195
196return *this;
197 }
198
201MmaSimtTileIterator & operator++() {
202
203 ref_.add_coord_offset({0, Shape::kColumn});
204
205return *this;
206 }
207
210MmaSimtTileIterator & operator--() {
211
212 ref_.add_coord_offset({0, -Shape::kColumn});
213
214return *this;
215 }
216
219void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
220 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr =
221reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag);
222
224for (int k = 0; k < Iterations::kColumn; ++k) {
226for (int m = 0; m < Iterations::kRow; ++m) {
227 dst_ptr[m + k * Iterations::kRow] =
228 *(ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM);
229 }
230 }
231 }
234void load(Fragment &frag) const {
235 load_with_pointer_offset(frag, 0);
236 }
237
240void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
241
242 Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =
243reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag);
244
246for (int k = 0; k < Iterations::kN; ++k) {
248for (int m = 0; m < Iterations::kM; ++m) {
249 *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) =
250 src_ptr[m + k * Iterations::kM];
251 }
252 }
253 }
254
257void store(Fragment const &frag) const {
258 store_with_pointer_offset(frag, 0);
259 }
260
268 CUTLASS_DEVICE
269void set_kgroup_index(int k_group) {
270// no operation here
271 }
272 };
273
275
280 template <
282typename Shape_,
284typename Element_,
286typename Policy_,
288int PartitionsK,
290int PartitionGroupSize
291 >
292 class MmaSimtTileIterator<Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize> {
293 public:
294
297
299static Operand const kOperand = Operand::kB;
300
303
305using Layout = layout::RowMajor;
306
309
311using TensorRef = TensorRef<Element, Layout>;
312
314using Index = typename TensorRef::Index;
315
317using LongIndex = typename TensorRef::LongIndex;
318
320using TensorCoord = typename TensorRef::TensorCoord;
321
322//
323// Derived quantities
324//
325
326static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn),
327"The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");
328
329static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
330static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
331static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");
332static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
333
335using ThreadShape = MatrixShape<
336 Shape::kRow,
337 Shape::kColumn / Policy::WarpShape::kColumn
338 >;
339
340static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN),
341"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
342
344using Iterations = MatrixShape<
345 ThreadShape::kRow,
346 ThreadShape::kColumn / Policy::LaneMmaShape::kN
347 >;
348
350using Fragment = Array<Element, ThreadShape::kCount>;
351
352 private:
353
355cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kN>, layout::RowMajor> ref_;
356
357
358 public:
359
362MmaSimtTileIterator() { }
363
367TensorRef ref,
368int lane_id
369 ) {
370
371// compute offset based on thread ID and lane layout
372typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
373
374MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
375MatrixCoord(0, Policy::LaneMmaShape::kN);
376
377 ref.add_coord_offset(lane_offset);
378
379 ref_.reset(
380reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(ref.data()),
381 ref.stride(0) / Policy::LaneMmaShape::kN);
382 }
383
386MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {
387 ref_.add_pointer_offset(offset);
388return *this;
389 }
390
393MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
394
395 ref_.add_coord_offset({
396 coord.row() * Shape::kRow,
397 coord.column() * Shape::kColumn / Policy::LaneMmaShape::kN});
398
399return *this;
400 }
401
404MmaSimtTileIterator & operator++() {
405
406 ref_.add_coord_offset({Shape::kRow, 0});
407
408return *this;
409 }
410
413MmaSimtTileIterator & operator--() {
414
415 ref_.add_coord_offset({-Shape::kRow, 0});
416
417return *this;
418 }
419
422void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
423
424 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
425reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
426
428for (int k = 0; k < Iterations::kRow; ++k) {
430for (int n = 0; n < Iterations::kColumn; ++n) {
431 dst_ptr[n + k * Iterations::kColumn] =
432 *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN);
433 }
434 }
435 }
436
439void load(Fragment &frag) const {
440 load_with_pointer_offset(frag, 0);
441 }
442
445void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
446
447 Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =
448reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
449
451for (int k = 0; k < Iterations::kM; ++k) {
453for (int n = 0; n < Iterations::kN; ++n) {
454 *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) =
455 src_ptr[n + k * Iterations::kN];
456 }
457 }
458 }
459
462void store(Fragment const &frag, Index pointer_offset) const {
463 store_with_pointer_offset(frag, 0);
464 }
465
473 CUTLASS_DEVICE
474void set_kgroup_index(int k_group) {
475// no operation here
476 }
477 };
478
480
485 template <
487typename Shape_,
489typename Element_,
491typename Policy_
492 >
493 class MmaSimtTileIterator<Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_> {
494 public:
495
498
500static Operand const kOperand = Operand::kC;
501
504
506using Layout = layout::ColumnMajor;
507
510
512using TensorRef = TensorRef<Element, Layout>;
513
515using Index = typename TensorRef::Index;
516
518using LongIndex = typename TensorRef::LongIndex;
519
521using TensorCoord = typename TensorRef::TensorCoord;
522
523//
524// Derived quantities
525//
526
527static_assert(
528 (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),
529"Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
530
531static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
532static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
533static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
534static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");
535static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
536static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
537
539using ThreadShape = MatrixShape<
540 Shape::kRow / Policy::WarpShape::kRow,
541 Shape::kColumn / Policy::WarpShape::kColumn
542 >;
543
544static_assert(
545 (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),
546"Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
547
549using Iterations = MatrixShape<
550 ThreadShape::kRow / Policy::LaneMmaShape::kM,
551 ThreadShape::kColumn / Policy::LaneMmaShape::kN
552 >;
553
554using Delta = MatrixShape<
555 Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,
556 Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN
557 >;
558
560using Fragment = Array<Element, ThreadShape::kCount>;
561
562 private:
563
564TensorRef ref_;
565
566 public:
567
570MmaSimtTileIterator() { }
571
575TensorRef const &ref,
576int lane_id
577 ):
578 ref_(ref) {
579
580// compute offset based on thread ID and lane layout
581typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
582
583MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
584MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
585
586 ref_.add_coord_offset(lane_offset);
587 }
588
591MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {
592 ref_.add_pointer_offset(offset);
593return *this;
594 }
595
598MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
599
600 ref_.add_coord_offset({
601 coord.row() * Shape::kRow,
602 coord.column() * Shape::kColumn});
603
604return *this;
605 }
606
609MmaSimtTileIterator & operator++() {
610
611 ref_.add_coord_offset({Shape::kRow, 0});
612
613return *this;
614 }
615
618MmaSimtTileIterator & operator--() {
619
620 ref_.add_coord_offset({-Shape::kRow, 0});
621
622return *this;
623 }
624
627void load_with_pointer_offset(
628Fragment &frag,
629Index pointer_offset) const {
630
632for (int mma_n = 0; mma_n < Iterations::kN; ++mma_n) {
634for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
635
636 Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =
637reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> const *>(
638 ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kN + n}));
639
641for (int mma_m = 0; mma_m < Iterations::kM; ++mma_m) {
642
643 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr =
644reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag) +
645 mma_m + Iterations::kM * (n + mma_n * Policy::LaneMmaShape::kN);
646
647 *dst_ptr = src_ptr[mma_m * Policy::WarpShape::kM];
648 }
649 }
650 }
651 }
652
655void load(Fragment &frag) const {
656 load_with_pointer_offset(frag, 0);
657 }
658
661void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
662
664for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
666for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
667
668 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr=
669reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(
670 ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kColumn + n}));
671
673for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
674
675 Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =
676reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> const *>(&frag) +
677 mma_m + Iterations::kRow * (n + mma_n * Policy::LaneMmaShape::kN);
678
679 dst_ptr[mma_m * Policy::WarpShape::kRow] = *src_ptr;
680 }
681 }
682 }
683 }
686void store(Fragment const &frag) const {
687 store_with_pointer_offset(frag, 0);
688 }
689 };
690
692
697 template <
699typename Shape_,
701typename Element_,
703typename Policy_
704 >
705 class MmaSimtTileIterator<Shape_, Operand::kC, Element_, layout::RowMajor, Policy_> {
706 public:
707
710
712static Operand const kOperand = Operand::kC;
713
716
718using Layout = layout::RowMajor;
719
722
724using TensorRef = TensorRef<Element, Layout>;
725
727using Index = typename TensorRef::Index;
728
730using LongIndex = typename TensorRef::LongIndex;
731
733using TensorCoord = typename TensorRef::TensorCoord;
734
735//
736// Derived quantities
737//
738
739static_assert(
740 (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),
741"Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
742
743static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
744static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
745static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
746static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");
747static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
748static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
749
751using ThreadShape = MatrixShape<
752 Shape::kRow / Policy::WarpShape::kRow,
753 Shape::kColumn / Policy::WarpShape::kColumn
754 >;
755
756static_assert(
757 (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),
758"Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
759
761using Iterations = MatrixShape<
762 ThreadShape::kRow / Policy::LaneMmaShape::kM,
763 ThreadShape::kColumn / Policy::LaneMmaShape::kN
764 >;
765
766using Delta = MatrixShape<
767 Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,
768 Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN
769 >;
770
772using Fragment = Array<Element, ThreadShape::kCount>;
773
774 private:
775
776TensorRef ref_;
777
778 public:
779
782MmaSimtTileIterator() { }
783
787TensorRef const &ref,
788int lane_id
789 ):
790 ref_(ref) {
791
792// compute offset based on thread ID and lane layout
793typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
794
795MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
796MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
797
798 ref_.add_coord_offset(lane_offset);
799 }
800
803MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {
804 ref_.add_pointer_offset(offset);
805return *this;
806 }
807
810MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
811
812 ref_.add_coord_offset({
813 coord.row() * Shape::kRow,
814 coord.column() * Shape::kColumn});
815
816return *this;
817 }
818
821MmaSimtTileIterator & operator++() {
822
823 ref_.add_coord_offset({Shape::kRow, 0});
824
825return *this;
826 }
827
830MmaSimtTileIterator & operator--() {
831
832 ref_.add_coord_offset({-Shape::kRow, 0});
833
834return *this;
835 }
836
839void load_with_pointer_offset(
840Fragment &frag,
841Index pointer_offset) const {
842
844for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
846for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
847
848 Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =
849reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> const *>(
850 ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));
851
853for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
854
855 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
856reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag) +
857 mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);
858
859 *dst_ptr = src_ptr[mma_n * Policy::WarpShape::kColumn];
860 }
861 }
862 }
863 }
864
867void load(Fragment &frag) const {
868 load_with_pointer_offset(frag, 0);
869 }
870
873void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
874
876for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
878for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
879
880 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
881reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(
882 ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));
883
885for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
886
887 Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =
888reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> const *>(&frag) +
889 mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);
890
891 dst_ptr[mma_n * Policy::WarpShape::kColumn] = *src_ptr;
892 }
893 }
894 }
895 }
896
899void store(Fragment const &frag) const {
900 store_with_pointer_offset(frag, 0);
901 }
902 };
903
905
907
912 template <
914typename Shape_,
916typename Element_,
918typename Policy_,
920int PartitionsK,
922int PartitionGroupSize
923 >
924 class MmaSimtTileIterator<Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved<4>, Policy_, PartitionsK, PartitionGroupSize> {
925 public:
926
929
931static Operand const kOperand = Operand::kA;
932
935
937using Layout = layout::ColumnMajorInterleaved<4> ;
938
941
943using TensorRef = TensorRef<Element, Layout>;
944
946using Index = typename TensorRef::Index;
947
949using LongIndex = typename TensorRef::LongIndex;
950
952using TensorCoord = typename TensorRef::TensorCoord;
953
955static const int kInterleave = 4;
956
958static const int kPartitionsK = PartitionsK;
959
961static const int kGroupPerTile = PartitionGroupSize / Shape::kColumn;
962
963//
964// Derived quantities
965//
966
967static_assert(!(Shape::kRow % Policy::WarpShape::kRow),
968"The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
969
970static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
971static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
972static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
973static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
974
976using ThreadShape = MatrixShape<
977 Shape::kRow / Policy::WarpShape::kRow,
978 Shape::kColumn
979 >;
980
981static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM) && !(ThreadShape::kColumn % Policy::LaneMmaShape::kK),
982"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
983
985using Iterations = MatrixShape<
986 ThreadShape::kRow / Policy::LaneMmaShape::kM,
987 ThreadShape::kColumn / Policy::LaneMmaShape::kK
988 >;
989
991using Fragment = Array<Element, ThreadShape::kCount>;
992
993 private:
994
996cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kMK>, layout::ColumnMajorInterleaved<4>> ref_;
997
999int k_group_idx_;
1000
1001 public:
1003MmaSimtTileIterator() { }
1004
1008TensorRef ref,
1009int lane_id
1010 ) {
1011
1012// compute offset based on thread ID and lane layout
1013typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
1014
1015MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
1016MatrixCoord(Policy::LaneMmaShape::kM, 0);
1017
1018 ref.add_coord_offset(lane_offset);
1019
1020 k_group_idx_ = 0;
1021 ref_.reset(reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *>(ref.data()), ref.stride(0)/Policy::LaneMmaShape::kMK);
1022 }
1023
1024
1027MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {
1028 ref_.add_pointer_offset(offset);
1029return *this;
1030 }
1031
1034MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
1035
1036 ref_.add_coord_offset({
1037 coord.row() * Shape::kRow / Policy::LaneMmaShape::kMK,
1038 coord.column() * Shape::kColumn});
1039
1040return *this;
1041 }
1042
1045MmaSimtTileIterator & operator++() {
1046
1047 add_tile_offset({0, 1});
1048
1049if (kPartitionsK > 1) {
1050 ++k_group_idx_;
1051// Jump to next stage
1052if (k_group_idx_ == kGroupPerTile) {
1053 k_group_idx_ = 0;
1054 add_tile_offset({0, kGroupPerTile * (kPartitionsK-1)});
1055 }
1056 }
1057
1058return *this;
1059 }
1060
1063MmaSimtTileIterator & operator--() {
1064
1065 ref_.add_coord_offset({0, -Shape::kColumn});
1066
1067return *this;
1068 }
1069
1072void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
1073
1074 Array<Element, Policy::LaneMmaShape::kMK > *dst_ptr =
1075reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *>(&frag);
1076
1078for (int k = 0; k < Iterations::kColumn; ++k) {
1079
1081for (int m = 0; m < Iterations::kRow; ++m) {
1082
1083 dst_ptr[m + k * Iterations::kRow] =
1084 *((ref_.data() + ref_.offset({m * Policy::WarpShape::kRow / kInterleave,
1085 k*Policy::LaneMmaShape::kK}) + pointer_offset / Policy::LaneMmaShape::kM));
1086 }
1087 }
1088 }
1089
1092void load(Fragment &frag) const {
1093 load_with_pointer_offset(frag, 0);
1094 }
1095
1098void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
1099
1100 Array<Element, Policy::LaneMmaShape::kMK> const *src_ptr =
1101reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK > *>(&frag);
1102
1104for (int k = 0; k < Iterations::kN; ++k) {
1106for (int m = 0; m < Iterations::kM; ++m) {
1107 *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) =
1108 src_ptr[m + k * Iterations::kM];
1109 }
1110 }
1111 }
1112
1115void store(Fragment const &frag) const {
1116 store_with_pointer_offset(frag, 0);
1117 }
1118
1126 CUTLASS_DEVICE
1127void set_kgroup_index(int k_group) {
1128// no operation here
1129 }
1130 };
1131
1133
1138 template <
1140typename Shape_,
1142typename Element_,
1144typename Policy_,
1146int PartitionsK,
1148int PartitionGroupSize
1149 >
1150 class MmaSimtTileIterator<Shape_, Operand::kB, Element_, layout::RowMajorInterleaved<4>, Policy_, PartitionsK, PartitionGroupSize> {
1151 public:
1152
1155
1157static Operand const kOperand = Operand::kB;
1158
1161
1163using Layout = layout::RowMajorInterleaved<4>;
1164
1167
1169using TensorRef = TensorRef<Element, Layout>;
1170
1172using Index = typename TensorRef::Index;
1173
1175using LongIndex = typename TensorRef::LongIndex;
1176
1178using TensorCoord = typename TensorRef::TensorCoord;
1179
1181static const int kInterleave = 4;
1182
1184static const int kPartitionsK = PartitionsK;
1185
1187static const int kGroupPerTile = PartitionGroupSize / Shape::kRow;
1188
1189//
1190// Derived quantities
1191//
1192
1193static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn),
1194"The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");
1195
1196static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
1197static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
1198static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");
1199static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
1200
1202using ThreadShape = MatrixShape<
1203 Shape::kRow,
1204 Shape::kColumn / Policy::WarpShape::kColumn
1205 >;
1206
1207static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN) && !(ThreadShape::kRow % Policy::LaneMmaShape::kK),
1208"Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
1209
1211using Iterations = MatrixShape<
1212 ThreadShape::kRow / Policy::LaneMmaShape::kK,
1213 ThreadShape::kColumn / Policy::LaneMmaShape::kN
1214 >;
1215
1217using Fragment = Array<Element, ThreadShape::kCount>;
1218
1219
1220 private:
1221
1223cutlass::TensorRef<Array<Element, Policy::LaneMmaShape::kKN>, layout::RowMajorInterleaved<4>> ref_;
1224
1226int k_group_idx_;
1227
1228 public:
1229
1232MmaSimtTileIterator() { }
1233
1237TensorRef ref,
1238int lane_id
1239 ) {
1240
1241// compute offset based on thread ID and lane layout
1242typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
1243
1244MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
1245MatrixCoord(0, Policy::LaneMmaShape::kN);
1246
1247 ref.add_coord_offset(lane_offset);
1248
1249 k_group_idx_ = 0;
1250
1251 ref_.reset(
1252reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *>(ref.data()),
1253 ref.stride(0) / Policy::LaneMmaShape::kKN);
1254 }
1255
1258MmaSimtTileIterator &add_pointer_offset(LongIndex offset) {
1259 ref_.add_pointer_offset(offset);
1260return *this;
1261 }
1262
1265MmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) {
1266
1267 ref_.add_coord_offset({
1268 coord.row() * Shape::kRow,
1269 coord.column() * Shape::kColumn / Policy::LaneMmaShape::kKN});
1270
1271return *this;
1272 }
1273
1276MmaSimtTileIterator & operator++() {
1277
1278 add_tile_offset({1, 0});
1279
1280if (kPartitionsK > 1) {
1281 ++k_group_idx_;
1282// Jump to next stage
1283if (k_group_idx_ == kGroupPerTile) {
1284 k_group_idx_ = 0;
1285 add_tile_offset({kGroupPerTile * (kPartitionsK-1), 0});
1286 }
1287 }
1288
1289return *this;
1290 }
1291
1294MmaSimtTileIterator & operator--() {
1295
1296 ref_.add_coord_offset({-Shape::kRow, 0});
1297
1298return *this;
1299 }
1300
1303void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
1304
1305 Array<Element, Policy::LaneMmaShape::kKN> *dst_ptr =
1306reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *>(&frag);
1307
1309for (int k = 0; k < Iterations::kRow; ++k) {
1311for (int n = 0; n < Iterations::kColumn; ++n) {
1312 dst_ptr[n + k * Iterations::kColumn] =
1313 *(ref_.data() + ref_.offset({k * Policy::LaneMmaShape::kK,
1314 n * Policy::WarpShape::kColumn / kInterleave}) + pointer_offset / Policy::LaneMmaShape::kN);
1315 }
1316 }
1317 }
1318
1321void load(Fragment &frag) const {
1322 load_with_pointer_offset(frag, 0);
1323 }
1324
1327void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
1328
1329 Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =
1330reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
1331
1333for (int k = 0; k < Iterations::kM; ++k) {
1335for (int n = 0; n < Iterations::kN; ++n) {
1336 *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) =
1337 src_ptr[n + k * Iterations::kN];
1338 }
1339 }
1340 }
1341
1344void store(Fragment const &frag, Index pointer_offset) const {
1345 store_with_pointer_offset(frag, 0);
1346 }
1347
1355 CUTLASS_DEVICE
1356void set_kgroup_index(int k_group) {
1357// no operation here
1358 }
1359 };
1360
1362
1363 } // namespace warp
1364 } // namespace gemm
1365 } // namespace cutlass
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:257
[mma_simt_policy.h](mma simt policy_8h.html)
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:991
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:404
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:686
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:1007
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:308
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:786
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:574
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:1356
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Definition: mma_simt_tile_iterator.h:1003
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:730
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:1127
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:521
Operand
GEMM operand enumeration: D = A * B + C.
Definition: include/cutlass/gemm/gemm.h:39
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:946
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:105
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:317
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:1217
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1072
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:772
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:201
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:1236
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:445
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:362
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:655
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:386
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:302
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:234
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:810
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:782
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:162
cutlass::TensorRef::add_coord_offset
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:497
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1303
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
B multiplicand.
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1098
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::gemm::warp::MmaSimtTileIterator
Definition: mma_simt_tile_iterator.h:69
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:439
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:618
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:210
TensorRef< Element, Layout > TensorRef
TensorRef type for loading element from a tensor.
Definition: mma_simt_tile_iterator.h:724
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:1160
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:1232
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:591
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:940
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:219
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object's stride vector.
Definition: tensor_ref.h:277
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:114
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1276
cutlass::TensorRef::TensorCoord
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref.h:171
Defines a Shape template for matrix tiles.
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:393
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:560
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:474
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1092
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:183
CUTLASS_HOST_DEVICE void reset(Element *ptr=nullptr)
Updates only the pointer.
Definition: tensor_ref.h:235
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:350
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:515
cutlass::TensorRef< Element, Layout >
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:1027
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:952
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:934
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:518
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:320
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:269
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:366
CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:462
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1063
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:240
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:928
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:661
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:99
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:509
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory with additional logical offset.
Definition: mma_simt_tile_iterator.h:627
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:1034
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1045
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:873
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:1258
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:422
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:1166
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:867
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:1172
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:111
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:733
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:1175
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:899
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory with additional logical offset.
Definition: mma_simt_tile_iterator.h:839
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:314
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:413
Defines layout functions used by TensorRef and derived classes.
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:296
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:821
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:570
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:147
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:190
cutlass::layout::ColumnMajorInterleaved
Definition: layout/matrix.h:343
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1321
cutlass::TensorRef::add_pointer_offset
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1294
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:803
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:158
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:830
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:598
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:1178
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1115
A multiplicand.
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:727
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:1154
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:117
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:609
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:93
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:1265
CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1344
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1327
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:949
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:715
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:503
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:721
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:709
cutlass::layout::RowMajorInterleaved
Definition: layout/matrix.h:237
Generated by 1.8.11