Back to Cutlass

CUTLASS: predicate_vector.h Source File

docs/predicate__vector_8h_source.html

4.4.244.0 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

predicate_vector.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #if !defined(__CUDACC_RTC__)

32 #include <assert.h>

33 #endif

34 #include <stdint.h>

35

36 #include "cutlass/cutlass.h"

37

38 #include "cutlass/platform/platform.h"

39

40 namespace cutlass {

41

43

60

80

96

99 template <

101int kPredicates_,

103int kPredicatesPerByte_ = 4,

105int kPredicateStart_ = 0>

106 struct PredicateVector {

108static int const kPredicates = kPredicates_;

109

111static int const kPredicatesPerByte = kPredicatesPerByte_;

112

114static int const kPredicateStart = kPredicateStart_;

115

116// Make sure no one tries to put more than 8 bits in a byte :)

117static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");

118// Make sure the "offsetted" bits fit in one byte.

119static_assert(kPredicateStart + kPredicatesPerByte <= 8,

120"The offsetted predicates must fit within an actual byte.");

121

123typedef uint32_t Storage;

124

126static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;

127

129static int const kWordCount = (kBytes + sizeof(Storage) - 1) / sizeof(Storage);

130

131private:

132//

133// Data members

134//

135

137 Storage storageData[kWordCount];

138

139//

140// Methods

141//

142

144CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const {

145CUTLASS_ASSERT(idx < kPredicates);

146

147int byte = (idx / kPredicatesPerByte);

148int bit_offset = (idx % kPredicatesPerByte);

149

150 word = byte / sizeof(Storage);

151int byte_offset = (byte % sizeof(Storage));

152

153 bit = byte_offset * 8 + bit_offset + kPredicateStart;

154 }

155

157CUTLASS_HOST_DEVICE Storage &storage(int word) {

158CUTLASS_ASSERT(word < kWordCount);

159return storageData[word];

160 }

161

163CUTLASS_HOST_DEVICE Storage const &storage(int word) const {

164CUTLASS_ASSERT(word < kWordCount);

165return storageData[word];

166 }

167

168public:

169//

170// Iterator

171//

172

178class Iterator {

180PredicateVector &vec_;

181

183int bit_;

184

185public:

187CUTLASS_HOST_DEVICE

188Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {}

189

191CUTLASS_HOST_DEVICE

192Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {}

193

195CUTLASS_HOST_DEVICE

196Iterator &operator++() {

197 ++bit_;

198return *this;

199 }

200

202CUTLASS_HOST_DEVICE

203Iterator &operator+=(int offset) {

204 bit_ += offset;

205return *this;

206 }

207

209CUTLASS_HOST_DEVICE

210Iterator &operator--() {

211 --bit_;

212return *this;

213 }

214

216CUTLASS_HOST_DEVICE

217Iterator &operator-=(int offset) {

218 bit_ -= offset;

219return *this;

220 }

221

223CUTLASS_HOST_DEVICE

224Iterator operator++(int) {

225Iterator ret(*this);

226 ret.bit_++;

227return ret;

228 }

229

231CUTLASS_HOST_DEVICE

232Iterator operator--(int) {

233Iterator ret(*this);

234 ret.bit_--;

235return ret;

236 }

237

239CUTLASS_HOST_DEVICE

240Iterator operator+(int offset) {

241Iterator ret(*this);

242 ret.bit_ += offset;

243return ret;

244 }

245

247CUTLASS_HOST_DEVICE

248Iterator operator-(int offset) {

249ConstIterator ret(*this);

250 ret.bit_ -= offset;

251return ret;

252 }

253

255CUTLASS_HOST_DEVICE

256bool operator==(Iterator const &it) const { return bit_ == it.bit_; }

257

259CUTLASS_HOST_DEVICE

260bool operator!=(Iterator const &it) const { return bit_ != it.bit_; }

261

263CUTLASS_HOST_DEVICE

264bool get() { return vec_.at(bit_); }

265

267CUTLASS_HOST_DEVICE

268bool at() const { return vec_.at(bit_); }

269

271CUTLASS_HOST_DEVICE

272bool operator*() const { return at(); }

273

275CUTLASS_HOST_DEVICE

276void set(bool value = true) { vec_.set(bit_, value); }

277 };

278

284class ConstIterator {

286PredicateVector const &vec_;

287

289int bit_;

290

291public:

293CUTLASS_HOST_DEVICE

294ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}

295

297CUTLASS_HOST_DEVICE

298ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {}

299

301CUTLASS_HOST_DEVICE

302ConstIterator &operator++() {

303 ++bit_;

304return *this;

305 }

306

308CUTLASS_HOST_DEVICE

309ConstIterator &operator+=(int offset) {

310 bit_ += offset;

311return *this;

312 }

313

315CUTLASS_HOST_DEVICE

316ConstIterator &operator--() {

317 --bit_;

318return *this;

319 }

320

322CUTLASS_HOST_DEVICE

323ConstIterator &operator-=(int offset) {

324 bit_ -= offset;

325return *this;

326 }

327

329CUTLASS_HOST_DEVICE

330ConstIterator operator++(int) {

331ConstIterator ret(*this);

332 ret.bit_++;

333return ret;

334 }

335

337CUTLASS_HOST_DEVICE

338ConstIterator operator--(int) {

339ConstIterator ret(*this);

340 ret.bit_--;

341return ret;

342 }

343

345CUTLASS_HOST_DEVICE

346ConstIterator operator+(int offset) {

347ConstIterator ret(*this);

348 ret.bit_ += offset;

349return ret;

350 }

351

353CUTLASS_HOST_DEVICE

354ConstIterator operator-(int offset) {

355ConstIterator ret(*this);

356 ret.bit_ -= offset;

357return ret;

358 }

359

361CUTLASS_HOST_DEVICE

362bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }

363

365CUTLASS_HOST_DEVICE

366bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; }

