Back to Cutlass

CUTLASS: mma_sm70.h Source File

docs/mma__sm70_8h_source.html

4.4.273.0 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

mma_sm70.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

28 #pragma once

29

30 #include <assert.h>

31

32 #include "mma.h"

33 #include "cutlass/layout/matrix.h"

34 #include "cutlass/numeric_types.h"

35

36 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))

37 #define CUTLASS_ARCH_MMA_SM70_SUPPORTED

38 #endif

39

40 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))

41

42 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1))

43 #define CUTLASS_ARCH_MMA_SM70_ENABLED

44 #endif

45

46 #endif

47

49

50 namespace cutlass {

51 namespace arch {

52

54 //

55 // Matrix multiply accumulate 884 - FP16 accumulation

56 //

58

60 template <>

61 struct Mma<

62 gemm::GemmShape<8,8,4>,

63 8,

64half_t,

65layout::ColumnMajor,

66half_t,

67layout::ColumnMajor,

68half_t,

69layout::RowMajor,

70 OpMultiplyAdd> {

71

72using Shape = gemm::GemmShape<8, 8, 4>;

73

74using ElementA = half_t;

75using LayoutA = layout::ColumnMajor;

76using FragmentA = Array<half_t, 4>;

77

78using ElementB = half_t;

79using LayoutB = layout::ColumnMajor;

80using FragmentB = Array<half_t, 4>;

81

82using ElementC = half_t;

83using LayoutC = layout::RowMajor;

84using FragmentC = Array<half_t, 8>;

85

86using Operator = OpMultiplyAdd;

87

88CUTLASS_HOST_DEVICE

89void operator()(

90FragmentC &d,

91FragmentA const &a,

92FragmentB const &b,

93FragmentC const &c

94 ) {

95

96 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)

97

98unsigned const *A = reinterpret_cast<unsigned const *>(&a);

99unsigned const *B = reinterpret_cast<unsigned const *>(&b);

100unsigned const *C = reinterpret_cast<unsigned const *>(&c);

101unsigned *D = reinterpret_cast<unsigned *>(&d);

102

103asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"

104 : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])

105 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])

106 );

107

108 #else

109 assert(0);

110 #endif

111 }

112 };

113

115 template <>

116 struct Mma<

117 gemm::GemmShape<8, 8, 4>,

118 8,

119half_t,

120layout::ColumnMajor,

121half_t,

122layout::RowMajor,

123half_t,

124layout::RowMajor,

125 OpMultiplyAdd> {

126

127using Shape = gemm::GemmShape<8, 8, 4>;

128

129using ElementA = half_t;

130using LayoutA = layout::ColumnMajor;

131using FragmentA = Array<half_t, 4>;

132

133using ElementB = half_t;

134using LayoutB = layout::RowMajor;

135using FragmentB = Array<half_t, 4>;

136

137using ElementC = half_t;

138using LayoutC = layout::RowMajor;

139using FragmentC = Array<half_t, 8>;

140

141using Operator = OpMultiplyAdd;

142

143CUTLASS_HOST_DEVICE

144void operator()(

145FragmentC &d,

146FragmentA const &a,

147FragmentB const &b,

148FragmentC const &c

149 ) {

150

151 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)

152

153unsigned const *A = reinterpret_cast<unsigned const *>(&a);

154unsigned const *B = reinterpret_cast<unsigned const *>(&b);

155unsigned const *C = reinterpret_cast<unsigned const *>(&c);

156unsigned *D = reinterpret_cast<unsigned *>(&d);

157

158asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"

159 : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])

160 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])

161 );

162

163 #else

164 assert(0);

165 #endif

166 }

167 };

168

170 template <>

171 struct Mma<

172 gemm::GemmShape<8, 8, 4>,

173 8,

174half_t,

175layout::RowMajor,

176half_t,

177layout::ColumnMajor,

178half_t,

179layout::RowMajor,

180 OpMultiplyAdd> {

181

182using Shape = gemm::GemmShape<8, 8, 4>;

183

184using ElementA = half_t;

185using LayoutA = layout::RowMajor;

186using FragmentA = Array<half_t, 4>;

187

188using ElementB = half_t;

189using LayoutB = layout::ColumnMajor;

190using FragmentB = Array<half_t, 4>;

191

192using ElementC = half_t;

193using LayoutC = layout::RowMajor;

194using FragmentC = Array<half_t, 8>;

195

196using Operator = OpMultiplyAdd;

197

198CUTLASS_HOST_DEVICE

199void operator()(

200FragmentC &d,

201FragmentA const &a,

202FragmentB const &b,

203FragmentC const &c

204 ) {

205

206 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)

207

208unsigned const *A = reinterpret_cast<unsigned const *>(&a);

209unsigned const *B = reinterpret_cast<unsigned const *>(&b);

210unsigned const *C = reinterpret_cast<unsigned const *>(&c);

211unsigned *D = reinterpret_cast<unsigned *>(&d);

212

213asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"

214 : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])

