Back to Cutlass

CUTLASS: predicated_tile_iterator.h Source File

docs/epilogue_2threadblock_2predicated__tile__iterator_8h_source.html

4.4.275.0 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

epilogue/threadblock/predicated_tile_iterator.h

[Go to the documentation of this file.](epilogue_2threadblock_2predicated tile iterator_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

33 #pragma once

34

35 #include "cutlass/cutlass.h"

36 #include "cutlass/numeric_types.h"

37 #include "cutlass/array.h"

38 #include "cutlass/layout/matrix.h"

39 #include "cutlass/matrix_shape.h"

40 #include "cutlass/tensor_ref.h"

41

42 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"

43 #include "[cutlass/epilogue/threadblock/output_tile_thread_map.h](output tile thread__map_8h.html)"

44

45

47

48 namespace cutlass {

49

51

52 namespace epilogue {

53 namespace threadblock {

54

56

61 template <

62typename ThreadMap_,

63typename Element_

64 >

65 class PredicatedTileIterator {

66 public:

67using ThreadMap = ThreadMap_;

68using Shape = typename ThreadMap::Shape;

69

70using Element = Element_;

71

72using Layout = layout::RowMajor;

73using TensorRef = TensorRef<Element, Layout>;

74using ConstTensorRef = typename TensorRef::ConstTensorRef;

75

76using Index = typename Layout::Index;

77using LongIndex = typename Layout::LongIndex;

78using TensorCoord = MatrixCoord;

79

80static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;

81static int const kThreads = ThreadMap::kThreads;

82static int const kIterations = ThreadMap::Count::kTile;

83

84static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0");

85static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0");

86static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0");

87static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0");

88

90using Fragment = Array<

91Element,

92 ThreadMap::Iterations::kColumn *

93 ThreadMap::Iterations::kRow *

94 ThreadMap::Iterations::kGroup *

95 ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;

96

98using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;

99

100//

101// Parameters struct

102//

103

104struct Params {

105

106//

107// Data members

108//

109

110Index stride;

111

112Index increment_row;

113Index increment_group;

114Index increment_cluster;

115

116Index advance_row;

117Index advance_group;

118Index advance_cluster;

119Index advance_tile;

120

121//

122// Methods

123//

124

125CUTLASS_HOST_DEVICE

126Status initialize(Index stride_) {

127

128 stride = stride_;

129

130 increment_row = stride * ThreadMap::Delta::kRow;

131

132 increment_group = stride * ThreadMap::Delta::kGroup

133 - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);

134

135 increment_cluster = stride * ThreadMap::Delta::kCluster

136 - stride * ThreadMap::Delta::kGroup * (ThreadMap::Iterations::kGroup - 1)

137 - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);

138

139 advance_row = stride * ThreadMap::Shape::kRow;

140

141 advance_group = stride * (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;

142

143 advance_cluster =

144 stride *

145 ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;;

146

147 advance_tile =

148 stride *

149 ThreadMap::Shape::kGroup *

150 ThreadMap::Shape::kRow *

151 ThreadMap::Shape::kCluster *

152 ThreadMap::Shape::kTile;

153

154return Status::kSuccess;

155 }

156

157CUTLASS_HOST_DEVICE

158Params() {

159initialize(0);

160 }

161

162CUTLASS_HOST_DEVICE

163Params(Layout const &layout) {

164

165initialize(layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess);

166 }

167 };

168

170struct Mask {

171

172static int const kCount = ThreadMap::Iterations::kColumn;

173

175bool predicates[kCount];

176

177//

178// Mask

179//

180CUTLASS_HOST_DEVICE

181Mask() {

182 enable();

183 }

184

186CUTLASS_HOST_DEVICE void clear() {

187CUTLASS_PRAGMA_UNROLL

188for (int i = 0; i < kCount; ++i) {

189 predicates[i] = false;

190 }

191 }

192

194 CUTLASS_DEVICE void enable() {

195CUTLASS_PRAGMA_UNROLL

196for (int i = 0; i < kCount; ++i) {

197 predicates[i] = true;

198 }

199 }

200 };

201

202 private:

203

204//

205// Data members

206//

207

209Params params_;

210

212 uint8_t *byte_pointer_;

213

215Mask mask_;

216

218Index extent_row_;

219

221Index thread_start_row_;

222

224int state_[3];

225

226 private:

227

228//

229// Methods

230//

231

232 public:

233

234//

235// Methods

236//

237

239 CUTLASS_DEVICE

240PredicatedTileIterator(

241Params const & params,

242 Element *pointer,

243TensorCoord extent,

244int thread_idx,

245TensorCoord threadblock_offset = TensorCoord()

246 ):

247 params_(params) {

248

249TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;

250

251 extent_row_ = extent.row();

252 thread_start_row_ = thread_offset.row();

253

254// Initialize predicates

255CUTLASS_PRAGMA_UNROLL

256for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {

257

258 mask_.predicates[c] = ((thread_offset.column()

259 + ThreadMap::Delta::kColumn * c) < extent.column());

260 }

261

262// Initialize pointer

263 byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +

264 thread_offset.row() * params_.stride +

265 thread_offset.column() * sizeof(AccessType) / kElementsPerAccess;

266

267// Initialize internal state counter

268 state_[0] = state_[1] = state_[2] = 0;

269 }

270

272CUTLASS_HOST_DEVICE

273void add_pointer_offset(LongIndex pointer_offset) {

274 byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;

275 }

276

278 CUTLASS_DEVICE

279void load(Fragment &frag) {

280

281 uint8_t *byte_pointer = byte_pointer_;

282AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);

283

284CUTLASS_PRAGMA_UNROLL

285for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {

286

287CUTLASS_PRAGMA_UNROLL

288for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {

289

290CUTLASS_PRAGMA_UNROLL

291for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {

292

293int frag_row_idx =

294 (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));

295

296int row_offset = row * ThreadMap::Delta::kRow

297 + group * ThreadMap::Delta::kGroup

298 + cluster * ThreadMap::Delta::kCluster;

299

300bool row_guard = ((row_offset + thread_start_row_) < extent_row_);

301

302AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);

303

304CUTLASS_PRAGMA_UNROLL

305for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {

306

307bool guard = row_guard && mask_.predicates[column];

308

309if (guard) {

310 frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] =

311 memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess];

312 }

313 }

314

315if (row + 1 < ThreadMap::Iterations::kRow) {

316 byte_pointer += params_.increment_row;

317 }

318 }

319

320if (group + 1 < ThreadMap::Iterations::kGroup) {

321 byte_pointer += params_.increment_group;

322 }

323 }

324

325if (cluster + 1 < ThreadMap::Iterations::kCluster) {

326 byte_pointer += params_.increment_cluster;

327 }

328 }

