Back to Cutlass

CUTLASS: threadblock_swizzle.h Source File

docs/gemm_2threadblock_2threadblock__swizzle_8h_source.html

4.4.241.9 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

gemm/threadblock/threadblock_swizzle.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

30 #pragma once

31

32 #include "cutlass/cutlass.h"

33

34 #include "cutlass/gemm/gemm.h"

35

37

38 namespace cutlass {

39 namespace gemm {

40 namespace threadblock {

41

43

45 CUTLASS_DEVICE

46 int RematerializeThreadIdxX() {

47return threadIdx.x;

48 }

49

51 CUTLASS_DEVICE

52 int RematerializeThreadIdxY() {

53return threadIdx.y;

54 }

55

57 CUTLASS_DEVICE

58 int RematerializeThreadIdxZ() {

59return threadIdx.z;

60 }

61

63 CUTLASS_DEVICE

64 int RematerializeBlockIdxX() {

65return blockIdx.x;

66 }

67

69 CUTLASS_DEVICE

70 int RematerializeBlockIdxY() {

71return blockIdx.y;

72 }

73

75 CUTLASS_DEVICE

76 int RematerializeBlockIdxZ() {

77return blockIdx.z;

78 }

79

81 CUTLASS_DEVICE

82 int RematerializeBlockDimX() {

83return blockDim.x;

84 }

85

87 CUTLASS_DEVICE

88 int RematerializeBlockDimY() {

89return blockDim.y;

90 }

91

93 CUTLASS_DEVICE

94 int RematerializeBlockDimZ() {

95return blockDim.z;

96 }

97

99

101 struct GemmIdentityThreadblockSwizzle {

102

103CUTLASS_HOST_DEVICE

104GemmIdentityThreadblockSwizzle() { }

105

106int const kTile = 1;

107

109CUTLASS_HOST_DEVICE

110GemmCoord get_tiled_shape(

111GemmCoord problem_size,

112GemmCoord tile_size,

113int split_k_slices) const {

114

115return GemmCoord(

116 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),

117 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),

118 split_k_slices);

119 }

120

122CUTLASS_HOST_DEVICE

123 dim3 get_grid_shape(GemmCoord tiled_shape) const {

124return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k());

125 }

126

128 CUTLASS_DEVICE

129GemmCoord get_tile_offset() const {

130

131int block_idx_x = RematerializeBlockIdxX();

132int block_idx_y = RematerializeBlockIdxY();

133

134return GemmCoord{

135 (block_idx_x / kTile),

136 (block_idx_y * kTile) + (block_idx_x % kTile),

137RematerializeBlockIdxZ()

138 };

139 }

140 };

141

143

145 struct GemmHorizontalThreadblockSwizzle {

146

147CUTLASS_HOST_DEVICE

148GemmHorizontalThreadblockSwizzle() { }

149

151CUTLASS_HOST_DEVICE

152GemmCoord get_tiled_shape(

153GemmCoord problem_size,

154GemmCoord tile_size,

155int split_k_slices) const {

156

157return GemmCoord(

158 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),

159 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),

160 split_k_slices);

161 }

162

164CUTLASS_HOST_DEVICE

165 dim3 get_grid_shape(GemmCoord tiled_shape) const {

166return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k());

167 }

168

170 CUTLASS_DEVICE

171GemmCoord get_tile_offset() const {

172return GemmCoord{

173RematerializeBlockIdxY(),

174RematerializeBlockIdxX(),

175RematerializeBlockIdxZ()

176 };

177 }

178 };

179

181

183 struct GemmBatchedIdentityThreadblockSwizzle {

184

186CUTLASS_HOST_DEVICE

187GemmCoord get_tiled_shape(

188GemmCoord problem_size,

189int batch_count,

190GemmCoord tile_size) const {

191

192return GemmCoord(

193 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),

194 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),

195 batch_count % (1 << 16));

196 }

197

199CUTLASS_HOST_DEVICE

200 dim3 get_grid_shape(GemmCoord tiled_shape) const {

201return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());

202 }

203

205 CUTLASS_DEVICE

206GemmCoord get_tile_offset() const {

207return GemmCoord{

208RematerializeBlockIdxX(),

209RematerializeBlockIdxY(),

210 0

211 };

212 }

213

215 CUTLASS_DEVICE

216int get_batch_idx() const {

217return RematerializeBlockIdxZ();

218 }

219 };

220

222

224 struct GemmSplitKIdentityThreadblockSwizzle {

225

227CUTLASS_HOST_DEVICE

228GemmCoord get_tiled_shape(

229GemmCoord problem_size,

230GemmCoord tile_size,

231int partitions) const {

232

233return GemmCoord(

234 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),

235 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),

236 partitions);

237 }

238

240CUTLASS_HOST_DEVICE

241 dim3 get_grid_shape(GemmCoord tiled_shape) const {

242return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k());

243 }

244

245

