docs/predicated__tile__iterator__2dthreadtile_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
predicated_tile_iterator_2dthreadtile.h
[Go to the documentation of this file.](predicated tile iterator__2dthreadtile_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
35 #pragma once
36
37 #include "[cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h](predicated tile access iterator 2dthreadtile_8h.html)"
38 #include "cutlass/transform/thread/transpose.h"
39
41
42 namespace cutlass {
43 namespace transform {
44 namespace threadblock {
45
47
84 // template <typename Iterator>
85 // __global__ void kernel(
86 // typename Iterator::Params params,
87 // typename Iterator::Element *ptr,
88 // TensorCoord extent) {
89 //
90 // typename Iterator::Fragment fragment;
91 //
92 // TensorCoord threadblock_offset(0, 0);
93 //
94 // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
95 //
96 //
97 // fragment = *iter; // load "residue" tile first
98 // ++iter; // advance to first "steady state" tile and update internal masks
99 //
100 //
101 // #pragma unroll
102 // for (int i = Remaining - 1; i >= 0; --i) {
103 //
104 // f(fragment);
105 //
106 // if (!i) {
107 // iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
108 // }
109 //
110 // fragment = *iter; // load tile during "steady state" phase
111 // ++iter; // advance to next tile - lightweight due to steady-state masks
112 // }
113 // }
114 //
115 // void host(TensorView<Element, 2, layout::PitchLinear> view) {
116 //
117 // using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile;
118 //
119 // typename Iterator::Params params(view.layout());
120 //
121 // kernel<Iterator>(params, view.data());
122 // }
125 template <
126typename Shape,
127typename Element,
128typename Layout,
129int AdvanceRank,
130typename ThreadMap,
131bool Transpose = false
132 >
133 class PredicatedTileIterator2dThreadTile;
134
136
144 template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_, bool Transpose_>
145 class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::PitchLinear, AdvanceRank, ThreadMap_, Transpose_> {
146public:
147static_assert(
148 AdvanceRank == 0 || AdvanceRank == 1,
149"Specialization for pitch-linear iterator may along advance along the "
150"contiguous(rank=0) or strided(rank=1) dimension.");
151
153using Element = Element_;
154using Layout = layout::PitchLinear;
155static int const kAdvanceRank = AdvanceRank;
156using ThreadMap = ThreadMap_;
157
158using Index = typename Layout::Index;
159using LongIndex = typename Layout::LongIndex;
160
161using TensorRef = TensorRef<Element, Layout>;
162using TensorView = TensorView<Element, Layout>;
163using TensorCoord = typename Layout::TensorCoord;
164
166using NonConstPointer = typename platform::remove_const<Element>::type *;
167
170struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value /
172 8)) AccessType {
174 Array<Element, ThreadMap::kElementsPerAccess> storage;
176static int const kElements = ThreadMap::kElementsPerAccess;
177 };
178
180using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>;
181static bool const transpose = Transpose_;
182
184using TileAccessIterator =
185PredicatedTileAccessIterator2dThreadTile<Shape, Element, Layout, kAdvanceRank,
186 ThreadMap, AccessType>;
187
189using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
190 ThreadMap::ThreadAccessShape::kCount>;
191
193using Mask = typename TileAccessIterator::Mask;
194
196class Params {
197public:
198friend PredicatedTileIterator2dThreadTile;
199
200private:
202typename TileAccessIterator::Params params_;
203
204public:
207 Params(Layout const &layout) : params_(layout) { }
208
210 Params() { }
211 };
212
213private:
215using BytePointer = char *;
216
217private:
218//
219// Data members
220//
221
223TileAccessIterator address_iterator_;
224
225public:
229PredicatedTileIterator2dThreadTile(
231 Params const ¶ms,
233Pointer pointer,
235TensorCoord extent,
237int thread_id,
239TensorCoord const &threadblock_offset)
240 : address_iterator_(params.params_, pointer, extent, thread_id,
241 threadblock_offset) {}
242
245PredicatedTileIterator2dThreadTile(
246 Params const ¶ms,
247Pointer pointer,
248TensorCoord extent,
249int thread_id
250 )
251 : PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id,
252make_Coord(0, 0)) {}
253
256void add_pointer_offset(LongIndex pointer_offset) {
257 address_iterator_.add_pointer_offset(pointer_offset);
258 }
259
267PredicatedTileIterator2dThreadTile &operator++() {
268if (kAdvanceRank)
269 address_iterator_.add_tile_offset({0, 1});
270else
271 address_iterator_.add_tile_offset({1, 0});
272
273return *this;
274 }
275
283PredicatedTileIterator2dThreadTile operator++(int) {
284PredicatedTileIterator2dThreadTile self(*this);
285operator++();
286return self;
287 }
288
291void clear_mask() { address_iterator_.clear_mask(); }
292
295void enable_mask() { address_iterator_.enable_mask(); }
296
299void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
300
303void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
304
306 CUTLASS_DEVICE
307void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
308
309 AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
310
312for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
314for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
316for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){
317
318int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \
319 s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
320
321 address_iterator_.set_iteration_index(access_idx);
322if (address_iterator_.valid()) {
323
324 frag_ptr[access_idx] =
325 *(address_iterator_.get() + pointer_offset);
326 }
327
328 ++address_iterator_;
329 }
330 }
331 }
332
333if (transpose) {
334Transform t;
335 t.transform(frag, frag);
336 }
337 }
338
340 CUTLASS_DEVICE
341void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
342
344 CUTLASS_DEVICE
345void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
346
347 AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
348
350for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
352for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
354for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){
355
356int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \
357 s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
358
359 address_iterator_.set_iteration_index(access_idx);
360if (address_iterator_.valid()) {
361 *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx];
362 }
363 ++address_iterator_;
364 }
365 }
366 }
367 }
368
370 CUTLASS_DEVICE
371void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
372 };
373
375
383 template <
384typename Shape_,
385typename Element_,
386int AdvanceRank,
387typename ThreadMap_,
388bool Transpose_
389 >
390 class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_, Transpose_> {
391 public:
392
393static_assert(AdvanceRank == 0 || AdvanceRank == 1,
394"Specialization for pitch-linear iterator may along advance along the "
395"contiguous(rank=0) or strided(rank=1) dimension.");
398using Element = Element_;
399using Layout = layout::ColumnMajor;
400static int const kAdvanceRank = AdvanceRank;
401using ThreadMap = ThreadMap_;
402static bool const Transpose = Transpose_;
404using Index = typename Layout::Index;
405using LongIndex = typename Layout::LongIndex;
407using TensorRef = TensorRef<Element, Layout>;
408using TensorView = TensorView<Element, Layout>;
409using TensorCoord = typename Layout::TensorCoord;
412using NonConstPointer = typename platform::remove_const<Element>::type *;
413
414using UnderlyingIterator = PredicatedTileIterator2dThreadTile<
415layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
416 Element,
418 (kAdvanceRank == 0 ? 0 : 1),
419 ThreadMap,
420 Transpose
421 >;
423using AccessType = typename UnderlyingIterator::AccessType;
424
426using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
427
429using Mask = typename UnderlyingIterator::Mask;
430
432class Params {
433private:
434
435friend PredicatedTileIterator2dThreadTile;
436
438typename UnderlyingIterator::Params params_;
439
440public:
441
443 Params() { }
444
447 Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
448
449 }
450 };
451
452
453 private:
454
455//
456// Data members
457//
458
460UnderlyingIterator iterator_;
461
462 public:
463
466PredicatedTileIterator2dThreadTile(
467 Params const ¶ms,
468Pointer pointer,
469TensorCoord extent,
470int thread_id,
471TensorCoord const &threadblock_offset
472 ):
473 iterator_(
474 params.params_,
475 pointer,
476 layout::PitchLinearCoord(extent.row(), extent.column()),
477 thread_id,
478 layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
479 ) { }
480
483PredicatedTileIterator2dThreadTile(
484 Params const ¶ms,
485Pointer pointer,
486TensorCoord extent,
487int thread_id
488 ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
489
492void add_pointer_offset(LongIndex pointer_offset) {
493 iterator_.add_pointer_offset(pointer_offset);
494 }
495
502PredicatedTileIterator2dThreadTile &operator++() {
503 ++iterator_;
504return *this;
505 }
506
513PredicatedTileIterator2dThreadTile operator++(int) {
514PredicatedTileIterator2dThreadTile self(*this);
515operator++();
516return self;
517 }
518
521void clear_mask() {
522 iterator_.clear_mask();
523 }
524
527void enable_mask() {
528 iterator_.enable_mask();
529 }
530
533void set_mask(Mask const &mask) {
534 iterator_.set_mask(mask);
535 }
536
539void get_mask(Mask &mask) {
540 iterator_.get_mask(mask);
541 }
542
544 CUTLASS_DEVICE
545void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
546 iterator_.load_with_pointer_offset(frag, pointer_offset);
547 }
548
550 CUTLASS_DEVICE
551void load(Fragment &frag) {
552 load_with_pointer_offset(frag, 0);
553 }
554
556 CUTLASS_DEVICE
557void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
558 iterator_.store_with_pointer_offset(frag, pointer_offset);
559 }
560
562 CUTLASS_DEVICE
563void store(Fragment const &frag) {
564 store_with_pointer_offset(frag, 0);
565 }
566 };
567
569
577 template <
578typename Shape_,
579typename Element_,
580int AdvanceRank,
581typename ThreadMap_,
582bool Transpose_
583 >
584 class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_, Transpose_> {
585 public:
586
587static_assert(AdvanceRank == 0 || AdvanceRank == 1,
588"Specialization for pitch-linear iterator may along advance along the "
589"contiguous(rank=0) or strided(rank=1) dimension.");
592using Element = Element_;
593using Layout = layout::RowMajor;
594static int const kAdvanceRank = AdvanceRank;
595using ThreadMap = ThreadMap_;
596static bool const Transpose = Transpose_;
598using Index = typename Layout::Index;
599using LongIndex = typename Layout::LongIndex;
601using TensorRef = TensorRef<Element, Layout>;
602using TensorView = TensorView<Element, Layout>;
603using TensorCoord = typename Layout::TensorCoord;
606using NonConstPointer = typename platform::remove_const<Element>::type *;
607
608using UnderlyingIterator = PredicatedTileIterator2dThreadTile<
609layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
610 Element,
612 (kAdvanceRank == 0 ? 1 : 0),
613 ThreadMap,
614 Transpose
615 >;
617using AccessType = typename UnderlyingIterator::AccessType;
618
620using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
621
623using Mask = typename UnderlyingIterator::Mask;
624
626class Params {
627private:
628
629friend PredicatedTileIterator2dThreadTile;
630
632typename UnderlyingIterator::Params params_;
633
634public:
635
637 Params() { }
638
641 Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
642
643 };
644 };
645
646
647 private:
648
649//
650// Data members
651//
652
654UnderlyingIterator iterator_;
655
656 public:
657
660PredicatedTileIterator2dThreadTile(
661 Params const ¶ms,
662Pointer pointer,
663TensorCoord extent,
664int thread_id,
665TensorCoord const &threadblock_offset
666 ):
667 iterator_(
668 params.params_,
669 pointer,
670 layout::PitchLinearCoord(extent.column(), extent.row()),
671 thread_id,
672 layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
673 ) { }
674
677PredicatedTileIterator2dThreadTile(
678 Params const ¶ms,
679Pointer pointer,
680TensorCoord extent,
681int thread_id
682 ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
683
686void add_pointer_offset(LongIndex pointer_offset) {
687 iterator_.add_pointer_offset(pointer_offset);
688 }
689
696PredicatedTileIterator2dThreadTile &operator++() {
697 ++iterator_;
698return *this;
699 }
700
707PredicatedTileIterator2dThreadTile operator++(int) {
708PredicatedTileIterator2dThreadTile self(*this);
709operator++();
710return self;
711 }
712
715void clear_mask() {
716 iterator_.clear_mask();
717 }
718
721void enable_mask() {
722 iterator_.enable_mask();
723 }
724
727void set_mask(Mask const &mask) {
728 iterator_.set_mask(mask);
729 }
730
733void get_mask(Mask &mask) {
734 iterator_.get_mask(mask);
735 }
736
738 CUTLASS_DEVICE
739void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
740 iterator_.load_with_pointer_offset(frag, pointer_offset);
741 }
742
744 CUTLASS_DEVICE
745void load(Fragment &frag) {
746 load_with_pointer_offset(frag, 0);
747 }
748
750 CUTLASS_DEVICE
751void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
752 iterator_.store_with_pointer_offset(frag, pointer_offset);
753 }
754
756 CUTLASS_DEVICE
757void store(Fragment const &frag) {
758 store_with_pointer_offset(frag, 0);
759 }
760 };
761
763
764 } // namespace threadblock
765 } // namespace transform
766 } // namespace cutlass
767
cutlass::layout::RowMajor::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_iterator_2dthreadtile.h:163
Definition: aligned_buffer.h:35
cutlass::layout::PitchLinearCoord
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
cutlass::platform::remove_const::type
T type
Definition: platform.h:351
Shape_ Shape
Definition: predicated_tile_iterator_2dthreadtile.h:590
Basic copy routines for tensor views.
typename Layout::LongIndex LongIndex
Definition: predicated_tile_iterator_2dthreadtile.h:598
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
cutlass::layout::ColumnMajor::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:154
typename Layout::Index Index
Definition: predicated_tile_iterator_2dthreadtile.h:158
typename UnderlyingIterator::AccessType AccessType
Definition: predicated_tile_iterator_2dthreadtile.h:422
Element * Pointer
Definition: predicated_tile_iterator_2dthreadtile.h:165
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_iterator_2dthreadtile.h:602
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::ThreadAccessShape::kCount > Fragment
Fragment object to be loaded or stored.
Definition: predicated_tile_iterator_2dthreadtile.h:619
typename Layout::Index Index
Definition: predicated_tile_iterator_2dthreadtile.h:597
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
cutlass::layout::PitchLinearShape
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::layout::RowMajor::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
cutlass::layout::PitchLinear::LongIndex
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
typename Layout::LongIndex LongIndex
Definition: predicated_tile_iterator_2dthreadtile.h:159
cutlass::TensorView< Element, Layout >
Element * Pointer
Definition: predicated_tile_iterator_2dthreadtile.h:410
typename TileAccessIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_iterator_2dthreadtile.h:192
typename Layout::LongIndex LongIndex
Definition: predicated_tile_iterator_2dthreadtile.h:404
cutlass::TensorRef< Element, Layout >
typename Layout::Index Index
Definition: predicated_tile_iterator_2dthreadtile.h:403
typename UnderlyingIterator::AccessType AccessType
Definition: predicated_tile_iterator_2dthreadtile.h:616
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::transform::thread::Transpose
Transforms a fragment by doing a transpose.
Definition: transpose.h:39
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::layout::PitchLinear::Index
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Shape_ Shape
Definition: predicated_tile_iterator_2dthreadtile.h:152
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_iterator_2dthreadtile.h:428
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_iterator_2dthreadtile.h:605
cutlass::layout::ColumnMajor::Index
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:151
Element * Pointer
Definition: predicated_tile_iterator_2dthreadtile.h:604
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_iterator_2dthreadtile.h:622
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_iterator_2dthreadtile.h:166
[predicated_tile_access_iterator_2dthreadtile.h](predicated tile access iterator 2dthreadtile_8h.html)
Templates calculating the address and predicates to the load of tiles from pitch-linear rank=2 tensor...
cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile
Definition: predicated_tile_iterator_2dthreadtile.h:133
Shape_ Shape
Definition: predicated_tile_iterator_2dthreadtile.h:396
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_iterator_2dthreadtile.h:411
Definition: matrix_coord.h:39
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::ThreadAccessShape::kCount > Fragment
Fragment object to be loaded or stored.
Definition: predicated_tile_iterator_2dthreadtile.h:189
typename Layout::TensorCoord TensorCoord
Definition: predicated_tile_iterator_2dthreadtile.h:408
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::ThreadAccessShape::kCount > Fragment
Fragment object to be loaded or stored.
Definition: predicated_tile_iterator_2dthreadtile.h:425
<!-- fragment --> <!-- contents --><!-- start footer part -->