docs/arch_2mma__sm60_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
arch/mma_sm60.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include <cuda_fp16.h>
32
33 #include "cutlass/arch/mma.h"
34
35 #include "cutlass/layout/matrix.h"
36
38
39 namespace cutlass {
40 namespace arch {
41
43
45 template <typename LayoutA, typename LayoutB, typename LayoutC>
47 gemm::GemmShape<2,1,1>,
48 1,
49half_t,
50 LayoutA,
51half_t,
52 LayoutB,
53half_t,
54 LayoutC,
55 OpMultiplyAdd> {
56
57using Shape = gemm::GemmShape<2, 1, 1>;
58
60void operator()(
61 Array<half_t, 2> &d,
62 Array<half_t, 2> const &a,
63 Array<half_t, 1> const &b,
64 Array<half_t, 2> const &c
65 ) {
66
67 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
68
69 __half2 const & A = reinterpret_cast<__half2 const &>(a);
70 __half2 B = __half2half2(reinterpret_cast<__half const &>(b));
71 __half2 const & C = reinterpret_cast<__half2 const &>(c);
72
73 __half2 D = __hfma2(A, B, C);
74
75 d = reinterpret_cast<Array<half_t, 2> &>(D);
76
77 #else
79for (int i = 0; i < 2; ++i) {
80 d[i] = a[i] * b[0] + c[i];
81 }
82 #endif
83 }
84 };
85
87
89 template <typename LayoutA, typename LayoutB>
91 gemm::GemmShape<1,2,1>,
92 1,
93half_t,
94 LayoutA,
95half_t,
96 LayoutB,
97half_t,
99 OpMultiplyAdd> {
100
101using Shape = gemm::GemmShape<1, 2, 1>;
102
104void operator()(
105 Array<half_t, 2> &d,
106 Array<half_t, 1> const &a,
107 Array<half_t, 2> const &b,
108 Array<half_t, 2> const &c
109 ) {
110
111 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
112
113 __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a));
114 __half2 B = reinterpret_cast<__half2 const &>(b);
115 __half2 const & C = reinterpret_cast<__half2 const &>(c);
116
117 __half2 D = __hfma2(A, B, C);
118
119 d = reinterpret_cast<Array<half_t, 2> &>(D);
120
121 #else
123for (int i = 0; i < 2; ++i) {
124 d[i] = a[0] * b[i] + c[i];
125 }
126 #endif
127 }
128 };
129
131
133 template <>
135 gemm::GemmShape<2, 2, 1>,
136 1,
137half_t,
139half_t,
140layout::RowMajor,
141half_t,
143 OpMultiplyAdd> {
144
145using Shape = gemm::GemmShape<2, 2, 1>;
146
148void operator()(
149 Array<half_t, 4> &d,
150 Array<half_t, 2> const &a,
151 Array<half_t, 2> const &b,
152 Array<half_t, 4> const &c
153 ) {
154
155 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
156
157 __half2 const & A = reinterpret_cast<__half2 const &>(a);
158 __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b));
159 __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b));
160
161 __half2 const *C = reinterpret_cast<__half2 const *>(&c);
162
163 __half2 Dlo = __hfma2(A, Blo, C[0]);
164 __half2 Dhi = __hfma2(A, Bhi, C[1]);
165
166 Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
167
168 D[0] = reinterpret_cast<Array<half_t, 2> const &>(Dlo);
169 D[1] = reinterpret_cast<Array<half_t, 2> const &>(Dhi);
170
171 #else
173for (int j = 0; j < 2; ++j) {
175for (int i = 0; i < 2; ++i) {
176 d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j];
177 }
178 }
179 #endif
180 }
181 };
182
184
186 template <>
188 gemm::GemmShape<2, 2, 1>,
189 1,
190half_t,
192half_t,
193layout::RowMajor,
194half_t,
195layout::RowMajor,
196 OpMultiplyAdd> {
197
198using Shape = gemm::GemmShape<2, 2, 1>;
199
201void operator()(
202 Array<half_t, 4> &d,
203 Array<half_t, 2> const &a,
204 Array<half_t, 2> const &b,
205 Array<half_t, 4> const &c
206 ) {
207
208 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
209
210 __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));
211 __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));
212 __half2 const & B = reinterpret_cast<__half2 const &>(b);
213
214 __half2 const *C = reinterpret_cast<__half2 const *>(&c);
215
216 __half2 Dlo = __hfma2(Alo, B, C[0]);
217 __half2 Dhi = __hfma2(Ahi, B, C[0]);
218
219 Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
220
221 D[0] = reinterpret_cast<Array<half_t, 2> &>(Dlo);
222 D[1] = reinterpret_cast<Array<half_t, 2> &>(Dhi);
223 #else
225for (int i = 0; i < 2; ++i) {
227for (int j = 0; j < 2; ++j) {
228 d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j];
229 }
230 }
231 #endif
232 }
233 };
234
236
237 }
238 }
239
CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 1 > const &a, Array< half_t, 2 > const &b, Array< half_t, 2 > const &c)
Definition: arch/mma_sm60.h:104
Definition: aligned_buffer.h:35
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 2 > const &a, Array< half_t, 1 > const &b, Array< half_t, 2 > const &c)
Definition: arch/mma_sm60.h:60
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for multiply-add operations.
CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)
Definition: arch/mma_sm60.h:148
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes.
Matrix multiply-add operation.
Definition: arch/mma.h:92
CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)
Definition: arch/mma_sm60.h:201
Generated by 1.8.11