Back to Cutlass

CUTLASS: output_tile_thread_map.h Source File

docs/output__tile__thread__map_8h_source.html

4.4.243.5 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

output_tile_thread_map.h

[Go to the documentation of this file.](output tile thread__map_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

31 #pragma once

32

33 #include "cutlass/cutlass.h"

34 #include "cutlass/numeric_types.h"

35 #include "cutlass/array.h"

36 #include "cutlass/layout/matrix.h"

37 #include "cutlass/matrix_shape.h"

38 #include "cutlass/tensor_ref.h"

39 #include "cutlass/fast_math.h"

40

42

43 namespace cutlass {

44 namespace epilogue {

45 namespace threadblock {

46

48

50 template <

51int Column,

52int Row,

53int Group,

54int Cluster,

55int Tile

56 >

57 struct OutputTileShape {

58static int const kColumn = Column;

59static int const kRow = Row;

60static int const kGroup = Group;

61static int const kCluster = Cluster;

62static int const kTile = Tile;

63

64static int const kCount = kColumn * kRow * kGroup * kCluster * kTile;

65 };

66

68

69 template <

70typename ThreadMap_,

71typename Shape_,

72typename Iterations_,

73typename Delta_,

74typename Count_

75 >

76 struct OutputTileThreadMap {

77

79using ThreadMap = ThreadMap_;

80

82static int const kThreads = ThreadMap::kThreads;

83

85static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;

86

88using Shape = Shape_;

89

91using Iterations = Iterations_;

92

94using Delta = Delta_;

95

97using Count = Count_;

98

100CUTLASS_HOST_DEVICE

101static MatrixCoord initial_offset(int thread_idx) {

102

103using Index = typename layout::PitchLinearCoord::Index;

104

105layout::PitchLinearCoord coord = ThreadMap::initial_offset(thread_idx);

106

107 Index cluster = coord.strided() / (Shape::kGroup * Shape::kRow);

108 Index cluster_residual = coord.strided() % (Shape::kGroup * Shape::kRow);

109

110 Index group = cluster_residual / (Shape::kRow);

111 Index row = cluster_residual % (Shape::kRow);

112

113return MatrixCoord{

114 row + group * Shape::kRow * Count::kRow

115 + cluster * Shape::kGroup * Count::kGroup * Shape::kRow * Count::kRow,

116 coord.contiguous()

117 };

118 }

119 };

120

122

123 namespace detail {

124

126 template <

127typename Shape,

128int WarpsRemaining,

129int ElementsPerAccess,

130int ElementSize,

131bool Is2dTile

132 >

133 struct RowArrangement;

134

136 template <

137typename Shape,

138int WarpsRemaining,

139int ElementsPerAccess,

140int ElementSize

141 >

142 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, false> {

143static int const kWarpSize = 32;

144static int const kElementsPerAccess = ElementsPerAccess;

145static int const kElementSize = ElementSize;

146

147static int const kIterationsRow = 1;

148static int const kDeltaRow = 1;

149static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize;

150static int const kDeltaColumn = kWarpSize * kElementsPerAccess;

151

152static int const kAccessWidth = kWarpSize;

153static int const kAccessRows = 1;

154static int const kWarpPartitionsRow = 1;

155static int const kWarpPartitionsColumn = WarpsRemaining;

156 };

157

159 template <

160typename Shape,

161int WarpsRemaining,

162int ElementsPerAccess,

163int ElementSize

164 >

165 struct RowArrangement<Shape, WarpsRemaining, ElementsPerAccess, ElementSize, true> {

166

167static int const kMemoryAccessSize = 128;

168static int const kWarpSize = 32;

169

170static int const kElementsPerAccess = ElementsPerAccess;

171static int const kElementSize = ElementSize;

172

173struct Detail {

174static int const kShapeRow = Shape::kRow / WarpsRemaining;

175static int const kShapeWidth = Shape::kColumn / kElementsPerAccess;

176

177static int const kTargetMemoryAccessWidth =

178 kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8);

179

180static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth;

181 };

182

183static int const kAccessWidth =

184 (Detail::kTargetAccessRows > Detail::kShapeRow ?

185 kWarpSize / Detail::kShapeRow

186 : const_min(

187 Detail::kShapeWidth,

188const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8))

189 ));

190

191static int const kAccessRows =

192 (Detail::kTargetAccessRows > Detail::kShapeRow ?

193 Detail::kShapeRow

194 : const_min(Shape::kRow, kWarpSize / kAccessWidth));

195

196static int const kIterationsRow = Detail::kShapeRow / kAccessRows;

197static int const kDeltaRow = kAccessRows;

198

199static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth;

200static int const kDeltaColumn = kAccessWidth * kElementsPerAccess;

201

202static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access");

203static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" );

204static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" );

205

206static int const kWarpPartitionsRow = 1;

207static int const kWarpPartitionsColumn = 1;

208 };

209

210 }

211

213

221 template <

222typename Shape_,

223typename Count_,

224int Threads,

225int ElementsPerAccess,

226int ElementSize

227 >

228 struct OutputTileOptimalThreadMap {

229

230using Shape = Shape_;

231using Count = Count_;

232

233static int const kWarpSize = 32;

234static int const kThreads = Threads;

235static int const kWarpCount = kThreads / kWarpSize;

236

237static int const kElementsPerAccess = ElementsPerAccess;

238static int const kElementSize = ElementSize;

239

240//

241// Metaprogram computation

242//

243

244struct Detail {

245

246// Clusters

247static int const kIterationsCluster =

248 ((Shape::kCluster > kWarpCount) ?

249 Shape::kCluster / kWarpCount

250 : 1);

251

252static int const kDeltaCluster =

253 ((Shape::kCluster > kWarpCount) ?

254 Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster

255 : 1);

256

257static int const kCompactedDeltaCluster =

258 ((Shape::kCluster > kWarpCount) ?

259 Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster

260 : 1);

261

262static int const kWarpPartitionsCluster =

263 ((Shape::kCluster > kWarpCount) ?

264 kWarpCount

265 : kWarpCount / Shape::kCluster);

266

267static int const kWarpsRemainingForGroups =

268 ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster);

269

270// Groups

271static int const kIterationsGroup =

272 ((Shape::kGroup > kWarpsRemainingForGroups) ?

273 Shape::kGroup / kWarpsRemainingForGroups

274 : 1);

275

276static int const kDeltaGroup =

277 ((Shape::kGroup > kWarpsRemainingForGroups) ?

278 Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup

279 : 1);

280

281static int const kCompactedDeltaGroup =

282 ((Shape::kGroup > kWarpsRemainingForGroups) ?

283 Shape::kRow * Shape::kGroup / kIterationsGroup

284 : 1);

285

286static int const kWarpPartitionsGroup =

287 ((Shape::kGroup > kWarpsRemainingForGroups) ?

288 1

289 : kWarpsRemainingForGroups / Shape::kGroup);

290

291static int const kWarpsRemainingForRows =

292 ((Shape::kGroup > kWarpsRemainingForGroups) ?

293 1

294 : kWarpsRemainingForGroups / Shape::kGroup);

295

296// Rows

297using RowArrangement = detail::RowArrangement<

298Shape,

299 kWarpsRemainingForRows,

300 kElementsPerAccess,

301 kElementSize,

302 (Shape::kRow > kWarpsRemainingForRows)

303 >;

304

305// Warp partitions

306using WarpPartitions = OutputTileShape<

307 RowArrangement::kWarpPartitionsColumn,

308 RowArrangement::kWarpPartitionsRow,

309 kWarpPartitionsGroup,

310 kWarpPartitionsCluster,

311 1>;

312

313static int const kAccessWidth = RowArrangement::kAccessWidth;

314static int const kAccessRows = RowArrangement::kAccessRows;

315 };

316

317//

318// Output

319//

320

321using Iterations = OutputTileShape<

322 Detail::RowArrangement::kIterationsColumn,

323 Detail::RowArrangement::kIterationsRow,

324 Detail::kIterationsGroup,

325 Detail::kIterationsCluster,

326 1>;

327

328using Delta = OutputTileShape<

329 Detail::RowArrangement::kDeltaColumn,

330 Detail::RowArrangement::kDeltaRow,

331 Detail::kDeltaGroup,

332 Detail::kDeltaCluster,

333 1>;

334

336CUTLASS_HOST_DEVICE

337static MatrixCoord initial_offset(int thread_idx) {

338

339int warp_idx = thread_idx / kWarpSize;

340int lane_idx = thread_idx % kWarpSize;

341

342// Compute warp location

343int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;

344int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;

345

346int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;

347int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;

348

349int row_idx = residual_group / Detail::WarpPartitions::kRow;

350int col_idx = residual_group % Detail::WarpPartitions::kRow;

351

352// Compute per-lane offset

353int lane_row_offset = lane_idx / Detail::kAccessWidth;

354int lane_col_offset = lane_idx % Detail::kAccessWidth;

355

356// Compute coordinate in output space

357int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup;

358int group_offset = group_idx * Shape::kRow * Count::kRow;

359int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;

360int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;

361

362return MatrixCoord(

363 cluster_offset + group_offset + row_offset + lane_row_offset,

364 (column_offset + lane_col_offset) * kElementsPerAccess

365 );

366 }

367

369struct CompactedThreadMap {

370

371

372using Shape = Shape_;

373

374using Iterations = OutputTileShape<

375 Detail::RowArrangement::kIterationsColumn,

376 Detail::RowArrangement::kIterationsRow,

377 Detail::kIterationsGroup,

378 Detail::kIterationsCluster,

379 1>;

380

381using Delta = OutputTileShape<

382 Detail::RowArrangement::kDeltaColumn,

383 Detail::RowArrangement::kDeltaRow,

384 Detail::kCompactedDeltaGroup,

385 Detail::kCompactedDeltaCluster,

386 1>;

387

389static int const kElementsPerAccess = ElementsPerAccess;

390

392static int const kThreads = Threads;

393

395CUTLASS_HOST_DEVICE

396static MatrixCoord initial_offset(int thread_idx) {

397

398int warp_idx = thread_idx / kWarpSize;

399int lane_idx = thread_idx % kWarpSize;

400

401// Compute warp location

402int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster;

403int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster;

404

405int group_idx = residual_cluster / Detail::WarpPartitions::kGroup;

406int residual_group = residual_cluster % Detail::WarpPartitions::kGroup;

407

408int row_idx = residual_group / Detail::WarpPartitions::kRow;

409int col_idx = residual_group % Detail::WarpPartitions::kRow;

410

411// Compute per-lane offset

412int lane_row_offset = lane_idx / Detail::kAccessWidth;

413int lane_col_offset = lane_idx % Detail::kAccessWidth;

414

415// Compute coordinate in output space

416int cluster_offset = cluster_idx * Shape::kRow * Shape::kGroup;

417int group_offset = group_idx * Shape::kRow;

418int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows;

419int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess;

420

421MatrixCoord coord(

422 cluster_offset + group_offset + row_offset + lane_row_offset,

423 (column_offset + lane_col_offset) * kElementsPerAccess

424 );

425

426return coord;

427 }

428 };

429 };