367

369CUTLASS_HOST_DEVICE

370bool get() { return vec_.at(bit_); }

371

373CUTLASS_HOST_DEVICE

374bool at() const { return vec_.at(bit_); }

375

377CUTLASS_HOST_DEVICE

378bool operator*() const { return at(); }

379 };

380

382struct TrivialIterator {

384CUTLASS_HOST_DEVICE

385TrivialIterator() {}

386

388CUTLASS_HOST_DEVICE

389TrivialIterator(Iterator const &it) {}

390

392CUTLASS_HOST_DEVICE

393TrivialIterator(PredicateVector const &_vec) {}

394

396CUTLASS_HOST_DEVICE

397TrivialIterator &operator++() { return *this; }

398

400CUTLASS_HOST_DEVICE

401TrivialIterator operator++(int) { return *this; }

402

404CUTLASS_HOST_DEVICE

405bool operator*() const { return true; }

406 };

407

408public:

409//

410// Methods

411//

412

414CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); }

415

417CUTLASS_HOST_DEVICE void fill(bool value = true) {

418 Storage item = (value ? ~Storage(0) : Storage(0));

419

420CUTLASS_PRAGMA_UNROLL

421for (int i = 0; i < kWordCount; ++i) {

422 storage(i) = item;

423 }

424 }

425

427CUTLASS_HOST_DEVICE void clear() {

428CUTLASS_PRAGMA_UNROLL

429for (int i = 0; i < kWordCount; ++i) {

430 storage(i) = 0;

431 }

432 }

433

435CUTLASS_HOST_DEVICE void enable() {

436CUTLASS_PRAGMA_UNROLL

437for (int i = 0; i < kWordCount; ++i) {

438 storage(i) = ~Storage(0);

439 }

440 }

441

443CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); }

444

446CUTLASS_HOST_DEVICE bool at(int idx) const {

447int bit, word;

448 computeStorageOffset(word, bit, idx);

449

450return ((storage(word) >> bit) & 1);

451 }

452

454CUTLASS_HOST_DEVICE void set(int idx, bool value = true) {

455int bit, word;

456 computeStorageOffset(word, bit, idx);

457

458 Storage disable_mask = (~(Storage(1) << bit));

459 Storage enable_mask = (Storage(value) << bit);

460

461 storage(word) = ((storage(word) & disable_mask) | enable_mask);

462 }