329 }

330

332 CUTLASS_DEVICE

333void store(Fragment const &frag) {

334 uint8_t *byte_pointer = byte_pointer_;

335AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);

336

337CUTLASS_PRAGMA_UNROLL

338for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {

339

340CUTLASS_PRAGMA_UNROLL

341for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {

342

343CUTLASS_PRAGMA_UNROLL

344for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {

345

346int frag_row_idx =

347 (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));

348

349int row_offset = row * ThreadMap::Delta::kRow

350 + group * ThreadMap::Delta::kGroup

351 + cluster * ThreadMap::Delta::kCluster;

352

353bool row_guard = ((row_offset + thread_start_row_) < extent_row_);

354

355AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);

356

357CUTLASS_PRAGMA_UNROLL

358for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {

359

360bool guard = row_guard && mask_.predicates[column];

361

362if (guard) {

363

364 memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =

365 frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];

366 }

367 }

368

369if (row + 1 < ThreadMap::Iterations::kRow) {

370 byte_pointer += params_.increment_row;

371 }

372 }

373

374if (group + 1 < ThreadMap::Iterations::kGroup) {

375 byte_pointer += params_.increment_group;

376 }

377 }

378

379if (cluster + 1 < ThreadMap::Iterations::kCluster) {

380 byte_pointer += params_.increment_cluster;

381 }

382 }

383 }

384

386CUTLASS_HOST_DEVICE

