Back to Cutlass

CUTLASS: functional.h Source File

docs/functional_8h_source.html

4.4.278.4 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

functional.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

31 #pragma once

32

33 #include "cutlass/cutlass.h"

34 #include "cutlass/numeric_types.h"

35

36 #include "cutlass/complex.h"

37

38 #include "cutlass/array.h"

39 #include "cutlass/half.h"

40

41 namespace cutlass {

42

44

45 template <typename T>

46 struct plus {

47CUTLASS_HOST_DEVICE

48 T operator()(T lhs, T const &rhs) const {

49 lhs += rhs;

50return lhs;

51 }

52 };

53

54 template <typename T>

55 struct minus {

56CUTLASS_HOST_DEVICE

57 T operator()(T lhs, T const &rhs) const {

58 lhs -= rhs;

59return lhs;

60 }

61 };

62

63 template <typename T>

64 struct multiplies {

65CUTLASS_HOST_DEVICE

66 T operator()(T lhs, T const &rhs) const {

67 lhs *= rhs;

68return lhs;

69 }

70 };

71

72 template <typename T>

73 struct divides {

74CUTLASS_HOST_DEVICE

75 T operator()(T lhs, T const &rhs) const {

76 lhs /= rhs;

77return lhs;

78 }

79 };

80

81

82 template <typename T>

83 struct negate {

84CUTLASS_HOST_DEVICE

85 T operator()(T lhs) const {

86return -lhs;

87 }

88 };

89

91 template <typename A, typename B = A, typename C = A>

92 struct multiply_add {

93CUTLASS_HOST_DEVICE

94 C operator()(A const &a, B const &b, C const &c) const {

95return C(a) * C(b) + c;

96 }

97 };

98

100 template <typename T>

101 struct xor_add {

102CUTLASS_HOST_DEVICE

103 T operator()(T const &a, T const &b, T const &c) const {

104return ((a ^ b) + c);

105 }

106 };

107

109 //

110 // Partial specialization for complex<T> to target four scalar fused multiply-adds.

111 //

113

115 template <typename T>

116 struct multiply_add<complex<T>, complex<T>, complex<T>> {

117CUTLASS_HOST_DEVICE

118complex<T> operator()(

119complex<T> const &a,

120complex<T> const &b,

121complex<T> const &c) const {

122

123 T real = c.real();

124 T imag = c.imag();

125

126 real += a.real() * b.real();

127 real += -a.imag() * b.imag();

128 imag += a.real() * b.imag();

129 imag += a.imag () * b.real();

130

131return complex<T>{

132real,

133 imag

134 };

135 }

136 };

137

139 template <typename T>

140 struct multiply_add<complex<T>, T, complex<T>> {

141CUTLASS_HOST_DEVICE

142complex<T> operator()(

143complex<T> const &a,

144 T const &b,

145complex<T> const &c) const {

146

147 T real = c.real();

148 T imag = c.imag();

149

150 real += a.real() * b;

151 imag += a.imag () * b;

152

153return complex<T>{

154real,

155 imag

156 };

157 }

158 };

159

161 template <typename T>

162 struct multiply_add<T, complex<T>, complex<T>> {

163CUTLASS_HOST_DEVICE

164complex<T> operator()(

165 T const &a,

166complex<T> const &b,

167complex<T> const &c) const {

168

169 T real = c.real();

170 T imag = c.imag();

171

172 real += a * b.real();

173 imag += a * b.imag();

174

175return complex<T>{

176real,

177 imag

178 };

179 }

180 };

181

183 //

184 // Partial specializations for Array<T, N>

185 //

187

188 template <typename T, int N>

189 struct plus<Array<T, N>> {

190CUTLASS_HOST_DEVICE

191 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {

192

193 Array<T, N> result;

194plus<T> scalar_op;

195

196CUTLASS_PRAGMA_UNROLL

197for (int i = 0; i < N; ++i) {

198 result[i] = scalar_op(lhs[i], rhs[i]);

199 }

200

201return result;

202 }

203

204CUTLASS_HOST_DEVICE

205 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {

206

207 Array<T, N> result;

208plus<T> scalar_op;

209

210CUTLASS_PRAGMA_UNROLL

211for (int i = 0; i < N; ++i) {

212 result[i] = scalar_op(lhs[i], scalar);

213 }

214

215return result;

216 }

217

218CUTLASS_HOST_DEVICE

219 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {

220

221 Array<T, N> result;

222plus<T> scalar_op;

223

224CUTLASS_PRAGMA_UNROLL

225for (int i = 0; i < N; ++i) {

226 result[i] = scalar_op(scalar, rhs[i]);

227 }

228

229return result;

230 }

231 };

232

233

234 template <typename T>

235 struct maximum {

236

237CUTLASS_HOST_DEVICE

238 T operator()(T const &lhs, T const &rhs) const {

239return (lhs < rhs ? rhs : lhs);

240 }

241 };

242

243 template <>

244 struct maximum<float> {

245CUTLASS_HOST_DEVICE

246float operator()(float const &lhs, float const &rhs) const {

247return fmaxf(lhs, rhs);

248 }

249 };

250

251 template <typename T, int N>

252 struct maximum<Array<T, N>> {

253

254CUTLASS_HOST_DEVICE

255 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {

256

257 Array<T, N> result;

258maximum<T> scalar_op;

259

260CUTLASS_PRAGMA_UNROLL

261for (int i = 0; i < N; ++i) {

262 result[i] = scalar_op(lhs[i], rhs[i]);

263 }

264

265return result;

266 }

267

268CUTLASS_HOST_DEVICE

269 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {

270

271 Array<T, N> result;

272maximum<T> scalar_op;

273

274CUTLASS_PRAGMA_UNROLL

275for (int i = 0; i < N; ++i) {

276 result[i] = scalar_op(lhs[i], scalar);

277 }

278

279return result;

280 }

281

282CUTLASS_HOST_DEVICE

283 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {

284

285 Array<T, N> result;

286maximum<T> scalar_op;

287

288CUTLASS_PRAGMA_UNROLL

289for (int i = 0; i < N; ++i) {

290 result[i] = scalar_op(scalar, rhs[i]);

291 }

292

293return result;

294 }

295 };

296

297 template <typename T>

298 struct minimum {

299

300CUTLASS_HOST_DEVICE

301 T operator()(T const &lhs, T const &rhs) const {

302return (rhs < lhs ? rhs : lhs);

303 }

304 };

305

306 template <>

307 struct minimum<float> {

308CUTLASS_HOST_DEVICE

309float operator()(float const &lhs, float const &rhs) const {

310return fminf(lhs, rhs);

311 }

312 };

313

314 template <typename T, int N>

315 struct minimum<Array<T, N>> {

316

317CUTLASS_HOST_DEVICE

318static T scalar_op(T const &lhs, T const &rhs) {

319return (rhs < lhs ? rhs : lhs);

320 }

321

322CUTLASS_HOST_DEVICE

323 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {

324

325 Array<T, N> result;

326minimum<T> scalar_op;

327

328CUTLASS_PRAGMA_UNROLL

329for (int i = 0; i < N; ++i) {

330 result[i] = scalar_op(lhs[i], rhs[i]);

331 }

332

333return result;

334 }

335

336CUTLASS_HOST_DEVICE

337 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {

338

339 Array<T, N> result;

340minimum<T> scalar_op;

341

342CUTLASS_PRAGMA_UNROLL

343for (int i = 0; i < N; ++i) {

344 result[i] = scalar_op(lhs[i], scalar);

345 }

346

347return result;

348 }

349

350CUTLASS_HOST_DEVICE

351 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {

352

353 Array<T, N> result;

354minimum<T> scalar_op;

355

356CUTLASS_PRAGMA_UNROLL

357for (int i = 0; i < N; ++i) {

358 result[i] = scalar_op(scalar, rhs[i]);

359 }

360

361return result;

362 }

363 };

364

365 template <typename T, int N>

366 struct minus<Array<T, N>> {

367

368CUTLASS_HOST_DEVICE

369 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {

370

371 Array<T, N> result;

372minus<T> scalar_op;

373

374CUTLASS_PRAGMA_UNROLL

375for (int i = 0; i < N; ++i) {

376 result[i] = scalar_op(lhs[i], rhs[i]);

377 }

378

379return result;

380 }

381

382CUTLASS_HOST_DEVICE

383 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {

384

385 Array<T, N> result;

386minus<T> scalar_op;

387

388CUTLASS_PRAGMA_UNROLL

389for (int i = 0; i < N; ++i) {

390 result[i] = scalar_op(lhs[i], scalar);

391 }

392

393return result;

394 }

395

396CUTLASS_HOST_DEVICE

397 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {

398

399 Array<T, N> result;

400minus<T> scalar_op;

401

402CUTLASS_PRAGMA_UNROLL

403for (int i = 0; i < N; ++i) {

404 result[i] = scalar_op(scalar, rhs[i]);

405 }

406

407return result;

408 }

409 };

410

411 template <typename T, int N>

412 struct multiplies<Array<T, N>> {

413

414CUTLASS_HOST_DEVICE

415 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {

416

417 Array<T, N> result;

418multiplies<T> scalar_op;

419

420CUTLASS_PRAGMA_UNROLL

421for (int i = 0; i < N; ++i) {

422 result[i] = scalar_op(lhs[i], rhs[i]);

423 }

424

425return result;

426 }

427

428CUTLASS_HOST_DEVICE

429 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {

430

431 Array<T, N> result;

432multiplies<T> scalar_op;

433

434CUTLASS_PRAGMA_UNROLL

435for (int i = 0; i < N; ++i) {

436 result[i] = scalar_op(lhs[i], scalar);

437 }

438

439return result;

440 }

441

442CUTLASS_HOST_DEVICE

443 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {

444

445 Array<T, N> result;

446multiplies<T> scalar_op;

447

448CUTLASS_PRAGMA_UNROLL

449for (int i = 0; i < N; ++i) {

450 result[i] = scalar_op(scalar, rhs[i]);

451 }

452

453return result;

454 }

455 };

456

457 template <typename T, int N>

458 struct divides<Array<T, N>> {

459

460CUTLASS_HOST_DEVICE

461 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {

462

463 Array<T, N> result;

464divides<T> scalar_op;

465

466CUTLASS_PRAGMA_UNROLL

467for (int i = 0; i < N; ++i) {

468 result[i] = scalar_op(lhs[i], rhs[i]);

469 }

470

471return result;

472 }

473

474CUTLASS_HOST_DEVICE

475 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {

476

477 Array<T, N> result;

478divides<T> scalar_op;

479

480CUTLASS_PRAGMA_UNROLL

481for (int i = 0; i < N; ++i) {

482 result[i] = scalar_op(lhs[i], scalar);

483 }

484

485return result;

486 }

487

488CUTLASS_HOST_DEVICE

489 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {

490

491 Array<T, N> result;

492divides<T> scalar_op;

493

494CUTLASS_PRAGMA_UNROLL

495for (int i = 0; i < N; ++i) {

496 result[i] = scalar_op(scalar, rhs[i]);

497 }

498

499return result;

500 }

501 };

502

503

504 template <typename T, int N>

505 struct negate<Array<T, N>> {

506

507CUTLASS_HOST_DEVICE

508 Array<T, N> operator()(Array<T, N> const &lhs) const {

509

510 Array<T, N> result;

511negate<T> scalar_op;

512

513CUTLASS_PRAGMA_UNROLL

514for (int i = 0; i < N; ++i) {

515 result[i] = scalar_op(lhs[i]);

516 }

517

518return result;

519 }

520 };

521

523 template <typename T, int N>

524 struct multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> {

525

526CUTLASS_HOST_DEVICE

527 Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {

528

529 Array<T, N> result;

530multiply_add<T> scalar_op;

531

532CUTLASS_PRAGMA_UNROLL

533for (int i = 0; i < N; ++i) {

534 result[i] = scalar_op(a[i], b[i], c[i]);

535 }

536

537return result;

538 }

539

540CUTLASS_HOST_DEVICE

541 Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {

542

543 Array<T, N> result;

544multiply_add<T> scalar_op;

545

546CUTLASS_PRAGMA_UNROLL

547for (int i = 0; i < N; ++i) {

548 result[i] = scalar_op(a[i], scalar, c[i]);

549 }

550

551return result;

552 }

553

554CUTLASS_HOST_DEVICE

555 Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {

556

557 Array<T, N> result;

558multiply_add<T> scalar_op;

559

560CUTLASS_PRAGMA_UNROLL

561for (int i = 0; i < N; ++i) {

562 result[i] = scalar_op(scalar, b[i], c[i]);

563 }

564

565return result;

566 }

567 };

568

570 //

571 // Partial specializations for Array<half_t, N> targeting SIMD instructions in device code.

572 //

574

575 template <int N>

576 struct plus<Array<half_t, N>> {

577CUTLASS_HOST_DEVICE

578 Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {

579 Array<half_t, N> result;

580 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

581

582 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

583 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);

584 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);

585

586CUTLASS_PRAGMA_UNROLL

587for (int i = 0; i < N / 2; ++i) {

588 result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]);