430

432

440 template <typename WarpCount_, typename MmaCount_, int Threads,

441int ElementsPerAccess, int ElementSize>

442 struct InterleavedOutputTileThreadMap {

443using WarpCount = WarpCount_;

444using MmaCount = MmaCount_;

445

446static int const kWarpSize = 32;

447static int const kThreads = Threads;

448static int const kWarpCount = kThreads / kWarpSize;

449

450static int const kElementsPerAccess = ElementsPerAccess;

451static int const kElementSize = ElementSize;

452

453//

454// Metaprogram computation

455//

456

457struct Detail {};

458

459//

460// Output

461//

462

463using Iterations = MmaCount;

464

465using Delta = layout::PitchLinearShape<kWarpSize * kElementsPerAccess, 1>;

466

468CUTLASS_HOST_DEVICE

469static layout::PitchLinearCoord initial_offset(int thread_idx) {

470int warp_idx = thread_idx / kWarpSize;

471int lane_idx = thread_idx % kWarpSize;

472

473// Compute warp location

474layout::PitchLinearCoord warp_footprint{

475 Delta::kContiguous * Iterations::kContiguous,

476 Delta::kStrided * Iterations::kStrided};

477

478layout::PitchLinearCoord warp_offset{warp_idx % WarpCount::kContiguous,

479 warp_idx / WarpCount::kContiguous};

480

481// Compute per-lane offset

482layout::PitchLinearCoord thread_offset_in_warp{

483 lane_idx * kElementsPerAccess, 0};

484

485layout::PitchLinearCoord thread_offset_in_threadblock_tile =

486 warp_footprint * warp_offset + thread_offset_in_warp;

487

488return thread_offset_in_threadblock_tile;

489 }

490 };