387PredicatedTileIterator &operator++() {

388

389 ++state_[0];

390 byte_pointer_ += params_.advance_row;

391 thread_start_row_ += ThreadMap::Shape::kRow;

392

393if (state_[0] == ThreadMap::Count::kRow) {

394

395 state_[0] = 0;

396 ++state_[1];

397 byte_pointer_ += params_.advance_group;

398

399 thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *

400 ThreadMap::Shape::kRow * ThreadMap::Count::kRow;

401

402if (state_[1] == ThreadMap::Count::kGroup) {

403

404 state_[1] = 0;

405 ++state_[2];

406 byte_pointer_ += params_.advance_cluster;

407

408 thread_start_row_ += ThreadMap::Count::kGroup *

409 ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;

410

411if (state_[2] == ThreadMap::Count::kCluster) {

412 state_[2] = 0;

413 byte_pointer_ += params_.advance_tile;

414 }

415 }

416 }

417

418return *this;

419 }

420

422 CUTLASS_DEVICE void clear_mask() {

423 mask_.clear();

424 }

425

427 CUTLASS_DEVICE void enable_mask() {

428 mask_.enable();

429 }

430

432 CUTLASS_DEVICE void get_mask(Mask &mask) {

433return mask_;

434 }

435

437 CUTLASS_DEVICE void set_mask(Mask const &mask) {

438 mask_ = mask;

439 }

440 };

441

447 template <

448typename ThreadMap_,

449typename Element_,

450int InterleavedK

451 >

452 class InterleavedPredicatedTileIterator {

453 public:

454using ThreadMap = ThreadMap_;

455

456using Element = Element_;

457

458using Layout = layout::ColumnMajorInterleaved<InterleavedK>;

459using TensorRef = TensorRef<Element, Layout>;

460using ConstTensorRef = typename TensorRef::ConstTensorRef;

461

462using Index = typename Layout::Index;

463using LongIndex = typename Layout::LongIndex;

464using TensorCoord = layout::PitchLinearCoord;

465

466static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;

467static int const kThreads = ThreadMap::kThreads;

468static int const kIterations = ThreadMap::Iterations::kCount;

469

471using Fragment = Array<Element, ThreadMap::kElementsPerAccess>;

472

474using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;

475

476//

477// Parameters struct

478//

479

480struct Params {

481

482//

483// Data members

484//

485

486Index stride;

487

488Index advance_row;

489Index advance_column;

490

491//

492// Methods

493//

494

495CUTLASS_HOST_DEVICE

496Status initialize(Index stride_) {

497 stride = stride_;

498

499 advance_row =

500 ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8;

501

502 advance_column =

503 stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess *

504sizeof_bits<Element>::value * ThreadMap::kWarpSize / 8;

505

506return Status::kSuccess;

507 }

508

509CUTLASS_HOST_DEVICE

510Params() {

511initialize(0);

512 }

513

514CUTLASS_HOST_DEVICE

515Params(Layout const &layout) {

516

517initialize(layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess);

518 }

519 };

520

522struct Mask {

523static int const kCount = (ThreadMap::Iterations::kContiguous < 8)

524 ? 8

525 : ThreadMap::Iterations::kContiguous;

526

528bool predicates[kCount];

529

530//

531// Mask

532//

533CUTLASS_HOST_DEVICE

534Mask() {

535 enable();

536 }

537

539CUTLASS_HOST_DEVICE void clear() {

540CUTLASS_PRAGMA_UNROLL

541for (int i = 0; i < kCount; ++i) {

542 predicates[i] = false;

543 }

544 }

545

547 CUTLASS_DEVICE void enable() {

548CUTLASS_PRAGMA_UNROLL

549for (int i = 0; i < kCount; ++i) {

550 predicates[i] = true;

551 }

552 }

553 };

554

555 private:

556

557//

558// Data members

559//

560

562Params params_;

563

565 uint8_t *byte_pointer_;

566

568Mask mask_;

569

571Index extent_col_;

572

575Index thread_start_col_;

576

578int iteration_contiguous_;

579

580int iteration_strided_;

581

582 private:

583

584//

585// Methods

586//

587

588 public:

589

590//

591// Methods

592//

593

595 CUTLASS_DEVICE

596InterleavedPredicatedTileIterator(

597Params const & params,

598 Element *pointer,

599TensorCoord extent,

600int thread_idx,

601TensorCoord threadblock_offset

602 ):

603 params_(params) {

604TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) +

605TensorCoord(threadblock_offset.contiguous() * InterleavedK,

606 threadblock_offset.strided() / InterleavedK);

607

608 extent_col_ = extent.strided() / InterleavedK;

609 thread_start_col_ = thread_offset.strided();

610

611// Initialize predicates

612CUTLASS_PRAGMA_UNROLL

613for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {

614 mask_.predicates[c] =

615 ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) <

616 (extent.contiguous() * InterleavedK));