247 CUTLASS_DEVICE

248GemmCoord get_tile_offset() const {

249return GemmCoord{

250RematerializeBlockIdxX(),

251RematerializeBlockIdxY(),

252RematerializeBlockIdxZ()

253 };

254 }

255 };

256

258

260 struct GemmSplitKHorizontalThreadblockSwizzle {

261

263CUTLASS_HOST_DEVICE

264GemmCoord get_tiled_shape(

265GemmCoord problem_size,

266GemmCoord tile_size,

267int partitions) const {

268

269return GemmCoord(

270 (problem_size.m() + tile_size.m() - 1) / tile_size.m(),

271 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),

272 partitions);

273 }

274

276CUTLASS_HOST_DEVICE

277 dim3 get_grid_shape(GemmCoord tiled_shape) const {

278return dim3(tiled_shape.n(), tiled_shape.m(), tiled_shape.k());

279 }

280

281

283 CUTLASS_DEVICE

284GemmCoord get_tile_offset() const {

285return GemmCoord{

286RematerializeBlockIdxY(),

287RematerializeBlockIdxX(),

288RematerializeBlockIdxZ()

289 };

290 }

291 };

292

294

296 struct GemvBatchedStridedThreadblockDefaultSwizzle {

297

299CUTLASS_HOST_DEVICE

300BatchedGemmCoord get_tiled_shape(

301BatchedGemmCoord problem_size,

302BatchedGemmCoord tile_size) const {

303

304return BatchedGemmCoord(

305 1, // M is always 1

306 (problem_size.n() + tile_size.n() - 1) / tile_size.n(),

307 (problem_size.k() + tile_size.k() - 1) / tile_size.k(),

308 (problem_size.batch() + tile_size.batch() - 1) / tile_size.batch());

309 }

310

312CUTLASS_HOST_DEVICE

313 dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const {

314return dim3(tiled_shape.n(), tiled_shape.batch(), tiled_shape.k());

315 }

316

318 CUTLASS_DEVICE

319BatchedGemmCoord get_tile_offset() const {

320return BatchedGemmCoord{

321 0, // M is always 1

322RematerializeBlockIdxX(),

323RematerializeBlockIdxZ(),

324RematerializeBlockIdxY(),

325 };

326 }

327

329 CUTLASS_DEVICE

330int get_batch_tile_idx() const {

331return RematerializeBlockIdxY();

332 }

333

335 CUTLASS_DEVICE

336int get_batch_idx() const {

337return RematerializeBlockDimY()*RematerializeBlockIdxY() + RematerializeThreadIdxY();

338 }

339 };

340

342

343 } // namespace threadblock

344 } // namespace gemm

345 } // namespace cutlass

346

cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::kTile

int const kTile

Definition: gemm/threadblock/threadblock_swizzle.h:106

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::threadblock::RematerializeThreadIdxY

CUTLASS_DEVICE int RematerializeThreadIdxY()

Helper to rematerialize block Idx. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:52

cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle::get_grid_shape

CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const

Computes CUDA grid dimensions given a size in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:200

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle::get_tile_offset

CUTLASS_DEVICE GemmCoord get_tile_offset() const

Obtains the threadblock offset (in units of threadblock-scoped tiles)

Definition: gemm/threadblock/threadblock_swizzle.h:206

cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle::get_grid_shape

CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const

Computes CUDA grid dimensions given a size in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:241

cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle::get_tiled_shape

CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int partitions) const

Returns the shape of the problem in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:264

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::gemm::threadblock::RematerializeBlockDimX

CUTLASS_DEVICE int RematerializeBlockDimX()

Helper to rematerialize block Dim. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:82

cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::get_grid_shape

CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const

Computes CUDA grid dimensions given a size in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:123

cutlass::gemm::GemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:145

cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle::get_tiled_shape

CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int split_k_slices) const

Returns the shape of the problem in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:152

cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle::get_grid_shape

CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const

Computes CUDA grid dimensions given a size in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:165

cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_grid_shape

CUTLASS_HOST_DEVICE dim3 get_grid_shape(BatchedGemmCoord tiled_shape) const

Computes CUDA grid dimensions given a size in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:313

cutlass::gemm::threadblock::RematerializeThreadIdxX

CUTLASS_DEVICE int RematerializeThreadIdxX()

Helper to rematerialize block Idx. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:46

cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle

Threadblock swizzling function for GEMMs.

Definition: gemm/threadblock/threadblock_swizzle.h:145

cutlass::gemm::BatchedGemmCoord

Definition: include/cutlass/gemm/gemm.h:260

cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle::GemmHorizontalThreadblockSwizzle

CUTLASS_HOST_DEVICE GemmHorizontalThreadblockSwizzle()

Definition: gemm/threadblock/threadblock_swizzle.h:148

cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_batch_tile_idx

CUTLASS_DEVICE int get_batch_tile_idx() const

Gets the batch tile index.

