Back to Cutlass

CUTLASS: default_gemm_configuration.h Source File

docs/default__gemm__configuration_8h_source.html

4.4.250.3 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

default_gemm_configuration.h

[Go to the documentation of this file.](default gemm configuration_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/cutlass.h"

32 #include "cutlass/numeric_types.h"

33 #include "cutlass/arch/arch.h"

34 #include "cutlass/arch/mma.h"

35 #include "cutlass/arch/wmma.h"

36

37 #include "cutlass/gemm/gemm.h"

38 #include "cutlass/epilogue/thread/linear_combination.h"

39 #include "[cutlass/epilogue/thread/linear_combination_clamp.h](linear combination clamp_8h.html)"

40

42

43 namespace cutlass {

44 namespace gemm {

45 namespace device {

46

48

49 template <

50typename OperatorClass,

51typename ArchTag,

52typename ElementA,

53typename ElementB,

54typename ElementC,

55typename ElementAccumulator

56 >

57 struct DefaultGemmConfiguration;

58

60

61 template <

62typename ArchTag,

63typename ElementA,

64typename ElementB,

65typename ElementC,

66typename ElementAccumulator>

67 struct DefaultGemmConfiguration<

68 arch::OpClassSimt,

69 ArchTag,

70 ElementA,

71 ElementB,

72 ElementC,

73 ElementAccumulator> {

74

75static int const kAlignmentA = 1;

76static int const kAlignmentB = 1;

77using ThreadblockShape = GemmShape<128, 128, 8>;

78using WarpShape = GemmShape<32, 64, 8>;

79using InstructionShape = GemmShape<1, 1, 1>;

80static int const kStages = 2;

81

82using EpilogueOutputOp = epilogue::thread::LinearCombination<

83 ElementC,

84 1,

85 ElementAccumulator,

86 ElementAccumulator

87 >;

88

89using Operator = arch::OpMultiplyAdd;

90 };

91

93

94 template <

95typename ArchTag,

96typename ElementC>

97 struct DefaultGemmConfiguration<arch::OpClassSimt, ArchTag, int8_t, int8_t, ElementC, int32_t> {

98

99static int const kAlignmentA = 4;

100static int const kAlignmentB = 4;

101using ThreadblockShape = GemmShape<128, 128, 32>;

102using WarpShape = GemmShape<32, 64, 32>;

103using InstructionShape = GemmShape<1, 1, 4>;

104static int const kStages = 2;

105

106using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

107 ElementC,

108 1,

109 int32_t,

110float

111 >;

112

113using Operator = arch::OpMultiplyAdd;

114 };

115

117

118 template <

119typename ArchTag,

120typename ElementA,

121typename ElementB,

122typename ElementC,

123typename ElementAccumulator>

124 struct DefaultGemmConfiguration<

125 arch::OpClassWmmaTensorOp,

126 ArchTag,

127 ElementA,

128 ElementB,

129 ElementC,

130 ElementAccumulator> {

131

132static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;

133static int const kAlignmentB = 128 / sizeof_bits<ElementB>::value;

134

135static int const kStages = 2;

136

137using EpilogueOutputOp = epilogue::thread::LinearCombination<

138 ElementC,

139 128 / sizeof_bits<ElementC>::value,

140 ElementAccumulator,

141 ElementAccumulator

142 >;

143

144using Operator = arch::OpMultiplyAdd;

145 };

146

148

149 template <

150typename ElementA,

151typename ElementB,

152typename ElementC,

153typename ElementAccumulator>

154 struct DefaultGemmConfiguration<

155 arch::OpClassTensorOp,

156arch::Sm70,

157 ElementA,

158 ElementB,

159 ElementC,

160 ElementAccumulator> {

161

162static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;

163static int const kAlignmentB = 128 / sizeof_bits<ElementB>::value;

164

165using ThreadblockShape = GemmShape<128, 256, 32>;

166using WarpShape = GemmShape<64, 64, 32>;

167using InstructionShape = GemmShape<16, 16, 4>;

168static int const kStages = 2;

169

170using EpilogueOutputOp = epilogue::thread::LinearCombination<

171 ElementC,

172 128 / sizeof_bits<ElementC>::value,

173 ElementAccumulator,

174 ElementAccumulator

175 >;

176

177using Operator = arch::OpMultiplyAdd;

178 };

179

181

182 template <

183typename ElementA,

184typename ElementB,

185typename ElementC,

186typename ElementAccumulator>

187 struct DefaultGemmConfiguration<

188 arch::OpClassTensorOp,

189arch::Sm75,

190 ElementA,

191 ElementB,

192 ElementC,

193 ElementAccumulator> {

194

195static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;

196static int const kAlignmentB = 128 / sizeof_bits<ElementA>::value;

197using ThreadblockShape = GemmShape<128, 256, 32>;

198using WarpShape = GemmShape<64, 64, 32>;

199using InstructionShape = GemmShape<16, 8, 8>;

200static int const kStages = 2;

201

202using EpilogueOutputOp = epilogue::thread::LinearCombination<

203 ElementC,

204 128 / sizeof_bits<ElementC>::value,

205 ElementAccumulator,

206 ElementAccumulator

207 >;

208

209using Operator = typename platform::conditional<

210 (platform::is_same<ElementA, int8_t>::value ||

211platform::is_same<ElementA, int4b_t>::value ||

212platform::is_same<ElementA, uint8_t>::value ||

213platform::is_same<ElementA, uint4b_t>::value),

214 arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type;

215 };

216

218

219 template <

220typename ElementC>

221 struct DefaultGemmConfiguration<

222 arch::OpClassTensorOp,

223arch::Sm75,

224 int8_t,

225 int8_t,

226 ElementC,

227 int32_t> {

228

229static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;

230static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;

231

232using ThreadblockShape = GemmShape<128, 256, 64>;

233using WarpShape = GemmShape<64, 64, 64>;

234using InstructionShape = GemmShape<8, 8, 16>;

235static int const kStages = 2;

236

237using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

238 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

239

240using Operator = arch::OpMultiplyAddSaturate;

241 };

242

244

245 template <

246typename ElementC>

247 struct DefaultGemmConfiguration<

248 arch::OpClassTensorOp,

249arch::Sm75,

250 int8_t,

251 uint8_t,

252 ElementC,

253 int32_t> {

254

255static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;

256static int const kAlignmentB = 128 / sizeof_bits<uint8_t>::value;

257

258using ThreadblockShape = GemmShape<128, 256, 64>;

259using WarpShape = GemmShape<64, 64, 64>;

260using InstructionShape = GemmShape<8, 8, 16>;

261static int const kStages = 2;

262

263using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

264 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

265

266using Operator = arch::OpMultiplyAddSaturate;

267 };

268

270

271 template <

272typename ElementC>

273 struct DefaultGemmConfiguration<

274 arch::OpClassTensorOp,

275arch::Sm75,

276 uint8_t,

277 int8_t,

278 ElementC,

279 int32_t> {

280

281static int const kAlignmentA = 128 / sizeof_bits<uint8_t>::value;

282static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;

283

284using ThreadblockShape = GemmShape<128, 256, 64>;

285using WarpShape = GemmShape<64, 64, 64>;

286using InstructionShape = GemmShape<8, 8, 16>;

287static int const kStages = 2;

288

289using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

290 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

291

292using Operator = arch::OpMultiplyAddSaturate;

293 };

294

296

297 template <

298typename ElementC>

299 struct DefaultGemmConfiguration<

300 arch::OpClassTensorOp,

301arch::Sm75,

302 uint8_t,

303 uint8_t,

304 ElementC,

305 int32_t> {

306

307static int const kAlignmentA = 128 / sizeof_bits<uint8_t>::value;

308static int const kAlignmentB = 128 / sizeof_bits<uint8_t>::value;

309

310using ThreadblockShape = GemmShape<128, 256, 64>;

311using WarpShape = GemmShape<64, 64, 64>;

312using InstructionShape = GemmShape<8, 8, 16>;

313static int const kStages = 2;

314

315using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

316 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

317

318using Operator = arch::OpMultiplyAddSaturate;

319 };

320

322

323 template <

324typename ElementC>

325 struct DefaultGemmConfiguration<

326 arch::OpClassTensorOp,

327arch::Sm75,

328int4b_t,

329int4b_t,

330 ElementC,

331 int32_t> {

332

333static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;

334static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;

335

336using ThreadblockShape = GemmShape<128, 256, 128>;

337using WarpShape = GemmShape<64, 64, 128>;

338using InstructionShape = GemmShape<8, 8, 32>;

339static int const kStages = 2;

340

341using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

342 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

343

344using Operator = arch::OpMultiplyAddSaturate;

345 };

346

348

349 template <

350typename ElementC>

351 struct DefaultGemmConfiguration<

352 arch::OpClassTensorOp,

353arch::Sm75,

354int4b_t,

355uint4b_t,

356 ElementC,

357 int32_t> {

358

359static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;

360static int const kAlignmentB = 128 / sizeof_bits<uint4b_t>::value;

361

362using ThreadblockShape = GemmShape<128, 256, 128>;

363using WarpShape = GemmShape<64, 64, 128>;

364using InstructionShape = GemmShape<8, 8, 32>;

365static int const kStages = 2;

366

367using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

368 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

369

370using Operator = arch::OpMultiplyAddSaturate;

371 };

372

374

375 template <

376typename ElementC>

377 struct DefaultGemmConfiguration<

378 arch::OpClassTensorOp,

379arch::Sm75,

380uint4b_t,

381int4b_t,

382 ElementC,

383 int32_t> {

384

385static int const kAlignmentA = 128 / sizeof_bits<uint4b_t>::value;

386static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;

387

388using ThreadblockShape = GemmShape<128, 256, 128>;

389using WarpShape = GemmShape<64, 64, 128>;

390using InstructionShape = GemmShape<8, 8, 32>;

391static int const kStages = 2;

392

393using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

394 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

395

396using Operator = arch::OpMultiplyAddSaturate;

397 };

398

400

401 template <

402typename ElementC>

403 struct DefaultGemmConfiguration<

404 arch::OpClassTensorOp,

405arch::Sm75,

406uint4b_t,

407uint4b_t,

408 ElementC,

409 int32_t> {

410

411static int const kAlignmentA = 128 / sizeof_bits<uint4b_t>::value;

412static int const kAlignmentB = 128 / sizeof_bits<uint4b_t>::value;

413

414using ThreadblockShape = GemmShape<128, 256, 128>;

415using WarpShape = GemmShape<64, 64, 128>;

416using InstructionShape = GemmShape<8, 8, 32>;

417static int const kStages = 2;

418

419using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<

420 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

421

422using Operator = arch::OpMultiplyAddSaturate;

423 };

424

426 } // namespace device

427 } // namespace gemm

428 } // namespace cutlass

429

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, int8_t, int8_t, ElementC, int32_t >::Operator

arch::OpMultiplyAddSaturate Operator

Definition: default_gemm_configuration.h:240

cutlass::epilogue::thread::LinearCombination

Definition: linear_combination.h:56

cutlass::platform::is_same

std::is_same (false specialization)

Definition: platform.h:394

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, uint4b_t, uint4b_t, ElementC, int32_t >::Operator

arch::OpMultiplyAddSaturate Operator

Definition: default_gemm_configuration.h:422

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, uint8_t, int8_t, ElementC, int32_t >::Operator

arch::OpMultiplyAddSaturate Operator

Definition: default_gemm_configuration.h:292

cutlass::epilogue::thread::LinearCombinationClamp

Definition: linear_combination_clamp.h:58

cutlass::integer_subbyte

4-bit signed integer type

Definition: integer_subbyte.h:42

[linear_combination_clamp.h](linear combination clamp_8h.html)

Functor performing linear scaling operations used by epilogues. Values are clamped before converting ...

cutlass::arch::Sm70

Definition: arch.h:46

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, ElementA, ElementB, ElementC, ElementAccumulator >::Operator

typename platform::conditional< (platform::is_same< ElementA, int8_t >::value||platform::is_same< ElementA, int4b_t >::value||platform::is_same< ElementA, uint8_t >::value||platform::is_same< ElementA, uint4b_t >::value), arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd >::type Operator

Definition: default_gemm_configuration.h:214

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, uint4b_t, int4b_t, ElementC, int32_t >::Operator

arch::OpMultiplyAddSaturate Operator

Definition: default_gemm_configuration.h:396

mma.h

Templates exposing architecture support for multiply-add operations.

cutlass::arch::Sm75

Definition: arch.h:52

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassWmmaTensorOp, ArchTag, ElementA, ElementB, ElementC, ElementAccumulator >::Operator

arch::OpMultiplyAdd Operator

Definition: default_gemm_configuration.h:144

linear_combination.h

Functor performing linear combination operations used by epilogues.

cutlass::sizeof_bits

Defines the size of an element in bits.

Definition: numeric_types.h:42

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassSimt, ArchTag, int8_t, int8_t, ElementC, int32_t >::Operator

arch::OpMultiplyAdd Operator

Definition: default_gemm_configuration.h:113

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, int8_t, uint8_t, ElementC, int32_t >::Operator

arch::OpMultiplyAddSaturate Operator

Definition: default_gemm_configuration.h:266

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::platform::conditional

std::conditional (true specialization)

Definition: platform.h:325

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm70, ElementA, ElementB, ElementC, ElementAccumulator >::Operator

arch::OpMultiplyAdd Operator

Definition: default_gemm_configuration.h:177

cutlass::gemm::device::DefaultGemmConfiguration

Definition: default_gemm_configuration.h:57

arch.h

Defines tags for architecture-specific configurations.

wmma.h

Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, uint8_t, uint8_t, ElementC, int32_t >::Operator

arch::OpMultiplyAddSaturate Operator

Definition: default_gemm_configuration.h:318

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, int4b_t, int4b_t, ElementC, int32_t >::Operator

arch::OpMultiplyAddSaturate Operator

Definition: default_gemm_configuration.h:344

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassTensorOp, arch::Sm75, int4b_t, uint4b_t, ElementC, int32_t >::Operator

arch::OpMultiplyAddSaturate Operator

Definition: default_gemm_configuration.h:370

cutlass::gemm::device::DefaultGemmConfiguration< arch::OpClassSimt, ArchTag, ElementA, ElementB, ElementC, ElementAccumulator >::Operator

arch::OpMultiplyAdd Operator

Definition: default_gemm_configuration.h:89


Generated by 1.8.11