617 }

618

619// Initialize pointer

620 byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +

621 thread_offset.strided() * params_.stride +

622 thread_offset.contiguous() * sizeof(AccessType) / kElementsPerAccess;

623

624// Initialize internal state counter

625 iteration_contiguous_ = iteration_strided_ = 0;

626 }

627

629CUTLASS_HOST_DEVICE

630void add_pointer_offset(LongIndex pointer_offset) {

631 byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;

632 }

633

635 CUTLASS_DEVICE

636void load(Fragment &frag) {

637 uint8_t *byte_pointer = byte_pointer_;

638AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);

639AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);

640

641int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;

642

643bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);

644

645bool guard = col_guard && mask_.predicates[iteration_contiguous_];

646

647if (guard) {

648 *frag_ptr = *memory_pointer;

649 }

650 }

651

653 CUTLASS_DEVICE

654void store(Fragment const &frag) {

655 uint8_t *byte_pointer = byte_pointer_;

656AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);

657AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);

658

659int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;

660

661bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);

662

663bool guard = col_guard && mask_.predicates[iteration_contiguous_];

664

665if (guard) {

666 *memory_pointer = *frag_ptr;

667 }

668 }

669

671CUTLASS_HOST_DEVICE

672void set_iteration_index(int iteration) {

673 iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous;

674 iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous;

675 }

676

678CUTLASS_HOST_DEVICE

679InterleavedPredicatedTileIterator &operator++() {

680

681 ++iteration_contiguous_;

682 byte_pointer_ += params_.advance_row;

683

684if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) {

685

686 iteration_contiguous_ = 0;

687 ++iteration_strided_;

688 byte_pointer_ += params_.advance_column;

689

690if (iteration_strided_ == ThreadMap::Iterations::kStrided) {

691 iteration_strided_ = 0;

692 }

693 }

694

695return *this;

696 }

697

699 CUTLASS_DEVICE void clear_mask() {

700 mask_.clear();

701 }

702

704 CUTLASS_DEVICE void enable_mask() {

705 mask_.enable();

706 }

707

709 CUTLASS_DEVICE void get_mask(Mask &mask) {

710return mask_;

711 }

712

714 CUTLASS_DEVICE void set_mask(Mask const &mask) {

715 mask_ = mask;

716 }

717 };

718

720

721 } // namespace threadblock

722 } // namespace epilogue

723 } // namespace cutlass

724

cutlass::epilogue::threadblock::PredicatedTileIterator::Mask::predicates

bool predicates[kCount]

Predicate state.

Definition: epilogue/threadblock/predicated_tile_iterator.h:175

cutlass::layout::RowMajor::LongIndex

int64_t LongIndex

Long index type used for offsets.

Definition: layout/matrix.h:62

cutlass::epilogue::threadblock::PredicatedTileIterator::kElementsPerAccess

static int const kElementsPerAccess

Definition: epilogue/threadblock/predicated_tile_iterator.h:80

cutlass::epilogue::threadblock::PredicatedTileIterator::Mask::enable

CUTLASS_DEVICE void enable()

Definition: epilogue/threadblock/predicated_tile_iterator.h:194

cutlass::MatrixCoord::column

CUTLASS_HOST_DEVICE Index const & column() const

Returns the column of the coordinate.

Definition: matrix_coord.h:85

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::advance_row

Index advance_row

amount to add to move to the next 'row' position

Definition: epilogue/threadblock/predicated_tile_iterator.h:116

cutlass::epilogue::threadblock::PredicatedTileIterator::load

CUTLASS_DEVICE void load(Fragment &frag)

Loads a fragment from memory.

Definition: epilogue/threadblock/predicated_tile_iterator.h:279