491

493

494 } // namespace threadblock

495 } // namespace epilogue

496 } // namespace cutlass

cutlass::layout::PitchLinearCoord::Index

int Index

Integer-valued index.

Definition: pitch_linear.h:56

cutlass::epilogue::threadblock::OutputTileThreadMap::ThreadMap

ThreadMap_ ThreadMap

Conventional thread map (concept: ThreadMap)

Definition: output_tile_thread_map.h:79

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap

Definition: output_tile_thread_map.h:228

cutlass

Definition: aligned_buffer.h:35

cutlass::layout::PitchLinearCoord

Coordinate in pitch-linear space.

Definition: pitch_linear.h:52

tensor_ref.h

Defines a structure containing strides, bounds, and a pointer to tensor data.

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::Count

Count_ Count

Definition: output_tile_thread_map.h:231

cutlass::epilogue::threadblock::OutputTileShape::kGroup

static int const kGroup

Definition: output_tile_thread_map.h:60

cutlass::epilogue::threadblock::OutputTileShape

Tuple defining point in output tile.

Definition: output_tile_thread_map.h:57

cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::WarpCount

WarpCount_ WarpCount

Definition: output_tile_thread_map.h:443

cutlass::epilogue::threadblock::OutputTileThreadMap::Iterations

Iterations_ Iterations