215 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])

216 );

217

218 #else

219 assert(0);

220 #endif

221 }

222 };

223

225 template <>

226 struct Mma<

227 gemm::GemmShape<8, 8, 4>,

228 8,

229half_t,

230layout::RowMajor,

231half_t,

232layout::RowMajor,

233half_t,

234layout::RowMajor,

235 OpMultiplyAdd> {

236

237using Shape = gemm::GemmShape<8, 8, 4>;

238

239using ElementA = half_t;

240using LayoutA = layout::RowMajor;

241using FragmentA = Array<half_t, 4>;

242

243using ElementB = half_t;

244using LayoutB = layout::RowMajor;

245using FragmentB = Array<half_t, 4>;

246

247using ElementC = half_t;

248using LayoutC = layout::RowMajor;

249using FragmentC = Array<half_t, 8>;

250

251using Operator = OpMultiplyAdd;

252

253CUTLASS_HOST_DEVICE

254void operator()(

255FragmentC &d,

256FragmentA const &a,

257FragmentB const &b,

258FragmentC const &c

259 ) {

260

261 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)

262

263unsigned const *A = reinterpret_cast<unsigned const *>(&a);

264unsigned const *B = reinterpret_cast<unsigned const *>(&b);

265unsigned const *C = reinterpret_cast<unsigned const *>(&c);

266unsigned *D = reinterpret_cast<unsigned *>(&d);

267

268asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"

269 : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])

270 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])

271 );

272

273 #else

274 assert(0);

275 #endif

276 }

277 };

278

280 //

281 // Matrix multiply accumulate 884 - FP32 accumulation

282 //

284

286 template <>

287 struct Mma<

288 gemm::GemmShape<8, 8, 4>,

289 8,

290half_t,

291layout::ColumnMajor,

292half_t,

293layout::ColumnMajor,

294 float,

295layout::RowMajor,

296 OpMultiplyAdd> {

297

298using Shape = gemm::GemmShape<8, 8, 4>;

299

300using ElementA = half_t;

301using LayoutA = layout::ColumnMajor;

302using FragmentA = Array<half_t, 4>;

303

304using ElementB = half_t;

305using LayoutB = layout::ColumnMajor;

306using FragmentB = Array<half_t, 4>;

307

308using ElementC = float;

309using LayoutC = layout::RowMajor;

310using FragmentC = Array<float, 8>;

311

312using Operator = OpMultiplyAdd;

313

315CUTLASS_HOST_DEVICE

316void operator()(

317FragmentC &d,

318FragmentA const &a,

319FragmentB const &b,

320FragmentC const &c

321 ) {

322

323 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)

324

325unsigned const *A = reinterpret_cast<unsigned const *>(&a);

326unsigned const *B = reinterpret_cast<unsigned const *>(&b);

327float const *C = reinterpret_cast<float const *>(&c);

328float *D = reinterpret_cast<float *>(&d);

329

330asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "

331"{%12,%13,%14,%15,%16,%17,%18,%19};\n"

332 : "=f"(D[0]),

333"=f"(D[1]),

334"=f"(D[2]),

335"=f"(D[3]),

336"=f"(D[4]),

337"=f"(D[5]),

338"=f"(D[6]),

339"=f"(D[7])

340 : "r"(A[0]),

341"r"(A[1]),

342"r"(B[0]),

343"r"(B[1]),

344"f"(C[0]),

345"f"(C[1]),

346"f"(C[2]),

347"f"(C[3]),

348"f"(C[4]),

349"f"(C[5]),

350"f"(C[6]),

351"f"(C[7])

352 );

353

354 #else

355 assert(0);

356 #endif

357 }

358 };

359

361 template <>

362 struct Mma<

363 gemm::GemmShape<8, 8, 4>,

364 8,

365half_t,

366layout::ColumnMajor,

367half_t,

368layout::RowMajor,

369 float,

370layout::RowMajor,