589 }

590

591if (N % 2) {

592 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);

593 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);

594 __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);

595

596 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

597 }

598

599 #else

600

601CUTLASS_PRAGMA_UNROLL

602for (int i = 0; i < N; ++i) {

603 result[i] = lhs[i] + rhs[i];

604 }

605 #endif

606

607return result;

608 }

609

610CUTLASS_HOST_DEVICE

611 Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {

612 Array<half_t, N> result;

613 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

614

615 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

616 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));

617 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);

618

619CUTLASS_PRAGMA_UNROLL

620for (int i = 0; i < N / 2; ++i) {

621 result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]);

622 }

623

624if (N % 2) {

625 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);

626 __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);

627

628 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

629 }

630

631 #else

632

633CUTLASS_PRAGMA_UNROLL

634for (int i = 0; i < N; ++i) {

635 result[i] = lhs + rhs[i];

636 }

637 #endif

638

639return result;

640 }

641

642CUTLASS_HOST_DEVICE

643 Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {

644 Array<half_t, N> result;

645 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

646

647 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

648 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);

649 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));

650

651CUTLASS_PRAGMA_UNROLL

652for (int i = 0; i < N / 2; ++i) {

653 result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair);

654 }

655

656if (N % 2) {

657 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);

658 __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));

659

660 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

