docs/mma__sm75_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
<!-- end header part --><!-- Generated by Doxygen 1.8.11 --> <input type="text" id="MSearchField" value="Search" accesskey="S" onfocus="searchBox.OnSearchFieldFocus(true)" onblur="searchBox.OnSearchFieldFocus(false)" onkeyup="searchBox.OnSearchFieldChange(event)"> <!-- window showing the filter options --> <!-- iframe showing the search results (closed by default) --> <!-- top -->mma_sm75.h
<!--header-->Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include <assert.h>
32
33 #include "cutlass/arch/wmma.h"
34
35 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
36 // CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply.
37 #include <mma.h>
38 #include "cutlass/wmma_array.h"
39 #endif
40
41 // CUTLASS includes
42 #include "cutlass/arch/mma.h"
43 #include "cutlass/layout/matrix.h"
44 #include "cutlass/numeric_types.h"
45
47
48 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
49
50 #define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1
51
52 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
53 #define CUTLASS_ARCH_MMA_SM75_ENABLED
54 #endif
55 #endif
56
58
59 namespace cutlass {
60 namespace arch {
61
63 //
64 // Matrix Multiply 1688 - FP16 accumulation
65 //
67
69 template <>
71 gemm::GemmShape<16, 8, 8>,
72 32,
73half_t,
75half_t,
77half_t,
79 OpMultiplyAdd> {
80
81using Shape = gemm::GemmShape<16, 8, 8>;
82
84using LayoutA = layout::RowMajor;
85using FragmentA = Array<half_t, 4>;
86
88using LayoutB = layout::ColumnMajor;
89using FragmentB = Array<half_t, 2>;
90
92using LayoutC = layout::RowMajor;
93using FragmentC = Array<half_t, 4>;
94
95using Operator = OpMultiplyAdd;
96
98void operator()(
99FragmentC &d,
100FragmentA const &a,
101FragmentB const &b,
102FragmentC const &c
103 ) const {
104
105 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
106
107unsigned const *A = reinterpret_cast<unsigned const *>(&a);
108unsigned const *B = reinterpret_cast<unsigned const *>(&b);
109unsigned const *C = reinterpret_cast<unsigned const *>(&c);
110unsigned *D = reinterpret_cast<unsigned *>(&d);
111
112asm volatile(
113"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
114 : "=r"(D[0]), "=r"(D[1])
115 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
116
117 #else
118 assert(0);
119 #endif
120 }
121 };
122
124 //
125 // Matrix Multiply 1688 - FP32 accumulation
126 //
128
130 template <>
132 gemm::GemmShape<16, 8, 8>,
133 32,
134half_t,
135layout::RowMajor,
136half_t,
138 float,
139layout::RowMajor,
140 OpMultiplyAdd> {
141
142using Shape = gemm::GemmShape<16, 8, 8>;
143
145using LayoutA = layout::RowMajor;
146using FragmentA = Array<half_t, 4>;
147
149using LayoutB = layout::ColumnMajor;
150using FragmentB = Array<half_t, 2>;
151
153using LayoutC = layout::RowMajor;
154using FragmentC = Array<float, 4>;
155
156using Operator = OpMultiplyAdd;
157
160void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
161FragmentC const &c) const {
162
163 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
164
165unsigned const *A = reinterpret_cast<unsigned const *>(&a);
166unsigned const *B = reinterpret_cast<unsigned const *>(&b);
167float const *C = reinterpret_cast<float const *>(&c);
168float *D = reinterpret_cast<float *>(&d);
169
170asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
171 : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
172 :
173"r"(A[0]), "r"(A[1]),
174"r"(B[0]),
175"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
176 );
177
178 #else
179 assert(0);
180 #endif
181 }
182 };
183
185 //
186 // Integer matrix multiply .8816 (8b)
187 //
189
191 template <>
193 gemm::GemmShape<8, 8, 16>,
194 32,
195 int8_t,
196layout::RowMajor,
197 int8_t,
199 int,
200layout::RowMajor,
201 OpMultiplyAdd> {
202
203using Shape = gemm::GemmShape<8, 8, 16>;
204
206using LayoutA = layout::RowMajor;
207using FragmentA = Array<int8_t, 4>;
208
210using LayoutB = layout::ColumnMajor;
211using FragmentB = Array<int8_t, 4>;
212
214using LayoutC = layout::RowMajor;
215using FragmentC = Array<int, 2>;
216
217using Operator = OpMultiplyAdd;
218
221void operator()(
222FragmentC &d,
223FragmentA const &a,
224FragmentB const &b,
225FragmentC const &c
226 ) const {
227
228 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
229
230unsigned const & A = reinterpret_cast<unsigned const &>(a);
231unsigned const & B = reinterpret_cast<unsigned const &>(b);
232
233int const *C = reinterpret_cast<int const *>(&c);
234int *D = reinterpret_cast<int *>(&d);
235
236asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
237 : "=r"(D[0]), "=r"(D[1])
238 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
239
240 #else
241 assert(0);
242 #endif
243 }
244 };
245
247 template <>
249 gemm::GemmShape<8, 8, 16>,
250 32,
251 uint8_t,
252layout::RowMajor,
253 int8_t,
255 int,
256layout::RowMajor,
257 OpMultiplyAdd> {
258
259using Shape = gemm::GemmShape<8, 8, 16>;
260
262using LayoutA = layout::RowMajor;
263using FragmentA = Array<uint8_t, 4>;
264
266using LayoutB = layout::ColumnMajor;
267using FragmentB = Array<int8_t, 4>;
268
270using LayoutC = layout::RowMajor;
271using FragmentC = Array<int, 2>;
272
273using Operator = OpMultiplyAdd;
274
277void operator()(
278FragmentC &d,
279FragmentA const &a,
280FragmentB const &b,
281FragmentC const &c
282 ) const {
283
284 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
285
286unsigned const & A = reinterpret_cast<unsigned const &>(a);
287unsigned const & B = reinterpret_cast<unsigned const &>(b);
288
289int const *C = reinterpret_cast<int const *>(&c);
290int *D = reinterpret_cast<int *>(&d);
291
292asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
293 : "=r"(D[0]), "=r"(D[1])
294 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
295
296 #else
297 assert(0);
298 #endif
299 }
300 };
301
303 template <>
305 gemm::GemmShape<8, 8, 16>,
306 32,
307 int8_t,
308layout::RowMajor,
309 uint8_t,
311 int,
312layout::RowMajor,
313 OpMultiplyAdd> {
314
315using Shape = gemm::GemmShape<8, 8, 16>;
316
318using LayoutA = layout::RowMajor;
319using FragmentA = Array<int8_t, 4>;
320
322using LayoutB = layout::ColumnMajor;
323using FragmentB = Array<uint8_t, 4>;
324
326using LayoutC = layout::RowMajor;
327using FragmentC = Array<int, 2>;
328
329using Operator = OpMultiplyAdd;
330
333void operator()(
334FragmentC &d,
335FragmentA const &a,
336FragmentB const &b,
337FragmentC const &c
338 ) const {
339
340 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
341
342unsigned const & A = reinterpret_cast<unsigned const &>(a);
343unsigned const & B = reinterpret_cast<unsigned const &>(b);
344
345int const *C = reinterpret_cast<int const *>(&c);
346int *D = reinterpret_cast<int *>(&d);
347
348asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
349 : "=r"(D[0]), "=r"(D[1])
350 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
351
352
353 #else
354 assert(0);
355 #endif
356 }
357 };
358
360 template <>
362 gemm::GemmShape<8, 8, 16>,
363 32,
364 uint8_t,
365layout::RowMajor,
366 uint8_t,
368 int,
369layout::RowMajor,
370 OpMultiplyAdd> {
371
372using Shape = gemm::GemmShape<8, 8, 16>;
373
375using LayoutA = layout::RowMajor;
376using FragmentA = Array<uint8_t, 4>;
377
379using LayoutB = layout::ColumnMajor;
380using FragmentB = Array<uint8_t, 4>;
381
383using LayoutC = layout::RowMajor;
384using FragmentC = Array<int, 2>;
385
386using Operator = OpMultiplyAdd;
387
390void operator()(
391FragmentC &d,
392FragmentA const &a,
393FragmentB const &b,
394FragmentC const &c
395 ) const {
396
397 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
398
399unsigned const & A = reinterpret_cast<unsigned const &>(a);
400unsigned const & B = reinterpret_cast<unsigned const &>(b);
401
402int const *C = reinterpret_cast<int const *>(&c);
403int *D = reinterpret_cast<int *>(&d);
404
405asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
406 : "=r"(D[0]), "=r"(D[1])
407 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
408
409 #else
410 assert(0);
411 #endif
412 }
413 };
414
416 //
417 // Integer matrix multiply (8b) with SATURATE
418 //
420
422 template <>
424 gemm::GemmShape<8,8,16>,
425 32,
426 int8_t,
427layout::RowMajor,
428 int8_t,
430 int,
431layout::RowMajor,
432 OpMultiplyAddSaturate> {
433
434using Shape = gemm::GemmShape<8,8,16>;
435
437using LayoutA = layout::RowMajor;
438using FragmentA = Array<int8_t, 4>;
439
441using LayoutB = layout::ColumnMajor;
442using FragmentB = Array<int8_t, 4>;
443
445using LayoutC = layout::RowMajor;
446using FragmentC = Array<int, 2>;
447
448using Operator = OpMultiplyAddSaturate;
449
452void operator()(
453FragmentC &d,
454FragmentA const &a,
455FragmentB const &b,
456FragmentC const &c
457 ) const {
458
459 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
460
461unsigned const & A = reinterpret_cast<unsigned const &>(a);
462unsigned const & B = reinterpret_cast<unsigned const &>(b);
463
464int const *C = reinterpret_cast<int const *>(&c);
465int *D = reinterpret_cast<int *>(&d);
466
467asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
468 : "=r"(D[0]), "=r"(D[1])
469 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
470
471 #else
472 assert(0);
473 #endif
474 }
475 };
476
478 template <>
480 gemm::GemmShape<8,8,16>,
481 32,
482 uint8_t,
483layout::RowMajor,
484 int8_t,
486 int,
487layout::RowMajor,
488 OpMultiplyAddSaturate> {
489
490using Shape = gemm::GemmShape<8,8,16>;
491
493using LayoutA = layout::RowMajor;
494using FragmentA = Array<uint8_t, 4>;
495
497using LayoutB = layout::ColumnMajor;
498using FragmentB = Array<int8_t, 4>;
499
501using LayoutC = layout::RowMajor;
502using FragmentC = Array<int, 2>;
503
504using Operator = OpMultiplyAddSaturate;
505
508void operator()(
509FragmentC &d,
510FragmentA const &a,
511FragmentB const &b,
512FragmentC const &c
513 ) const {
514
515 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
516
517unsigned const & A = reinterpret_cast<unsigned const &>(a);
518unsigned const & B = reinterpret_cast<unsigned const &>(b);
519
520int const *C = reinterpret_cast<int const *>(&c);
521int *D = reinterpret_cast<int *>(&d);
522
523asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
524 : "=r"(D[0]), "=r"(D[1])
525 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
526
527 #else
528 assert(0);
529 #endif
530 }
531 };
532
534 template <>
536 gemm::GemmShape<8,8,16>,
537 32,
538 int8_t,
539layout::RowMajor,
540 uint8_t,
542 int,
543layout::RowMajor,
544 OpMultiplyAddSaturate> {
545
546using Shape = gemm::GemmShape<8,8,16>;
547
549using LayoutA = layout::RowMajor;
550using FragmentA = Array<int8_t, 4>;
551
553using LayoutB = layout::ColumnMajor;
554using FragmentB = Array<uint8_t, 4>;
555
557using LayoutC = layout::RowMajor;
558using FragmentC = Array<int, 2>;
559
560using Operator = OpMultiplyAddSaturate;
561
564void operator()(
565FragmentC &d,
566FragmentA const &a,
567FragmentB const &b,
568FragmentC const &c
569 ) const {
570
571 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
572
573unsigned const & A = reinterpret_cast<unsigned const &>(a);
574unsigned const & B = reinterpret_cast<unsigned const &>(b);
575
576int const *C = reinterpret_cast<int const *>(&c);
577int *D = reinterpret_cast<int *>(&d);
578
579asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
580 : "=r"(D[0]), "=r"(D[1])
581 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
582
583 #else
584 assert(0);
585 #endif
586 }
587 };
588
590 template <>
592 gemm::GemmShape<8,8,16>,
593 32,
594 uint8_t,
595layout::RowMajor,
596 uint8_t,
598 int,
599layout::RowMajor,
600 OpMultiplyAddSaturate> {
601
602using Shape = gemm::GemmShape<8,8,16>;
603
605using LayoutA = layout::RowMajor;
606using FragmentA = Array<uint8_t, 4>;
607
609using LayoutB = layout::ColumnMajor;
610using FragmentB = Array<uint8_t, 4>;
611
613using LayoutC = layout::RowMajor;
614using FragmentC = Array<int, 2>;
615
616using Operator = OpMultiplyAddSaturate;
617
620void operator()(
621FragmentC &d,
622FragmentA const &a,
623FragmentB const &b,
624FragmentC const &c
625 ) const {
626
627 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
628
629unsigned const & A = reinterpret_cast<unsigned const &>(a);
630unsigned const & B = reinterpret_cast<unsigned const &>(b);
631
632int const *C = reinterpret_cast<int const *>(&c);
633int *D = reinterpret_cast<int *>(&d);
634
635asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
636 : "=r"(D[0]), "=r"(D[1])
637 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
638
639 #else
640 assert(0);
641 #endif
642 }
643 };
644
646 //
647 // Integer matrix multiply (4b)
648 //
650
652 template <>
654 gemm::GemmShape<8,8,32>,
655 32,
656int4b_t,
657layout::RowMajor,
658int4b_t,
660 int,
661layout::RowMajor,
662 OpMultiplyAdd> {
663
664using Shape = gemm::GemmShape<8,8,32>;
665
667using LayoutA = layout::RowMajor;
668using FragmentA = Array<int4b_t, 8>;
669
671using LayoutB = layout::ColumnMajor;
672using FragmentB = Array<int4b_t, 8>;
673
675using LayoutC = layout::RowMajor;
676using FragmentC = Array<int, 2>;
677
678using Operator = OpMultiplyAdd;
679
682void operator()(
683FragmentC &d,
684FragmentA const &a,
685FragmentB const &b,
686FragmentC const &c
687 ) const {
688
689 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
690
691unsigned const & A = reinterpret_cast<unsigned const &>(a);
692unsigned const & B = reinterpret_cast<unsigned const &>(b);
693
694int const *C = reinterpret_cast<int const *>(&c);
695int *D = reinterpret_cast<int *>(&d);
696
697asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
698 : "=r"(D[0]), "=r"(D[1])
699 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
700
701 #else
702 assert(0);
703 #endif
704 }
705 };
706
708 template <>
710 gemm::GemmShape<8,8,32>,
711 32,
712uint4b_t,
713layout::RowMajor,
714int4b_t,
716 int,
717layout::RowMajor,
718 OpMultiplyAdd> {
719
720using Shape = gemm::GemmShape<8,8,32>;
721
723using LayoutA = layout::RowMajor;
724using FragmentA = Array<uint4b_t, 8>;
725
727using LayoutB = layout::ColumnMajor;
728using FragmentB = Array<int4b_t, 8>;
729
731using LayoutC = layout::RowMajor;
732using FragmentC = Array<int, 2>;
733
734using Operator = OpMultiplyAdd;
735
738void operator()(
739FragmentC &d,
740FragmentA const &a,
741FragmentB const &b,
742FragmentC const &c
743 ) const {
744
745 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
746
747unsigned const & A = reinterpret_cast<unsigned const &>(a);
748unsigned const & B = reinterpret_cast<unsigned const &>(b);
749
750int const *C = reinterpret_cast<int const *>(&c);
751int *D = reinterpret_cast<int *>(&d);
752
753asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
754 : "=r"(D[0]), "=r"(D[1])
755 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
756
757 #else
758 assert(0);
759 #endif
760 }
761 };
762
764 template <>
766 gemm::GemmShape<8,8,32>,
767 32,
768int4b_t,
769layout::RowMajor,
770uint4b_t,
772 int,
773layout::RowMajor,
774 OpMultiplyAdd> {
775
776using Shape = gemm::GemmShape<8,8,32>;
777
779using LayoutA = layout::RowMajor;
780using FragmentA = Array<int4b_t, 8>;
781
783using LayoutB = layout::ColumnMajor;
784using FragmentB = Array<uint4b_t, 8>;
785
787using LayoutC = layout::RowMajor;
788using FragmentC = Array<int, 2>;
789
790using Operator = OpMultiplyAdd;
791
794void operator()(
795FragmentC &d,
796FragmentA const &a,
797FragmentB const &b,
798FragmentC const &c
799 ) const {
800
801 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
802
803unsigned const & A = reinterpret_cast<unsigned const &>(a);
804unsigned const & B = reinterpret_cast<unsigned const &>(b);
805
806int const *C = reinterpret_cast<int const *>(&c);
807int *D = reinterpret_cast<int *>(&d);
808
809asm volatile("_mma.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
810 : "=r"(D[0]), "=r"(D[1])
811 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
812
813 #else
814 assert(0);
815 #endif
816 }
817 };
818
820 template <>
822 gemm::GemmShape<8,8,32>,
823 32,
824uint4b_t,
825layout::RowMajor,
826uint4b_t,
828 int,
829layout::RowMajor,
830 OpMultiplyAdd> {
831
832using Shape = gemm::GemmShape<8,8,32>;
833
835using LayoutA = layout::RowMajor;
836using FragmentA = Array<uint4b_t, 8>;
837
839using LayoutB = layout::ColumnMajor;
840using FragmentB = Array<uint4b_t, 8>;
841
843using LayoutC = layout::RowMajor;
844using FragmentC = Array<int, 2>;
845
846using Operator = OpMultiplyAdd;
847
850void operator()(
851FragmentC &d,
852FragmentA const &a,
853FragmentB const &b,
854FragmentC const &c
855 ) const {
856
857 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
858
859unsigned const & A = reinterpret_cast<unsigned const &>(a);
860unsigned const & B = reinterpret_cast<unsigned const &>(b);
861
862int const *C = reinterpret_cast<int const *>(&c);
863int *D = reinterpret_cast<int *>(&d);
864
865asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
866 : "=r"(D[0]), "=r"(D[1])
867 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
868
869 #else
870 assert(0);
871 #endif
872 }
873 };
874
876 //
877 // Integer matrix multiply (4b) - SATURATE
878 //
880
882 template <>
884 gemm::GemmShape<8,8,32>,
885 32,
886int4b_t,
887layout::RowMajor,
888int4b_t,
890 int,
891layout::RowMajor,
892 OpMultiplyAddSaturate> {
893
894using Shape = gemm::GemmShape<8,8,32>;
895
897using LayoutA = layout::RowMajor;
898using FragmentA = Array<int4b_t, 8>;
899
901using LayoutB = layout::ColumnMajor;
902using FragmentB = Array<int4b_t, 8>;
903
905using LayoutC = layout::RowMajor;
906using FragmentC = Array<int, 2>;
907
908using Operator = OpMultiplyAddSaturate;
909
912void operator()(
913FragmentC &d,
914FragmentA const &a,
915FragmentB const &b,
916FragmentC const &c
917 ) const {
918
919 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
920
921unsigned const & A = reinterpret_cast<unsigned const &>(a);
922unsigned const & B = reinterpret_cast<unsigned const &>(b);
923
924int const *C = reinterpret_cast<int const *>(&c);
925int *D = reinterpret_cast<int *>(&d);
926
927asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
928 : "=r"(D[0]), "=r"(D[1])
929 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
930
931 #else
932 assert(0);
933 #endif
934 }
935 };
936
938 template <>
940 gemm::GemmShape<8,8,32>,
941 32,
942uint4b_t,
943layout::RowMajor,
944int4b_t,
946 int,
947layout::RowMajor,
948 OpMultiplyAddSaturate> {
949
950using Shape = gemm::GemmShape<8,8,32>;
951
953using LayoutA = layout::RowMajor;
954using FragmentA = Array<uint4b_t, 8>;
955
957using LayoutB = layout::ColumnMajor;
958using FragmentB = Array<int4b_t, 8>;
959
961using LayoutC = layout::RowMajor;
962using FragmentC = Array<int, 2>;
963
964using Operator = OpMultiplyAddSaturate;
965
968void operator()(
969FragmentC &d,
970FragmentA const &a,
971FragmentB const &b,
972FragmentC const &c
973 ) const {
974
975 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
976
977unsigned const & A = reinterpret_cast<unsigned const &>(a);
978unsigned const & B = reinterpret_cast<unsigned const &>(b);
979
980int const *C = reinterpret_cast<int const *>(&c);
981int *D = reinterpret_cast<int *>(&d);
982
983asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
984 : "=r"(D[0]), "=r"(D[1])
985 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
986
987 #else
988 assert(0);
989 #endif
990 }
991 };
992
994 template <>
996 gemm::GemmShape<8,8,32>,
997 32,
998int4b_t,
999layout::RowMajor,
1000uint4b_t,
1001layout::ColumnMajor,
1002 int,
1003layout::RowMajor,
1004 OpMultiplyAddSaturate> {
1005
1006using Shape = gemm::GemmShape<8,8,32>;
1007
1009using LayoutA = layout::RowMajor;
1010using FragmentA = Array<int4b_t, 8>;
1011
1012using ElementB = uint4b_t;
1013using LayoutB = layout::ColumnMajor;
1014using FragmentB = Array<uint4b_t, 8>;
1015
1017using LayoutC = layout::RowMajor;
1018using FragmentC = Array<int, 2>;
1019
1020using Operator = OpMultiplyAddSaturate;
1021
1024void operator()(
1025FragmentC &d,
1026FragmentA const &a,
1027FragmentB const &b,
1028FragmentC const &c
1029 ) const {
1030
1031 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
1032
1033unsigned const & A = reinterpret_cast<unsigned const &>(a);
1034unsigned const & B = reinterpret_cast<unsigned const &>(b);
1035
1036int const *C = reinterpret_cast<int const *>(&c);
1037int *D = reinterpret_cast<int *>(&d);
1038
1039asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
1040 : "=r"(D[0]), "=r"(D[1])
1041 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
1042
1043 #else
1044 assert(0);
1045 #endif
1046 }
1047 };
1048
1050 template <>
1052 gemm::GemmShape<8,8,32>,
1053 32,
1054uint4b_t,
1055layout::RowMajor,
1056uint4b_t,
1057layout::ColumnMajor,
1058 int,
1059layout::RowMajor,
1060 OpMultiplyAddSaturate> {
1061
1062using Shape = gemm::GemmShape<8,8,32>;
1063
1064using ElementA = uint4b_t;
1065using LayoutA = layout::RowMajor;
1066using FragmentA = Array<uint4b_t, 8>;
1067
1068using ElementB = uint4b_t;
1069using LayoutB = layout::ColumnMajor;
1070using FragmentB = Array<uint4b_t, 8>;
1071
1073using LayoutC = layout::RowMajor;
1074using FragmentC = Array<int, 2>;
1075
1076using Operator = OpMultiplyAddSaturate;
1077
1080void operator()(
1081FragmentC &d,
1082FragmentA const &a,
1083FragmentB const &b,
1084FragmentC const &c
1085 ) const {
1086
1087 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
1088
1089unsigned const & A = reinterpret_cast<unsigned const &>(a);
1090unsigned const & B = reinterpret_cast<unsigned const &>(b);
1091
1092int const *C = reinterpret_cast<int const *>(&c);
1093int *D = reinterpret_cast<int *>(&d);
1094
1095asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
1096 : "=r"(D[0]), "=r"(D[1])
1097 : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
1098
1099 #else
1100 assert(0);
1101 #endif
1102 }
1103 };
1104
1106 //
1107 // b1 ^ b1 + s32 => s32
1108 //
1110
1112 template <>
1114 gemm::GemmShape<8,8,128>,
1115 32,
1116uint1b_t,
1117layout::RowMajor,
1118uint1b_t,
1119layout::ColumnMajor,
1120 int,
1121layout::RowMajor,
1122 OpXorPopc> {
1123
1124using Shape = gemm::GemmShape<8,8,128>;
1125
1126using ElementA = uint1b_t;
1127using LayoutA = layout::RowMajor;
1128using FragmentA = Array<uint1b_t, 32>;
1129
1130using ElementB = uint1b_t;
1131using LayoutB = layout::ColumnMajor;
1132using FragmentB = Array<uint1b_t, 32>;
1133
1135using LayoutC = layout::RowMajor;
1136using FragmentC = Array<int, 2>;
1137
1138using Operator = OpXorPopc;
1139
1142void operator()(
1143FragmentC &d,
1144FragmentA const &a,
1145FragmentB const &b,
1146FragmentC const &c
1147 ) const {
1148
1149 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
1150
1151 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
1152using WmmaFragmentA = nvcuda::wmma::fragment<
1153 nvcuda::wmma::matrix_a,
1154 Shape::kM,
1155 Shape::kN,
1156 Shape::kK,
1157 nvcuda::wmma::experimental::precision::b1,
1158 nvcuda::wmma::row_major>;
1159
1160using WmmaFragmentB = nvcuda::wmma::fragment<
1161 nvcuda::wmma::matrix_b,
1162 Shape::kM,
1163 Shape::kN,
1164 Shape::kK,
1165 nvcuda::wmma::experimental::precision::b1,
1166 nvcuda::wmma::col_major>;
1167
1168using WmmaFragmentC = nvcuda::wmma::fragment<
1169 nvcuda::wmma::accumulator,
1170 Shape::kM,
1171 Shape::kN,
1172 Shape::kK,
1173int>;
1174
1175 WmmaFragmentA const & A = reinterpret_cast<WmmaFragmentA const &>(a);
1176 WmmaFragmentB const & B = reinterpret_cast<WmmaFragmentB const &>(b);
1177
1178 WmmaFragmentC const & C = reinterpret_cast<WmmaFragmentC const &>(c);
1179 WmmaFragmentC & D = reinterpret_cast<WmmaFragmentC &>(d);
1180
1181 nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
1182 nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
1183 #else
1184
1185 assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions.
1186
1187 #endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
1188
1189 #else
1190 assert(0);
1191 #endif
1192
1193 }
1194 };
1195
1197
1198 } // namespace arch
1199 } // namespace cutlass
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:794
OpMultiplyAdd Operator
Definition: mma_sm75.h:217
uint8_t ElementA
Definition: mma_sm75.h:492
integer_subbyte< 4, false > uint4b_t
4-bit Unsigned integer type
Definition: integer_subbyte.h:158
OpMultiplyAdd Operator
Definition: mma_sm75.h:734
Array< uint4b_t, 8 > FragmentB
Definition: mma_sm75.h:1070
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Array< int8_t, 4 > FragmentB
Definition: mma_sm75.h:211
Array< uint8_t, 4 > FragmentB
Definition: mma_sm75.h:610
Definition: aligned_buffer.h:35
int ElementC
Definition: mma_sm75.h:269
OpMultiplyAdd Operator
Definition: mma_sm75.h:846
OpMultiplyAddSaturate Operator
Definition: mma_sm75.h:616
int ElementC
Definition: mma_sm75.h:500
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:277
OpMultiplyAdd Operator
Definition: mma_sm75.h:329
Array< int, 2 > FragmentC
Definition: mma_sm75.h:676
Array< uint8_t, 4 > FragmentB
Definition: mma_sm75.h:323
Array< uint8_t, 4 > FragmentB
Definition: mma_sm75.h:554
int ElementC
Definition: mma_sm75.h:382
integer_subbyte< 1, false > uint1b_t
1-bit Unsigned integer type
Definition: integer_subbyte.h:152
int ElementC
Definition: mma_sm75.h:1016
Array< int, 2 > FragmentC
Definition: mma_sm75.h:446
uint8_t ElementB
Definition: mma_sm75.h:378
int8_t ElementA
Definition: mma_sm75.h:205
int8_t ElementA
Definition: mma_sm75.h:548
Array< int8_t, 4 > FragmentB
Definition: mma_sm75.h:267
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:968
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:1024
Array< int8_t, 4 > FragmentA
Definition: mma_sm75.h:207
Array< half_t, 2 > FragmentB
Definition: mma_sm75.h:150
Array< int8_t, 4 > FragmentA
Definition: mma_sm75.h:319
4-bit signed integer type
Definition: integer_subbyte.h:42
IEEE half-precision floating-point type.
Definition: half.h:126
int ElementC
Definition: mma_sm75.h:730
int ElementC
Definition: mma_sm75.h:444
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:564
int ElementC
Definition: mma_sm75.h:904
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:1142
Array< int4b_t, 8 > FragmentA
Definition: mma_sm75.h:780
int ElementC
Definition: mma_sm75.h:1134
int ElementC
Definition: mma_sm75.h:842
OpMultiplyAddSaturate Operator
Definition: mma_sm75.h:964
Array< uint8_t, 4 > FragmentB
Definition: mma_sm75.h:380
OpMultiplyAddSaturate Operator
Definition: mma_sm75.h:560
int8_t ElementB
Definition: mma_sm75.h:496
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:912
uint8_t ElementB
Definition: mma_sm75.h:608
Array< int, 2 > FragmentC
Definition: mma_sm75.h:1018
Array< int4b_t, 8 > FragmentB
Definition: mma_sm75.h:958
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
int ElementC
Definition: mma_sm75.h:612
int8_t ElementB
Definition: mma_sm75.h:209
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:333
uint8_t ElementA
Definition: mma_sm75.h:261
Array< int8_t, 4 > FragmentB
Definition: mma_sm75.h:498
float ElementC
Definition: mma_sm75.h:152
int8_t ElementB
Definition: mma_sm75.h:440
Templates exposing architecture support for multiply-add operations.
Array< int8_t, 4 > FragmentA
Definition: mma_sm75.h:550
uint8_t ElementB
Definition: mma_sm75.h:321
uint8_t ElementA
Definition: mma_sm75.h:374
Array< int, 2 > FragmentC
Definition: mma_sm75.h:558
Array< uint4b_t, 8 > FragmentB
Definition: mma_sm75.h:1014
Array< int4b_t, 8 > FragmentA
Definition: mma_sm75.h:1010
OpMultiplyAdd Operator
Definition: mma_sm75.h:156
OpMultiplyAdd Operator
Definition: mma_sm75.h:95
Array< uint4b_t, 8 > FragmentB
Definition: mma_sm75.h:840
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:508
int ElementC
Definition: mma_sm75.h:556
int ElementC
Definition: mma_sm75.h:674
OpMultiplyAdd Operator
Definition: mma_sm75.h:273
Array< int, 2 > FragmentC
Definition: mma_sm75.h:327
Array< int, 2 > FragmentC
Definition: mma_sm75.h:384
OpXorPopc Operator
Definition: mma_sm75.h:1138
Array< float, 4 > FragmentC
Definition: mma_sm75.h:154
Array< uint8_t, 4 > FragmentA
Definition: mma_sm75.h:606
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:682
int8_t ElementA
Definition: mma_sm75.h:436
Array< int, 2 > FragmentC
Definition: mma_sm75.h:732
Array< half_t, 4 > FragmentC
Definition: mma_sm75.h:93
Array< int, 2 > FragmentC
Definition: mma_sm75.h:215
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:221
uint8_t ElementB
Definition: mma_sm75.h:552
Array< int8_t, 4 > FragmentA
Definition: mma_sm75.h:438
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
OpMultiplyAddSaturate Operator
Definition: mma_sm75.h:1020
Top-level include for all CUTLASS numeric types.
OpMultiplyAddSaturate Operator
Definition: mma_sm75.h:1076
Array< half_t, 4 > FragmentA
Definition: mma_sm75.h:146
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:1080
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Array< int4b_t, 8 > FragmentB
Definition: mma_sm75.h:728
int8_t ElementB
Definition: mma_sm75.h:265
Array< int, 2 > FragmentC
Definition: mma_sm75.h:502
int8_t ElementA
Definition: mma_sm75.h:317
Array< uint4b_t, 8 > FragmentA
Definition: mma_sm75.h:954
Array< int, 2 > FragmentC
Definition: mma_sm75.h:1136
Array< int4b_t, 8 > FragmentA
Definition: mma_sm75.h:668
int ElementC
Definition: mma_sm75.h:1072
OpMultiplyAddSaturate Operator
Definition: mma_sm75.h:908
Array< uint4b_t, 8 > FragmentA
Definition: mma_sm75.h:724
Array< uint1b_t, 32 > FragmentA
Definition: mma_sm75.h:1128
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Array< uint4b_t, 8 > FragmentA
Definition: mma_sm75.h:1066
Array< int, 2 > FragmentC
Definition: mma_sm75.h:962
Array< uint8_t, 4 > FragmentA
Definition: mma_sm75.h:376
Array< uint4b_t, 8 > FragmentA
Definition: mma_sm75.h:836
Array< int, 2 > FragmentC
Definition: mma_sm75.h:844
Array< int, 2 > FragmentC
Definition: mma_sm75.h:906
OpMultiplyAdd Operator
Definition: mma_sm75.h:678
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:620
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:738
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:160
uint8_t ElementA
Definition: mma_sm75.h:604
int ElementC
Definition: mma_sm75.h:325
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:850
Defines layout functions used by TensorRef and derived classes.
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:452
OpMultiplyAddSaturate Operator
Definition: mma_sm75.h:448
Array< int8_t, 4 > FragmentB
Definition: mma_sm75.h:442
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:390
Array< uint8_t, 4 > FragmentA
Definition: mma_sm75.h:263
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Definition: mma_sm75.h:98
Matrix multiply-add operation.
Definition: arch/mma.h:92
Array< int4b_t, 8 > FragmentA
Definition: mma_sm75.h:898
Array< int4b_t, 8 > FragmentB
Definition: mma_sm75.h:902
Array< uint8_t, 4 > FragmentA
Definition: mma_sm75.h:494
OpMultiplyAdd Operator
Definition: mma_sm75.h:386
Array< uint1b_t, 32 > FragmentB
Definition: mma_sm75.h:1132
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Array< half_t, 2 > FragmentB
Definition: mma_sm75.h:89
Array< half_t, 4 > FragmentA
Definition: mma_sm75.h:85
Array< int, 2 > FragmentC
Definition: mma_sm75.h:1074
Array< int, 2 > FragmentC
Definition: mma_sm75.h:614
Array< uint4b_t, 8 > FragmentB
Definition: mma_sm75.h:784
integer_subbyte< 4, true > int4b_t
4-bit Integer type
Definition: integer_subbyte.h:155
OpMultiplyAddSaturate Operator
Definition: mma_sm75.h:504
Array< int4b_t, 8 > FragmentB
Definition: mma_sm75.h:672
Array< int, 2 > FragmentC
Definition: mma_sm75.h:271
int ElementC
Definition: mma_sm75.h:786
Array< int, 2 > FragmentC
Definition: mma_sm75.h:788
OpMultiplyAdd Operator
Definition: mma_sm75.h:790
int ElementC
Definition: mma_sm75.h:213
int ElementC
Definition: mma_sm75.h:960
<!-- fragment --> <!-- contents --><!-- start footer part -->