Definition: gemm/threadblock/threadblock_swizzle.h:330

cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle::get_grid_shape

CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const

Computes CUDA grid dimensions given a size in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:277

cutlass::gemm::threadblock::RematerializeThreadIdxZ

CUTLASS_DEVICE int RematerializeThreadIdxZ()

Helper to rematerialize block Idx. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:58

cutlass::gemm::BatchedGemmCoord::batch

CUTLASS_HOST_DEVICE Index const & batch() const

Returns the GEMM batch coordinate.

Definition: include/cutlass/gemm/gemm.h:322

cutlass::gemm::threadblock::RematerializeBlockDimY

CUTLASS_DEVICE int RematerializeBlockDimY()

Helper to rematerialize block Dim. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:88

cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::GemmIdentityThreadblockSwizzle

CUTLASS_HOST_DEVICE GemmIdentityThreadblockSwizzle()

Definition: gemm/threadblock/threadblock_swizzle.h:104

cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_batch_idx

CUTLASS_DEVICE int get_batch_idx() const

Gets the absolute batch index.

Definition: gemm/threadblock/threadblock_swizzle.h:336

cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::get_tiled_shape

CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int split_k_slices) const

Returns the shape of the problem in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:110

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::gemm::BatchedGemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:314

cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle::get_tile_offset

CUTLASS_DEVICE GemmCoord get_tile_offset() const

Obtains the threadblock offset (in units of threadblock-scoped tiles)

Definition: gemm/threadblock/threadblock_swizzle.h:129

cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle

Threadblock swizzling function for split-K GEMMs.

Definition: gemm/threadblock/threadblock_swizzle.h:260

cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle::get_tiled_shape

CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, int batch_count, GemmCoord tile_size) const

Returns the shape of the problem in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:187

cutlass::gemm::threadblock::RematerializeBlockIdxY

CUTLASS_DEVICE int RematerializeBlockIdxY()

Helper to rematerialize block Idx. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:70

cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle::get_tiled_shape

CUTLASS_HOST_DEVICE GemmCoord get_tiled_shape(GemmCoord problem_size, GemmCoord tile_size, int partitions) const

Returns the shape of the problem in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:228

cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_tiled_shape

CUTLASS_HOST_DEVICE BatchedGemmCoord get_tiled_shape(BatchedGemmCoord problem_size, BatchedGemmCoord tile_size) const

Returns the shape of the problem in units of logical tiles.

Definition: gemm/threadblock/threadblock_swizzle.h:300

cutlass::gemm::threadblock::RematerializeBlockDimZ

CUTLASS_DEVICE int RematerializeBlockDimZ()

Helper to rematerialize block Dim. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:94

cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle::get_tile_offset

CUTLASS_DEVICE BatchedGemmCoord get_tile_offset() const

Obtains the threadblock offset (in units of threadblock-scoped tiles)

Definition: gemm/threadblock/threadblock_swizzle.h:319

cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle

Threadblock swizzling function for GEMMs.

Definition: gemm/threadblock/threadblock_swizzle.h:101

cutlass::gemm::threadblock::GemvBatchedStridedThreadblockDefaultSwizzle

Threadblock swizzling function for batched GEMVs.

Definition: gemm/threadblock/threadblock_swizzle.h:296

cutlass::gemm::BatchedGemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:306

cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle::get_batch_idx

CUTLASS_DEVICE int get_batch_idx() const

Gets the batch index.

Definition: gemm/threadblock/threadblock_swizzle.h:216

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

cutlass::gemm::threadblock::RematerializeBlockIdxZ

CUTLASS_DEVICE int RematerializeBlockIdxZ()

Helper to rematerialize block Idx. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:76

cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle

Threadblock swizzling function for batched GEMMs.

Definition: gemm/threadblock/threadblock_swizzle.h:183

cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle::get_tile_offset

CUTLASS_DEVICE GemmCoord get_tile_offset() const

Obtains the threadblock offset (in units of threadblock-scoped tiles)

Definition: gemm/threadblock/threadblock_swizzle.h:284

cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle::get_tile_offset

CUTLASS_DEVICE GemmCoord get_tile_offset() const

Obtains the threadblock offset (in units of threadblock-scoped tiles)

Definition: gemm/threadblock/threadblock_swizzle.h:171

cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle

Threadblock swizzling function for split-K GEMMs.

Definition: gemm/threadblock/threadblock_swizzle.h:224

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::threadblock::RematerializeBlockIdxX

CUTLASS_DEVICE int RematerializeBlockIdxX()

Helper to rematerialize block Idx. Reduces register liveness.

Definition: gemm/threadblock/threadblock_swizzle.h:64

cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle::get_tile_offset

CUTLASS_DEVICE GemmCoord get_tile_offset() const

Obtains the threadblock offset (in units of threadblock-scoped tiles)

Definition: gemm/threadblock/threadblock_swizzle.h:248


Generated by 1.8.11