661 }

662

663 #else

664

665CUTLASS_PRAGMA_UNROLL

666for (int i = 0; i < N; ++i) {

667 result[i] = lhs[i] + rhs;

668 }

669 #endif

670

671return result;

672 }

673 };

674

675 template <int N>

676 struct minus<Array<half_t, N>> {

677CUTLASS_HOST_DEVICE

678 Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {

679 Array<half_t, N> result;

680 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

681

682 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

683 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);

684 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);

685

686CUTLASS_PRAGMA_UNROLL

687for (int i = 0; i < N / 2; ++i) {

688 result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]);

689 }

690

691if (N % 2) {

692 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);

693 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);

694 __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);

695

696 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

697 }

698

699 #else

700

701CUTLASS_PRAGMA_UNROLL

702for (int i = 0; i < N; ++i) {

703 result[i] = lhs[i] - rhs[i];

704 }

705 #endif

706

707return result;

708 }

709

710CUTLASS_HOST_DEVICE

711 Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {

712 Array<half_t, N> result;

713 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

714

715 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

716 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));

717 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);

718

719CUTLASS_PRAGMA_UNROLL

720for (int i = 0; i < N / 2; ++i) {

721 result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]);

722 }

723

724if (N % 2) {

725 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);

726 __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);

727

728 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