cutlass::epilogue::threadblock::PredicatedTileIterator::Element

Element_ Element

Definition: epilogue/threadblock/predicated_tile_iterator.h:70

cutlass

Definition: aligned_buffer.h:35

cutlass::layout::PitchLinearCoord

Coordinate in pitch-linear space.

Definition: pitch_linear.h:52

tensor_ref.h

Defines a structure containing strides, bounds, and a pointer to tensor data.

cutlass::epilogue::threadblock::PredicatedTileIterator::AccessType

AlignedArray< Element, ThreadMap::kElementsPerAccess > AccessType

Memory access size.

Definition: epilogue/threadblock/predicated_tile_iterator.h:98

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::initialize

CUTLASS_HOST_DEVICE Status initialize(Index stride_)

Definition: epilogue/threadblock/predicated_tile_iterator.h:496

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask::clear

CUTLASS_HOST_DEVICE void clear()

CUTLASS_HOST_DEVICE enables all accesses guarded by mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:539

[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)

Templates implementing how threads are mapped to a given tile.

cutlass::epilogue::threadblock::PredicatedTileIterator::get_mask

CUTLASS_DEVICE void get_mask(Mask &mask)

Sets the mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:432

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Fragment

Array< Element, ThreadMap::kElementsPerAccess > Fragment

Fragment object.

Definition: epilogue/threadblock/predicated_tile_iterator.h:471

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::ThreadMap

ThreadMap_ ThreadMap

Definition: epilogue/threadblock/predicated_tile_iterator.h:454

cutlass::AlignedArray

Aligned array type.

Definition: array.h:511

cutlass::epilogue::threadblock::PredicatedTileIterator::Mask

Mask object.

Definition: epilogue/threadblock/predicated_tile_iterator.h:170

cutlass::MatrixCoord::row

CUTLASS_HOST_DEVICE Index const & row() const

Returns the row of the coordinate.

Definition: matrix_coord.h:77

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask::predicates

bool predicates[kCount]

Predicate state.

Definition: epilogue/threadblock/predicated_tile_iterator.h:528

cutlass::epilogue::threadblock::PredicatedTileIterator::ConstTensorRef

typename TensorRef::ConstTensorRef ConstTensorRef

Definition: epilogue/threadblock/predicated_tile_iterator.h:74

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask::Mask

CUTLASS_HOST_DEVICE Mask()

Efficiently disables all accesses guarded by mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:534

cutlass::layout::ColumnMajorInterleaved::stride

CUTLASS_HOST_DEVICE Stride stride() const

Returns the stride of the layout.

Definition: layout/matrix.h:418

cutlass::layout::RowMajor::stride

CUTLASS_HOST_DEVICE Stride stride() const

Returns the stride of the layout.

Definition: layout/matrix.h:112

cutlass::epilogue::threadblock::PredicatedTileIterator::Fragment

Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment

Fragment object.

Definition: epilogue/threadblock/predicated_tile_iterator.h:95

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params

Definition: epilogue/threadblock/predicated_tile_iterator.h:480

cutlass::epilogue::threadblock::PredicatedTileIterator::store

CUTLASS_DEVICE void store(Fragment const &frag)

Stores a fragment to memory.

Definition: epilogue/threadblock/predicated_tile_iterator.h:333

cutlass::epilogue::threadblock::PredicatedTileIterator::LongIndex

typename Layout::LongIndex LongIndex

Definition: epilogue/threadblock/predicated_tile_iterator.h:77

cutlass::TensorRef< Element, Layout >::ConstTensorRef

TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef

TensorRef to constant data.

Definition: tensor_ref.h:179

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::load

CUTLASS_DEVICE void load(Fragment &frag)

Loads a fragment from memory.

Definition: epilogue/threadblock/predicated_tile_iterator.h:636

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::InterleavedPredicatedTileIterator

CUTLASS_DEVICE InterleavedPredicatedTileIterator(Params const &params, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset)

Constructor.

Definition: epilogue/threadblock/predicated_tile_iterator.h:596

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Index

typename Layout::Index Index

Definition: epilogue/threadblock/predicated_tile_iterator.h:462

cutlass::layout::RowMajor::Index

int32_t Index

Index type used for coordinates.

