docs/gemm_2thread_2mma__sm60_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
gemm/thread/mma_sm60.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/tensor_ref.h"
33 #include "cutlass/layout/matrix.h"
34 #include "cutlass/gemm/gemm.h"
35 #include "cutlass/gemm/thread/mma.h"
36 #include "cutlass/functional.h"
37 #include "cutlass/reduction/thread/reduce.h"
38
40
41 namespace cutlass {
42 namespace gemm {
43 namespace thread {
44
46
47 namespace detail {
48
50 template <
52typename Shape,
53
55typename LayoutA,
56
58typename LayoutB,
59
61typename LayoutC,
62
64bool
65 >
67
68
70 // Specialization for NNN //
72
73 template <typename Shape>
75 Shape,
76 layout::ColumnMajor,
79 true
80 > {
81
83 !(Shape::kM % 2),
84"Mma_HFMA2 requires the M dimension to be divisible by 2."
85 );
86
88using FragmentA = Array<half_t, Shape::kMK>;
89
91using FragmentB = Array<half_t, Shape::kKN>;
92
94using FragmentC = Array<half_t, Shape::kMN>;
95
96//
97// Methods
98//
99
102void operator()(
103FragmentC & D,
104FragmentA const & A,
105FragmentB const & B,
106FragmentC const & C) {
107
109 D = C;
110
114 1,
115half_t,
117 half_t,
118 layout::ColumnMajor,
119 half_t,
120 layout::ColumnMajor,
121 arch::OpMultiplyAdd>;
122
123 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
124 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
125 Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);
126
127Mma mma;
128
130for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
131
133for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
134
136for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
137
138 Array<half_t, 2> tmp;
139 Array<half_t, 2> *ptr_tmp = &tmp;
140 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
141
142 mma(
143 tmp,
144 ptr_A[k*Shape::kM/2 + m],
145 ptr_B[n*Shape::kK + k],
146 tmp);
147
148 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
149 }
150 }
151 }
152 }
153 };
154
156 // Specialization for NNT //
158
159 template <typename Shape>
161 Shape,
162 layout::ColumnMajor,
164layout::RowMajor,
165 true
166 > {
167
168static_assert(
169 !(Shape::kN % 2),
170"Mma_HFMA2 requires the N dimension to be divisible by 2."
171 );
172
174using FragmentA = Array<half_t, Shape::kMK>;
175
177using FragmentB = Array<half_t, Shape::kKN>;
178
180using FragmentC = Array<half_t, Shape::kMN>;
181
182//
183// Methods
184//
185
188void operator()(
189FragmentC & D,
190FragmentA const & A,
191FragmentB const & B,
192FragmentC const & C) {
193
195 D = C;
196
200 1,
201half_t,
203 half_t,
204 layout::ColumnMajor,
205 half_t,
206layout::RowMajor,
207 arch::OpMultiplyAdd>;
208
209 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
210 Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);
211 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
212
213Mma mma;
214
216for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
217
219for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
220
222for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
223
224 Array<half_t, 2> tmp;
225 Array<half_t, 2> *ptr_tmp = &tmp;
226 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
227
228 Array<half_t, 2> tmp_B;
229 tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);
230 tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);
231
232 mma(
233 tmp,
234 ptr_A[k*Shape::kM + m],
235 tmp_B,
236 tmp);
237
238 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
239 }
240 }
241 }
242 }
243 };
244
245
247 // Specialization for NTN //
249
250 template <typename Shape>
252 Shape,
253 layout::ColumnMajor,
254layout::RowMajor,
256 true
257 > {
258
259static_assert(
260 !(Shape::kM % 2),
261"Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."
262 );
263
265using FragmentA = Array<half_t, Shape::kMK>;
266
268using FragmentB = Array<half_t, Shape::kKN>;
269
271using FragmentC = Array<half_t, Shape::kMN>;
272
273//
274// Methods
275//
276
279void operator()(
280FragmentC & D,
281FragmentA const & A,
282FragmentB const & B,
283FragmentC const & C) {
284
286 D = C;
287
290 1,
291half_t,
293 half_t,
294layout::RowMajor,
295 half_t,
296 layout::ColumnMajor,
297 arch::OpMultiplyAdd>;
298
299 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
300 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
301 Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);
302
303Mma mma;
304
306for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
307
309for (int m = 0; m < Shape::kM / Mma::Shape::kM; ++m) {
310
312for (int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) {
313
314 Array<half_t, 2> tmp;
315 Array<half_t, 2> *ptr_tmp = &tmp;
316
317 ptr_tmp[0] = ptr_D[m + n * Shape::kM/2];
318
319 mma(
320 tmp,
321 ptr_A[m + k * Shape::kM/2],
322 ptr_B[k * Shape::kN + n],
323 tmp);
324
325 ptr_D[m + n * Shape::kM/2] = ptr_tmp[0];
326 }
327 }
328 }
329 }
330 };
331
333 // Specialization for NTT //
335
336 template <typename Shape>
338 Shape,
339 layout::ColumnMajor,
340layout::RowMajor,
341layout::RowMajor,
342 true
343 > {
344
345static_assert(
346 !(Shape::kN % 2),
347"Mma_HFMA2 requires the N dimension to be divisible by 2."
348 );
349
351using FragmentA = Array<half_t, Shape::kMK>;
352
354using FragmentB = Array<half_t, Shape::kKN>;
355
357using FragmentC = Array<half_t, Shape::kMN>;
358
359//
360// Methods
361//
362
365void operator()(
366FragmentC & D,
367FragmentA const & A,
368FragmentB const & B,
369FragmentC const & C) {
370
372 D = C;
373
377 1,
378half_t,
380 half_t,
381layout::RowMajor,
382 half_t,
383 layout::RowMajor,
384 arch::OpMultiplyAdd>;
385
386 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
387 Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);
388 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
389
390Mma mma;
391
393for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
394
396for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
397
399for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
400
401 Array<half_t, 2> tmp;
402 Array<half_t, 2> *ptr_tmp = &tmp;
403 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
404
405 mma(
406 tmp,
407 ptr_A[k*Shape::kM + m],
408 ptr_B[k*Shape::kN/2 + n],
409 tmp);
410
411 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
412 }
413 }
414 }
415 }
416 };
417
418
420 // Specialization for TNN //
422
423 template <typename Shape>
425 Shape,
426 layout::RowMajor,
429 true
430 > {
431
432static_assert(
433 !(Shape::kM % 2),
434"Mma_HFMA2 requires the M dimension to be divisible by 2."
435 );
436
438using FragmentA = Array<half_t, Shape::kMK>;
439
441using FragmentB = Array<half_t, Shape::kKN>;
442
444using FragmentC = Array<half_t, Shape::kMN>;
445
446//
447// Methods
448//
449
452void operator()(
453FragmentC & D,
454FragmentA const & A,
455FragmentB const & B,
456FragmentC const & C) {
457
459 D = C;
460
464 1,
465half_t,
466layout::RowMajor,
467 half_t,
469 half_t,
470 layout::ColumnMajor,
471 arch::OpMultiplyAdd>;
472
473 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
474 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
475 Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);
476
477Mma mma;
478
480for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
481
483for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
484
486for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
487
488 Array<half_t, 2> tmp;
489 Array<half_t, 2> *ptr_tmp = &tmp;
490 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
491
492 Array<half_t, 2> tmp_A;
493 tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);
494 tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);
495
496 mma(
497 tmp,
498 tmp_A,
499 ptr_B[n*Shape::kK + k],
500 tmp);
501
502 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
503 }
504 }
505 }
506 }
507 };
508
510 // Specialization for TNT //
512
513 template <typename Shape>
515 Shape,
516 layout::RowMajor,
518layout::RowMajor,
519 true
520 > {
521
522static_assert(
523 !(Shape::kN % 2),
524"Mma_HFMA2 requires the N dimension to be divisible by 2."
525 );
526
528using FragmentA = Array<half_t, Shape::kMK>;
529
531using FragmentB = Array<half_t, Shape::kKN>;
532
534using FragmentC = Array<half_t, Shape::kMN>;
535
536//
537// Methods
538//
539
542void operator()(
543FragmentC & D,
544FragmentA const & A,
545FragmentB const & B,
546FragmentC const & C) {
547
549 D = C;
550
554 1,
555half_t,
556layout::RowMajor,
557 half_t,
559 half_t,
560 layout::RowMajor,
561 arch::OpMultiplyAdd>;
562
563 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
564 Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);
565 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
566
567Mma mma;
568
570for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
571
573for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
574
576for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
577
578 Array<half_t, 2> tmp;
579 Array<half_t, 2> *ptr_tmp = &tmp;
580 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
581
582 Array<half_t, 2> tmp_B;
583 tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);
584 tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);
585
586 mma(
587 tmp,
588 ptr_A[m*Shape::kK + k],
589 tmp_B,
590 tmp);
591
592 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
593 }
594 }
595 }
596 }
597 };
598
600 // Specialization for TTN //
602
603 template <typename Shape>
605 Shape,
606 layout::RowMajor,
607layout::RowMajor,
609 true
610 > {
611
612static_assert(
613 !(Shape::kM % 2),
614"Mma_HFMA2 requires the M dimension to be divisible by 2."
615 );
616
618using FragmentA = Array<half_t, Shape::kMK>;
619
621using FragmentB = Array<half_t, Shape::kKN>;
622
624using FragmentC = Array<half_t, Shape::kMN>;
625
626//
627// Methods
628//
629
632void operator()(
633FragmentC & D,
634FragmentA const & A,
635FragmentB const & B,
636FragmentC const & C) {
637
639 D = C;
640
644 1,
645half_t,
646layout::RowMajor,
647 half_t,
648 layout::RowMajor,
649 half_t,
651 arch::OpMultiplyAdd>;
652
653 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
654 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
655 Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);
656
657Mma mma;
658
660for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
661
663for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
664
666for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
667
668 Array<half_t, 2> tmp;
669 Array<half_t, 2> *ptr_tmp = &tmp;
670 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
671
672 Array<half_t, 2> tmp_A;
673 tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);
674 tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);
675
676 mma(
677 tmp,
678 tmp_A,
679 ptr_B[k*Shape::kN + n],
680 tmp);
681
682 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
683 }
684 }
685 }
686 }
687 };
688
689
691 // Specialization for TTT //
693
694 template <typename Shape>
696 Shape,
697 layout::RowMajor,
698layout::RowMajor,
699layout::RowMajor,
700 true
701 > {
702
703static_assert(
704 !(Shape::kN % 2),
705"Mma_HFMA2 requires the N dimension to be divisible by 2."
706 );
707
709using FragmentA = Array<half_t, Shape::kMK>;
710
712using FragmentB = Array<half_t, Shape::kKN>;
713
715using FragmentC = Array<half_t, Shape::kMN>;
716
717//
718// Methods
719//
720
723void operator()(
724FragmentC & D,
725FragmentA const & A,
726FragmentB const & B,
727FragmentC const & C) {
728
730 D = C;
731
735 1,
736half_t,
737layout::RowMajor,
738 half_t,
739 layout::RowMajor,
740 half_t,
741 layout::RowMajor,
742 arch::OpMultiplyAdd>;
743
744 Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
745 Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);
746 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
747
748Mma mma;
749
751for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
752
754for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
755
757for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
758
759 Array<half_t, 2> tmp;
760 Array<half_t, 2> *ptr_tmp = &tmp;
761 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
762
763 mma(
764 tmp,
765 ptr_A[m*Shape::kK + k],
766 ptr_B[k*Shape::kN/2 + n],
767 tmp);
768
769 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
770 }
771 }
772 }
773 }
774 };
775
777 // Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T //
779
780 template <typename Shape, typename LayoutA, typename LayoutB>
782 Shape,
783 LayoutA,
784 LayoutB,
785 layout::RowMajor,
786 false
787 > {
788
789static_assert(
790 !(Shape::kK % 2),
791"Mma_HFMA2 requires the K dimension to be divisible by 2."
792 );
793
795using FragmentA = Array<half_t, Shape::kMK>;
796
798using FragmentB = Array<half_t, Shape::kKN>;
799
801using FragmentC = Array<half_t, Shape::kMN>;
802
803//
804// Methods
805//
806
809void operator()(
810FragmentC & D,
811FragmentA const & A,
812FragmentB const & B,
813FragmentC const & C) {
814
816 D = C;
817
819using GemmShape = gemm::GemmShape<1,1,2>;
820
821 Array<half_t, 1> *ptr_D = reinterpret_cast<Array<half_t, 1> *>(&D);
822 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
823 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
824
825// Inner product is calculated using MACs, followed by final reduction
826multiply_add<Array<half_t, 2>> mac;
827cutlass::reduction::thread::Reduce< plus<half_t>, Array<half_t, 2> > reduce;
828
830for(auto n=0; n < Shape::kN / GemmShape::kN; n++){
831
833for(auto m=0; m < Shape::kM / GemmShape::kM; m++){
834
835 Array<half_t, 2> tmp_C;
836 tmp_C.clear();
837 Array<half_t, 1> *ptr_tmp_C = reinterpret_cast<Array<half_t, 1> *>(&tmp_C);
838 ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];
839
841for(auto k=0; k < Shape::kK / GemmShape::kK; k++){
842 tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);
843 }
844
845 Array<half_t, 1> res;
846 Array<half_t, 1> *ptr_res = &res;
847 res = reduce(tmp_C);
848
849 ptr_D[m*Shape::kN + n] = ptr_res[0];
850 }
851 }
852 }
853 };
854
856 // Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N //
858
859 template <typename Shape, typename LayoutA, typename LayoutB>
861 Shape,
862 LayoutA,
863 LayoutB,
864 layout::ColumnMajor,
865 false
866 > {
867
868static_assert(
869 !(Shape::kK % 2),
870"Mma_HFMA2 requires the K dimension to be divisible by 2."
871 );
872
874using FragmentA = Array<half_t, Shape::kMK>;
875
877using FragmentB = Array<half_t, Shape::kKN>;
878
880using FragmentC = Array<half_t, Shape::kMN>;
881
882//
883// Methods
884//
885
888void operator()(
889FragmentC & D,
890FragmentA const & A,
891FragmentB const & B,
892FragmentC const & C) {
893
895 D = C;
896
898using GemmShape= gemm::GemmShape<1,1,2>;
899
900 Array<half_t, 1> *ptr_D = reinterpret_cast<Array<half_t, 1> *>(&D);
901 Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
902 Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
903
904// Inner product is calculated using MACs, followed by final reduction
905multiply_add<Array<half_t, 2>> mac;
906cutlass::reduction::thread::Reduce< plus<half_t>, Array<half_t, 2> > reduce;
907
909for(auto n=0; n < Shape::kN / GemmShape::kN; n++){
910
912for(auto m=0; m < Shape::kM / GemmShape::kM; m++){
913
914 Array<half_t, 2> tmp_C;
915 tmp_C.clear();
916 Array<half_t, 1> *ptr_tmp_C = reinterpret_cast<Array<half_t, 1> *>(&tmp_C);
917 ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];
918
920for(auto k=0; k < Shape::kK / GemmShape::kK; k++){
921
922 tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);
923
924 }
925
926 Array<half_t, 1> res;
927 Array<half_t, 1> *ptr_res = &res;
928 res = reduce(tmp_C);
929
930 ptr_D[n*Shape::kM + m] = ptr_res[0];
931 }
932 }
933 }
934 };
935
936 } // namespace detail
937
939
941 template <
943typename Shape_, typename LayoutA, typename LayoutB, typename LayoutC
944 >
[945](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html) struct Mma<
946 Shape_,
947half_t,
948 LayoutA,
949half_t,
950 LayoutB,
951half_t,
952 LayoutC,
953 arch::OpMultiplyAdd
954 > {
955
[957](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a041bfce41e4c95a7a67dc4156173e1f4)using [Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a041bfce41e4c95a7a67dc4156173e1f4) = Shape_;
958
[960](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#abc237ebaf010ac6a3e91a93830772707)using ElementA = half_t;
961
[963](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a4b52c217fcddfa6f6ec603ed0caff3f0)using ElementB = half_t;
964
[966](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a836cdbd43f3a01a930049af70f8009bd)using ElementC = half_t;
967
[969](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a62084aaf63a7538ba29de4c60d64d133)using [Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a62084aaf63a7538ba29de4c60d64d133) = arch::OpMultiplyAdd;
970
[972](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a64b2cf33786247c4acd872fb8856abd5)using [FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a64b2cf33786247c4acd872fb8856abd5) = Array<ElementA, Shape::kMK>;
973
[975](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a5a00c6305fd345f12f9469b790e99f12)using [FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a5a00c6305fd345f12f9469b790e99f12) = Array<ElementB, Shape::kKN>;
976
[978](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7)using [FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7) = Array<ElementC, Shape::kMN>;
979
980//
981// Methods
982//
983
[986](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a7eb69f25c0b516fda203957a230df3ee)void [operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01half t_00_01LayoutA_00_01half t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a7eb69f25c0b516fda203957a230df3ee)(
987[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7) & D,
988[FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a64b2cf33786247c4acd872fb8856abd5) const & A,
989[FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a5a00c6305fd345f12f9469b790e99f12) const & B,
990[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7) const & C) {
991
992constexpr bool a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value;
993constexpr bool b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value;
994constexpr bool c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value;
995constexpr bool c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value;
996
997constexpr bool m_mod2 = !(Shape::kM % 2);
998constexpr bool n_mod2 = !(Shape::kN % 2);
999constexpr bool k_mod2 = !(Shape::kK % 2);
1000
1001// HFMA based MMA optimizations are of 2 types :
1002// 1. Inner product
1003// 2. Outer product
1004// It is chosen based on LayoutC (for outer product gemm) or
1005// Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms)
1006// If all fails, we choose the generic MMA
1007constexpr bool use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);
1008constexpr bool use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);
1009constexpr bool use_optimized = (use_outer_prod || use_inner_prod);
1010
1011typename platform::conditional< use_optimized,
1012detail::Mma_HFMA2<Shape, LayoutA, LayoutB, LayoutC, use_outer_prod>,
1013MmaGeneric <Shape, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, Operator>
1014 >::type mma;
1015
1016 mma(D, A, B, C);
1017
1018 }
1019 };
1020
1022
1023 namespace detail {
1024
1026template <
1027typename LayoutA,
1029typename LayoutB>
[1030](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html)struct [EnableMma_Crow_SM60](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html) {
1031
[1032](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html#a8ec734b2126bd5147abafee8a3b7be70)static bool const kIsConventionalLayout =
1033 (platform::is_same<LayoutA, layout::RowMajor>::value ||
1034platform::is_same<LayoutA, layout::ColumnMajor>::value) &&
1035 (platform::is_same<LayoutB, layout::RowMajor>::value ||
1036platform::is_same<LayoutB, layout::ColumnMajor>::value);
1037
[1038](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html#a2efb4c6abab3bfc29c0d58df8ccc0fd3)static bool const value = kIsConventionalLayout;
1039 };
1040 };
1041
1043
1045 template <
1047typename Shape_,
1048typename LayoutA_,
1049typename LayoutB_
1050 >
[1051](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html) struct Mma<
1052 Shape_,
1053half_t,
1054 LayoutA_,
1055half_t,
1056 LayoutB_,
1057half_t,
1058 layout::RowMajor,
1059 arch::OpMultiplyAdd,
1060 typename platform::enable_if<detail::EnableMma_Crow_SM60<
1061 LayoutA_,
1062 LayoutB_
1063 >::value>::type>{
1064
[1065](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a951f25ff3bb7a76bac1f867ee21c657f)using [Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a951f25ff3bb7a76bac1f867ee21c657f) = Shape_;
[1066](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a7ffe7f427ffce1c269587417e4fed240)using ElementA = half_t;
[1067](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0975c18cc4a9d376011858c6dbf740d0)using LayoutA = LayoutA_;
[1068](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a962acba07bc680b70ee1b08732d2516f)using ElementB = half_t;
[1069](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a28b637c1f311310a27b39c44e89e698e)using LayoutB = LayoutB_;
[1070](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#aff9afb3fc630bd0bdb35de1b402c65fa)using ElementC = half_t;
[1071](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a397dfb5a622d1ebe47177825194a03a9)using LayoutC = layout::RowMajor;
[1072](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#af96ae215c5f273447ed44baa1315ffcf)using [Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#af96ae215c5f273447ed44baa1315ffcf) = arch::OpMultiplyAdd;
1073
1074using TransposeMma = Mma<
1075GemmShapeTranspose<Shape>,
1076half_t,
1077typename layout::LayoutTranspose<LayoutB>::type,
1078 half_t,
1079typename layout::LayoutTranspose<LayoutA>::type,
1080 half_t,
1081layout::ColumnMajor,
1082 arch::OpMultiplyAdd,
[1083](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a2acc2e5fb14c4e62ea997d80402730c5)bool>;
1084
[1085](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#adbd6a51a9e477d917f5739230a023524)using [FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#adbd6a51a9e477d917f5739230a023524) = Array<ElementA, Shape::kMK>;
[1086](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a687f0bd7056ea8ff518bfed26f027e4f)using [FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a687f0bd7056ea8ff518bfed26f027e4f) = Array<ElementB, Shape::kKN>;
[1087](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f)using [FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f) = Array<ElementC, Shape::kMN>;
1088
[1090](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a72fad6edd8b029407aad12fb22937358)void [operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a72fad6edd8b029407aad12fb22937358)(
1091[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f) & D,
1092[FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#adbd6a51a9e477d917f5739230a023524) const & A,
1093[FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a687f0bd7056ea8ff518bfed26f027e4f) const & B,
1094[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f) const & C) {
1095
1096TransposeMma mma;
1097
1098 mma(D, B, A, C);
1099 }
1100 };
1101
1103
1104 } // namespace thread
1105 } // namespace gemm
1106 } // namespace cutlass
1107
Fused multiply-add.
Definition: functional.h:92
[cutlass::gemm::thread::detail::EnableMma_Crow_SM60](structcutlass_1_1gemm_1_1thread_1_1detail_1_1EnableMma Crow SM60.html)
Determines whether to enable thread::Gemm<> specializations compatible with SM50. ...
Definition: gemm/thread/mma_sm60.h:1030
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:801
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:271
Definition: aligned_buffer.h:35
#define constexpr
Definition: platform.h:137
Defines a structure containing strides, bounds, and a pointer to tensor data.
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:94
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:528
std::is_same (false specialization)
Definition: platform.h:394
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:809
cutlass::gemm::thread::detail::Mma_HFMA2
Structure to compute the matrix product for HFMA.
Definition: gemm/thread/mma_sm60.h:66
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a0382879142faec5a4b6190869fc6187f)
Array< ElementC, Shape::kMN > FragmentC
Definition: gemm/thread/mma_sm60.h:1087
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:441
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a72fad6edd8b029407aad12fb22937358)
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Definition: gemm/thread/mma_sm60.h:1090
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:444
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:438
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:357
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:102
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:723
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:632
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:174
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:712
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:624
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#af96ae215c5f273447ed44baa1315ffcf)
arch::OpMultiplyAdd Operator
Definition: gemm/thread/mma_sm60.h:1072
static int const kK
Definition: include/cutlass/gemm/gemm.h:60
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:177
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a5a00c6305fd345f12f9469b790e99f12)
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:975
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:91
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a687f0bd7056ea8ff518bfed26f027e4f)
Array< ElementB, Shape::kKN > FragmentB
Definition: gemm/thread/mma_sm60.h:1086
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:365
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a62084aaf63a7538ba29de4c60d64d133)
arch::OpMultiplyAdd Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm60.h:969
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:709
cutlass::layout::LayoutTranspose
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:531
cutlass::gemm::thread::MmaGeneric
Gemplate that handles all packed matrix layouts.
Definition: gemm/thread/mma_sm50.h:65
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:888
Defines basic thread level reduction with specializations for Array<T, N>.
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#ac5f3a1ad86714fdce4af1e8a4738a4f7)
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:978
std::enable_if (true specialization)
Definition: platform.h:315
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:188
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Templates exposing architecture support for warp-level multiply-add operations.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
cutlass::platform::conditional
std::conditional (true specialization)
Definition: platform.h:325
#define static_assert(__e, __m)
Definition: platform.h:153
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:265
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:88
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:880
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a7eb69f25c0b516fda203957a230df3ee)
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:986
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:795
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:351
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:452
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:621
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a041bfce41e4c95a7a67dc4156173e1f4)
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm60.h:957
Structure to compute the matrix product.
Definition: gemm/thread/mma.h:66
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:877
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:715
Defines layout functions used by TensorRef and derived classes.
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:534
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:180
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:798
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:618
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:268
Matrix multiply-add operation.
Definition: arch/mma.h:92
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, arch::OpMultiplyAdd >::FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA_00_01half__t_00_01L066c9d2371712cdf0cac099ca9bcc578.html#a64b2cf33786247c4acd872fb8856abd5)
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:972
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:354
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:542
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:279
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#adbd6a51a9e477d917f5739230a023524)
Array< ElementA, Shape::kMK > FragmentA
Definition: gemm/thread/mma_sm60.h:1085
cutlass::reduction::thread::Reduce
Structure to compute the thread level reduction.
Definition: reduce.h:43
CUTLASS_HOST_DEVICE Array< T, N > mac(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c)
Definition: simd.h:84
[cutlass::gemm::thread::Mma< Shape_, half_t, LayoutA_, half_t, LayoutB_, half_t, layout::RowMajor, arch::OpMultiplyAdd, typename platform::enable_if< detail::EnableMma_Crow_SM60< LayoutA_, LayoutB_ >::value >::type >::Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape _00_01half t_00_01LayoutA _00_01half t_00_088f0e99e501b6012297eb30b4e89bcea.html#a951f25ff3bb7a76bac1f867ee21c657f)
Shape_ Shape
Definition: gemm/thread/mma_sm60.h:1065
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:874
static int const kN
Definition: include/cutlass/gemm/gemm.h:59
Generated by 1.8.11