729 }

730

731 #else

732

733CUTLASS_PRAGMA_UNROLL

734for (int i = 0; i < N; ++i) {

735 result[i] = lhs - rhs[i];

736 }

737 #endif

738

739return result;

740 }

741

742CUTLASS_HOST_DEVICE

743 Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {

744 Array<half_t, N> result;

745 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

746

747 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

748 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);

749 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));

750

751CUTLASS_PRAGMA_UNROLL

752for (int i = 0; i < N / 2; ++i) {

753 result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair);

754 }

755

756if (N % 2) {

757 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);

758 __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));

759

760 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

761 }

762

763 #else

764

765CUTLASS_PRAGMA_UNROLL

766for (int i = 0; i < N; ++i) {

767 result[i] = lhs[i] - rhs;

768 }

769 #endif

770

771return result;

772 }

773 };

774

775 template <int N>

776 struct multiplies<Array<half_t, N>> {

777CUTLASS_HOST_DEVICE

778 Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {

779 Array<half_t, N> result;

780 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

781

782 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

783 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);

784 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);

785

786CUTLASS_PRAGMA_UNROLL

787for (int i = 0; i < N / 2; ++i) {

788 result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]);

789 }

790

791if (N % 2) {

792 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);

793 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);

794 __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);

795

796 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

797 }

