docs/epilogue_2threadblock_2predicated__tile__iterator_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
epilogue/threadblock/predicated_tile_iterator.h
[Go to the documentation of this file.](epilogue_2threadblock_2predicated tile iterator_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
33 #pragma once
34
35 #include "cutlass/cutlass.h"
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/array.h"
38 #include "cutlass/layout/matrix.h"
39 #include "cutlass/matrix_shape.h"
40 #include "cutlass/tensor_ref.h"
41
42 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"
43 #include "[cutlass/epilogue/threadblock/output_tile_thread_map.h](output tile thread__map_8h.html)"
44
45
47
48 namespace cutlass {
49
51
52 namespace epilogue {
53 namespace threadblock {
54
56
61 template <
62typename ThreadMap_,
63typename Element_
64 >
65 class PredicatedTileIterator {
66 public:
67using ThreadMap = ThreadMap_;
68using Shape = typename ThreadMap::Shape;
69
71
72using Layout = layout::RowMajor;
73using TensorRef = TensorRef<Element, Layout>;
74using ConstTensorRef = typename TensorRef::ConstTensorRef;
75
76using Index = typename Layout::Index;
77using LongIndex = typename Layout::LongIndex;
78using TensorCoord = MatrixCoord;
79
80static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
81static int const kThreads = ThreadMap::kThreads;
82static int const kIterations = ThreadMap::Count::kTile;
83
84static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0");
85static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0");
86static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0");
87static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0");
88
90using Fragment = Array<
91Element,
92 ThreadMap::Iterations::kColumn *
93 ThreadMap::Iterations::kRow *
94 ThreadMap::Iterations::kGroup *
95 ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
96
98using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
99
100//
101// Parameters struct
102//
103
105
106//
107// Data members
108//
109
111
115
120
121//
122// Methods
123//
124
126Status initialize(Index stride_) {
127
128 stride = stride_;
129
130 increment_row = stride * ThreadMap::Delta::kRow;
131
132 increment_group = stride * ThreadMap::Delta::kGroup
133 - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
134
135 increment_cluster = stride * ThreadMap::Delta::kCluster
136 - stride * ThreadMap::Delta::kGroup * (ThreadMap::Iterations::kGroup - 1)
137 - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
138
139 advance_row = stride * ThreadMap::Shape::kRow;
140
141 advance_group = stride * (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
142
143 advance_cluster =
144 stride *
145 ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;;
146
147 advance_tile =
148 stride *
149 ThreadMap::Shape::kGroup *
150 ThreadMap::Shape::kRow *
151 ThreadMap::Shape::kCluster *
152 ThreadMap::Shape::kTile;
153
154return Status::kSuccess;
155 }
156
159initialize(0);
160 }
161
163Params(Layout const &layout) {
164
165initialize(layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess);
166 }
167 };
168
171
172static int const kCount = ThreadMap::Iterations::kColumn;
173
175bool predicates[kCount];
176
177//
178// Mask
179//
182 enable();
183 }
184
186CUTLASS_HOST_DEVICE void clear() {
188for (int i = 0; i < kCount; ++i) {
189 predicates[i] = false;
190 }
191 }
192
194 CUTLASS_DEVICE void enable() {
196for (int i = 0; i < kCount; ++i) {
197 predicates[i] = true;
198 }
199 }
200 };
201
202 private:
203
204//
205// Data members
206//
207
209Params params_;
210
212 uint8_t *byte_pointer_;
213
215Mask mask_;
216
218Index extent_row_;
219
221Index thread_start_row_;
222
224int state_[3];
225
226 private:
227
228//
229// Methods
230//
231
232 public:
233
234//
235// Methods
236//
237
239 CUTLASS_DEVICE
241Params const & params,
242 Element *pointer,
243TensorCoord extent,
244int thread_idx,
245TensorCoord threadblock_offset = TensorCoord()
246 ):
247 params_(params) {
248
249TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
250
251 extent_row_ = extent.row();
252 thread_start_row_ = thread_offset.row();
253
254// Initialize predicates
256for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
257
258 mask_.predicates[c] = ((thread_offset.column()
259 + ThreadMap::Delta::kColumn * c) < extent.column());
260 }
261
262// Initialize pointer
263 byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
264 thread_offset.row() * params_.stride +
265 thread_offset.column() * sizeof(AccessType) / kElementsPerAccess;
266
267// Initialize internal state counter
268 state_[0] = state_[1] = state_[2] = 0;
269 }
270
273void add_pointer_offset(LongIndex pointer_offset) {
274 byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
275 }
276
278 CUTLASS_DEVICE
279void load(Fragment &frag) {
280
281 uint8_t *byte_pointer = byte_pointer_;
282AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
283
285for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
286
288for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
289
291for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
292
293int frag_row_idx =
294 (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
295
296int row_offset = row * ThreadMap::Delta::kRow
297 + group * ThreadMap::Delta::kGroup
298 + cluster * ThreadMap::Delta::kCluster;
299
300bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
301
302AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
303
305for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
306
307bool guard = row_guard && mask_.predicates[column];
308
309if (guard) {
310 frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] =
311 memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess];
312 }
313 }
314
315if (row + 1 < ThreadMap::Iterations::kRow) {
316 byte_pointer += params_.increment_row;
317 }
318 }
319
320if (group + 1 < ThreadMap::Iterations::kGroup) {
321 byte_pointer += params_.increment_group;
322 }
323 }
324
325if (cluster + 1 < ThreadMap::Iterations::kCluster) {
326 byte_pointer += params_.increment_cluster;
327 }
328 }
329 }
330
332 CUTLASS_DEVICE
333void store(Fragment const &frag) {
334 uint8_t *byte_pointer = byte_pointer_;
335AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
336
338for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
339
341for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
342
344for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
345
346int frag_row_idx =
347 (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
348
349int row_offset = row * ThreadMap::Delta::kRow
350 + group * ThreadMap::Delta::kGroup
351 + cluster * ThreadMap::Delta::kCluster;
352
353bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
354
355AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
356
358for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
359
360bool guard = row_guard && mask_.predicates[column];
361
362if (guard) {
363
364 memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
365 frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
366 }
367 }
368
369if (row + 1 < ThreadMap::Iterations::kRow) {
370 byte_pointer += params_.increment_row;
371 }
372 }
373
374if (group + 1 < ThreadMap::Iterations::kGroup) {
375 byte_pointer += params_.increment_group;
376 }
377 }
378
379if (cluster + 1 < ThreadMap::Iterations::kCluster) {
380 byte_pointer += params_.increment_cluster;
381 }
382 }
383 }
384
387PredicatedTileIterator &operator++() {
388
389 ++state_[0];
390 byte_pointer_ += params_.advance_row;
391 thread_start_row_ += ThreadMap::Shape::kRow;
392
393if (state_[0] == ThreadMap::Count::kRow) {
394
395 state_[0] = 0;
396 ++state_[1];
397 byte_pointer_ += params_.advance_group;
398
399 thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
400 ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
401
402if (state_[1] == ThreadMap::Count::kGroup) {
403
404 state_[1] = 0;
405 ++state_[2];
406 byte_pointer_ += params_.advance_cluster;
407
408 thread_start_row_ += ThreadMap::Count::kGroup *
409 ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
410
411if (state_[2] == ThreadMap::Count::kCluster) {
412 state_[2] = 0;
413 byte_pointer_ += params_.advance_tile;
414 }
415 }
416 }
417
418return *this;
419 }
420
422 CUTLASS_DEVICE void clear_mask() {
423 mask_.clear();
424 }
425
427 CUTLASS_DEVICE void enable_mask() {
428 mask_.enable();
429 }
430
432 CUTLASS_DEVICE void get_mask(Mask &mask) {
433return mask_;
434 }
435
437 CUTLASS_DEVICE void set_mask(Mask const &mask) {
438 mask_ = mask;
439 }
440 };
441
447 template <
448typename ThreadMap_,
449typename Element_,
450int InterleavedK
451 >
452 class InterleavedPredicatedTileIterator {
453 public:
454using ThreadMap = ThreadMap_;
455
456using Element = Element_;
457
458using Layout = layout::ColumnMajorInterleaved<InterleavedK>;
459using TensorRef = TensorRef<Element, Layout>;
460using ConstTensorRef = typename TensorRef::ConstTensorRef;
461
462using Index = typename Layout::Index;
463using LongIndex = typename Layout::LongIndex;
464using TensorCoord = layout::PitchLinearCoord;
465
466static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
467static int const kThreads = ThreadMap::kThreads;
468static int const kIterations = ThreadMap::Iterations::kCount;
469
471using Fragment = Array<Element, ThreadMap::kElementsPerAccess>;
472
474using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
475
476//
477// Parameters struct
478//
479
481
482//
483// Data members
484//
485
487
490
491//
492// Methods
493//
494
496Status initialize(Index stride_) {
497 stride = stride_;
498
499 advance_row =
500 ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8;
501
502 advance_column =
503 stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess *
504sizeof_bits<Element>::value * ThreadMap::kWarpSize / 8;
505
506return Status::kSuccess;
507 }
508
511initialize(0);
512 }
513
515Params(Layout const &layout) {
516
517initialize(layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess);
518 }
519 };
520
523static int const kCount = (ThreadMap::Iterations::kContiguous < 8)
524 ? 8
525 : ThreadMap::Iterations::kContiguous;
526
528bool predicates[kCount];
529
530//
531// Mask
532//
535 enable();
536 }
537
539CUTLASS_HOST_DEVICE void clear() {
541for (int i = 0; i < kCount; ++i) {
542 predicates[i] = false;
543 }
544 }
545
547 CUTLASS_DEVICE void enable() {
549for (int i = 0; i < kCount; ++i) {
550 predicates[i] = true;
551 }
552 }
553 };
554
555 private:
556
557//
558// Data members
559//
560
562Params params_;
563
565 uint8_t *byte_pointer_;
566
568Mask mask_;
569
571Index extent_col_;
572
575Index thread_start_col_;
576
578int iteration_contiguous_;
579
580int iteration_strided_;
581
582 private:
583
584//
585// Methods
586//
587
588 public:
589
590//
591// Methods
592//
593
595 CUTLASS_DEVICE
596InterleavedPredicatedTileIterator(
597Params const & params,
598 Element *pointer,
599TensorCoord extent,
600int thread_idx,
601TensorCoord threadblock_offset
602 ):
603 params_(params) {
604TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) +
605TensorCoord(threadblock_offset.contiguous() * InterleavedK,
606 threadblock_offset.strided() / InterleavedK);
607
608 extent_col_ = extent.strided() / InterleavedK;
609 thread_start_col_ = thread_offset.strided();
610
611// Initialize predicates
613for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
614 mask_.predicates[c] =
615 ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) <
616 (extent.contiguous() * InterleavedK));
617 }
618
619// Initialize pointer
620 byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
621 thread_offset.strided() * params_.stride +
622 thread_offset.contiguous() * sizeof(AccessType) / kElementsPerAccess;
623
624// Initialize internal state counter
625 iteration_contiguous_ = iteration_strided_ = 0;
626 }
627
630void add_pointer_offset(LongIndex pointer_offset) {
631 byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
632 }
633
635 CUTLASS_DEVICE
636void load(Fragment &frag) {
637 uint8_t *byte_pointer = byte_pointer_;
638AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
639AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
640
641int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;
642
643bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
644
645bool guard = col_guard && mask_.predicates[iteration_contiguous_];
646
647if (guard) {
648 *frag_ptr = *memory_pointer;
649 }
650 }
651
653 CUTLASS_DEVICE
654void store(Fragment const &frag) {
655 uint8_t *byte_pointer = byte_pointer_;
656AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
657AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
658
659int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;
660
661bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
662
663bool guard = col_guard && mask_.predicates[iteration_contiguous_];
664
665if (guard) {
666 *memory_pointer = *frag_ptr;
667 }
668 }
669
672void set_iteration_index(int iteration) {
673 iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous;
674 iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous;
675 }
676
679InterleavedPredicatedTileIterator &operator++() {
680
681 ++iteration_contiguous_;
682 byte_pointer_ += params_.advance_row;
683
684if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) {
685
686 iteration_contiguous_ = 0;
687 ++iteration_strided_;
688 byte_pointer_ += params_.advance_column;
689
690if (iteration_strided_ == ThreadMap::Iterations::kStrided) {
691 iteration_strided_ = 0;
692 }
693 }
694
695return *this;
696 }
697
699 CUTLASS_DEVICE void clear_mask() {
700 mask_.clear();
701 }
702
704 CUTLASS_DEVICE void enable_mask() {
705 mask_.enable();
706 }
707
709 CUTLASS_DEVICE void get_mask(Mask &mask) {
710return mask_;
711 }
712
714 CUTLASS_DEVICE void set_mask(Mask const &mask) {
715 mask_ = mask;
716 }
717 };
718
720
721 } // namespace threadblock
722 } // namespace epilogue
723 } // namespace cutlass
724
cutlass::epilogue::threadblock::PredicatedTileIterator::Mask::predicates
bool predicates[kCount]
Predicate state.
Definition: epilogue/threadblock/predicated_tile_iterator.h:175
cutlass::layout::RowMajor::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
cutlass::epilogue::threadblock::PredicatedTileIterator::kElementsPerAccess
static int const kElementsPerAccess
Definition: epilogue/threadblock/predicated_tile_iterator.h:80
cutlass::epilogue::threadblock::PredicatedTileIterator::Mask::enable
CUTLASS_DEVICE void enable()
Definition: epilogue/threadblock/predicated_tile_iterator.h:194
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::advance_row
Index advance_row
amount to add to move to the next 'row' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:116
cutlass::epilogue::threadblock::PredicatedTileIterator::load
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:279
cutlass::epilogue::threadblock::PredicatedTileIterator::Element
Element_ Element
Definition: epilogue/threadblock/predicated_tile_iterator.h:70
Definition: aligned_buffer.h:35
cutlass::layout::PitchLinearCoord
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
cutlass::epilogue::threadblock::PredicatedTileIterator::AccessType
AlignedArray< Element, ThreadMap::kElementsPerAccess > AccessType
Memory access size.
Definition: epilogue/threadblock/predicated_tile_iterator.h:98
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::initialize
CUTLASS_HOST_DEVICE Status initialize(Index stride_)
Definition: epilogue/threadblock/predicated_tile_iterator.h:496
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask::clear
CUTLASS_HOST_DEVICE void clear()
CUTLASS_HOST_DEVICE enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:539
[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)
Templates implementing how threads are mapped to a given tile.
cutlass::epilogue::threadblock::PredicatedTileIterator::get_mask
CUTLASS_DEVICE void get_mask(Mask &mask)
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:432
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Fragment
Array< Element, ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:471
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::ThreadMap
ThreadMap_ ThreadMap
Definition: epilogue/threadblock/predicated_tile_iterator.h:454
Aligned array type.
Definition: array.h:511
cutlass::epilogue::threadblock::PredicatedTileIterator::Mask
Mask object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:170
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask::predicates
bool predicates[kCount]
Predicate state.
Definition: epilogue/threadblock/predicated_tile_iterator.h:528
cutlass::epilogue::threadblock::PredicatedTileIterator::ConstTensorRef
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: epilogue/threadblock/predicated_tile_iterator.h:74
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask::Mask
CUTLASS_HOST_DEVICE Mask()
Efficiently disables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:534
cutlass::layout::ColumnMajorInterleaved::stride
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: layout/matrix.h:418
cutlass::layout::RowMajor::stride
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: layout/matrix.h:112
cutlass::epilogue::threadblock::PredicatedTileIterator::Fragment
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:95
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params
Definition: epilogue/threadblock/predicated_tile_iterator.h:480
cutlass::epilogue::threadblock::PredicatedTileIterator::store
CUTLASS_DEVICE void store(Fragment const &frag)
Stores a fragment to memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:333
cutlass::epilogue::threadblock::PredicatedTileIterator::LongIndex
typename Layout::LongIndex LongIndex
Definition: epilogue/threadblock/predicated_tile_iterator.h:77
cutlass::TensorRef< Element, Layout >::ConstTensorRef
TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef
TensorRef to constant data.
Definition: tensor_ref.h:179
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::load
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:636
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::InterleavedPredicatedTileIterator
CUTLASS_DEVICE InterleavedPredicatedTileIterator(Params const ¶ms, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset)
Constructor.
Definition: epilogue/threadblock/predicated_tile_iterator.h:596
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Index
typename Layout::Index Index
Definition: epilogue/threadblock/predicated_tile_iterator.h:462
cutlass::layout::RowMajor::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::advance_cluster
Index advance_cluster
amount to add to move to the next 'cluster' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:118
Defines a Shape template for matrix tiles.
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::store
CUTLASS_DEVICE void store(Fragment const &frag)
Stores a fragment to memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:654
Defines the size of an element in bits.
Definition: numeric_types.h:42
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::advance_row
Index advance_row
amount to add to move to the next 'row' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:488
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::clear_mask
CUTLASS_DEVICE void clear_mask()
Efficiently enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:699
cutlass::TensorRef< Element, Layout >
cutlass::epilogue::threadblock::PredicatedTileIterator::ThreadMap
ThreadMap_ ThreadMap
Definition: epilogue/threadblock/predicated_tile_iterator.h:67
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask
Mask object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:522
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::Params
CUTLASS_HOST_DEVICE Params()
Definition: epilogue/threadblock/predicated_tile_iterator.h:510
cutlass::epilogue::threadblock::PredicatedTileIterator::PredicatedTileIterator
CUTLASS_DEVICE PredicatedTileIterator(Params const ¶ms, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset=TensorCoord())
Constructor.
Definition: epilogue/threadblock/predicated_tile_iterator.h:240
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
cutlass::layout::PitchLinearCoord::contiguous
CUTLASS_HOST_DEVICE Index const & contiguous() const
Returns the contiguous dimension.
Definition: pitch_linear.h:89
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator
Definition: epilogue/threadblock/predicated_tile_iterator.h:452
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::ConstTensorRef
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: epilogue/threadblock/predicated_tile_iterator.h:460
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::set_iteration_index
CUTLASS_HOST_DEVICE void set_iteration_index(int iteration)
Overrides the internal iteration index.
Definition: epilogue/threadblock/predicated_tile_iterator.h:672
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::stride
Index stride
stride in bytes between rows
Definition: epilogue/threadblock/predicated_tile_iterator.h:110
cutlass::epilogue::threadblock::PredicatedTileIterator::operator++
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Advances to the next position to load or store.
Definition: epilogue/threadblock/predicated_tile_iterator.h:387
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::Params
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Definition: epilogue/threadblock/predicated_tile_iterator.h:163
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::stride
Index stride
stride in bytes between columns
Definition: epilogue/threadblock/predicated_tile_iterator.h:486
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::advance_column
Index advance_column
amount to add to move to the next 'column' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:489
cutlass::epilogue::threadblock::PredicatedTileIterator::Params
Definition: epilogue/threadblock/predicated_tile_iterator.h:104
cutlass::epilogue::threadblock::PredicatedTileIterator::kIterations
static int const kIterations
Definition: epilogue/threadblock/predicated_tile_iterator.h:82
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::advance_tile
Index advance_tile
amount to add to move to the next 'tile'
Definition: epilogue/threadblock/predicated_tile_iterator.h:119
[output_tile_thread_map.h](output tile thread__map_8h.html)
Metaprogram for determining the mapping of output elements to threads for epilogue tiles...
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::operator++
CUTLASS_HOST_DEVICE InterleavedPredicatedTileIterator & operator++()
Advances to the next position to load or store.
Definition: epilogue/threadblock/predicated_tile_iterator.h:679
cutlass::epilogue::threadblock::PredicatedTileIterator::clear_mask
CUTLASS_DEVICE void clear_mask()
Efficiently enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:422
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::increment_group
Index increment_group
increment quantity (in bytes) to advance when moving to the next group
Definition: epilogue/threadblock/predicated_tile_iterator.h:113
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
cutlass::epilogue::threadblock::PredicatedTileIterator::Index
typename Layout::Index Index
Definition: epilogue/threadblock/predicated_tile_iterator.h:76
cutlass::epilogue::threadblock::PredicatedTileIterator
Definition: epilogue/threadblock/predicated_tile_iterator.h:65
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::LongIndex
typename Layout::LongIndex LongIndex
Definition: epilogue/threadblock/predicated_tile_iterator.h:463
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::set_mask
CUTLASS_DEVICE void set_mask(Mask const &mask)
Definition: epilogue/threadblock/predicated_tile_iterator.h:714
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::Params
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Definition: epilogue/threadblock/predicated_tile_iterator.h:515
cutlass::epilogue::threadblock::PredicatedTileIterator::set_mask
CUTLASS_DEVICE void set_mask(Mask const &mask)
Definition: epilogue/threadblock/predicated_tile_iterator.h:437
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask::enable
CUTLASS_DEVICE void enable()
Definition: epilogue/threadblock/predicated_tile_iterator.h:547
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::advance_group
Index advance_group
amount to add to move to the next 'group' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:117
Defines layout functions used by TensorRef and derived classes.
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::add_pointer_offset
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: epilogue/threadblock/predicated_tile_iterator.h:630
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::initialize
CUTLASS_HOST_DEVICE Status initialize(Index stride_)
Definition: epilogue/threadblock/predicated_tile_iterator.h:126
Operation was successful.
cutlass::epilogue::threadblock::PredicatedTileIterator::add_pointer_offset
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: epilogue/threadblock/predicated_tile_iterator.h:273
cutlass::epilogue::threadblock::PredicatedTileIterator::Shape
typename ThreadMap::Shape Shape
Definition: epilogue/threadblock/predicated_tile_iterator.h:68
cutlass::layout::ColumnMajorInterleaved
Definition: layout/matrix.h:343
cutlass::epilogue::threadblock::PredicatedTileIterator::TensorCoord
MatrixCoord TensorCoord
Definition: epilogue/threadblock/predicated_tile_iterator.h:78
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::Params
CUTLASS_HOST_DEVICE Params()
Definition: epilogue/threadblock/predicated_tile_iterator.h:158
cutlass::epilogue::threadblock::PredicatedTileIterator::enable_mask
CUTLASS_DEVICE void enable_mask()
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:427
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::increment_row
Index increment_row
increment quantity (in bytes) to advance when moving between rows
Definition: epilogue/threadblock/predicated_tile_iterator.h:112
cutlass::epilogue::threadblock::PredicatedTileIterator::Params::increment_cluster
Index increment_cluster
increment quantity (in bytes) to advance when moving to the next cluster
Definition: epilogue/threadblock/predicated_tile_iterator.h:114
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::get_mask
CUTLASS_DEVICE void get_mask(Mask &mask)
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:709
Basic include for CUTLASS.
Definition: matrix_coord.h:39
cutlass::epilogue::threadblock::PredicatedTileIterator::Mask::clear
CUTLASS_HOST_DEVICE void clear()
CUTLASS_HOST_DEVICE enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:186
cutlass::layout::PitchLinearCoord::strided
CUTLASS_HOST_DEVICE Index const & strided() const
Returns the column of the coordinate.
Definition: pitch_linear.h:97
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::enable_mask
CUTLASS_DEVICE void enable_mask()
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:704
cutlass::epilogue::threadblock::PredicatedTileIterator::Mask::Mask
CUTLASS_HOST_DEVICE Mask()
Efficiently disables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:181
cutlass::epilogue::threadblock::PredicatedTileIterator::kThreads
static int const kThreads
Definition: epilogue/threadblock/predicated_tile_iterator.h:81
Generated by 1.8.11