docs/functional_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
functional.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
31 #pragma once
32
33 #include "cutlass/cutlass.h"
34 #include "cutlass/numeric_types.h"
35
36 #include "cutlass/complex.h"
37
38 #include "cutlass/array.h"
39 #include "cutlass/half.h"
40
41 namespace cutlass {
42
44
45 template <typename T>
48 T operator()(T lhs, T const &rhs) const {
49 lhs += rhs;
50return lhs;
51 }
52 };
53
54 template <typename T>
57 T operator()(T lhs, T const &rhs) const {
58 lhs -= rhs;
59return lhs;
60 }
61 };
62
63 template <typename T>
64 struct multiplies {
66 T operator()(T lhs, T const &rhs) const {
67 lhs *= rhs;
68return lhs;
69 }
70 };
71
72 template <typename T>
75 T operator()(T lhs, T const &rhs) const {
76 lhs /= rhs;
77return lhs;
78 }
79 };
80
81
82 template <typename T>
85 T operator()(T lhs) const {
86return -lhs;
87 }
88 };
89
91 template <typename A, typename B = A, typename C = A>
92 struct multiply_add {
94 C operator()(A const &a, B const &b, C const &c) const {
95return C(a) * C(b) + c;
96 }
97 };
98
100 template <typename T>
103 T operator()(T const &a, T const &b, T const &c) const {
104return ((a ^ b) + c);
105 }
106 };
107
109 //
110 // Partial specialization for complex<T> to target four scalar fused multiply-adds.
111 //
113
115 template <typename T>
116 struct multiply_add<complex<T>, complex<T>, complex<T>> {
119complex<T> const &a,
120complex<T> const &b,
121complex<T> const &c) const {
122
125
126 real += a.real() * b.real();
127 real += -a.imag() * b.imag();
128 imag += a.real() * b.imag();
129 imag += a.imag () * b.real();
130
131return complex<T>{
132real,
133 imag
134 };
135 }
136 };
137
139 template <typename T>
140 struct multiply_add<complex<T>, T, complex<T>> {
143complex<T> const &a,
144 T const &b,
145complex<T> const &c) const {
146
149
150 real += a.real() * b;
151 imag += a.imag () * b;
152
153return complex<T>{
154real,
155 imag
156 };
157 }
158 };
159
161 template <typename T>
162 struct multiply_add<T, complex<T>, complex<T>> {
165 T const &a,
166complex<T> const &b,
167complex<T> const &c) const {
168
171
172 real += a * b.real();
173 imag += a * b.imag();
174
175return complex<T>{
176real,
177 imag
178 };
179 }
180 };
181
183 //
184 // Partial specializations for Array<T, N>
185 //
187
188 template <typename T, int N>
189 struct plus<Array<T, N>> {
191 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
192
193 Array<T, N> result;
194plus<T> scalar_op;
195
197for (int i = 0; i < N; ++i) {
198 result[i] = scalar_op(lhs[i], rhs[i]);
199 }
200
201return result;
202 }
203
205 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
206
207 Array<T, N> result;
208plus<T> scalar_op;
209
211for (int i = 0; i < N; ++i) {
212 result[i] = scalar_op(lhs[i], scalar);
213 }
214
215return result;
216 }
217
219 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
220
221 Array<T, N> result;
222plus<T> scalar_op;
223
225for (int i = 0; i < N; ++i) {
226 result[i] = scalar_op(scalar, rhs[i]);
227 }
228
229return result;
230 }
231 };
232
233
234 template <typename T>
236
238 T operator()(T const &lhs, T const &rhs) const {
239return (lhs < rhs ? rhs : lhs);
240 }
241 };
242
243 template <>
246float operator()(float const &lhs, float const &rhs) const {
247return fmaxf(lhs, rhs);
248 }
249 };
250
251 template <typename T, int N>
252 struct maximum<Array<T, N>> {
253
255 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
256
257 Array<T, N> result;
258maximum<T> scalar_op;
259
261for (int i = 0; i < N; ++i) {
262 result[i] = scalar_op(lhs[i], rhs[i]);
263 }
264
265return result;
266 }
267
269 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
270
271 Array<T, N> result;
272maximum<T> scalar_op;
273
275for (int i = 0; i < N; ++i) {
276 result[i] = scalar_op(lhs[i], scalar);
277 }
278
279return result;
280 }
281
283 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
284
285 Array<T, N> result;
286maximum<T> scalar_op;
287
289for (int i = 0; i < N; ++i) {
290 result[i] = scalar_op(scalar, rhs[i]);
291 }
292
293return result;
294 }
295 };
296
297 template <typename T>
299
301 T operator()(T const &lhs, T const &rhs) const {
302return (rhs < lhs ? rhs : lhs);
303 }
304 };
305
306 template <>
309float operator()(float const &lhs, float const &rhs) const {
310return fminf(lhs, rhs);
311 }
312 };
313
314 template <typename T, int N>
315 struct minimum<Array<T, N>> {
316
318static T scalar_op(T const &lhs, T const &rhs) {
319return (rhs < lhs ? rhs : lhs);
320 }
321
323 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
324
325 Array<T, N> result;
326minimum<T> scalar_op;
327
329for (int i = 0; i < N; ++i) {
330 result[i] = scalar_op(lhs[i], rhs[i]);
331 }
332
333return result;
334 }
335
337 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
338
339 Array<T, N> result;
340minimum<T> scalar_op;
341
343for (int i = 0; i < N; ++i) {
344 result[i] = scalar_op(lhs[i], scalar);
345 }
346
347return result;
348 }
349
351 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
352
353 Array<T, N> result;
354minimum<T> scalar_op;
355
357for (int i = 0; i < N; ++i) {
358 result[i] = scalar_op(scalar, rhs[i]);
359 }
360
361return result;
362 }
363 };
364
365 template <typename T, int N>
366 struct minus<Array<T, N>> {
367
369 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
370
371 Array<T, N> result;
372minus<T> scalar_op;
373
375for (int i = 0; i < N; ++i) {
376 result[i] = scalar_op(lhs[i], rhs[i]);
377 }
378
379return result;
380 }
381
383 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
384
385 Array<T, N> result;
386minus<T> scalar_op;
387
389for (int i = 0; i < N; ++i) {
390 result[i] = scalar_op(lhs[i], scalar);
391 }
392
393return result;
394 }
395
397 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
398
399 Array<T, N> result;
400minus<T> scalar_op;
401
403for (int i = 0; i < N; ++i) {
404 result[i] = scalar_op(scalar, rhs[i]);
405 }
406
407return result;
408 }
409 };
410
411 template <typename T, int N>
412 struct multiplies<Array<T, N>> {
413
415 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
416
417 Array<T, N> result;
418multiplies<T> scalar_op;
419
421for (int i = 0; i < N; ++i) {
422 result[i] = scalar_op(lhs[i], rhs[i]);
423 }
424
425return result;
426 }
427
429 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
430
431 Array<T, N> result;
432multiplies<T> scalar_op;
433
435for (int i = 0; i < N; ++i) {
436 result[i] = scalar_op(lhs[i], scalar);
437 }
438
439return result;
440 }
441
443 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
444
445 Array<T, N> result;
446multiplies<T> scalar_op;
447
449for (int i = 0; i < N; ++i) {
450 result[i] = scalar_op(scalar, rhs[i]);
451 }
452
453return result;
454 }
455 };
456
457 template <typename T, int N>
458 struct divides<Array<T, N>> {
459
461 Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
462
463 Array<T, N> result;
464divides<T> scalar_op;
465
467for (int i = 0; i < N; ++i) {
468 result[i] = scalar_op(lhs[i], rhs[i]);
469 }
470
471return result;
472 }
473
475 Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
476
477 Array<T, N> result;
478divides<T> scalar_op;
479
481for (int i = 0; i < N; ++i) {
482 result[i] = scalar_op(lhs[i], scalar);
483 }
484
485return result;
486 }
487
489 Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
490
491 Array<T, N> result;
492divides<T> scalar_op;
493
495for (int i = 0; i < N; ++i) {
496 result[i] = scalar_op(scalar, rhs[i]);
497 }
498
499return result;
500 }
501 };
502
503
504 template <typename T, int N>
505 struct negate<Array<T, N>> {
506
508 Array<T, N> operator()(Array<T, N> const &lhs) const {
509
510 Array<T, N> result;
511negate<T> scalar_op;
512
514for (int i = 0; i < N; ++i) {
515 result[i] = scalar_op(lhs[i]);
516 }
517
518return result;
519 }
520 };
521
523 template <typename T, int N>
524 struct multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> {
525
527 Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
528
529 Array<T, N> result;
530multiply_add<T> scalar_op;
531
533for (int i = 0; i < N; ++i) {
534 result[i] = scalar_op(a[i], b[i], c[i]);
535 }
536
537return result;
538 }
539
541 Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
542
543 Array<T, N> result;
544multiply_add<T> scalar_op;
545
547for (int i = 0; i < N; ++i) {
548 result[i] = scalar_op(a[i], scalar, c[i]);
549 }
550
551return result;
552 }
553
555 Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
556
557 Array<T, N> result;
558multiply_add<T> scalar_op;
559
561for (int i = 0; i < N; ++i) {
562 result[i] = scalar_op(scalar, b[i], c[i]);
563 }
564
565return result;
566 }
567 };
568
570 //
571 // Partial specializations for Array<half_t, N> targeting SIMD instructions in device code.
572 //
574
575 template <int N>
576 struct plus<Array<half_t, N>> {
578 Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
579 Array<half_t, N> result;
580 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
581
582 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
583 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
584 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
585
587for (int i = 0; i < N / 2; ++i) {
588 result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]);
589 }
590
591if (N % 2) {
592 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
593 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
594 __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
595
596 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
597 }
598
599 #else
600
602for (int i = 0; i < N; ++i) {
603 result[i] = lhs[i] + rhs[i];
604 }
605 #endif
606
607return result;
608 }
609
611 Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
612 Array<half_t, N> result;
613 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
614
615 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
616 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
617 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
618
620for (int i = 0; i < N / 2; ++i) {
621 result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]);
622 }
623
624if (N % 2) {
625 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
626 __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
627
628 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
629 }
630
631 #else
632
634for (int i = 0; i < N; ++i) {
635 result[i] = lhs + rhs[i];
636 }
637 #endif
638
639return result;
640 }
641
643 Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
644 Array<half_t, N> result;
645 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
646
647 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
648 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
649 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
650
652for (int i = 0; i < N / 2; ++i) {
653 result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair);
654 }
655
656if (N % 2) {
657 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
658 __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
659
660 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
661 }
662
663 #else
664
666for (int i = 0; i < N; ++i) {
667 result[i] = lhs[i] + rhs;
668 }
669 #endif
670
671return result;
672 }
673 };
674
675 template <int N>
676 struct minus<Array<half_t, N>> {
678 Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
679 Array<half_t, N> result;
680 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
681
682 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
683 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
684 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
685
687for (int i = 0; i < N / 2; ++i) {
688 result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]);
689 }
690
691if (N % 2) {
692 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
693 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
694 __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
695
696 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
697 }
698
699 #else
700
702for (int i = 0; i < N; ++i) {
703 result[i] = lhs[i] - rhs[i];
704 }
705 #endif
706
707return result;
708 }
709
711 Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
712 Array<half_t, N> result;
713 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
714
715 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
716 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
717 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
718
720for (int i = 0; i < N / 2; ++i) {
721 result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]);
722 }
723
724if (N % 2) {
725 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
726 __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
727
728 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
729 }
730
731 #else
732
734for (int i = 0; i < N; ++i) {
735 result[i] = lhs - rhs[i];
736 }
737 #endif
738
739return result;
740 }
741
743 Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
744 Array<half_t, N> result;
745 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
746
747 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
748 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
749 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
750
752for (int i = 0; i < N / 2; ++i) {
753 result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair);
754 }
755
756if (N % 2) {
757 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
758 __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
759
760 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
761 }
762
763 #else
764
766for (int i = 0; i < N; ++i) {
767 result[i] = lhs[i] - rhs;
768 }
769 #endif
770
771return result;
772 }
773 };
774
775 template <int N>
776 struct multiplies<Array<half_t, N>> {
778 Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
779 Array<half_t, N> result;
780 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
781
782 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
783 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
784 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
785
787for (int i = 0; i < N / 2; ++i) {
788 result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]);
789 }
790
791if (N % 2) {
792 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
793 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
794 __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
795
796 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
797 }
798
799 #else
800
802for (int i = 0; i < N; ++i) {
803 result[i] = lhs[i] * rhs[i];
804 }
805 #endif
806
807return result;
808 }
809
811 Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
812 Array<half_t, N> result;
813 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
814
815 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
816 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
817 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
818
820for (int i = 0; i < N / 2; ++i) {
821 result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]);
822 }
823
824if (N % 2) {
825 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
826
827 __half d_residual = __hmul(
828 reinterpret_cast<__half const &>(lhs),
829 b_residual_ptr[N - 1]);
830
831 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
832 }
833
834 #else
835
837for (int i = 0; i < N; ++i) {
838 result[i] = lhs * rhs[i];
839 }
840 #endif
841
842return result;
843 }
844
846 Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
847 Array<half_t, N> result;
848 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
849
850 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
851 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
852 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
853
855for (int i = 0; i < N / 2; ++i) {
856 result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair);
857 }
858
859if (N % 2) {
860 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
861
862 __half d_residual = __hmul(
863 a_residual_ptr[N - 1],
864 reinterpret_cast<__half const &>(rhs));
865
866 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
867 }
868
869 #else
870
872for (int i = 0; i < N; ++i) {
873 result[i] = lhs[i] * rhs;
874 }
875 #endif
876
877return result;
878 }
879 };
880
881 template <int N>
882 struct divides<Array<half_t, N>> {
884 Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
885 Array<half_t, N> result;
886 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
887
888 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
889 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
890 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
891
893for (int i = 0; i < N / 2; ++i) {
894 result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]);
895 }
896
897if (N % 2) {
898 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
899 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
900
901 __half d_residual = __hdiv(
902 a_residual_ptr[N - 1],
903 b_residual_ptr[N - 1]);
904
905 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
906 }
907
908 #else
909
911for (int i = 0; i < N; ++i) {
912 result[i] = lhs[i] / rhs[i];
913 }
914 #endif
915
916return result;
917 }
918
920 Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
921 Array<half_t, N> result;
922 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
923
924 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
925 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
926 __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
927
929for (int i = 0; i < N / 2; ++i) {
930 result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]);
931 }
932
933if (N % 2) {
934 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
935
936 __half d_residual = __hdiv(
937 reinterpret_cast<__half const &>(lhs),
938 b_residual_ptr[N - 1]);
939
940 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
941 }
942
943 #else
944
946for (int i = 0; i < N; ++i) {
947 result[i] = lhs / rhs[i];
948 }
949 #endif
950
951return result;
952 }
953
955 Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
956 Array<half_t, N> result;
957 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
958
959 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
960 __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
961 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
962
964for (int i = 0; i < N / 2; ++i) {
965 result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair);
966 }
967
968if (N % 2) {
969 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
970
971 __half d_residual = __hdiv(
972 a_residual_ptr[N - 1],
973 reinterpret_cast<__half const &>(rhs));
974
975 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
976 }
977
978 #else
979
981for (int i = 0; i < N; ++i) {
982 result[i] = lhs[i] / rhs;
983 }
984 #endif
985
986return result;
987 }
988 };
989
990 template <int N>
991 struct negate<Array<half_t, N>> {
993 Array<half_t, N> operator()(Array<half_t, N> const & lhs) const {
994 Array<half_t, N> result;
995 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
996
997 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
998 __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs);
999
1001for (int i = 0; i < N / 2; ++i) {
1002 result_ptr[i] = __hneg2(source_ptr[i]);
1003 }
1004
1005if (N % 2) {
1006half_t x = lhs[N - 1];
1007 __half lhs_val = -reinterpret_cast<__half const &>(x);
1008 result[N - 1] = reinterpret_cast<half_t const &>(lhs_val);
1009 }
1010
1011 #else
1012
1014for (int i = 0; i < N; ++i) {
1015 result[i] = -lhs[i];
1016 }
1017 #endif
1018
1019return result;
1020 }
1021 };
1022
1024 template <int N>
[1025](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html) struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
1026
[1028](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#abc9dd51cad4f2997dae521fad8f5b486) Array<half_t, N> [operator()](structcutlass_1_1multiply__add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#abc9dd51cad4f2997dae521fad8f5b486)(
1029 Array<half_t, N> const &a,
1030 Array<half_t, N> const &b,
1031 Array<half_t, N> const &c) const {
1032
1033 Array<half_t, N> result;
1034 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1035
1036 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1037 __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1038 __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1039 __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1040
1042for (int i = 0; i < N / 2; ++i) {
1043 result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]);
1044 }
1045
1046if (N % 2) {
1047
1048 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1049 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1050 __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1051
1052 __half d_residual = __hfma(
1053 a_residual_ptr[N - 1],
1054 b_residual_ptr[N - 1],
1055 c_residual_ptr[N - 1]);
1056
1057 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1058 }
1059
1060 #else
1061
1062multiply_add<half_t> op;
1063
1065for (int i = 0; i < N; ++i) {
1066 result[i] = op(a[i], b[i], c[i]);
1067 }
1068 #endif
1069
1070return result;
1071 }
1072
[1074](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a7659a559754edb4949d54cc641a5bd01) Array<half_t, N> [operator()](structcutlass_1_1multiply__add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a7659a559754edb4949d54cc641a5bd01)(
1075half_t const &a,
1076 Array<half_t, N> const &b,
1077 Array<half_t, N> const &c) const {
1078
1079 Array<half_t, N> result;
1080 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1081
1082 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1083 __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
1084 __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1085 __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1086
1088for (int i = 0; i < N / 2; ++i) {
1089 result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]);
1090 }
1091
1092if (N % 2) {
1093
1094 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1095 __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1096 __half d_residual = __hfma(
1097 reinterpret_cast<__half const &>(a),
1098 b_residual_ptr[N - 1],
1099 c_residual_ptr[N - 1]);
1100
1101 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1102 }
1103
1104 #else
1105
1106multiply_add<half_t> op;
1107
1109for (int i = 0; i < N; ++i) {
1110 result[i] = op(a, b[i], c[i]);
1111 }
1112 #endif
1113
1114return result;
1115 }
1116
[1118](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#aed1e0930836e1da4c53fcdfb683847a1) Array<half_t, N> [operator()](structcutlass_1_1multiply__add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#aed1e0930836e1da4c53fcdfb683847a1)(
1119 Array<half_t, N> const &a,
1120half_t const &b,
1121 Array<half_t, N> const &c) const {
1122
1123 Array<half_t, N> result;
1124 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1125
1126 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1127 __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1128 __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
1129 __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1130
1132for (int i = 0; i < N / 2; ++i) {
1133 result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]);
1134 }
1135
1136if (N % 2) {
1137
1138 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1139 __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1140
1141 __half d_residual = __hfma(
1142 a_residual_ptr[N - 1],
1143 reinterpret_cast<__half const &>(b),
1144 c_residual_ptr[N - 1]);
1145
1146 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1147 }
1148
1149 #else
1150
1151multiply_add<half_t> op;
1152
1154for (int i = 0; i < N; ++i) {
1155 result[i] = op(a[i], b, c[i]);
1156 }
1157 #endif
1158
1159return result;
1160 }
1161
[1163](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a6b4fe9d3366d389034a53f5ba71bdaee) Array<half_t, N> [operator()](structcutlass_1_1multiply__add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a6b4fe9d3366d389034a53f5ba71bdaee)(
1164 Array<half_t, N> const &a,
1165 Array<half_t, N> const &b,
1166half_t const &c) const {
1167
1168 Array<half_t, N> result;
1169 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1170
1171 __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1172 __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1173 __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1174 __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
1175
1177for (int i = 0; i < N / 2; ++i) {
1178 result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair);
1179 }
1180
1181if (N % 2) {
1182
1183 __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1184 __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1185
1186 __half d_residual = __hfma(
1187 a_residual_ptr[N - 1],
1188 b_residual_ptr[N - 1],
1189 reinterpret_cast<__half const &>(c));
1190
1191 result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1192 }
1193
1194 #else
1195
1196multiply_add<half_t> op;
1197
1199for (int i = 0; i < N; ++i) {
1200 result[i] = op(a[i], b[i], c);
1201 }
1202 #endif
1203
1204return result;
1205 }
1206 };
1207
1209
1210 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
cutlass::minimum< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:351
cutlass::plus< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:578
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b, T const &c) const
Definition: functional.h:103
cutlass::divides< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:955
cutlass::minus< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:383
CUTLASS_HOST_DEVICE float const & imag(cuFloatComplex const &z)
Returns the imaginary part of the complex number.
Definition: complex.h:72
cutlass::minus< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:678
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:48
cutlass::maximum< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:269
cutlass::minimum< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:323
cutlass::plus< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:611
Defines a class for using IEEE half-precision floating-point types in host or device code...
cutlass::divides< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:920
cutlass::multiplies< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:778
cutlass::plus< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:191
Definition: functional.h:298
cutlass::maximum< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:255
Definition: functional.h:235
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE T operator()(T lhs) const
Definition: functional.h:85
cutlass::multiplies< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:811
cutlass::plus< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:205
CUTLASS_HOST_DEVICE float const & real(cuFloatComplex const &z)
Returns the real part of the complex number.
Definition: complex.h:56
cutlass::minus< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:743
cutlass::minimum< float >::operator()
CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const
Definition: functional.h:309
cutlass::minimum< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:337
CUTLASS_HOST_DEVICE T const & imag() const
Accesses the imaginary part of the complex number.
Definition: complex.h:240
cutlass::multiply_add::operator()
CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const
Definition: functional.h:94
[cutlass::multiply_add< Array< half_t, N >, Array< half_t, N >, Array< half_t, N > >::operator()](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a6b4fe9d3366d389034a53f5ba71bdaee)
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, Array< half_t, N > const &b, half_t const &c) const
Definition: functional.h:1163
Definition: functional.h:46
cutlass::minus< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:369
CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const
Definition: functional.h:238
cutlass::multiplies< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:415
cutlass::minus< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:397
cutlass::multiplies::operator()
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:66
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::maximum< float >::operator()
CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const
Definition: functional.h:246
Definition: functional.h:83
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:75
cutlass::multiply_add< Array< T, N >, Array< T, N >, Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &a, T const &scalar, Array< T, N > const &c) const
Definition: functional.h:541
cutlass::multiply_add< Array< T, N >, Array< T, N >, Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c) const
Definition: functional.h:527
cutlass::multiplies< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:443
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
cutlass::multiplies< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:429
cutlass::divides< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:489
cutlass::divides< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:884
Definition: functional.h:73
CUTLASS_HOST_DEVICE T const & real() const
Accesses the real part of the complex number.
Definition: complex.h:232
cutlass::plus< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:643
Definition: complex.h:92
CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const
Definition: functional.h:301
cutlass::minimum< Array< T, N > >::scalar_op
static CUTLASS_HOST_DEVICE T scalar_op(T const &lhs, T const &rhs)
Definition: functional.h:318
cutlass::multiply_add< T, complex< T >, complex< T > >::operator()
CUTLASS_HOST_DEVICE complex< T > operator()(T const &a, complex< T > const &b, complex< T > const &c) const
Definition: functional.h:164
[cutlass::multiply_add< Array< half_t, N >, Array< half_t, N >, Array< half_t, N > >::operator()](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#a7659a559754edb4949d54cc641a5bd01)
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &a, Array< half_t, N > const &b, Array< half_t, N > const &c) const
Definition: functional.h:1074
cutlass::multiply_add< complex< T >, T, complex< T > >::operator()
CUTLASS_HOST_DEVICE complex< T > operator()(complex< T > const &a, T const &b, complex< T > const &c) const
Definition: functional.h:142
cutlass::negate< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs) const
Definition: functional.h:508
Fused multiply-add.
Definition: functional.h:101
cutlass::maximum< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:283
cutlass::divides< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:475
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:57
[cutlass::multiply_add< Array< half_t, N >, Array< half_t, N >, Array< half_t, N > >::operator()](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#aed1e0930836e1da4c53fcdfb683847a1)
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, half_t const &b, Array< half_t, N > const &c) const
Definition: functional.h:1118
cutlass::plus< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:219
cutlass::minus< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:711
cutlass::multiply_add< Array< T, N >, Array< T, N >, Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &b, Array< T, N > const &c) const
Definition: functional.h:555
cutlass::divides< Array< T, N > >::operator()
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:461
cutlass::negate< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs) const
Definition: functional.h:993
Definition: functional.h:55
cutlass::multiply_add< complex< T >, complex< T >, complex< T > >::operator()
CUTLASS_HOST_DEVICE complex< T > operator()(complex< T > const &a, complex< T > const &b, complex< T > const &c) const
Definition: functional.h:118
cutlass::multiplies< Array< half_t, N > >::operator()
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:846
Basic include for CUTLASS.
[cutlass::multiply_add< Array< half_t, N >, Array< half_t, N >, Array< half_t, N > >::operator()](structcutlass_1_1multiply add_3_01Array_3_01half t_00_01N_01_4_00_01Array_3_01half__t_00_01N_01adaeadb27c0e4439444709c0eb30963.html#abc9dd51cad4f2997dae521fad8f5b486)
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, Array< half_t, N > const &b, Array< half_t, N > const &c) const
Definition: functional.h:1028
<!-- fragment --> <!-- contents --><!-- start footer part -->