798

799 #else

800

801CUTLASS_PRAGMA_UNROLL

802for (int i = 0; i < N; ++i) {

803 result[i] = lhs[i] * rhs[i];

804 }

805 #endif

806

807return result;

808 }

809

810CUTLASS_HOST_DEVICE

811 Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {

812 Array<half_t, N> result;

813 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

814

815 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

816 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));

817 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);

818

819CUTLASS_PRAGMA_UNROLL

820for (int i = 0; i < N / 2; ++i) {

821 result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]);

822 }

823

824if (N % 2) {

825 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);

826

827 __half d_residual = __hmul(

828 reinterpret_cast<__half const &>(lhs),

829 b_residual_ptr[N - 1]);

830

831 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

832 }

833

834 #else

835

836CUTLASS_PRAGMA_UNROLL

837for (int i = 0; i < N; ++i) {

838 result[i] = lhs * rhs[i];

839 }

840 #endif

841

842return result;

843 }

844

845CUTLASS_HOST_DEVICE

846 Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {

847 Array<half_t, N> result;

848 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

849

850 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

851 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);

852 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));

853

854CUTLASS_PRAGMA_UNROLL

855for (int i = 0; i < N / 2; ++i) {

856 result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair);

857 }

858

859if (N % 2) {

860 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);

861

862 __half d_residual = __hmul(

863 a_residual_ptr[N - 1],

864 reinterpret_cast<__half const &>(rhs));

865

866 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

867 }

868

869 #else

870

871CUTLASS_PRAGMA_UNROLL

872for (int i = 0; i < N; ++i) {

873 result[i] = lhs[i] * rhs;

874 }

875 #endif

876

877return result;

878 }

879 };

880

881 template <int N>

882 struct divides<Array<half_t, N>> {

883CUTLASS_HOST_DEVICE

884 Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {

885 Array<half_t, N> result;

886 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

887

888 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

889 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);

890 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);

891

892CUTLASS_PRAGMA_UNROLL

893for (int i = 0; i < N / 2; ++i) {

894 result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]);

895 }

896

897if (N % 2) {

898 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);

899 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);

900

901 __half d_residual = __hdiv(

902 a_residual_ptr[N - 1],

903 b_residual_ptr[N - 1]);

904

905 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

906 }

907

908 #else

909

910CUTLASS_PRAGMA_UNROLL

911for (int i = 0; i < N; ++i) {

912 result[i] = lhs[i] / rhs[i];

913 }

914 #endif

915

916return result;

917 }

918

919CUTLASS_HOST_DEVICE

920 Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {

921 Array<half_t, N> result;

922 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

923

924 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

925 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));

926 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);

927

928CUTLASS_PRAGMA_UNROLL

929for (int i = 0; i < N / 2; ++i) {

930 result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]);

931 }

932

933if (N % 2) {

934 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);

935

936 __half d_residual = __hdiv(

937 reinterpret_cast<__half const &>(lhs),

938 b_residual_ptr[N - 1]);

939

940 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

941 }

942

943 #else

944

945CUTLASS_PRAGMA_UNROLL

946for (int i = 0; i < N; ++i) {

947 result[i] = lhs / rhs[i];

948 }

949 #endif

950

951return result;

952 }

953

954CUTLASS_HOST_DEVICE

955 Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {

956 Array<half_t, N> result;

957 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

958

959 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

960 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);

961 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));

962

963CUTLASS_PRAGMA_UNROLL

964for (int i = 0; i < N / 2; ++i) {

965 result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair);

966 }

967

968if (N % 2) {

969 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);

970

971 __half d_residual = __hdiv(

972 a_residual_ptr[N - 1],

973 reinterpret_cast<__half const &>(rhs));

974

975 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

976 }

977

978 #else

979

980CUTLASS_PRAGMA_UNROLL

981for (int i = 0; i < N; ++i) {

982 result[i] = lhs[i] / rhs;

983 }

984 #endif

985

986return result;

987 }

988 };

989

990 template <int N>