463

465CUTLASS_HOST_DEVICE PredicateVector &operator&=(PredicateVector const &predicates) {

466CUTLASS_PRAGMA_UNROLL

467for (int i = 0; i < kWordCount; ++i) {

468 storage(i) = (storage(i) & predicates.storage(i));

469 }

470return *this;

471 }

472

474CUTLASS_HOST_DEVICE PredicateVector &operator|=(PredicateVector const &predicates) {

475CUTLASS_PRAGMA_UNROLL

476for (int i = 0; i < kWordCount; ++i) {

477 storage(i) = (storage(i) | predicates.storage(i));

478 }

479return *this;

480 }

481

483CUTLASS_HOST_DEVICE bool is_zero() const {

484 Storage mask(0);

485for (int byte = 0; byte < sizeof(Storage); ++byte) {

486 Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);

487 mask |= (byte_mask << (byte * 8));

488 }

489 uint32_t result = 0;

490for (int word = 0; word < kWordCount; ++word) {

491 result |= storage(word);

492 }

493return result == 0;

494 }

495

497 CUTLASS_DEVICE

498Iterator begin() { return Iterator(*this); }

499

501 CUTLASS_DEVICE

502Iterator end() { return Iterator(*this, kPredicates); }

503

505 CUTLASS_DEVICE

506ConstIterator const_begin() const { return ConstIterator(*this); }

507

509 CUTLASS_DEVICE

510ConstIterator const_end() const { return ConstIterator(*this, kPredicates); }

511 };

512

514

515 } // namespace cutlass

cutlass::PredicateVector::operator|=

CUTLASS_HOST_DEVICE PredicateVector & operator|=(PredicateVector const &predicates)

Computes the union of two identical predicate vectors.

Definition: predicate_vector.h:474

cutlass::PredicateVector::TrivialIterator::operator++

CUTLASS_HOST_DEVICE TrivialIterator & operator++()

Pre-increment.

Definition: predicate_vector.h:397

cutlass::PredicateVector::ConstIterator::ConstIterator

CUTLASS_HOST_DEVICE ConstIterator(PredicateVector const &vec, int _start=0)

Constructs an iterator from a PredicateVector.

Definition: predicate_vector.h:298

cutlass

Definition: aligned_buffer.h:35

cutlass::PredicateVector::Storage

uint32_t Storage

Storage type of individual elements.

Definition: predicate_vector.h:117

cutlass::PredicateVector::TrivialIterator::TrivialIterator

CUTLASS_HOST_DEVICE TrivialIterator(PredicateVector const &_vec)

Constructs an iterator from a PredicateVector.

Definition: predicate_vector.h:393

cutlass::PredicateVector::is_zero

CUTLASS_HOST_DEVICE bool is_zero() const

Returns true if entire predicate array is zero.

Definition: predicate_vector.h:483

cutlass::PredicateVector::ConstIterator::operator--

CUTLASS_HOST_DEVICE ConstIterator & operator--()

Pre-decrement.

Definition: predicate_vector.h:316

cutlass::PredicateVector::kBytes

static int const kBytes

Number of bytes needed.

Definition: predicate_vector.h:126

cutlass::PredicateVector::const_end

CUTLASS_DEVICE ConstIterator const_end() const

Returns a ConstIterator.

Definition: predicate_vector.h:510

cutlass::PredicateVector::ConstIterator::at

CUTLASS_HOST_DEVICE bool at() const

Gets the bit at the pointed to location.

Definition: predicate_vector.h:374

cutlass::PredicateVector::ConstIterator::operator++

CUTLASS_HOST_DEVICE ConstIterator & operator++()

Pre-increment.

Definition: predicate_vector.h:302

cutlass::PredicateVector::enable

CUTLASS_HOST_DEVICE void enable()

Sets all predicates to true.

Definition: predicate_vector.h:435

cutlass::PredicateVector::ConstIterator::operator==

CUTLASS_HOST_DEVICE bool operator==(ConstIterator const &it) const

Returns true if iterators point to the same bit.

Definition: predicate_vector.h:362

cutlass::PredicateVector::operator[]