Definition: layout/matrix.h:59

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::advance_cluster

Index advance_cluster

amount to add to move to the next 'cluster' position

Definition: epilogue/threadblock/predicated_tile_iterator.h:118

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::store

CUTLASS_DEVICE void store(Fragment const &frag)

Stores a fragment to memory.

Definition: epilogue/threadblock/predicated_tile_iterator.h:654

cutlass::sizeof_bits

Defines the size of an element in bits.

Definition: numeric_types.h:42

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::advance_row

Index advance_row

amount to add to move to the next 'row' position

Definition: epilogue/threadblock/predicated_tile_iterator.h:488

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::clear_mask

CUTLASS_DEVICE void clear_mask()

Efficiently enables all accesses guarded by mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:699

cutlass::TensorRef< Element, Layout >

cutlass::epilogue::threadblock::PredicatedTileIterator::ThreadMap

ThreadMap_ ThreadMap

Definition: epilogue/threadblock/predicated_tile_iterator.h:67

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask

Mask object.

Definition: epilogue/threadblock/predicated_tile_iterator.h:522

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::Params

CUTLASS_HOST_DEVICE Params()

Definition: epilogue/threadblock/predicated_tile_iterator.h:510

cutlass::epilogue::threadblock::PredicatedTileIterator::PredicatedTileIterator

CUTLASS_DEVICE PredicatedTileIterator(Params const &params, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset=TensorCoord())

Constructor.

Definition: epilogue/threadblock/predicated_tile_iterator.h:240

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::layout::PitchLinearCoord::contiguous

CUTLASS_HOST_DEVICE Index const & contiguous() const

Returns the contiguous dimension.

Definition: pitch_linear.h:89

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator

Definition: epilogue/threadblock/predicated_tile_iterator.h:452

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::ConstTensorRef

typename TensorRef::ConstTensorRef ConstTensorRef

Definition: epilogue/threadblock/predicated_tile_iterator.h:460

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::set_iteration_index

CUTLASS_HOST_DEVICE void set_iteration_index(int iteration)

Overrides the internal iteration index.

Definition: epilogue/threadblock/predicated_tile_iterator.h:672

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::stride

Index stride

stride in bytes between rows

Definition: epilogue/threadblock/predicated_tile_iterator.h:110

cutlass::epilogue::threadblock::PredicatedTileIterator::operator++

CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()

Advances to the next position to load or store.

Definition: epilogue/threadblock/predicated_tile_iterator.h:387

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::Params

CUTLASS_HOST_DEVICE Params(Layout const &layout)

Definition: epilogue/threadblock/predicated_tile_iterator.h:163

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::stride

Index stride

stride in bytes between columns

Definition: epilogue/threadblock/predicated_tile_iterator.h:486

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::advance_column

Index advance_column

amount to add to move to the next 'column' position

Definition: epilogue/threadblock/predicated_tile_iterator.h:489

cutlass::epilogue::threadblock::PredicatedTileIterator::Params

Definition: epilogue/threadblock/predicated_tile_iterator.h:104

cutlass::epilogue::threadblock::PredicatedTileIterator::kIterations

static int const kIterations

Definition: epilogue/threadblock/predicated_tile_iterator.h:82

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::advance_tile

Index advance_tile

amount to add to move to the next 'tile'

Definition: epilogue/threadblock/predicated_tile_iterator.h:119

[output_tile_thread_map.h](output tile thread__map_8h.html)

Metaprogram for determining the mapping of output elements to threads for epilogue tiles...

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::operator++

CUTLASS_HOST_DEVICE InterleavedPredicatedTileIterator & operator++()

Advances to the next position to load or store.

Definition: epilogue/threadblock/predicated_tile_iterator.h:679

cutlass::epilogue::threadblock::PredicatedTileIterator::clear_mask

CUTLASS_DEVICE void clear_mask()

Efficiently enables all accesses guarded by mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:422

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::increment_group

Index increment_group

increment quantity (in bytes) to advance when moving to the next group

Definition: epilogue/threadblock/predicated_tile_iterator.h:113

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::epilogue::threadblock::PredicatedTileIterator::Index

typename Layout::Index Index

Definition: epilogue/threadblock/predicated_tile_iterator.h:76