991 struct negate<Array<half_t, N>> {

992CUTLASS_HOST_DEVICE

993 Array<half_t, N> operator()(Array<half_t, N> const & lhs) const {

994 Array<half_t, N> result;

995 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

996

997 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

998 __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs);

999

1000CUTLASS_PRAGMA_UNROLL

1001for (int i = 0; i < N / 2; ++i) {

1002 result_ptr[i] = __hneg2(source_ptr[i]);

1003 }

1004

1005if (N % 2) {

1006half_t x = lhs[N - 1];

1007 __half lhs_val = -reinterpret_cast<__half const &>(x);

1008 result[N - 1] = reinterpret_cast<half_t const &>(lhs_val);

1009 }

1010

1011 #else

1012

1013CUTLASS_PRAGMA_UNROLL

1014for (int i = 0; i < N; ++i) {

1015 result[i] = -lhs[i];

1016 }

1017 #endif

1018

1019return result;

1020 }

1021 };

1022

1024 template <int N>

[1025](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html) struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {

1026

1027CUTLASS_HOST_DEVICE

[1028](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#abc9dd51cad4f2997dae521fad8f5b486) Array<half_t, N> [operator()](structcutlass_1_1multiply__add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#abc9dd51cad4f2997dae521fad8f5b486)(

1029 Array<half_t, N> const &a,

1030 Array<half_t, N> const &b,

1031 Array<half_t, N> const &c) const {

1032

1033 Array<half_t, N> result;

1034 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

1035

1036 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

1037 __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);

1038 __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);

1039 __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);

1040

1041CUTLASS_PRAGMA_UNROLL

1042for (int i = 0; i < N / 2; ++i) {

1043 result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]);

1044 }

1045

1046if (N % 2) {

1047

1048 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);

1049 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);

1050 __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);

1051

1052 __half d_residual = __hfma(

1053 a_residual_ptr[N - 1],

1054 b_residual_ptr[N - 1],

1055 c_residual_ptr[N - 1]);

1056

1057 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

1058 }

1059

1060 #else

1061

1062multiply_add<half_t> op;

1063

1064CUTLASS_PRAGMA_UNROLL

1065for (int i = 0; i < N; ++i) {

1066 result[i] = op(a[i], b[i], c[i]);

1067 }

1068 #endif

1069

1070return result;

1071 }

1072

1073CUTLASS_HOST_DEVICE

[1074](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a7659a559754edb4949d54cc641a5bd01) Array<half_t, N> [operator()](structcutlass_1_1multiply__add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a7659a559754edb4949d54cc641a5bd01)(

1075half_t const &a,

1076 Array<half_t, N> const &b,

1077 Array<half_t, N> const &c) const {

1078

1079 Array<half_t, N> result;

1080 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

1081

1082 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

1083 __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));

1084 __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);

1085 __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);

1086

1087CUTLASS_PRAGMA_UNROLL

1088for (int i = 0; i < N / 2; ++i) {

1089 result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]);

1090 }

1091

1092if (N % 2) {

1093

1094 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);

1095 __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);

1096 __half d_residual = __hfma(

1097 reinterpret_cast<__half const &>(a),

1098 b_residual_ptr[N - 1],

1099 c_residual_ptr[N - 1]);

1100

1101 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

1102 }

1103

1104 #else

1105

1106multiply_add<half_t> op;

1107

1108CUTLASS_PRAGMA_UNROLL

1109for (int i = 0; i < N; ++i) {

1110 result[i] = op(a, b[i], c[i]);

1111 }

1112 #endif

1113

1114return result;

1115 }

1116

1117CUTLASS_HOST_DEVICE

[1118](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#aed1e0930836e1da4c53fcdfb683847a1) Array<half_t, N> [operator()](structcutlass_1_1multiply__add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#aed1e0930836e1da4c53fcdfb683847a1)(

1119 Array<half_t, N> const &a,

1120half_t const &b,

1121 Array<half_t, N> const &c) const {

1122

1123 Array<half_t, N> result;

1124 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

1125

1126 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

1127 __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);

1128 __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));

1129 __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);

1130

1131CUTLASS_PRAGMA_UNROLL

1132for (int i = 0; i < N / 2; ++i) {

1133 result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]);

1134 }

1135

1136if (N % 2) {

1137

1138 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);

1139 __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);

1140

1141 __half d_residual = __hfma(

1142 a_residual_ptr[N - 1],

1143 reinterpret_cast<__half const &>(b),

1144 c_residual_ptr[N - 1]);