CUTLASS_HOST_DEVICE bool operator[](int idx) const

Accesses a bit within the predicate vector.

Definition: predicate_vector.h:443

cutlass::PredicateVector::Iterator::operator==

CUTLASS_HOST_DEVICE bool operator==(Iterator const &it) const

Returns true if iterators point to the same bit.

Definition: predicate_vector.h:256

cutlass::PredicateVector::Iterator::operator!=

CUTLASS_HOST_DEVICE bool operator!=(Iterator const &it) const

Returns false if iterators point to the same bit.

Definition: predicate_vector.h:260

cutlass::PredicateVector::at

CUTLASS_HOST_DEVICE bool at(int idx) const

Accesses a bit within the predicate vector.

Definition: predicate_vector.h:446

cutlass::PredicateVector::Iterator::Iterator

CUTLASS_HOST_DEVICE Iterator(PredicateVector &vec, int _start=0)

Constructs an iterator from a PredicateVector.

Definition: predicate_vector.h:192

cutlass::PredicateVector::ConstIterator::operator++

CUTLASS_HOST_DEVICE ConstIterator operator++(int)

Post-increment.

Definition: predicate_vector.h:330

cutlass::PredicateVector::Iterator::operator++

CUTLASS_HOST_DEVICE Iterator operator++(int)

Post-increment.

Definition: predicate_vector.h:224

cutlass::PredicateVector::TrivialIterator::TrivialIterator

CUTLASS_HOST_DEVICE TrivialIterator(Iterator const &it)

Copy constructor.

Definition: predicate_vector.h:389

cutlass::PredicateVector::Iterator::operator*

CUTLASS_HOST_DEVICE bool operator*() const

Dereferences iterator.

Definition: predicate_vector.h:272

platform.h

C++ features that may be otherwise unimplemented for CUDA device functions.

cutlass::PredicateVector::TrivialIterator

Iterator that always returns true.

Definition: predicate_vector.h:382

cutlass::PredicateVector::TrivialIterator::operator++

CUTLASS_HOST_DEVICE TrivialIterator operator++(int)

Post-increment.

Definition: predicate_vector.h:401

cutlass::PredicateVector::ConstIterator::operator+

CUTLASS_HOST_DEVICE ConstIterator operator+(int offset)

Iterator advances by some amount.

Definition: predicate_vector.h:346

cutlass::PredicateVector::ConstIterator::operator+=

CUTLASS_HOST_DEVICE ConstIterator & operator+=(int offset)

Increment.

Definition: predicate_vector.h:309

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

cutlass::PredicateVector::const_begin

CUTLASS_DEVICE ConstIterator const_begin() const

Returns a ConstIterator.

Definition: predicate_vector.h:506

cutlass::PredicateVector::Iterator::operator--

CUTLASS_HOST_DEVICE Iterator & operator--()

Pre-decrement.

Definition: predicate_vector.h:210

cutlass::PredicateVector::Iterator::operator-=

CUTLASS_HOST_DEVICE Iterator & operator-=(int offset)

Decrement.

Definition: predicate_vector.h:217

cutlass::PredicateVector::Iterator::Iterator

CUTLASS_HOST_DEVICE Iterator(Iterator const &it)

Copy constructor.

Definition: predicate_vector.h:188

cutlass::PredicateVector::ConstIterator::operator-

CUTLASS_HOST_DEVICE ConstIterator operator-(int offset)

Iterator recedes by some amount.

Definition: predicate_vector.h:354

cutlass::PredicateVector::Iterator::operator+

CUTLASS_HOST_DEVICE Iterator operator+(int offset)

Iterator advances by some amount.

Definition: predicate_vector.h:240

cutlass::PredicateVector::fill

CUTLASS_HOST_DEVICE void fill(bool value=true)

Fills all predicates with a given value.

Definition: predicate_vector.h:417

cutlass::PredicateVector::kPredicates

static int const kPredicates

Number of bits stored by the PredicateVector.

Definition: predicate_vector.h:108

cutlass::PredicateVector::end

CUTLASS_DEVICE Iterator end()

Returns an iterator.

Definition: predicate_vector.h:502

CUTLASS_ASSERT

