docs/predicated__tile__access__iterator__2dthreadtile_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
predicated_tile_access_iterator_2dthreadtile.h
[Go to the documentation of this file.](predicated tile access iterator 2dthreadtile_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 *modification, are permitted provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice,
7 *this list of conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright
9 *notice, this list of conditions and the following disclaimer in the
10 *documentation and/or other materials provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its
12 *contributors may be used to endorse or promote products derived from this
13 *software without specific prior written permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
19 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
20 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
22 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
23 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 *
26 **************************************************************************************************/
40 #pragma once
41
42 #include "cutlass/array.h"
43 #include "cutlass/coord.h"
44 #include "cutlass/cutlass.h"
45 #include "cutlass/layout/matrix.h"
46 #include "cutlass/layout/pitch_linear.h"
47 #include "cutlass/matrix_shape.h"
48 #include "cutlass/predicate_vector.h"
49 #include "cutlass/tensor_ref.h"
50 #include "cutlass/tensor_view.h"
51
53
55
56 namespace cutlass {
57 namespace transform {
58 namespace threadblock {
59
61
64 template <typename Shape, typename Element, typename Layout, int AdvanceRank,
65typename ThreadMap, typename AccessType>
66 class PredicatedTileAccessIterator2dThreadTile;
67
69
72 template <typename Shape_, typename Element_, int AdvanceRank,
73typename ThreadMap_, typename AccessType_>
74 class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::PitchLinear,
75 AdvanceRank, ThreadMap_, AccessType_> {
76public:
78 AdvanceRank == 0 || AdvanceRank == 1,
79"Specialization for pitch-linear iterator may along advance along the "
80"contiguous(rank=0) or strided(rank=1) dimension.");
81
83using Element = Element_;
84using Layout = layout::PitchLinear;
85static int const kAdvanceRank = AdvanceRank;
86using ThreadMap = ThreadMap_;
87using AccessType = AccessType_;
88
89using Index = typename Layout::Index;
90using LongIndex = typename Layout::LongIndex;
91
92using TensorRef = TensorRef<Element, Layout>;
93using TensorView = TensorView<Element, Layout>;
94using TensorCoord = typename Layout::TensorCoord;
95
97using NonConstPointer = typename platform::remove_const<Element>::type *;
98
99static int const kPredicatesPerByte = 4;
100static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
101
103static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte;
104static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
105
106static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
107
108static_assert(kPredicateWordCount <= 4, "Too many predicates.");
109
111using Mask = Array<uint32_t, kPredicateWordCount>;
112
114class Params {
115public:
116friend PredicatedTileAccessIterator2dThreadTile;
117
118private:
120int stride_;
123int inc_strided_;
126int inc_next_;
129int inc_advance_;
130
131public:
132
133// Default ctor
135Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
136
139Params(Layout const &layout) : stride_(layout.stride(0)) {
140
141 inc_strided_ =
142 (stride_ * ThreadMap::Delta::kStrided) * int(sizeof(Element));
143
144if (kAdvanceRank) {
145// advance along strided dimension
146 inc_advance_ = Shape::kStrided * stride_ * int(sizeof(Element));
147 } else {
148// advance along contiguous dimension
149 inc_advance_ = Shape::kContiguous * int(sizeof(Element));
150 }
151
152 inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
153 ThreadMap::Delta::kStrided * stride_ *
154int(sizeof(Element));
155 };
156 };
157
158private:
160using BytePointer = char *;
161
162private:
163//
164// Data members
165//
166
168 Params const ¶ms_;
169
171 BytePointer pointer_;
172
174 uint32_t predicates_[kPredicateWordCount];
175
177TensorCoord extent_;
178
180TensorCoord thread_offset_;
181
183int residue_tile_idx_;
184
186bool is_residue_tile_;
187
189int iteration_contiguous_;
190
192int iteration_strided_;
193
195int iteration_thread_;
196
197private:
200void compute_predicates_(
202bool is_steady_state = false) {
203
205for (int i = 0; i < kPredicateWordCount; ++i) {
206 predicates_[i] = 0u;
207 }
208
210for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
212for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
214for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) {
215
216TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous,
217 ts + s * ThreadMap::Delta::kStrided);
218
219TensorCoord coord = thread_offset_ + iteration_coord;
220
221bool guard;
222
223if (is_steady_state) {
224if (kAdvanceRank == 0) {
225 guard = (coord.strided() < extent_.strided());
226 } else {
227 guard = (coord.contiguous() < extent_.contiguous());
228 }
229 } else {
230 guard = (coord.strided() < extent_.strided() &&
231 coord.contiguous() < extent_.contiguous());
232 }
233
234int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
235int word_idx = pred_idx / kPredicatesPerWord;
236int residual = pred_idx % kPredicatesPerWord;
237int byte_idx = residual / kPredicatesPerByte;
238int bit_idx = residual % kPredicatesPerByte;
239
240 predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
241
242 }
243 }
244 }
245
246 }
247
248public:
252PredicatedTileAccessIterator2dThreadTile(
254 Params const ¶ms,
256Pointer pointer,
258TensorCoord extent,
260int thread_id,
262TensorCoord const &threadblock_offset)
263 : params_(params),
264 pointer_(reinterpret_cast<BytePointer>(
265 const_cast<NonConstPointer>(pointer))),
266 extent_(extent),
267 is_residue_tile_(true) {
268
269
270TensorCoord residue_offset;
271if (kAdvanceRank) {
272 residue_tile_idx_ =
273 (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
274 Shape::kStrided;
275 residue_offset = make_Coord(0, residue_tile_idx_ * Shape::kStrided);
276 } else {
277 residue_tile_idx_ =
278 (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
279 Shape::kContiguous;
280 residue_offset = make_Coord(residue_tile_idx_ * Shape::kContiguous, 0);
281 }
282
283// Per-thread offset in logical coordinates of tensor
284 thread_offset_ = threadblock_offset + residue_offset +
285 ThreadMap::initial_offset(thread_id);
286
287// update internal pointers
288 Layout layout(params_.stride_);
289 add_pointer_offset(layout(thread_offset_));
290
291 compute_predicates_(false);
292
293 set_iteration_index(0);
294 }
295
298PredicatedTileAccessIterator2dThreadTile(
300 Params const ¶ms,
302Pointer pointer,
304TensorCoord extent,
306int thread_id)
307 : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
308make_Coord(0, 0)) {}
309
312void set_iteration_index(int index) {
313
314int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
315 iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
316
317 iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided;
318 iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided;
319
320 }
321
324void add_pointer_offset(LongIndex pointer_offset) {
325 pointer_ += int(sizeof(Element)) * pointer_offset;
326 }
327
329 CUTLASS_DEVICE
330void add_tile_offset(
331TensorCoord const &tile_offset) {
332if (is_residue_tile_) {
333TensorCoord residue_offset;
334if (kAdvanceRank) {
335 residue_offset = TensorCoord(0, residue_tile_idx_ * Shape::kStrided);
336 } else {
337 residue_offset = TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0);
338 }
339
340 thread_offset_ -= residue_offset;
341
342 Layout layout(params_.stride_);
343 add_pointer_offset(-layout(residue_offset));
344
345 compute_predicates_(true);
346
347if (kAdvanceRank) {
348 pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
349 pointer_ += Shape::kContiguous * tile_offset.contiguous();
350 } else {
351 pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
352 pointer_ += Shape::kStrided * tile_offset.strided();
353 }
354 } else {
355if (kAdvanceRank) {
356 pointer_ += params_.inc_advance_ * tile_offset.strided();
357 pointer_ += Shape::kContiguous * tile_offset.contiguous();
358 } else {
359 pointer_ += params_.inc_advance_ * tile_offset.contiguous();
360 pointer_ += Shape::kStrided * tile_offset.strided();
361 }
362 }
363 is_residue_tile_ = false;
364 }
365
367AccessType *get() const {
368
369AccessType *ret_val = reinterpret_cast<AccessType *>(
370 pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * int(sizeof(Element)));
371
372return ret_val;
373 }
374
377PredicatedTileAccessIterator2dThreadTile &operator++() {
378
379 iteration_thread_++;
380
381if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided)
382return *this;
383
384 iteration_thread_ = 0;
385
386 ++iteration_contiguous_;
387
388if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
389return *this;
390
391// Enter here only if (iteration_contiguous_ ==
392// ThreadMap::Iteration::kContiguous)
393 iteration_contiguous_ = 0;
394 ++iteration_strided_;
395
396if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
397 pointer_ += params_.inc_strided_;
398return *this;
399 }
400
401// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
402// which means we enter the next tile.
403 iteration_strided_ = 0;
404
405// advance to next tile
406 pointer_ += params_.inc_next_;
407
408// now return to start tile - if the iterator is subsequently advanced, this
409// subtraction as well as the subsequent integer addition are both elided by
410// the compiler.
411 pointer_ -= params_.inc_advance_;
412
413return *this;
414 }
415
418PredicatedTileAccessIterator2dThreadTile operator++(int) {
419PredicatedTileAccessIterator2dThreadTile self(*this);
420operator++();
421return self;
422 }
423
426void clear_mask() {
428for (int i = 0; i < kPredicateWordCount; ++i) {
429 predicates_[i] = 0u;
430 }
431
432 }
433
436void enable_mask() {
438for (int i = 0; i < kPredicateWordCount; ++i) {
439 predicates_[i] = 0xffffffff;
440 }
441 }
442
445void set_mask(Mask const &mask) {
447for (int i = 0; i < kPredicateWordCount; ++i) {
448 predicates_[i] = mask[i];
449 }
450
451 }
452
455void get_mask(Mask &mask) {
457for (int i = 0; i < kPredicateWordCount; ++i) {
458 mask[i] = predicates_[i];
459 }
460 }
461
465
466int pred_idx =
467 iteration_thread_ +
468 iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided +
469 iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
470
471int word_idx = pred_idx / kPredicatesPerWord;
472int residual = pred_idx % kPredicatesPerWord;
473int byte_idx = residual / kPredicatesPerByte;
474int bit_idx = residual % kPredicatesPerByte;
475
476bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
477
478return pred;
479 }
480 };
481
483
491 template <typename Shape_, typename Element_, int AdvanceRank,
492typename ThreadMap_, typename AccessType_>
493 class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::ColumnMajor,
494 AdvanceRank, ThreadMap_, AccessType_> {
495public:
496static_assert(
497 AdvanceRank == 0 || AdvanceRank == 1,
498"Specialization for pitch-linear iterator may along advance along the "
499"contiguous(rank=0) or strided(rank=1) dimension.");
500
502using Element = Element_;
503using Layout = layout::ColumnMajor;
504static int const kAdvanceRank = AdvanceRank;
505using ThreadMap = ThreadMap_;
506using AccessType = AccessType_;
507
508using Index = typename Layout::Index;
509using LongIndex = typename Layout::LongIndex;
510
511using TensorRef = TensorRef<Element, Layout>;
512using TensorView = TensorView<Element, Layout>;
513using TensorCoord = typename Layout::TensorCoord;
514
516using NonConstPointer = typename platform::remove_const<Element>::type *;
517
518using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile<
519layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
520layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
521
523using Mask = typename UnderlyingIterator::Mask;
524
526class Params {
527private:
528friend PredicatedTileAccessIterator2dThreadTile;
529
531typename UnderlyingIterator::Params params_;
532
533public:
534
538
541Params(Layout const &layout)
542 : params_(layout::PitchLinear(layout.stride(0))){};
543 };
544
545private:
546//
547// Data members
548//
549
551UnderlyingIterator iterator_;
552
553public:
557PredicatedTileAccessIterator2dThreadTile(
559 Params const ¶ms,
561Pointer pointer,
563TensorCoord extent,
565int thread_id,
567TensorCoord const &threadblock_offset)
568 : iterator_(params.params_, pointer,
569 layout::PitchLinearCoord(extent.row(), extent.column()),
570 thread_id,
571 layout::PitchLinearCoord(threadblock_offset.row(),
572 threadblock_offset.column())) {}
573
576PredicatedTileAccessIterator2dThreadTile(
577 Params const ¶ms,
578Pointer pointer,
579TensorCoord extent,
580int thread_id
581 )
582 : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
583make_Coord(0, 0)) {}
584
587void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
588
591void add_pointer_offset(LongIndex pointer_offset) {
592 iterator_.add_pointer_offset(pointer_offset);
593 }
594
598void add_tile_offset(TensorCoord const &tile_offset) {
599 iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
600 }
601
604AccessType *get() const {
605return reinterpret_cast<AccessType *>(iterator_.get());
606 }
607
615PredicatedTileAccessIterator2dThreadTile &operator++() {
616 ++iterator_;
617return *this;
618 }
619
627PredicatedTileAccessIterator2dThreadTile operator++(int) {
628PredicatedTileAccessIterator2dThreadTile self(*this);
629operator++();
630return self;
631 }
632
635void clear_mask() { iterator_.clear_mask(); }
636
639void enable_mask() { iterator_.enable_mask(); }
640
643void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
644
647void get_mask(Mask &mask) { iterator_.get_mask(mask); }
648
652return iterator_.valid();
653 }
654 };
655
657
665 template <typename Shape_, typename Element_, int AdvanceRank,
666typename ThreadMap_, typename AccessType_>
667 class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::RowMajor,
668 AdvanceRank, ThreadMap_, AccessType_> {
669public:
670static_assert(
671 AdvanceRank == 0 || AdvanceRank == 1,
672"Specialization for pitch-linear iterator may along advance along the "
673"contiguous(rank=0) or strided(rank=1) dimension.");
674
676using Element = Element_;
677using Layout = layout::RowMajor;
678static int const kAdvanceRank = AdvanceRank;
679using ThreadMap = ThreadMap_;
680using AccessType = AccessType_;
681
682using Index = typename Layout::Index;
683using LongIndex = typename Layout::LongIndex;
684
685using TensorRef = TensorRef<Element, Layout>;
686using TensorView = TensorView<Element, Layout>;
687using TensorCoord = typename Layout::TensorCoord;
688
690using NonConstPointer = typename platform::remove_const<Element>::type *;
691
692using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile<
693layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
694layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
695
697using Mask = typename UnderlyingIterator::Mask;
698
700class Params {
701private:
702friend PredicatedTileAccessIterator2dThreadTile;
703
705typename UnderlyingIterator::Params params_;
706
707public:
708
712
715Params(Layout const &layout)
716 : params_(layout::PitchLinear(layout.stride(0))){};
717 };
718
719private:
720//
721// Data members
722//
723
725UnderlyingIterator iterator_;
726
727public:
731PredicatedTileAccessIterator2dThreadTile(
733 Params const ¶ms,
735Pointer pointer,
737TensorCoord extent,
739int thread_id,
741TensorCoord const &threadblock_offset)
742 : iterator_(params.params_, pointer,
743 layout::PitchLinearCoord(extent.column(), extent.row()),
744 thread_id,
745 layout::PitchLinearCoord(threadblock_offset.column(),
746 threadblock_offset.row())) {}
747
750PredicatedTileAccessIterator2dThreadTile(
751 Params const ¶ms,
752Pointer pointer,
753TensorCoord extent,
754int thread_id
755 )
756 : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
757make_Coord(0, 0)) {}
758
761void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
762
765void add_pointer_offset(LongIndex pointer_offset) {
766 iterator_.add_pointer_offset(pointer_offset);
767 }
768
772void add_tile_offset(TensorCoord const &tile_offset) {
773 iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
774 }
775
778AccessType *get() const {
779return reinterpret_cast<AccessType *>(iterator_.get());
780 }
781
789PredicatedTileAccessIterator2dThreadTile &operator++() {
790 ++iterator_;
791return *this;
792 }
793
801PredicatedTileAccessIterator2dThreadTile operator++(int) {
802PredicatedTileAccessIterator2dThreadTile self(*this);
803operator++();
804return self;
805 }
806
809void clear_mask() { iterator_.clear_mask(); }
810
813void enable_mask() { iterator_.enable_mask(); }
814
817void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
818
821void get_mask(Mask &mask) { iterator_.get_mask(mask); }
822
826return iterator_.valid();
827 }
828 };
829
831
833
834 } // namespace threadblock
835 } // namespace transform
836 } // namespace cutlass
837
cutlass::layout::RowMajor::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile operator++(int)
Increment and return an instance to self.
Definition: predicated_tile_access_iterator_2dthreadtile.h:418
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile & operator++()
Increment and return an instance to self.
Definition: predicated_tile_access_iterator_2dthreadtile.h:377
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_access_iterator_2dthreadtile.h:94
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator_2dthreadtile.h:436
cutlass::layout::PitchLinearCoord
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
Shape_ Shape
Definition: predicated_tile_access_iterator_2dthreadtile.h:82
cutlass::platform::remove_const::type
T type
Definition: platform.h:351
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator_2dthreadtile.h:690
typename Layout::Index Index
Definition: predicated_tile_access_iterator_2dthreadtile.h:508
Shape_ Shape
Definition: predicated_tile_access_iterator_2dthreadtile.h:675
Shape_ Shape
Definition: predicated_tile_access_iterator_2dthreadtile.h:501
AccessType_ AccessType
Definition: predicated_tile_access_iterator_2dthreadtile.h:506
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset.
Definition: predicated_tile_access_iterator_2dthreadtile.h:576
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
Element * Pointer
Definition: predicated_tile_access_iterator_2dthreadtile.h:689
cutlass::layout::ColumnMajor::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:154
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator_2dthreadtile.h:772
typename Layout::Index Index
Definition: predicated_tile_access_iterator_2dthreadtile.h:682
CUTLASS_HOST_DEVICE Params()
Definition: predicated_tile_access_iterator_2dthreadtile.h:135
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator_2dthreadtile.h:639
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator_2dthreadtile.h:635
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator_2dthreadtile.h:324
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset.
Definition: predicated_tile_access_iterator_2dthreadtile.h:750
typename Layout::Index Index
Definition: predicated_tile_access_iterator_2dthreadtile.h:89
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile operator++(int)
Definition: predicated_tile_access_iterator_2dthreadtile.h:801
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor's layout.
Definition: predicated_tile_access_iterator_2dthreadtile.h:139
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator_2dthreadtile.h:97
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator_2dthreadtile.h:312
cutlass::transform::threadblock::PredicatedTileAccessIterator2dThreadTile
Definition: predicated_tile_access_iterator_2dthreadtile.h:66
CUTLASS_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: predicated_tile_access_iterator_2dthreadtile.h:330
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator_2dthreadtile.h:697
Defines a structure containing strides and a pointer to tensor data.
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator_2dthreadtile.h:825
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator_2dthreadtile.h:651
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_access_iterator_2dthreadtile.h:513
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator_2dthreadtile.h:455
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator_2dthreadtile.h:761
cutlass::layout::PitchLinearShape
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
cutlass::layout::RowMajor::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
cutlass::layout::PitchLinear::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
Element * Pointer
Definition: predicated_tile_access_iterator_2dthreadtile.h:515
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator_2dthreadtile.h:591
cutlass::TensorView< Element, Layout >
Defines a Shape template for matrix tiles.
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator_2dthreadtile.h:587
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator_2dthreadtile.h:765
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator_2dthreadtile.h:557
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator_2dthreadtile.h:445
Array< uint32_t, kPredicateWordCount > Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator_2dthreadtile.h:111
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile & operator++()
Definition: predicated_tile_access_iterator_2dthreadtile.h:615
cutlass::TensorRef< Element, Layout >
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator_2dthreadtile.h:731
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_access_iterator_2dthreadtile.h:687
AccessType_ AccessType
Definition: predicated_tile_access_iterator_2dthreadtile.h:680
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor's layout.
Definition: predicated_tile_access_iterator_2dthreadtile.h:715
CUTLASS_HOST_DEVICE Params()
Default ctor.
Definition: predicated_tile_access_iterator_2dthreadtile.h:711
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator_2dthreadtile.h:523
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator_2dthreadtile.h:426
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile & operator++()
Definition: predicated_tile_access_iterator_2dthreadtile.h:789
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
typename Layout::LongIndex LongIndex
Definition: predicated_tile_access_iterator_2dthreadtile.h:509
AccessType_ AccessType
Definition: predicated_tile_access_iterator_2dthreadtile.h:87
#define static_assert(__e, __m)
Definition: platform.h:153
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator_2dthreadtile.h:516
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator_2dthreadtile.h:252
cutlass::layout::PitchLinear::Index
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator_2dthreadtile.h:643
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator_2dthreadtile.h:464
CUTLASS_HOST_DEVICE Params()
Default ctor.
Definition: predicated_tile_access_iterator_2dthreadtile.h:537
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator_2dthreadtile.h:598
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
typename Layout::LongIndex LongIndex
Definition: predicated_tile_access_iterator_2dthreadtile.h:90
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor's layout.
Definition: predicated_tile_access_iterator_2dthreadtile.h:541
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator_2dthreadtile.h:821
Defines layout functions used by TensorRef and derived classes.
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator_2dthreadtile.h:647
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator_2dthreadtile.h:809
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
cutlass::layout::ColumnMajor::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:151
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator_2dthreadtile.h:813
friend PredicatedTileAccessIterator2dThreadTile
Definition: predicated_tile_access_iterator_2dthreadtile.h:116
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile operator++(int)
Definition: predicated_tile_access_iterator_2dthreadtile.h:627
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator_2dthreadtile.h:817
Element * Pointer
Definition: predicated_tile_access_iterator_2dthreadtile.h:96
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator2dThreadTile(Params const ¶ms, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset.
Definition: predicated_tile_access_iterator_2dthreadtile.h:298
typename Layout::LongIndex LongIndex
Definition: predicated_tile_access_iterator_2dthreadtile.h:683
Generated by 1.8.11