Iterations performed by each thread.

Definition: output_tile_thread_map.h:91

cutlass::epilogue::threadblock::OutputTileShape::kColumn

static int const kColumn

Definition: output_tile_thread_map.h:58

cutlass::epilogue::threadblock::detail::RowArrangement

RowArrangement determines how one or more warps cover a region of consecutive rows.

Definition: output_tile_thread_map.h:133

cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap

Definition: output_tile_thread_map.h:442

cutlass::layout::PitchLinearShape

Template defining a shape used by pitch-linear operators.

Definition: pitch_linear.h:43

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::CompactedThreadMap

Compacted thread map in which the 4D region is contiguous.

Definition: output_tile_thread_map.h:369

cutlass::epilogue::threadblock::OutputTileThreadMap::Count

Count_ Count

Number of iterator iterations.

Definition: output_tile_thread_map.h:97

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::CompactedThreadMap::initial_offset

static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)

Function to compute each thread's initial offset.

Definition: output_tile_thread_map.h:396

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::Shape

Shape_ Shape

Definition: output_tile_thread_map.h:230

cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::initial_offset

static CUTLASS_HOST_DEVICE layout::PitchLinearCoord initial_offset(int thread_idx)

Initial offset function.

Definition: output_tile_thread_map.h:469

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::Detail::RowArrangement

detail::RowArrangement< Shape, kWarpsRemainingForRows, kElementsPerAccess, kElementSize,(Shape::kRow > kWarpsRemainingForRows) > RowArrangement

Definition: output_tile_thread_map.h:303

cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::Iterations

MmaCount Iterations

Definition: output_tile_thread_map.h:463

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::layout::PitchLinearCoord::contiguous

CUTLASS_HOST_DEVICE Index const & contiguous() const

Returns the contiguous dimension.

Definition: pitch_linear.h:89

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::CompactedThreadMap::Shape

Shape_ Shape

Definition: output_tile_thread_map.h:372

cutlass::epilogue::threadblock::OutputTileThreadMap::Delta

Delta_ Delta

Delta between accesses.

Definition: output_tile_thread_map.h:94

cutlass::epilogue::threadblock::OutputTileThreadMap::initial_offset

static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)

Initial offset function.

Definition: output_tile_thread_map.h:101

cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::Detail

Definition: output_tile_thread_map.h:457

cutlass::epilogue::threadblock::OutputTileShape::kRow

static int const kRow

Definition: output_tile_thread_map.h:59

matrix.h

Defines layout functions used by TensorRef and derived classes.

fast_math.h

Math utilities.

cutlass::epilogue::threadblock::OutputTileThreadMap

Definition: output_tile_thread_map.h:76

cutlass::epilogue::threadblock::OutputTileThreadMap::Shape

Shape_ Shape

Shape of the tile.

Definition: output_tile_thread_map.h:88

cutlass::epilogue::threadblock::OutputTileShape::kTile

static int const kTile

Definition: output_tile_thread_map.h:62

cutlass::epilogue::threadblock::OutputTileShape::kCount

static int const kCount

Definition: output_tile_thread_map.h:64

cutlass::epilogue::threadblock::InterleavedOutputTileThreadMap::MmaCount

MmaCount_ MmaCount

Definition: output_tile_thread_map.h:444

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::initial_offset

static CUTLASS_HOST_DEVICE MatrixCoord initial_offset(int thread_idx)

Initial offset function.

Definition: output_tile_thread_map.h:337

cutlass::const_min

CUTLASS_HOST_DEVICE constexpr int const_min(int a, int b)

Definition: fast_math.h:219

cutlass.h

Basic include for CUTLASS.

cutlass::MatrixCoord

Definition: matrix_coord.h:39

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap::Detail

Definition: output_tile_thread_map.h:244

cutlass::epilogue::threadblock::OutputTileShape::kCluster

static int const kCluster

Definition: output_tile_thread_map.h:61

cutlass::layout::PitchLinearCoord::strided

CUTLASS_HOST_DEVICE Index const & strided() const

Returns the column of the coordinate.

Definition: pitch_linear.h:97


Generated by 1.8.11