cutlass::epilogue::threadblock::PredicatedTileIterator

Definition: epilogue/threadblock/predicated_tile_iterator.h:65

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::LongIndex

typename Layout::LongIndex LongIndex

Definition: epilogue/threadblock/predicated_tile_iterator.h:463

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::set_mask

CUTLASS_DEVICE void set_mask(Mask const &mask)

Definition: epilogue/threadblock/predicated_tile_iterator.h:714

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Params::Params

CUTLASS_HOST_DEVICE Params(Layout const &layout)

Definition: epilogue/threadblock/predicated_tile_iterator.h:515

cutlass::epilogue::threadblock::PredicatedTileIterator::set_mask

CUTLASS_DEVICE void set_mask(Mask const &mask)

Definition: epilogue/threadblock/predicated_tile_iterator.h:437

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::Mask::enable

CUTLASS_DEVICE void enable()

Definition: epilogue/threadblock/predicated_tile_iterator.h:547

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::advance_group

Index advance_group

amount to add to move to the next 'group' position

Definition: epilogue/threadblock/predicated_tile_iterator.h:117

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::add_pointer_offset

CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)

Adds a pointer offset in units of Element.

Definition: epilogue/threadblock/predicated_tile_iterator.h:630

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::initialize

CUTLASS_HOST_DEVICE Status initialize(Index stride_)

Definition: epilogue/threadblock/predicated_tile_iterator.h:126

cutlass::Status::kSuccess

Operation was successful.

cutlass::epilogue::threadblock::PredicatedTileIterator::add_pointer_offset

CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)

Adds a pointer offset in units of Element.

Definition: epilogue/threadblock/predicated_tile_iterator.h:273

cutlass::epilogue::threadblock::PredicatedTileIterator::Shape

typename ThreadMap::Shape Shape

Definition: epilogue/threadblock/predicated_tile_iterator.h:68

cutlass::layout::ColumnMajorInterleaved

Definition: layout/matrix.h:343

cutlass::epilogue::threadblock::PredicatedTileIterator::TensorCoord

MatrixCoord TensorCoord

Definition: epilogue/threadblock/predicated_tile_iterator.h:78

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::Params

CUTLASS_HOST_DEVICE Params()

Definition: epilogue/threadblock/predicated_tile_iterator.h:158

cutlass::epilogue::threadblock::PredicatedTileIterator::enable_mask

CUTLASS_DEVICE void enable_mask()

Sets the mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:427

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::increment_row

Index increment_row

increment quantity (in bytes) to advance when moving between rows

Definition: epilogue/threadblock/predicated_tile_iterator.h:112

cutlass::epilogue::threadblock::PredicatedTileIterator::Params::increment_cluster

Index increment_cluster

increment quantity (in bytes) to advance when moving to the next cluster

Definition: epilogue/threadblock/predicated_tile_iterator.h:114

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::get_mask

CUTLASS_DEVICE void get_mask(Mask &mask)

Sets the mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:709

cutlass.h

Basic include for CUTLASS.

cutlass::MatrixCoord

Definition: matrix_coord.h:39

cutlass::epilogue::threadblock::PredicatedTileIterator::Mask::clear

CUTLASS_HOST_DEVICE void clear()

CUTLASS_HOST_DEVICE enables all accesses guarded by mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:186

cutlass::layout::PitchLinearCoord::strided

CUTLASS_HOST_DEVICE Index const & strided() const

Returns the column of the coordinate.

Definition: pitch_linear.h:97

cutlass::Status

Status

Status code returned by CUTLASS operations.

Definition: cutlass.h:39

cutlass::epilogue::threadblock::InterleavedPredicatedTileIterator::enable_mask

CUTLASS_DEVICE void enable_mask()

Sets the mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:704

cutlass::epilogue::threadblock::PredicatedTileIterator::Mask::Mask

CUTLASS_HOST_DEVICE Mask()

Efficiently disables all accesses guarded by mask.

Definition: epilogue/threadblock/predicated_tile_iterator.h:181

cutlass::epilogue::threadblock::PredicatedTileIterator::kThreads

static int const kThreads

Definition: epilogue/threadblock/predicated_tile_iterator.h:81


Generated by 1.8.11