docs/predicate__vector_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
predicate_vector.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #if !defined(__CUDACC_RTC__)
32 #include <assert.h>
33 #endif
34 #include <stdint.h>
35
36 #include "cutlass/cutlass.h"
37
38 #include "cutlass/platform/platform.h"
39
40 namespace cutlass {
41
43
60
80
96
99 template <
101int kPredicates_,
103int kPredicatesPerByte_ = 4,
105int kPredicateStart_ = 0>
106 struct PredicateVector {
108static int const kPredicates = kPredicates_;
109
111static int const kPredicatesPerByte = kPredicatesPerByte_;
112
114static int const kPredicateStart = kPredicateStart_;
115
116// Make sure no one tries to put more than 8 bits in a byte :)
117static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
118// Make sure the "offsetted" bits fit in one byte.
119static_assert(kPredicateStart + kPredicatesPerByte <= 8,
120"The offsetted predicates must fit within an actual byte.");
121
123typedef uint32_t Storage;
124
126static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
127
129static int const kWordCount = (kBytes + sizeof(Storage) - 1) / sizeof(Storage);
130
131private:
132//
133// Data members
134//
135
137 Storage storageData[kWordCount];
138
139//
140// Methods
141//
142
144CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const {
145CUTLASS_ASSERT(idx < kPredicates);
146
147int byte = (idx / kPredicatesPerByte);
148int bit_offset = (idx % kPredicatesPerByte);
149
150 word = byte / sizeof(Storage);
151int byte_offset = (byte % sizeof(Storage));
152
153 bit = byte_offset * 8 + bit_offset + kPredicateStart;
154 }
155
157CUTLASS_HOST_DEVICE Storage &storage(int word) {
158CUTLASS_ASSERT(word < kWordCount);
159return storageData[word];
160 }
161
163CUTLASS_HOST_DEVICE Storage const &storage(int word) const {
164CUTLASS_ASSERT(word < kWordCount);
165return storageData[word];
166 }
167
168public:
169//
170// Iterator
171//
172
180PredicateVector &vec_;
181
183int bit_;
184
185public:
188Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
189
192Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {}
193
196Iterator &operator++() {
197 ++bit_;
198return *this;
199 }
200
203Iterator &operator+=(int offset) {
204 bit_ += offset;
205return *this;
206 }
207
210Iterator &operator--() {
211 --bit_;
212return *this;
213 }
214
217Iterator &operator-=(int offset) {
218 bit_ -= offset;
219return *this;
220 }
221
224Iterator operator++(int) {
225Iterator ret(*this);
226 ret.bit_++;
227return ret;
228 }
229
232Iterator operator--(int) {
233Iterator ret(*this);
234 ret.bit_--;
235return ret;
236 }
237
240Iterator operator+(int offset) {
241Iterator ret(*this);
242 ret.bit_ += offset;
243return ret;
244 }
245
248Iterator operator-(int offset) {
249ConstIterator ret(*this);
250 ret.bit_ -= offset;
251return ret;
252 }
253
256bool operator==(Iterator const &it) const { return bit_ == it.bit_; }
257
260bool operator!=(Iterator const &it) const { return bit_ != it.bit_; }
261
264bool get() { return vec_.at(bit_); }
265
268bool at() const { return vec_.at(bit_); }
269
272bool operator*() const { return at(); }
273
276void set(bool value = true) { vec_.set(bit_, value); }
277 };
278
284class ConstIterator {
286PredicateVector const &vec_;
287
289int bit_;
290
291public:
294ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
295
298ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {}
299
302ConstIterator &operator++() {
303 ++bit_;
304return *this;
305 }
306
309ConstIterator &operator+=(int offset) {
310 bit_ += offset;
311return *this;
312 }
313
316ConstIterator &operator--() {
317 --bit_;
318return *this;
319 }
320
323ConstIterator &operator-=(int offset) {
324 bit_ -= offset;
325return *this;
326 }
327
330ConstIterator operator++(int) {
331ConstIterator ret(*this);
332 ret.bit_++;
333return ret;
334 }
335
338ConstIterator operator--(int) {
339ConstIterator ret(*this);
340 ret.bit_--;
341return ret;
342 }
343
346ConstIterator operator+(int offset) {
347ConstIterator ret(*this);
348 ret.bit_ += offset;
349return ret;
350 }
351
354ConstIterator operator-(int offset) {
355ConstIterator ret(*this);
356 ret.bit_ -= offset;
357return ret;
358 }
359
362bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }
363
366bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; }
367
370bool get() { return vec_.at(bit_); }
371
374bool at() const { return vec_.at(bit_); }
375
378bool operator*() const { return at(); }
379 };
380
382struct TrivialIterator {
385TrivialIterator() {}
386
389TrivialIterator(Iterator const &it) {}
390
393TrivialIterator(PredicateVector const &_vec) {}
394
397TrivialIterator &operator++() { return *this; }
398
401TrivialIterator operator++(int) { return *this; }
402
405bool operator*() const { return true; }
406 };
407
408public:
409//
410// Methods
411//
412
414CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); }
415
417CUTLASS_HOST_DEVICE void fill(bool value = true) {
418 Storage item = (value ? ~Storage(0) : Storage(0));
419
421for (int i = 0; i < kWordCount; ++i) {
422 storage(i) = item;
423 }
424 }
425
427CUTLASS_HOST_DEVICE void clear() {
429for (int i = 0; i < kWordCount; ++i) {
430 storage(i) = 0;
431 }
432 }
433
435CUTLASS_HOST_DEVICE void enable() {
437for (int i = 0; i < kWordCount; ++i) {
438 storage(i) = ~Storage(0);
439 }
440 }
441
443CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); }
444
446CUTLASS_HOST_DEVICE bool at(int idx) const {
447int bit, word;
448 computeStorageOffset(word, bit, idx);
449
450return ((storage(word) >> bit) & 1);
451 }
452
454CUTLASS_HOST_DEVICE void set(int idx, bool value = true) {
455int bit, word;
456 computeStorageOffset(word, bit, idx);
457
458 Storage disable_mask = (~(Storage(1) << bit));
459 Storage enable_mask = (Storage(value) << bit);
460
461 storage(word) = ((storage(word) & disable_mask) | enable_mask);
462 }
463
465CUTLASS_HOST_DEVICE PredicateVector &operator&=(PredicateVector const &predicates) {
467for (int i = 0; i < kWordCount; ++i) {
468 storage(i) = (storage(i) & predicates.storage(i));
469 }
470return *this;
471 }
472
474CUTLASS_HOST_DEVICE PredicateVector &operator|=(PredicateVector const &predicates) {
476for (int i = 0; i < kWordCount; ++i) {
477 storage(i) = (storage(i) | predicates.storage(i));
478 }
479return *this;
480 }
481
483CUTLASS_HOST_DEVICE bool is_zero() const {
484 Storage mask(0);
485for (int byte = 0; byte < sizeof(Storage); ++byte) {
486 Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
487 mask |= (byte_mask << (byte * 8));
488 }
489 uint32_t result = 0;
490for (int word = 0; word < kWordCount; ++word) {
491 result |= storage(word);
492 }
493return result == 0;
494 }
495
497 CUTLASS_DEVICE
498Iterator begin() { return Iterator(*this); }
499
501 CUTLASS_DEVICE
502Iterator end() { return Iterator(*this, kPredicates); }
503
505 CUTLASS_DEVICE
506ConstIterator const_begin() const { return ConstIterator(*this); }
507
509 CUTLASS_DEVICE
510ConstIterator const_end() const { return ConstIterator(*this, kPredicates); }
511 };
512
514
515 } // namespace cutlass
cutlass::PredicateVector::operator|=
CUTLASS_HOST_DEVICE PredicateVector & operator|=(PredicateVector const &predicates)
Computes the union of two identical predicate vectors.
Definition: predicate_vector.h:474
cutlass::PredicateVector::TrivialIterator::operator++
CUTLASS_HOST_DEVICE TrivialIterator & operator++()
Pre-increment.
Definition: predicate_vector.h:397
cutlass::PredicateVector::ConstIterator::ConstIterator
CUTLASS_HOST_DEVICE ConstIterator(PredicateVector const &vec, int _start=0)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:298
Definition: aligned_buffer.h:35
cutlass::PredicateVector::Storage
uint32_t Storage
Storage type of individual elements.
Definition: predicate_vector.h:117
cutlass::PredicateVector::TrivialIterator::TrivialIterator
CUTLASS_HOST_DEVICE TrivialIterator(PredicateVector const &_vec)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:393
cutlass::PredicateVector::is_zero
CUTLASS_HOST_DEVICE bool is_zero() const
Returns true if entire predicate array is zero.
Definition: predicate_vector.h:483
cutlass::PredicateVector::ConstIterator::operator--
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Pre-decrement.
Definition: predicate_vector.h:316
cutlass::PredicateVector::kBytes
static int const kBytes
Number of bytes needed.
Definition: predicate_vector.h:126
cutlass::PredicateVector::const_end
CUTLASS_DEVICE ConstIterator const_end() const
Returns a ConstIterator.
Definition: predicate_vector.h:510
cutlass::PredicateVector::ConstIterator::at
CUTLASS_HOST_DEVICE bool at() const
Gets the bit at the pointed to location.
Definition: predicate_vector.h:374
cutlass::PredicateVector::ConstIterator::operator++
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Pre-increment.
Definition: predicate_vector.h:302
cutlass::PredicateVector::enable
CUTLASS_HOST_DEVICE void enable()
Sets all predicates to true.
Definition: predicate_vector.h:435
cutlass::PredicateVector::ConstIterator::operator==
CUTLASS_HOST_DEVICE bool operator==(ConstIterator const &it) const
Returns true if iterators point to the same bit.
Definition: predicate_vector.h:362
cutlass::PredicateVector::operator[]
CUTLASS_HOST_DEVICE bool operator[](int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:443
cutlass::PredicateVector::Iterator::operator==
CUTLASS_HOST_DEVICE bool operator==(Iterator const &it) const
Returns true if iterators point to the same bit.
Definition: predicate_vector.h:256
cutlass::PredicateVector::Iterator::operator!=
CUTLASS_HOST_DEVICE bool operator!=(Iterator const &it) const
Returns false if iterators point to the same bit.
Definition: predicate_vector.h:260
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:446
cutlass::PredicateVector::Iterator::Iterator
CUTLASS_HOST_DEVICE Iterator(PredicateVector &vec, int _start=0)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:192
cutlass::PredicateVector::ConstIterator::operator++
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Post-increment.
Definition: predicate_vector.h:330
cutlass::PredicateVector::Iterator::operator++
CUTLASS_HOST_DEVICE Iterator operator++(int)
Post-increment.
Definition: predicate_vector.h:224
cutlass::PredicateVector::TrivialIterator::TrivialIterator
CUTLASS_HOST_DEVICE TrivialIterator(Iterator const &it)
Copy constructor.
Definition: predicate_vector.h:389
cutlass::PredicateVector::Iterator::operator*
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:272
C++ features that may be otherwise unimplemented for CUDA device functions.
cutlass::PredicateVector::TrivialIterator
Iterator that always returns true.
Definition: predicate_vector.h:382
cutlass::PredicateVector::TrivialIterator::operator++
CUTLASS_HOST_DEVICE TrivialIterator operator++(int)
Post-increment.
Definition: predicate_vector.h:401
cutlass::PredicateVector::ConstIterator::operator+
CUTLASS_HOST_DEVICE ConstIterator operator+(int offset)
Iterator advances by some amount.
Definition: predicate_vector.h:346
cutlass::PredicateVector::ConstIterator::operator+=
CUTLASS_HOST_DEVICE ConstIterator & operator+=(int offset)
Increment.
Definition: predicate_vector.h:309
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::PredicateVector::const_begin
CUTLASS_DEVICE ConstIterator const_begin() const
Returns a ConstIterator.
Definition: predicate_vector.h:506
cutlass::PredicateVector::Iterator::operator--
CUTLASS_HOST_DEVICE Iterator & operator--()
Pre-decrement.
Definition: predicate_vector.h:210
cutlass::PredicateVector::Iterator::operator-=
CUTLASS_HOST_DEVICE Iterator & operator-=(int offset)
Decrement.
Definition: predicate_vector.h:217
cutlass::PredicateVector::Iterator::Iterator
CUTLASS_HOST_DEVICE Iterator(Iterator const &it)
Copy constructor.
Definition: predicate_vector.h:188
cutlass::PredicateVector::ConstIterator::operator-
CUTLASS_HOST_DEVICE ConstIterator operator-(int offset)
Iterator recedes by some amount.
Definition: predicate_vector.h:354
cutlass::PredicateVector::Iterator::operator+
CUTLASS_HOST_DEVICE Iterator operator+(int offset)
Iterator advances by some amount.
Definition: predicate_vector.h:240
cutlass::PredicateVector::fill
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:417
cutlass::PredicateVector::kPredicates
static int const kPredicates
Number of bits stored by the PredicateVector.
Definition: predicate_vector.h:108
CUTLASS_DEVICE Iterator end()
Returns an iterator.
Definition: predicate_vector.h:502
#define CUTLASS_ASSERT(x)
Definition: cutlass.h:92
cutlass::PredicateVector::Iterator::operator+=
CUTLASS_HOST_DEVICE Iterator & operator+=(int offset)
Increment.
Definition: predicate_vector.h:203
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::PredicateVector::kPredicatesPerByte
static int const kPredicatesPerByte
Number of bits stored within each byte of the predicate bit vector.
Definition: predicate_vector.h:111
cutlass::PredicateVector::operator&=
CUTLASS_HOST_DEVICE PredicateVector & operator&=(PredicateVector const &predicates)
Computes the intersection of two identical predicate vectors.
Definition: predicate_vector.h:465
#define static_assert(__e, __m)
Definition: platform.h:153
Statically sized array of bits implementing.
Definition: predicate_vector.h:106
cutlass::PredicateVector::kWordCount
static int const kWordCount
Number of storage elements needed.
Definition: predicate_vector.h:129
cutlass::PredicateVector::ConstIterator::operator!=
CUTLASS_HOST_DEVICE bool operator!=(ConstIterator const &it) const
Returns false if iterators point to the same bit.
Definition: predicate_vector.h:366
cutlass::PredicateVector::Iterator::at
CUTLASS_HOST_DEVICE bool at() const
Gets the bit at the pointed to location.
Definition: predicate_vector.h:268
cutlass::PredicateVector::Iterator::operator++
CUTLASS_HOST_DEVICE Iterator & operator++()
Pre-increment.
Definition: predicate_vector.h:196
cutlass::PredicateVector::ConstIterator
An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...
Definition: predicate_vector.h:284
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:454
cutlass::PredicateVector::Iterator::operator-
CUTLASS_HOST_DEVICE Iterator operator-(int offset)
Iterator recedes by some amount.
Definition: predicate_vector.h:248
cutlass::PredicateVector::kPredicateStart
static int const kPredicateStart
First bit withing each byte containing predicates.
Definition: predicate_vector.h:114
cutlass::PredicateVector::ConstIterator::ConstIterator
CUTLASS_HOST_DEVICE ConstIterator(ConstIterator const &it)
Copy constructor.
Definition: predicate_vector.h:294
cutlass::PredicateVector::ConstIterator::operator*
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:378
cutlass::PredicateVector::ConstIterator::operator--
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Post-decrement.
Definition: predicate_vector.h:338
cutlass::PredicateVector::clear
CUTLASS_HOST_DEVICE void clear()
Clears all predicates.
Definition: predicate_vector.h:427
cutlass::PredicateVector::PredicateVector
CUTLASS_HOST_DEVICE PredicateVector(bool value=true)
Initialize the predicate vector.
Definition: predicate_vector.h:414
cutlass::PredicateVector::begin
CUTLASS_DEVICE Iterator begin()
Returns an iterator to the start of the bit vector.
Definition: predicate_vector.h:498
Basic include for CUTLASS.
cutlass::PredicateVector::TrivialIterator::operator*
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:405
cutlass::PredicateVector::Iterator
An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...
Definition: predicate_vector.h:178
cutlass::PredicateVector::Iterator::operator--
CUTLASS_HOST_DEVICE Iterator operator--(int)
Post-decrement.
Definition: predicate_vector.h:232
cutlass::PredicateVector::ConstIterator::operator-=
CUTLASS_HOST_DEVICE ConstIterator & operator-=(int offset)
Decrement.
Definition: predicate_vector.h:323
cutlass::PredicateVector::TrivialIterator::TrivialIterator
CUTLASS_HOST_DEVICE TrivialIterator()
Constructor.
Definition: predicate_vector.h:385
Generated by 1.8.11