1145

1146 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

1147 }

1148

1149 #else

1150

1151multiply_add<half_t> op;

1152

1153CUTLASS_PRAGMA_UNROLL

1154for (int i = 0; i < N; ++i) {

1155 result[i] = op(a[i], b, c[i]);

1156 }

1157 #endif

1158

1159return result;

1160 }

1161

1162CUTLASS_HOST_DEVICE

[1163](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a6b4fe9d3366d389034a53f5ba71bdaee) Array<half_t, N> [operator()](structcutlass_1_1multiply__add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a6b4fe9d3366d389034a53f5ba71bdaee)(

1164 Array<half_t, N> const &a,

1165 Array<half_t, N> const &b,

1166half_t const &c) const {

1167

1168 Array<half_t, N> result;

1169 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)

1170

1171 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);

1172 __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);

1173 __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);

1174 __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));

1175

1176CUTLASS_PRAGMA_UNROLL

1177for (int i = 0; i < N / 2; ++i) {

1178 result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair);

1179 }

1180

1181if (N % 2) {

1182

1183 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);

1184 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);

1185

1186 __half d_residual = __hfma(

1187 a_residual_ptr[N - 1],

1188 b_residual_ptr[N - 1],

1189 reinterpret_cast<__half const &>(c));

1190

1191 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);

1192 }

1193

1194 #else

1195

1196multiply_add<half_t> op;

1197

1198CUTLASS_PRAGMA_UNROLL

1199for (int i = 0; i < N; ++i) {

1200 result[i] = op(a[i], b[i], c);

1201 }

1202 #endif

1203

1204return result;

1205 }

1206 };

1207

1209

1210 } // namespace cutlass

cutlass::multiply_add

Fused multiply-add.

Definition: functional.h:92

cutlass::minimum< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const

Definition: functional.h:351

cutlass::plus< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const

Definition: functional.h:578

cutlass

Definition: aligned_buffer.h:35

cutlass::xor_add::operator()

CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b, T const &c) const

Definition: functional.h:103

cutlass::divides< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const

Definition: functional.h:955

complex.h

cutlass::minus< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const

Definition: functional.h:383

cutlass::imag

CUTLASS_HOST_DEVICE float const & imag(cuFloatComplex const &z)

Returns the imaginary part of the complex number.

Definition: complex.h:72

cutlass::minus< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const

Definition: functional.h:678

cutlass::plus::operator()

CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const

Definition: functional.h:48

cutlass::maximum< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const

Definition: functional.h:269

cutlass::minimum< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const

Definition: functional.h:323

cutlass::plus< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const

Definition: functional.h:611

half.h

Defines a class for using IEEE half-precision floating-point types in host or device code...

cutlass::divides< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const

Definition: functional.h:920

cutlass::multiplies< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const

Definition: functional.h:778

cutlass::plus< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const

Definition: functional.h:191

cutlass::minimum

Definition: functional.h:298

cutlass::maximum< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const

Definition: functional.h:255

cutlass::maximum

Definition: functional.h:235

cutlass::half_t

IEEE half-precision floating-point type.

Definition: half.h:126

cutlass::negate::operator()

CUTLASS_HOST_DEVICE T operator()(T lhs) const

Definition: functional.h:85

cutlass::multiplies< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const

Definition: functional.h:811

cutlass::plus< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const

Definition: functional.h:205

cutlass::real

CUTLASS_HOST_DEVICE float const & real(cuFloatComplex const &z)

Returns the real part of the complex number.

Definition: complex.h:56

cutlass::minus< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const

Definition: functional.h:743

cutlass::minimum< float >::operator()

CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const

Definition: functional.h:309

cutlass::minimum< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const

Definition: functional.h:337

cutlass::complex::imag

CUTLASS_HOST_DEVICE T const & imag() const

Accesses the imaginary part of the complex number.

Definition: complex.h:240

cutlass::multiply_add::operator()

CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const

Definition: functional.h:94

[cutlass::multiply_add< Array< half_t, N >, Array< half_t, N >, Array< half_t, N > >::operator()](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a6b4fe9d3366d389034a53f5ba71bdaee)

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, Array< half_t, N > const &b, half_t const &c) const

Definition: functional.h:1163