371 OpMultiplyAdd> {

372

373using Shape = gemm::GemmShape<8, 8, 4>;

374

375using ElementA = half_t;

376using LayoutA = layout::ColumnMajor;

377using FragmentA = Array<half_t, 4>;

378

379using ElementB = half_t;

380using LayoutB = layout::RowMajor;

381using FragmentB = Array<half_t, 4>;

382

383using ElementC = float;

384using LayoutC = layout::RowMajor;

385using FragmentC = Array<float, 8>;

386

387using Operator = OpMultiplyAdd;

388

390CUTLASS_HOST_DEVICE

391void operator()(

392FragmentC &d,

393FragmentA const &a,

394FragmentB const &b,

395FragmentC const &c

396 ) {

397

398 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)

399

400unsigned const *A = reinterpret_cast<unsigned const *>(&a);

401unsigned const *B = reinterpret_cast<unsigned const *>(&b);

402float const *C = reinterpret_cast<float const *>(&c);

403float *D = reinterpret_cast<float *>(&d);

404

405asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "

406"{%12,%13,%14,%15,%16,%17,%18,%19};\n"

407 : "=f"(D[0]),

408"=f"(D[1]),

409"=f"(D[2]),

410"=f"(D[3]),

411"=f"(D[4]),

412"=f"(D[5]),

413"=f"(D[6]),

414"=f"(D[7])

415 : "r"(A[0]),

416"r"(A[1]),

417"r"(B[0]),

418"r"(B[1]),

419"f"(C[0]),

420"f"(C[1]),

421"f"(C[2]),

422"f"(C[3]),

423"f"(C[4]),

424"f"(C[5]),

425"f"(C[6]),

426"f"(C[7])

427 );

428

429 #else

430 assert(0);

431 #endif

432 }

433 };

434

436 template <>

437 struct Mma<

438 gemm::GemmShape<8, 8, 4>,

439 8,

440half_t,

441layout::RowMajor,

442half_t,

443layout::ColumnMajor,

444 float,

445layout::RowMajor,

446 OpMultiplyAdd> {

447

448using Shape = gemm::GemmShape<8, 8, 4>;

449

450using ElementA = half_t;

451using LayoutA = layout::RowMajor;

452using FragmentA = Array<half_t, 4>;

453

454using ElementB = half_t;

455using LayoutB = layout::ColumnMajor;

456using FragmentB = Array<half_t, 4>;

457

458using ElementC = float;

459using LayoutC = layout::RowMajor;

460using FragmentC = Array<float, 8>;

461

462using Operator = OpMultiplyAdd;

463

465CUTLASS_HOST_DEVICE

466void operator()(

467FragmentC &d,

468FragmentA const &a,

469FragmentB const &b,

470FragmentC const &c

471 ) {

472

473 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)

474

475unsigned const *A = reinterpret_cast<unsigned const *>(&a);

476unsigned const *B = reinterpret_cast<unsigned const *>(&b);

477float const *C = reinterpret_cast<float const *>(&c);

478float *D = reinterpret_cast<float *>(&d);

479

480asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "

481"{%12,%13,%14,%15,%16,%17,%18,%19};\n"

482 : "=f"(D[0]),

483"=f"(D[1]),

484"=f"(D[2]),

485"=f"(D[3]),

486"=f"(D[4]),

487"=f"(D[5]),

488"=f"(D[6]),

489"=f"(D[7])

490 : "r"(A[0]),

491"r"(A[1]),

492"r"(B[0]),

493"r"(B[1]),

494"f"(C[0]),

495"f"(C[1]),

496"f"(C[2]),

497"f"(C[3]),

498"f"(C[4]),

499"f"(C[5]),

500"f"(C[6]),

501"f"(C[7])

502 );

503

504 #else

505 assert(0);

506 #endif

507 }

508 };

509

511 template <>

512 struct Mma<

513 gemm::GemmShape<8, 8, 4>,

514 8,

515half_t,

516layout::RowMajor,

517half_t,

518layout::RowMajor,

519 float,

520layout::RowMajor,

