docs/mma__sm70_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_sm70.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
28 #pragma once
29
30 #include <assert.h>
31
32 #include "mma.h"
33 #include "cutlass/layout/matrix.h"
34 #include "cutlass/numeric_types.h"
35
36 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))
37 #define CUTLASS_ARCH_MMA_SM70_SUPPORTED
38 #endif
39
40 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
41
42 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1))
43 #define CUTLASS_ARCH_MMA_SM70_ENABLED
44 #endif
45
46 #endif
47
49
50 namespace cutlass {
51 namespace arch {
52
54 //
55 // Matrix multiply accumulate 884 - FP16 accumulation
56 //
58
60 template <>
62 gemm::GemmShape<8,8,4>,
63 8,
64half_t,
66half_t,
68half_t,
70 OpMultiplyAdd> {
71
72using Shape = gemm::GemmShape<8, 8, 4>;
73
75using LayoutA = layout::ColumnMajor;
76using FragmentA = Array<half_t, 4>;
77
79using LayoutB = layout::ColumnMajor;
80using FragmentB = Array<half_t, 4>;
81
83using LayoutC = layout::RowMajor;
84using FragmentC = Array<half_t, 8>;
85
86using Operator = OpMultiplyAdd;
87
89void operator()(
90FragmentC &d,
91FragmentA const &a,
92FragmentB const &b,
93FragmentC const &c
94 ) {
95
96 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
97
98unsigned const *A = reinterpret_cast<unsigned const *>(&a);
99unsigned const *B = reinterpret_cast<unsigned const *>(&b);
100unsigned const *C = reinterpret_cast<unsigned const *>(&c);
101unsigned *D = reinterpret_cast<unsigned *>(&d);
102
103asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
104 : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
105 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
106 );
107
108 #else
109 assert(0);
110 #endif
111 }
112 };
113
115 template <>
117 gemm::GemmShape<8, 8, 4>,
118 8,
119half_t,
121half_t,
122layout::RowMajor,
123half_t,
124layout::RowMajor,
125 OpMultiplyAdd> {
126
127using Shape = gemm::GemmShape<8, 8, 4>;
128
130using LayoutA = layout::ColumnMajor;
131using FragmentA = Array<half_t, 4>;
132
134using LayoutB = layout::RowMajor;
135using FragmentB = Array<half_t, 4>;
136
138using LayoutC = layout::RowMajor;
139using FragmentC = Array<half_t, 8>;
140
141using Operator = OpMultiplyAdd;
142
144void operator()(
145FragmentC &d,
146FragmentA const &a,
147FragmentB const &b,
148FragmentC const &c
149 ) {
150
151 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
152
153unsigned const *A = reinterpret_cast<unsigned const *>(&a);
154unsigned const *B = reinterpret_cast<unsigned const *>(&b);
155unsigned const *C = reinterpret_cast<unsigned const *>(&c);
156unsigned *D = reinterpret_cast<unsigned *>(&d);
157
158asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
159 : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
160 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
161 );
162
163 #else
164 assert(0);
165 #endif
166 }
167 };
168
170 template <>
172 gemm::GemmShape<8, 8, 4>,
173 8,
174half_t,
175layout::RowMajor,
176half_t,
178half_t,
179layout::RowMajor,
180 OpMultiplyAdd> {
181
182using Shape = gemm::GemmShape<8, 8, 4>;
183
185using LayoutA = layout::RowMajor;
186using FragmentA = Array<half_t, 4>;
187
189using LayoutB = layout::ColumnMajor;
190using FragmentB = Array<half_t, 4>;
191
193using LayoutC = layout::RowMajor;
194using FragmentC = Array<half_t, 8>;
195
196using Operator = OpMultiplyAdd;
197
199void operator()(
200FragmentC &d,
201FragmentA const &a,
202FragmentB const &b,
203FragmentC const &c
204 ) {
205
206 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
207
208unsigned const *A = reinterpret_cast<unsigned const *>(&a);
209unsigned const *B = reinterpret_cast<unsigned const *>(&b);
210unsigned const *C = reinterpret_cast<unsigned const *>(&c);
211unsigned *D = reinterpret_cast<unsigned *>(&d);
212
213asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
214 : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
215 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
216 );
217
218 #else
219 assert(0);
220 #endif
221 }
222 };
223
225 template <>
227 gemm::GemmShape<8, 8, 4>,
228 8,
229half_t,
230layout::RowMajor,
231half_t,
232layout::RowMajor,
233half_t,
234layout::RowMajor,
235 OpMultiplyAdd> {
236
237using Shape = gemm::GemmShape<8, 8, 4>;
238
240using LayoutA = layout::RowMajor;
241using FragmentA = Array<half_t, 4>;
242
244using LayoutB = layout::RowMajor;
245using FragmentB = Array<half_t, 4>;
246
248using LayoutC = layout::RowMajor;
249using FragmentC = Array<half_t, 8>;
250
251using Operator = OpMultiplyAdd;
252
254void operator()(
255FragmentC &d,
256FragmentA const &a,
257FragmentB const &b,
258FragmentC const &c
259 ) {
260
261 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
262
263unsigned const *A = reinterpret_cast<unsigned const *>(&a);
264unsigned const *B = reinterpret_cast<unsigned const *>(&b);
265unsigned const *C = reinterpret_cast<unsigned const *>(&c);
266unsigned *D = reinterpret_cast<unsigned *>(&d);
267
268asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
269 : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
270 : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
271 );
272
273 #else
274 assert(0);
275 #endif
276 }
277 };
278
280 //
281 // Matrix multiply accumulate 884 - FP32 accumulation
282 //
284
286 template <>
288 gemm::GemmShape<8, 8, 4>,
289 8,
290half_t,
292half_t,
294 float,
295layout::RowMajor,
296 OpMultiplyAdd> {
297
298using Shape = gemm::GemmShape<8, 8, 4>;
299
301using LayoutA = layout::ColumnMajor;
302using FragmentA = Array<half_t, 4>;
303
305using LayoutB = layout::ColumnMajor;
306using FragmentB = Array<half_t, 4>;
307
309using LayoutC = layout::RowMajor;
310using FragmentC = Array<float, 8>;
311
312using Operator = OpMultiplyAdd;
313
316void operator()(
317FragmentC &d,
318FragmentA const &a,
319FragmentB const &b,
320FragmentC const &c
321 ) {
322
323 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
324
325unsigned const *A = reinterpret_cast<unsigned const *>(&a);
326unsigned const *B = reinterpret_cast<unsigned const *>(&b);
327float const *C = reinterpret_cast<float const *>(&c);
328float *D = reinterpret_cast<float *>(&d);
329
330asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
331"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
332 : "=f"(D[0]),
333"=f"(D[1]),
334"=f"(D[2]),
335"=f"(D[3]),
336"=f"(D[4]),
337"=f"(D[5]),
338"=f"(D[6]),
339"=f"(D[7])
340 : "r"(A[0]),
341"r"(A[1]),
342"r"(B[0]),
343"r"(B[1]),
344"f"(C[0]),
345"f"(C[1]),
346"f"(C[2]),
347"f"(C[3]),
348"f"(C[4]),
349"f"(C[5]),
350"f"(C[6]),
351"f"(C[7])
352 );
353
354 #else
355 assert(0);
356 #endif
357 }
358 };
359
361 template <>
363 gemm::GemmShape<8, 8, 4>,
364 8,
365half_t,
367half_t,
368layout::RowMajor,
369 float,
370layout::RowMajor,
371 OpMultiplyAdd> {
372
373using Shape = gemm::GemmShape<8, 8, 4>;
374
376using LayoutA = layout::ColumnMajor;
377using FragmentA = Array<half_t, 4>;
378
380using LayoutB = layout::RowMajor;
381using FragmentB = Array<half_t, 4>;
382
384using LayoutC = layout::RowMajor;
385using FragmentC = Array<float, 8>;
386
387using Operator = OpMultiplyAdd;
388
391void operator()(
392FragmentC &d,
393FragmentA const &a,
394FragmentB const &b,
395FragmentC const &c
396 ) {
397
398 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
399
400unsigned const *A = reinterpret_cast<unsigned const *>(&a);
401unsigned const *B = reinterpret_cast<unsigned const *>(&b);
402float const *C = reinterpret_cast<float const *>(&c);
403float *D = reinterpret_cast<float *>(&d);
404
405asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
406"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
407 : "=f"(D[0]),
408"=f"(D[1]),
409"=f"(D[2]),
410"=f"(D[3]),
411"=f"(D[4]),
412"=f"(D[5]),
413"=f"(D[6]),
414"=f"(D[7])
415 : "r"(A[0]),
416"r"(A[1]),
417"r"(B[0]),
418"r"(B[1]),
419"f"(C[0]),
420"f"(C[1]),
421"f"(C[2]),
422"f"(C[3]),
423"f"(C[4]),
424"f"(C[5]),
425"f"(C[6]),
426"f"(C[7])
427 );
428
429 #else
430 assert(0);
431 #endif
432 }
433 };
434
436 template <>
438 gemm::GemmShape<8, 8, 4>,
439 8,
440half_t,
441layout::RowMajor,
442half_t,
444 float,
445layout::RowMajor,
446 OpMultiplyAdd> {
447
448using Shape = gemm::GemmShape<8, 8, 4>;
449
451using LayoutA = layout::RowMajor;
452using FragmentA = Array<half_t, 4>;
453
455using LayoutB = layout::ColumnMajor;
456using FragmentB = Array<half_t, 4>;
457
459using LayoutC = layout::RowMajor;
460using FragmentC = Array<float, 8>;
461
462using Operator = OpMultiplyAdd;
463
466void operator()(
467FragmentC &d,
468FragmentA const &a,
469FragmentB const &b,
470FragmentC const &c
471 ) {
472
473 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
474
475unsigned const *A = reinterpret_cast<unsigned const *>(&a);
476unsigned const *B = reinterpret_cast<unsigned const *>(&b);
477float const *C = reinterpret_cast<float const *>(&c);
478float *D = reinterpret_cast<float *>(&d);
479
480asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
481"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
482 : "=f"(D[0]),
483"=f"(D[1]),
484"=f"(D[2]),
485"=f"(D[3]),
486"=f"(D[4]),
487"=f"(D[5]),
488"=f"(D[6]),
489"=f"(D[7])
490 : "r"(A[0]),
491"r"(A[1]),
492"r"(B[0]),
493"r"(B[1]),
494"f"(C[0]),
495"f"(C[1]),
496"f"(C[2]),
497"f"(C[3]),
498"f"(C[4]),
499"f"(C[5]),
500"f"(C[6]),
501"f"(C[7])
502 );
503
504 #else
505 assert(0);
506 #endif
507 }
508 };
509
511 template <>
513 gemm::GemmShape<8, 8, 4>,
514 8,
515half_t,
516layout::RowMajor,
517half_t,
518layout::RowMajor,
519 float,
520layout::RowMajor,
521 OpMultiplyAdd> {
522
523using Shape = gemm::GemmShape<8, 8, 4>;
524
526using LayoutA = layout::RowMajor;
527using FragmentA = Array<half_t, 4>;
528
530using LayoutB = layout::RowMajor;
531using FragmentB = Array<half_t, 4>;
532
534using LayoutC = layout::RowMajor;
535using FragmentC = Array<float, 8>;
536
537using Operator = OpMultiplyAdd;
538
541void operator()(
542FragmentC &d,
543FragmentA const &a,
544FragmentB const &b,
545FragmentC const &c
546 ) {
547
548 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
549
550unsigned const *A = reinterpret_cast<unsigned const *>(&a);
551unsigned const *B = reinterpret_cast<unsigned const *>(&b);
552float const *C = reinterpret_cast<float const *>(&c);
553float *D = reinterpret_cast<float *>(&d);
554
555asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
556"{%12,%13,%14,%15,%16,%17,%18,%19};\n"
557 : "=f"(D[0]),
558"=f"(D[1]),
559"=f"(D[2]),
560"=f"(D[3]),
561"=f"(D[4]),
562"=f"(D[5]),
563"=f"(D[6]),
564"=f"(D[7])
565 : "r"(A[0]),
566"r"(A[1]),
567"r"(B[0]),
568"r"(B[1]),
569"f"(C[0]),
570"f"(C[1]),
571"f"(C[2]),
572"f"(C[3]),
573"f"(C[4]),
574"f"(C[5]),
575"f"(C[6]),
576"f"(C[7])
577 );
578
579 #else
580 assert(0);
581 #endif
582 }
583 };
584
586
588 template <
589typename LayoutA,
590typename LayoutB,
591typename ElementC,
592typename LayoutC,
593typename Operator
594 >
596 gemm::GemmShape<16, 16, 4>,
597 32,
598half_t,
599 LayoutA,
600half_t,
601 LayoutB,
602 ElementC,
603 LayoutC,
604 Operator
605 > :
606public Mma<
607 gemm::GemmShape<8, 8, 4>,
608 8,
609 half_t,
610 LayoutA,
611 half_t,
612 LayoutB,
613 ElementC,
614 LayoutC,
615 Operator> {
616
617using Shape = gemm::GemmShape<16, 16, 4>;
618 };
619
621
622 } // namespace arch
623 } // namespace cutlass
Array< half_t, 8 > FragmentC
Definition: mma_sm70.h:84
Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:245
Definition: aligned_buffer.h:35
Array< float, 8 > FragmentC
Definition: mma_sm70.h:535
Array< half_t, 8 > FragmentC
Definition: mma_sm70.h:194
float ElementC
Definition: mma_sm70.h:308
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:199
OpMultiplyAdd Operator
Definition: mma_sm70.h:86
Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:131
OpMultiplyAdd Operator
Definition: mma_sm70.h:312
IEEE half-precision floating-point type.
Definition: half.h:126
Array< float, 8 > FragmentC
Definition: mma_sm70.h:310
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:391
Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:80
Array< float, 8 > FragmentC
Definition: mma_sm70.h:385
float ElementC
Definition: mma_sm70.h:458
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:89
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
float ElementC
Definition: mma_sm70.h:383
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:316
Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:241
Templates exposing architecture support for multiply-add operations.
Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:531
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:254
OpMultiplyAdd Operator
Definition: mma_sm70.h:251
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:144
Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:377
Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:306
float ElementC
Definition: mma_sm70.h:533
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:76
Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:186
Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:302
OpMultiplyAdd Operator
Definition: mma_sm70.h:196
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Array< half_t, 8 > FragmentC
Definition: mma_sm70.h:249
Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:527
Array< float, 8 > FragmentC
Definition: mma_sm70.h:460
OpMultiplyAdd Operator
Definition: mma_sm70.h:141
Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:190
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:541
Array< half_t, 8 > FragmentC
Definition: mma_sm70.h:139
Defines layout functions used by TensorRef and derived classes.
Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:135
OpMultiplyAdd Operator
Definition: mma_sm70.h:387
Matrix multiply-add operation.
Definition: arch/mma.h:92
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:466
OpMultiplyAdd Operator
Definition: mma_sm70.h:537
OpMultiplyAdd Operator
Definition: mma_sm70.h:462
Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:452
Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:456
Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:381
Generated by 1.8.11