cutlass::plus

Definition: functional.h:46

cutlass::minus< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const

Definition: functional.h:369

cutlass::maximum::operator()

CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const

Definition: functional.h:238

cutlass::multiplies< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const

Definition: functional.h:415

cutlass::minus< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const

Definition: functional.h:397

cutlass::multiplies::operator()

CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const

Definition: functional.h:66

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

cutlass::maximum< float >::operator()

CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const

Definition: functional.h:246

cutlass::negate

Definition: functional.h:83

cutlass::divides::operator()

CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const

Definition: functional.h:75

cutlass::multiply_add< Array< T, N >, Array< T, N >, Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &a, T const &scalar, Array< T, N > const &c) const

Definition: functional.h:541

cutlass::multiply_add< Array< T, N >, Array< T, N >, Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c) const

Definition: functional.h:527

cutlass::multiplies< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const

Definition: functional.h:443

cutlass::multiplies

Definition: functional.h:64

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::multiplies< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const

Definition: functional.h:429

cutlass::divides< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const

Definition: functional.h:489

cutlass::divides< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const

Definition: functional.h:884

cutlass::divides

Definition: functional.h:73

cutlass::complex::real

CUTLASS_HOST_DEVICE T const & real() const

Accesses the real part of the complex number.

Definition: complex.h:232

cutlass::plus< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const

Definition: functional.h:643

cutlass::complex

Definition: complex.h:92

cutlass::minimum::operator()

CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const

Definition: functional.h:301

cutlass::minimum< Array< T, N > >::scalar_op

static CUTLASS_HOST_DEVICE T scalar_op(T const &lhs, T const &rhs)

Definition: functional.h:318

cutlass::multiply_add< T, complex< T >, complex< T > >::operator()

CUTLASS_HOST_DEVICE complex< T > operator()(T const &a, complex< T > const &b, complex< T > const &c) const

Definition: functional.h:164

[cutlass::multiply_add< Array< half_t, N >, Array< half_t, N >, Array< half_t, N > >::operator()](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a7659a559754edb4949d54cc641a5bd01)

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &a, Array< half_t, N > const &b, Array< half_t, N > const &c) const

Definition: functional.h:1074

cutlass::multiply_add< complex< T >, T, complex< T > >::operator()

CUTLASS_HOST_DEVICE complex< T > operator()(complex< T > const &a, T const &b, complex< T > const &c) const

Definition: functional.h:142

cutlass::negate< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs) const

Definition: functional.h:508

cutlass::xor_add

Fused multiply-add.

Definition: functional.h:101

cutlass::maximum< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const

Definition: functional.h:283

cutlass::divides< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const

Definition: functional.h:475

cutlass::minus::operator()

CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const

Definition: functional.h:57

[cutlass::multiply_add< Array< half_t, N >, Array< half_t, N >, Array< half_t, N > >::operator()](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#aed1e0930836e1da4c53fcdfb683847a1)

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, half_t const &b, Array< half_t, N > const &c) const

Definition: functional.h:1118

cutlass::plus< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const

Definition: functional.h:219

cutlass::minus< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const

Definition: functional.h:711

cutlass::multiply_add< Array< T, N >, Array< T, N >, Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &b, Array< T, N > const &c) const

Definition: functional.h:555

cutlass::divides< Array< T, N > >::operator()

CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const

Definition: functional.h:461

cutlass::negate< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs) const

Definition: functional.h:993

cutlass::minus

Definition: functional.h:55

cutlass::multiply_add< complex< T >, complex< T >, complex< T > >::operator()

CUTLASS_HOST_DEVICE complex< T > operator()(complex< T > const &a, complex< T > const &b, complex< T > const &c) const

Definition: functional.h:118

cutlass::multiplies< Array< half_t, N > >::operator()

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const

Definition: functional.h:846

cutlass.h

Basic include for CUTLASS.

[cutlass::multiply_add< Array< half_t, N >, Array< half_t, N >, Array< half_t, N > >::operator()](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#abc9dd51cad4f2997dae521fad8f5b486)

CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, Array< half_t, N > const &b, Array< half_t, N > const &c) const

Definition: functional.h:1028

<!-- fragment --> <!-- contents --><!-- start footer part -->
<address class="footer"><small> Generated by 1.8.11 </small></address>