521 OpMultiplyAdd> {

522

523using Shape = gemm::GemmShape<8, 8, 4>;

524

525using ElementA = half_t;

526using LayoutA = layout::RowMajor;

527using FragmentA = Array<half_t, 4>;

528

529using ElementB = half_t;

530using LayoutB = layout::RowMajor;

531using FragmentB = Array<half_t, 4>;

532

533using ElementC = float;

534using LayoutC = layout::RowMajor;

535using FragmentC = Array<float, 8>;

536

537using Operator = OpMultiplyAdd;

538

540CUTLASS_HOST_DEVICE

541void operator()(

542FragmentC &d,

543FragmentA const &a,

544FragmentB const &b,

545FragmentC const &c

546 ) {

547

548 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)

549

550unsigned const *A = reinterpret_cast<unsigned const *>(&a);

551unsigned const *B = reinterpret_cast<unsigned const *>(&b);

552float const *C = reinterpret_cast<float const *>(&c);

553float *D = reinterpret_cast<float *>(&d);

554

555asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "

556"{%12,%13,%14,%15,%16,%17,%18,%19};\n"

557 : "=f"(D[0]),

558"=f"(D[1]),

559"=f"(D[2]),

560"=f"(D[3]),

561"=f"(D[4]),

562"=f"(D[5]),

563"=f"(D[6]),

564"=f"(D[7])

565 : "r"(A[0]),

566"r"(A[1]),

567"r"(B[0]),

568"r"(B[1]),

569"f"(C[0]),

570"f"(C[1]),

571"f"(C[2]),

572"f"(C[3]),

573"f"(C[4]),

574"f"(C[5]),

575"f"(C[6]),

576"f"(C[7])

577 );

578

579 #else

580 assert(0);

581 #endif

582 }

583 };

584

586

588 template <

589typename LayoutA,

590typename LayoutB,

591typename ElementC,

592typename LayoutC,

593typename Operator

594 >

595 struct Mma<

596 gemm::GemmShape<16, 16, 4>,

597 32,

598half_t,

599 LayoutA,

600half_t,

601 LayoutB,

602 ElementC,

603 LayoutC,

604 Operator

605 > :

606public Mma<

607 gemm::GemmShape<8, 8, 4>,

608 8,

609 half_t,

610 LayoutA,

611 half_t,

612 LayoutB,

613 ElementC,

614 LayoutC,

615 Operator> {

616

617using Shape = gemm::GemmShape<16, 16, 4>;

618 };

619

621

622 } // namespace arch

623 } // namespace cutlass

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< half_t, 8 > FragmentC

Definition: mma_sm70.h:84

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 4 > FragmentB

Definition: mma_sm70.h:245

cutlass

Definition: aligned_buffer.h:35

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< float, 8 > FragmentC

Definition: mma_sm70.h:535

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< half_t, 8 > FragmentC

Definition: mma_sm70.h:194

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC

float ElementC

Definition: mma_sm70.h:308

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)

Definition: mma_sm70.h:199

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm70.h:86

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm70.h:131

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm70.h:312

cutlass::half_t

IEEE half-precision floating-point type.

Definition: half.h:126

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< float, 8 > FragmentC

Definition: mma_sm70.h:310

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)

Multiply-add.

Definition: mma_sm70.h:391

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 4 > FragmentB

Definition: mma_sm70.h:80

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< float, 8 > FragmentC

Definition: mma_sm70.h:385

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC

float ElementC

Definition: mma_sm70.h:458

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)

Definition: mma_sm70.h:89

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC

float ElementC

Definition: mma_sm70.h:383

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)

Multiply-add.

Definition: mma_sm70.h:316

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm70.h:241

mma.h

Templates exposing architecture support for multiply-add operations.

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 4 > FragmentB

Definition: mma_sm70.h:531

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)

Definition: mma_sm70.h:254

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm70.h:251

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)

Definition: mma_sm70.h:144

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm70.h:377

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 4 > FragmentB

Definition: mma_sm70.h:306

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC

float ElementC

Definition: mma_sm70.h:533

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm70.h:76

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm70.h:186

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm70.h:302

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm70.h:196

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< half_t, 8 > FragmentC

Definition: mma_sm70.h:249

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm70.h:527

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< float, 8 > FragmentC

Definition: mma_sm70.h:460

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm70.h:141

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 4 > FragmentB

Definition: mma_sm70.h:190

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)

Multiply-add.

Definition: mma_sm70.h:541

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC

Array< half_t, 8 > FragmentC

Definition: mma_sm70.h:139

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 4 > FragmentB

Definition: mma_sm70.h:135

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm70.h:387

cutlass::arch::Mma

Matrix multiply-add operation.

Definition: arch/mma.h:92

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)

Multiply-add.

Definition: mma_sm70.h:466

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm70.h:537

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator

OpMultiplyAdd Operator

Definition: mma_sm70.h:462

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA

Array< half_t, 4 > FragmentA

Definition: mma_sm70.h:452

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 4 > FragmentB

Definition: mma_sm70.h:456

cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB

Array< half_t, 4 > FragmentB

Definition: mma_sm70.h:381


Generated by 1.8.11