docs/predicated__tile__access__iterator_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
predicated_tile_access_iterator.h
[Go to the documentation of this file.](predicated tile access__iterator_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 *modification, are permitted provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice,
7 *this list of conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright
9 *notice, this list of conditions and the following disclaimer in the
10 *documentation and/or other materials provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its
12 *contributors may be used to endorse or promote products derived from this
13 *software without specific prior written permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
19 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
20 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
22 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
23 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 *
26 **************************************************************************************************/
40 #pragma once
41
42 #include "cutlass/array.h"
43 #include "cutlass/coord.h"
44 #include "cutlass/cutlass.h"
45 #include "cutlass/layout/matrix.h"
46 #include "cutlass/layout/pitch_linear.h"
47 #include "cutlass/matrix_shape.h"
48 #include "cutlass/predicate_vector.h"
49 #include "cutlass/tensor_ref.h"
50 #include "cutlass/tensor_view.h"
51
53
55
56 namespace cutlass {
57 namespace transform {
58 namespace threadblock {
59
61
64 template <typename Shape, typename Element, typename Layout, int AdvanceRank,
65typename ThreadMap, typename AccessType>
66 class PredicatedTileAccessIterator;
67
69
72 template <typename Shape_, typename Element_, int AdvanceRank,
73typename ThreadMap_, typename AccessType_>
74 class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
75 AdvanceRank, ThreadMap_, AccessType_> {
76public:
78 AdvanceRank == 0 || AdvanceRank == 1,
79"Specialization for pitch-linear iterator may along advance along the "
80"contiguous(rank=0) or strided(rank=1) dimension.");
81
83using Element = Element_;
84using Layout = layout::PitchLinear;
85static int const kAdvanceRank = AdvanceRank;
86using ThreadMap = ThreadMap_;
87using AccessType = AccessType_;
88
89using Index = typename Layout::Index;
90using LongIndex = typename Layout::LongIndex;
91
92using TensorRef = TensorRef<Element, Layout>;
93using TensorView = TensorView<Element, Layout>;
94using TensorCoord = typename Layout::TensorCoord;
95
97using NonConstPointer = typename platform::remove_const<Element>::type *;
98
99static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
100
101static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
102"Vectors implied by the thread map must be divisible by the access type.");
103
104static int const kPredicatesPerByte = 4;
105static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
106
107static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
108
110static int const kPredicateByteCount =
111 (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
112static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
113
114static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
115
116static_assert(kPredicateWordCount <= 4, "Too many predicates.");
117
119using Mask = Array<uint32_t, kPredicateWordCount>;
120
122class Params {
123public:
124friend PredicatedTileAccessIterator;
125
126private:
128int stride_;
131int inc_strided_;
134int inc_next_;
137int inc_advance_;
138
139public:
140
141// Default ctor
143Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
144
147Params(Layout const &layout) : stride_(layout.stride(0)) {
148 inc_strided_ = (stride_ * ThreadMap::Delta::kStrided) *
149sizeof_bits<Element>::value / 8;
150
151if (kAdvanceRank) {
152// advance along strided dimension
153 inc_advance_ =
154 Shape::kStrided * stride_ * sizeof_bits<Element>::value / 8;
155 } else {
156// advance along contiguous dimension
157 inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
158 }
159
160 inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
161 ThreadMap::Delta::kStrided * stride_ *
162sizeof_bits<Element>::value / 8;
163 };
164 };
165
166private:
168using BytePointer = char *;
169
170private:
171//
172// Data members
173//
174
176 Params const ¶ms_;
177
179 BytePointer pointer_;
180
182 uint32_t predicates_[kPredicateWordCount];
183
185TensorCoord extent_;
186
188TensorCoord thread_offset_;
189
191TensorCoord residue_offset_;
192
194bool is_residue_tile_;
195
197int iteration_vector_;
198
200int iteration_contiguous_;
201
203int iteration_strided_;
204
205private:
207 CUTLASS_DEVICE
208void compute_predicates_(
210TensorCoord extent,
212bool is_steady_state = false) {
213
215for (int i = 0; i < kPredicateWordCount; ++i) {
216 predicates_[i] = 0u;
217 }
218
219for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
220
221int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
222
223int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
224
225int c = access_residual / kAccessesPerVector;
226int v = access_residual % kAccessesPerVector;
227
228TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
229 s * ThreadMap::Delta::kStrided);
230
231TensorCoord coord = thread_offset_ + iteration_coord;
232
233bool guard;
234
235if (is_steady_state) {
236if (kAdvanceRank == 0) {
237 guard = (coord.strided() < extent.strided());
238 } else {
239 guard = (coord.contiguous() < extent.contiguous());
240 }
241 } else {
242 guard = (coord.strided() < extent.strided() &&
243 coord.contiguous() < extent.contiguous());
244 }
245
246int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
247
248int word_idx = pred_idx / kPredicatesPerWord;
249int residual = pred_idx % kPredicatesPerWord;
250int byte_idx = residual / kPredicatesPerByte;
251int bit_idx = residual % kPredicatesPerByte;
252
253 predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
254
255 }
256
257 }
258
259public:
263PredicatedTileAccessIterator(
265 Params const ¶ms,
267Pointer pointer,
269TensorCoord extent,
271int thread_id,
273TensorCoord const &threadblock_offset)
274 : params_(params),
275 pointer_(reinterpret_cast<BytePointer>(
276 const_cast<NonConstPointer>(pointer))),
277 extent_(extent),
278 is_residue_tile_(true) {
279
280TensorCoord residue_extent;
281if (kAdvanceRank) {
282
283Index residue_size = (extent_[kAdvanceRank] % Shape::kStrided);
284if (!residue_size) {
285 residue_size = Shape::kStrided;
286 }
287
288 residue_offset_ = make_Coord(0, residue_size);
289 residue_extent = make_Coord(
290 extent_.contiguous(),
291min(threadblock_offset.strided() + residue_offset_.strided(), extent_.strided())
292 );
293
294 } else {
295
296Index residue_size = (extent_[kAdvanceRank] % Shape::kContiguous);
297if (!residue_size) {
298 residue_size = Shape::kContiguous;
299 }
300 residue_offset_ = make_Coord(residue_size, 0);
301 residue_extent = make_Coord(
302min(extent_.contiguous(), threadblock_offset.contiguous() + residue_offset_.contiguous()),
303 extent_.strided()
304 );
305 }
306
307// Per-thread offset in logical coordinates of tensor
308 thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
309
310// update internal pointers
311 Layout layout(params_.stride_);
312 add_pointer_offset(layout(thread_offset_));
313
314 compute_predicates_(residue_extent, false);
315
316 set_iteration_index(0);
317 }
318
321PredicatedTileAccessIterator(
323 Params const ¶ms,
325Pointer pointer,
327TensorCoord extent,
329int thread_id)
330 : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
331make_Coord(0, 0)) {}
332
335void set_iteration_index(int index) {
336
337 iteration_vector_ = index % kAccessesPerVector;
338int residual_access = index / kAccessesPerVector;
339
340 iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
341 iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
342
343 }
344
347void add_pointer_offset(LongIndex pointer_offset) {
348 pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
349 }
350
352 CUTLASS_DEVICE
353void add_tile_offset(
354TensorCoord const &tile_offset) {
355if (is_residue_tile_) {
356
357 thread_offset_ += residue_offset_;
358
359 Layout layout(params_.stride_);
360 add_pointer_offset(layout(residue_offset_));
361
362 compute_predicates_(extent_, true);
363
364if (kAdvanceRank) {
365 pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
366 pointer_ += Shape::kContiguous * tile_offset.contiguous();
367 } else {
368 pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
369 pointer_ += Shape::kStrided * tile_offset.strided();
370 }
371 } else {
372if (kAdvanceRank) {
373 pointer_ += params_.inc_advance_ * tile_offset.strided();
374 pointer_ += Shape::kContiguous * tile_offset.contiguous();
375 } else {
376 pointer_ += params_.inc_advance_ * tile_offset.contiguous();
377 pointer_ += Shape::kStrided * tile_offset.strided();
378 }
379 }
380 is_residue_tile_ = false;
381 }
382
385AccessType *get() const {
386return reinterpret_cast<AccessType *>(
387 pointer_ +
388 iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value) / 8) + iteration_vector_;
389 }
390
393PredicatedTileAccessIterator &operator++() {
394
395 ++iteration_vector_;
396if (iteration_vector_ < kAccessesPerVector) {
397return *this;
398 }
399
400 iteration_vector_ = 0;
401 ++iteration_contiguous_;
402
403if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
404return *this;
405 }
406
407// Enter here only if (iteration_contiguous_ ==
408// ThreadMap::Iteration::kContiguous)
409 iteration_contiguous_ = 0;
410 ++iteration_strided_;
411
412if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
413 pointer_ += params_.inc_strided_;
414return *this;
415 }
416
417// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
418// which means we enter the next tile.
419 iteration_strided_ = 0;
420
421// advance to next tile
422 pointer_ += params_.inc_next_;
423
424// now return to start tile - if the iterator is subsequently advanced, this
425// subtraction as well as the subsequent integer addition are both elided by
426// the compiler.
427 pointer_ -= params_.inc_advance_;
428
429return *this;
430 }
431
434PredicatedTileAccessIterator operator++(int) {
435PredicatedTileAccessIterator self(*this);
436operator++();
437return self;
438 }
439
442void clear_mask() {
444for (int i = 0; i < kPredicateWordCount; ++i) {
445 predicates_[i] = 0u;
446 }
447
448 }
449
452void enable_mask() {
454for (int i = 0; i < kPredicateWordCount; ++i) {
455 predicates_[i] = 0xffffffff;
456 }
457 }
458
461void set_mask(Mask const &mask) {
463for (int i = 0; i < kPredicateWordCount; ++i) {
464 predicates_[i] = mask[i];
465 }
466
467 }
468
471void get_mask(Mask &mask) {
473for (int i = 0; i < kPredicateWordCount; ++i) {
474 mask[i] = predicates_[i];
475 }
476 }
477
481
482
483int pred_idx =
484 iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
485
486int word_idx = pred_idx / kPredicatesPerWord;
487int residual = pred_idx % kPredicatesPerWord;
488int byte_idx = residual / kPredicatesPerByte;
489int bit_idx = residual % kPredicatesPerByte;
490
491bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
492return pred;
493
494
495//return true;
496 }
497 };
498
500
508 template <typename Shape_, typename Element_, int AdvanceRank,
509typename ThreadMap_, typename AccessType_>
510 class PredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
511 AdvanceRank, ThreadMap_, AccessType_> {
512public:
513static_assert(
514 AdvanceRank == 0 || AdvanceRank == 1,
515"Specialization for pitch-linear iterator may along advance along the "
516"contiguous(rank=0) or strided(rank=1) dimension.");
517
519using Element = Element_;
520using Layout = layout::ColumnMajor;
521static int const kAdvanceRank = AdvanceRank;
522using ThreadMap = ThreadMap_;
523using AccessType = AccessType_;
524
525using Index = typename Layout::Index;
526using LongIndex = typename Layout::LongIndex;
527
528using TensorRef = TensorRef<Element, Layout>;
529using TensorView = TensorView<Element, Layout>;
530using TensorCoord = typename Layout::TensorCoord;
531
533using NonConstPointer = typename platform::remove_const<Element>::type *;
534
535using UnderlyingIterator = PredicatedTileAccessIterator<
536layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
537layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
538
540using Mask = typename UnderlyingIterator::Mask;
541
542static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
543
545class Params {
546private:
547friend PredicatedTileAccessIterator;
548
550typename UnderlyingIterator::Params params_;
551
552public:
553
557
560Params(Layout const &layout)
561 : params_(layout::PitchLinear(layout.stride(0))){};
562 };
563
564private:
565//
566// Data members
567//
568
570UnderlyingIterator iterator_;
571
572public:
576PredicatedTileAccessIterator(
578 Params const ¶ms,
580Pointer pointer,
582TensorCoord extent,
584int thread_id,
586TensorCoord const &threadblock_offset)
587 : iterator_(params.params_, pointer,
588 layout::PitchLinearCoord(extent.row(), extent.column()),
589 thread_id,
590 layout::PitchLinearCoord(threadblock_offset.row(),
591 threadblock_offset.column())) {}
592
595PredicatedTileAccessIterator(
596 Params const ¶ms,
597Pointer pointer,
598TensorCoord extent,
599int thread_id
600 )
601 : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
602make_Coord(0, 0)) {}
603
606void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
607
610void add_pointer_offset(LongIndex pointer_offset) {
611 iterator_.add_pointer_offset(pointer_offset);
612 }
613
617void add_tile_offset(TensorCoord const &tile_offset) {
618 iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
619 }
620
623AccessType *get() const {
624return reinterpret_cast<AccessType *>(iterator_.get());
625 }
626
634PredicatedTileAccessIterator &operator++() {
635 ++iterator_;
636return *this;
637 }
638
646PredicatedTileAccessIterator operator++(int) {
647PredicatedTileAccessIterator self(*this);
648operator++();
649return self;
650 }
651
654void clear_mask() { iterator_.clear_mask(); }
655
658void enable_mask() { iterator_.enable_mask(); }
659
662void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
663
666void get_mask(Mask &mask) { iterator_.get_mask(mask); }
667
671return iterator_.valid();
672 }
673 };
674
676
684 template <typename Shape_, typename Element_, int AdvanceRank,
685typename ThreadMap_, typename AccessType_>
686 class PredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
687 AdvanceRank, ThreadMap_, AccessType_> {
688public:
689static_assert(
690 AdvanceRank == 0 || AdvanceRank == 1,
691"Specialization for pitch-linear iterator may along advance along the "
692"contiguous(rank=0) or strided(rank=1) dimension.");
693
695using Element = Element_;
696using Layout = layout::RowMajor;
697static int const kAdvanceRank = AdvanceRank;
698using ThreadMap = ThreadMap_;
699using AccessType = AccessType_;
700
701using Index = typename Layout::Index;
702using LongIndex = typename Layout::LongIndex;
703
704using TensorRef = TensorRef<Element, Layout>;
705using TensorView = TensorView<Element, Layout>;
706using TensorCoord = typename Layout::TensorCoord;
707
709using NonConstPointer = typename platform::remove_const<Element>::type *;
710
711using UnderlyingIterator = PredicatedTileAccessIterator<
712layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
713layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
714
715static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
716
718using Mask = typename UnderlyingIterator::Mask;
719
721class Params {
722private:
723friend PredicatedTileAccessIterator;
724
726typename UnderlyingIterator::Params params_;
727
728public:
729
733
736Params(Layout const &layout)
737 : params_(layout::PitchLinear(layout.stride(0))){};
738 };
739
740private:
741//
742// Data members
743//
744
746UnderlyingIterator iterator_;
747
748public:
752PredicatedTileAccessIterator(
754 Params const ¶ms,
756Pointer pointer,
758TensorCoord extent,
760int thread_id,
762TensorCoord const &threadblock_offset)
763 : iterator_(params.params_, pointer,
764 layout::PitchLinearCoord(extent.column(), extent.row()),
765 thread_id,
766 layout::PitchLinearCoord(threadblock_offset.column(),
767 threadblock_offset.row())) {}
768
771PredicatedTileAccessIterator(
772 Params const ¶ms,
773Pointer pointer,
774TensorCoord extent,
775int thread_id
776 )
777 : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
778make_Coord(0, 0)) {}
779
782void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
783
786void add_pointer_offset(LongIndex pointer_offset) {
787 iterator_.add_pointer_offset(pointer_offset);
788 }
789
793void add_tile_offset(TensorCoord const &tile_offset) {
794 iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
795 }
796
799AccessType *get() const {
800return reinterpret_cast<AccessType *>(iterator_.get());
801 }
802
810PredicatedTileAccessIterator &operator++() {
811 ++iterator_;
812return *this;
813 }
814
822PredicatedTileAccessIterator operator++(int) {
823PredicatedTileAccessIterator self(*this);
824operator++();
825return self;
826 }
827
830void clear_mask() { iterator_.clear_mask(); }
831
834void enable_mask() { iterator_.enable_mask(); }
835
838void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
839
842void get_mask(Mask &mask) { iterator_.get_mask(mask); }
843
847return iterator_.valid();
848 }
849 };
850
852
861
862 template <typename Shape_, typename Element_, int AdvanceRank,
863typename ThreadMap_, typename AccessType_, int InterleavedK>
864 class PredicatedTileAccessIterator<Shape_, Element_,
865 layout::ColumnMajorInterleaved<InterleavedK>,
866 AdvanceRank, ThreadMap_, AccessType_> {
867public:
868static_assert(
869 AdvanceRank == 0 || AdvanceRank == 1,
870"Specialization for pitch-linear iterator may along advance along the "
871"contiguous(rank=0) or strided(rank=1) dimension.");
872
874using Element = Element_;
875static int const kInterleavedK = InterleavedK;
876using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
877static int const kAdvanceRank = AdvanceRank;
878using ThreadMap = ThreadMap_;
879using AccessType = AccessType_;
880
881using Index = typename Layout::Index;
882using LongIndex = typename Layout::LongIndex;
883
884using TensorRef = TensorRef<Element, Layout>;
885using TensorView = TensorView<Element, Layout>;
886using TensorCoord = typename Layout::TensorCoord;
887
889using NonConstPointer = typename platform::remove_const<Element>::type *;
890
891using UnderlyingIterator = PredicatedTileAccessIterator<
892layout::PitchLinearShape<Shape::kRow * kInterleavedK,
893 Shape::kColumn / kInterleavedK>,
894 Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap,
896
897static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
898
900using Mask = typename UnderlyingIterator::Mask;
901
903class Params {
904private:
905friend PredicatedTileAccessIterator;
906
908typename UnderlyingIterator::Params params_;
909
910public:
913
916Params(Layout const &layout)
917 : params_(layout::PitchLinear(layout.stride(0))) {}
918 };
919
920private:
921//
922// Data members
923//
924
926UnderlyingIterator iterator_;
927
928public:
932PredicatedTileAccessIterator(
934 Params const ¶ms,
936Pointer pointer,
938TensorCoord extent,
940int thread_id,
942TensorCoord const &threadblock_offset)
943 : iterator_(params.params_, pointer,
944 layout::PitchLinearCoord(extent.row() * kInterleavedK,
945 extent.column() / kInterleavedK),
946 thread_id,
947 layout::PitchLinearCoord(
948 threadblock_offset.row() * kInterleavedK,
949 threadblock_offset.column() / kInterleavedK)) {}
950
953PredicatedTileAccessIterator(
954 Params const ¶ms,
955Pointer pointer,
956TensorCoord extent,
957int thread_id
958 )
959 : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
960make_Coord(0, 0)) {}
961
964void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
965
968void add_pointer_offset(LongIndex pointer_offset) {
969 iterator_.add_pointer_offset(pointer_offset);
970 }
971
975void add_tile_offset(TensorCoord const &tile_offset) {
976 iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
977 }
978
981AccessType *get() const {
982return reinterpret_cast<AccessType *>(iterator_.get());
983 }
984
992PredicatedTileAccessIterator &operator++() {
993 ++iterator_;
994return *this;
995 }
996
1004PredicatedTileAccessIterator operator++(int) {
1005PredicatedTileAccessIterator self(*this);
1006operator++();
1007return self;
1008 }
1009
1012void clear_mask() { iterator_.clear_mask(); }
1013
1016void enable_mask() { iterator_.enable_mask(); }
1017
1020void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1021
1024void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1025
1028bool valid() { return iterator_.valid(); }
1029 };
1030
1032
1041 template <typename Shape_, typename Element_, int AdvanceRank,
1042typename ThreadMap_, typename AccessType_, int InterleavedK>
1043 class PredicatedTileAccessIterator<Shape_, Element_,
1044 layout::RowMajorInterleaved<InterleavedK>,
1045 AdvanceRank, ThreadMap_, AccessType_> {
1046public:
1047static_assert(
1048 AdvanceRank == 0 || AdvanceRank == 1,
1049"Specialization for pitch-linear iterator may along advance along the "
1050"contiguous(rank=0) or strided(rank=1) dimension.");
1051
1053using Element = Element_;
1054static int const kInterleavedK = InterleavedK;
1055using Layout = layout::RowMajorInterleaved<kInterleavedK>;
1056static int const kAdvanceRank = AdvanceRank;
1057using ThreadMap = ThreadMap_;
1058using AccessType = AccessType_;
1059
1060using Index = typename Layout::Index;
1061using LongIndex = typename Layout::LongIndex;
1062
1063using TensorRef = TensorRef<Element, Layout>;
1064using TensorView = TensorView<Element, Layout>;
1065using TensorCoord = typename Layout::TensorCoord;
1066
1067using Pointer = Element *;
1068using NonConstPointer = typename platform::remove_const<Element>::type *;
1069
1070using UnderlyingIterator = PredicatedTileAccessIterator<
1071layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
1072 Shape::kRow / kInterleavedK>,
1073 Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap,
1075
1076
1077static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
1078
1080using Mask = typename UnderlyingIterator::Mask;
1081
1083class Params {
1084private:
1085friend PredicatedTileAccessIterator;
1086
1088typename UnderlyingIterator::Params params_;
1089
1090public:
1093
1096Params(Layout const &layout)
1097 : params_(layout::PitchLinear(layout.stride(0))) {}
1098 };
1099
1100private:
1101//
1102// Data members
1103//
1104
1106UnderlyingIterator iterator_;
1107
1108public:
1112PredicatedTileAccessIterator(
1114 Params const ¶ms,
1116Pointer pointer,
1118TensorCoord extent,
1120int thread_id,
1122TensorCoord const &threadblock_offset)
1123 : iterator_(params.params_, pointer,
1124 layout::PitchLinearCoord(extent.column() * kInterleavedK,
1125 extent.row() / kInterleavedK),
1126 thread_id,
1127 layout::PitchLinearCoord(
1128 threadblock_offset.column() * kInterleavedK,
1129 threadblock_offset.row() / kInterleavedK)) {}
1130
1133PredicatedTileAccessIterator(
1134 Params const ¶ms,
1135Pointer pointer,
1136TensorCoord extent,
1137int thread_id
1138 )
1139 : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
1140make_Coord(0, 0)) {}
1141
1144void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1145
1148void add_pointer_offset(LongIndex pointer_offset) {
1149 iterator_.add_pointer_offset(pointer_offset);
1150 }
1151
1155void add_tile_offset(TensorCoord const &tile_offset) {
1156 iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
1157 }
1158
1161AccessType *get() const {
1162return reinterpret_cast<AccessType *>(iterator_.get());
1163 }
1164
1172PredicatedTileAccessIterator &operator++() {
1173 ++iterator_;
1174return *this;
1175 }
1176
1184PredicatedTileAccessIterator operator++(int) {
1185PredicatedTileAccessIterator self(*this);
1186operator++();
1187return self;
1188 }
1189
1192void clear_mask() { iterator_.clear_mask(); }
1193
1196void enable_mask() { iterator_.enable_mask(); }
1197
1200void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1201
1204void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1205
1208bool valid() { return iterator_.valid(); }
1209 };
1210
1212
1213 } // namespace threadblock
1214 } // namespace transform
1215 } // namespace cutlass
1216
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:606
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator & operator++()
Definition: predicated_tile_access_iterator.h:1172
cutlass::layout::RowMajor::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
cutlass::layout::ColumnMajorInterleaved::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:355
typename Layout::LongIndex LongIndex
Definition: predicated_tile_access_iterator.h:702
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator.h:842
AccessType_ AccessType
Definition: predicated_tile_access_iterator.h:1058
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator.h:1204
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:1200
cutlass::layout::PitchLinearCoord
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:932
typename Layout::LongIndex LongIndex
Definition: predicated_tile_access_iterator.h:882
cutlass::platform::remove_const::type
T type
Definition: platform.h:351
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_access_iterator.h:1065
typename Layout::Index Index
Definition: predicated_tile_access_iterator.h:881
cutlass::layout::RowMajorInterleaved::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:249
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:1133
Element * Pointer
Definition: predicated_tile_access_iterator.h:532
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:610
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
cutlass::layout::ColumnMajorInterleaved::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:352
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_access_iterator.h:706
typename Layout::Index Index
Definition: predicated_tile_access_iterator.h:1060
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
cutlass::layout::ColumnMajor::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:154
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:1196
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator.h:793
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:595
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator.h:97
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator.h:709
cutlass::layout::RowMajorInterleaved::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:246
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:670
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:1016
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:1080
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator & operator++()
Definition: predicated_tile_access_iterator.h:992
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:964
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:900
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor's layout.
Definition: predicated_tile_access_iterator.h:916
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:480
Shape_ Shape
Definition: predicated_tile_access_iterator.h:873
Defines a structure containing strides and a pointer to tensor data.
typename Layout::LongIndex LongIndex
Definition: predicated_tile_access_iterator.h:526
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator operator++(int)
Definition: predicated_tile_access_iterator.h:822
AccessType_ AccessType
Definition: predicated_tile_access_iterator.h:523
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator.h:889
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor's layout.
Definition: predicated_tile_access_iterator.h:147
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Element * Pointer
Definition: predicated_tile_access_iterator.h:708
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_access_iterator.h:530
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:830
Shape_ Shape
Definition: predicated_tile_access_iterator.h:518
Array< uint32_t, kPredicateWordCount > Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:119
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:347
cutlass::layout::PitchLinearShape
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_access_iterator.h:886
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:846
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator operator++(int)
Definition: predicated_tile_access_iterator.h:646
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator & operator++()
Increment and return an instance to self.
Definition: predicated_tile_access_iterator.h:393
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:321
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
cutlass::layout::RowMajor::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
Shape_ Shape
Definition: predicated_tile_access_iterator.h:694
AccessType_ AccessType
Definition: predicated_tile_access_iterator.h:87
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:1192
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:662
cutlass::layout::PitchLinear::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
CUTLASS_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: predicated_tile_access_iterator.h:353
Element * Pointer
Definition: predicated_tile_access_iterator.h:888
cutlass::TensorView< Element, Layout >
typename Layout::Index Index
Definition: predicated_tile_access_iterator.h:525
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator.h:617
Defines a Shape template for matrix tiles.
typename Layout::LongIndex LongIndex
Definition: predicated_tile_access_iterator.h:1061
Defines the size of an element in bits.
Definition: numeric_types.h:42
CUTLASS_HOST_DEVICE Params()
Definition: predicated_tile_access_iterator.h:143
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:953
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:718
Element * Pointer
Definition: predicated_tile_access_iterator.h:96
Shape_ Shape
Definition: predicated_tile_access_iterator.h:1052
typename Layout::Index Index
Definition: predicated_tile_access_iterator.h:701
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:771
cutlass::TensorRef< Element, Layout >
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator.h:666
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator.h:471
CUTLASS_HOST_DEVICE Params()
Definition: predicated_tile_access_iterator.h:1092
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:576
CUTLASS_HOST_DEVICE Params()
Default ctor.
Definition: predicated_tile_access_iterator.h:556
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
AccessType_ AccessType
Definition: predicated_tile_access_iterator.h:699
Shape_ Shape
Definition: predicated_tile_access_iterator.h:82
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator & operator++()
Definition: predicated_tile_access_iterator.h:634
CUTLASS_HOST_DEVICE Params()
Default ctor.
Definition: predicated_tile_access_iterator.h:732
CUTLASS_HOST_DEVICE constexpr const T & min(const T &a, const T &b)
std::min
Definition: platform.h:183
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator operator++(int)
Increment and return an instance to self.
Definition: predicated_tile_access_iterator.h:434
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator.h:1024
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::layout::PitchLinear::Index
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator.h:1068
CUTLASS_HOST_DEVICE Params()
Definition: predicated_tile_access_iterator.h:912
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator.h:975
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:1112
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor's layout.
Definition: predicated_tile_access_iterator.h:1096
friend PredicatedTileAccessIterator
Definition: predicated_tile_access_iterator.h:124
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:786
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:263
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:1028
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:782
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:1012
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:752
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:658
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:654
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:1148
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator operator++(int)
Definition: predicated_tile_access_iterator.h:1004
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_access_iterator.h:94
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:461
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator operator++(int)
Definition: predicated_tile_access_iterator.h:1184
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:968
AccessType_ AccessType
Definition: predicated_tile_access_iterator.h:879
Defines layout functions used by TensorRef and derived classes.
typename Layout::Index Index
Definition: predicated_tile_access_iterator.h:89
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor's layout.
Definition: predicated_tile_access_iterator.h:560
typename Layout::LongIndex LongIndex
Definition: predicated_tile_access_iterator.h:90
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:834
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
cutlass::layout::ColumnMajorInterleaved
Definition: layout/matrix.h:343
cutlass::layout::ColumnMajor::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:151
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator.h:533
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:335
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:1020
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:1144
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:452
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:442
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator & operator++()
Definition: predicated_tile_access_iterator.h:810
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:540
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:1208
Element * Pointer
Definition: predicated_tile_access_iterator.h:1067
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor's layout.
Definition: predicated_tile_access_iterator.h:736
cutlass::transform::threadblock::PredicatedTileAccessIterator
Definition: predicated_tile_access_iterator.h:66
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:838
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator.h:1155
cutlass::layout::RowMajorInterleaved
Definition: layout/matrix.h:237
Generated by 1.8.11