#define CUTLASS_ASSERT(x)

Definition: cutlass.h:92

cutlass::PredicateVector::Iterator::operator+=

CUTLASS_HOST_DEVICE Iterator & operator+=(int offset)

Increment.

Definition: predicate_vector.h:203

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::PredicateVector::kPredicatesPerByte

static int const kPredicatesPerByte

Number of bits stored within each byte of the predicate bit vector.

Definition: predicate_vector.h:111

cutlass::PredicateVector::operator&=

CUTLASS_HOST_DEVICE PredicateVector & operator&=(PredicateVector const &predicates)

Computes the intersection of two identical predicate vectors.

Definition: predicate_vector.h:465

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::PredicateVector

Statically sized array of bits implementing.

Definition: predicate_vector.h:106

cutlass::PredicateVector::kWordCount

static int const kWordCount

Number of storage elements needed.

Definition: predicate_vector.h:129

cutlass::PredicateVector::ConstIterator::operator!=

CUTLASS_HOST_DEVICE bool operator!=(ConstIterator const &it) const

Returns false if iterators point to the same bit.

Definition: predicate_vector.h:366

cutlass::PredicateVector::Iterator::at

CUTLASS_HOST_DEVICE bool at() const

Gets the bit at the pointed to location.

Definition: predicate_vector.h:268

cutlass::PredicateVector::Iterator::operator++

CUTLASS_HOST_DEVICE Iterator & operator++()

Pre-increment.

Definition: predicate_vector.h:196

cutlass::PredicateVector::ConstIterator

An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...

Definition: predicate_vector.h:284

cutlass::PredicateVector::set

CUTLASS_HOST_DEVICE void set(int idx, bool value=true)

Set a bit within the predicate vector.

Definition: predicate_vector.h:454

cutlass::PredicateVector::Iterator::operator-

CUTLASS_HOST_DEVICE Iterator operator-(int offset)

Iterator recedes by some amount.

Definition: predicate_vector.h:248

cutlass::PredicateVector::kPredicateStart

static int const kPredicateStart

First bit withing each byte containing predicates.

Definition: predicate_vector.h:114

cutlass::PredicateVector::ConstIterator::ConstIterator

CUTLASS_HOST_DEVICE ConstIterator(ConstIterator const &it)

Copy constructor.

Definition: predicate_vector.h:294

cutlass::PredicateVector::ConstIterator::operator*

CUTLASS_HOST_DEVICE bool operator*() const

Dereferences iterator.

Definition: predicate_vector.h:378

cutlass::PredicateVector::ConstIterator::operator--

CUTLASS_HOST_DEVICE ConstIterator operator--(int)

Post-decrement.

Definition: predicate_vector.h:338

cutlass::PredicateVector::clear

CUTLASS_HOST_DEVICE void clear()

Clears all predicates.

Definition: predicate_vector.h:427

cutlass::PredicateVector::PredicateVector

CUTLASS_HOST_DEVICE PredicateVector(bool value=true)

Initialize the predicate vector.

Definition: predicate_vector.h:414

cutlass::PredicateVector::begin

CUTLASS_DEVICE Iterator begin()

Returns an iterator to the start of the bit vector.

Definition: predicate_vector.h:498

cutlass.h

Basic include for CUTLASS.

cutlass::PredicateVector::TrivialIterator::operator*

CUTLASS_HOST_DEVICE bool operator*() const

Dereferences iterator.

Definition: predicate_vector.h:405

cutlass::PredicateVector::Iterator

An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...

Definition: predicate_vector.h:178

cutlass::PredicateVector::Iterator::operator--

CUTLASS_HOST_DEVICE Iterator operator--(int)

Post-decrement.

Definition: predicate_vector.h:232

cutlass::PredicateVector::ConstIterator::operator-=

CUTLASS_HOST_DEVICE ConstIterator & operator-=(int offset)

Decrement.

Definition: predicate_vector.h:323

cutlass::PredicateVector::TrivialIterator::TrivialIterator

CUTLASS_HOST_DEVICE TrivialIterator()

Constructor.

Definition: predicate_vector.h:385


Generated by 1.8.11