docs/arch_2mma__sm50_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
arch/mma_sm50.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/arch/mma.h"
32 #include "cutlass/complex.h"
33
34 #include "cutlass/layout/matrix.h"
35 #include "cutlass/gemm/gemm.h"
36
38
39 namespace cutlass {
40 namespace arch {
41
43
45 template <
47typename LayoutA,
49typename LayoutB,
51typename LayoutC
52 >
53 struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {
54
55using Shape = gemm::GemmShape<1, 1, 1>;
56
58void operator()(
59 Array<float, 1> &d,
60 Array<float, 1> const &a,
61 Array<float, 1> const &b,
62 Array<float, 1> const &c
63 ) {
64 d[0] = a[0] * b[0] + c[0];
65 }
66 };
67
69
71 template <
73typename LayoutA,
75typename LayoutB,
77typename LayoutC
78 >
79 struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> {
80
81using Shape = gemm::GemmShape<1, 1, 1>;
82
84void operator()(
85 Array<double, 1> &d,
86 Array<double, 1> const &a,
87 Array<double, 1> const &b,
88 Array<double, 1> const &c
89 ) {
90
91 d[0] = a[0] * b[0] + c[0];
92 }
93 };
94
96
98 template <
100typename LayoutA,
102typename LayoutB,
104typename LayoutC
105 >
106 struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> {
107
108using Shape = gemm::GemmShape<1, 1, 1>;
109
111void operator()(
112 Array<int, 1> &d,
113 Array<int, 1> const &a,
114 Array<int, 1> const &b,
115 Array<int, 1> const &c
116 ) {
117
118 d[0] = a[0] * b[0] + c[0];
119 }
120 };
121
123
125 template <
127typename LayoutA,
129typename LayoutB,
131typename LayoutC
132 >
134 gemm::GemmShape<1, 1, 1>,
135 1,
136complex<float>,
137 LayoutA,
138complex<float>,
139 LayoutB,
140complex<float>,
141 LayoutC,
142 OpMultiplyAdd> {
143
144using Shape = gemm::GemmShape<1, 1, 1>;
145
147void operator()(
148 Array<complex<float>, 1> &d,
149 Array<complex<float>, 1> const &a,
150 Array<complex<float>, 1> const &b,
151 Array<complex<float>, 1> const &c
152 ) {
153
154 d[0].real() = a[0].real() * b[0].real() + c[0].real();
155 d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
156 d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
157 d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
158 }
159 };
160
162
164 template <
166typename LayoutA,
168typename LayoutB,
170typename LayoutC
171 >
173 gemm::GemmShape<1, 1, 1>,
174 1,
175complex<float>,
176 LayoutA,
177 float,
178 LayoutB,
179complex<float>,
180 LayoutC,
181 OpMultiplyAdd> {
182
183using Shape = gemm::GemmShape<1, 1, 1>;
184
186void operator()(
187 Array<complex<float>, 1> &d,
188 Array<complex<float>, 1> const &a,
189 Array<float, 1> const &b,
190 Array<complex<float>, 1> const &c
191 ) {
192
193 d[0].real() = a[0].real() * b[0] + c[0].real();
194 d[0].imag() = a[0].imag() * b[0] + c[0].imag();
195 }
196 };
197
199
201 template <
203typename LayoutA,
205typename LayoutB,
207typename LayoutC
208 >
210 gemm::GemmShape<1, 1, 1>,
211 1,
212 float,
213 LayoutA,
214complex<float>,
215 LayoutB,
216complex<float>,
217 LayoutC,
218 OpMultiplyAdd> {
219
220using Shape = gemm::GemmShape<1, 1, 1>;
221
223void operator()(
224 Array<complex<float>, 1> &d,
225 Array<float, 1> const &a,
226 Array<complex<float>, 1> const &b,
227 Array<complex<float>, 1> const &c
228 ) {
229
230 d[0].real() = a[0] * b[0].real() + c[0].real();
231 d[0].imag() = a[0] * b[0].imag() + d[0].imag();
232 }
233 };
234
236
238 template <
240typename LayoutA,
242typename LayoutB,
244typename LayoutC
245 >
247 gemm::GemmShape<1, 1, 1>,
248 1,
249complex<double>,
250 LayoutA,
251complex<double>,
252 LayoutB,
253complex<double>,
254 LayoutC,
255 OpMultiplyAdd> {
256
257using Shape = gemm::GemmShape<1, 1, 1>;
258
260void operator()(
261 Array<complex<double>, 1> &d,
262 Array<complex<double>, 1> const &a,
263 Array<complex<double>, 1> const &b,
264 Array<complex<double>, 1> const &c
265 ) {
266
267 d[0].real() = a[0].real() * b[0].real() + c[0].real();
268 d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
269 d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
270 d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
271 }
272 };
273
275 template <
277typename LayoutA,
279typename LayoutB,
281typename LayoutC
282 >
284 gemm::GemmShape<1, 1, 1>,
285 1,
286complex<double>,
287 LayoutA,
288 double,
289 LayoutB,
290complex<double>,
291 LayoutC,
292 OpMultiplyAdd> {
293
294using Shape = gemm::GemmShape<1, 1, 1>;
295
297void operator()(
298 Array<complex<double>, 1> &d,
299 Array<complex<double>, 1> const &a,
300 Array<double, 1> const &b,
301 Array<complex<double>, 1> const &c
302 ) {
303
304 d[0].real() = a[0].real() * b[0] + c[0].real();
305 d[0].imag() = a[0].imag() * b[0] + c[0].imag();
306 }
307 };
308
310 template <
312typename LayoutA,
314typename LayoutB,
316typename LayoutC
317 >
319 gemm::GemmShape<1, 1, 1>,
320 1,
321 double,
322 LayoutA,
323complex<double>,
324 LayoutB,
325complex<double>,
326 LayoutC,
327 OpMultiplyAdd> {
328
329using Shape = gemm::GemmShape<1, 1, 1>;
330
332void operator()(
333 Array<complex<double>, 1> &d,
334 Array<double, 1> const &a,
335 Array<complex<double>, 1> const &b,
336 Array<complex<double>, 1> const &c
337 ) {
338
339 d[0].real() = a[0] * b[0].real() + c[0].real();
340 d[0].imag() = a[0] * b[0].imag() + d[0].imag();
341 }
342 };
343
345
347 template <
349typename LayoutA,
351typename LayoutB,
353typename LayoutC
354 >
355 struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> {
356
357using Shape = gemm::GemmShape<1, 1, 1>;
358
360void operator()(
361 Array<float, 1> &d,
362 Array<half_t, 1> const &a,
363 Array<half_t, 1> const &b,
364 Array<float, 1> const &c
365 ) {
366 d[0] = float(a[0]) * float(b[0]) + c[0];
367 }
368 };
369
371
372 }
373 }
CUTLASS_HOST_DEVICE void operator()(Array< int, 1 > &d, Array< int, 1 > const &a, Array< int, 1 > const &b, Array< int, 1 > const &c)
Definition: arch/mma_sm50.h:111
CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< complex< double >, 1 > const &a, Array< double, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:297
Definition: aligned_buffer.h:35
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< complex< float >, 1 > const &a, Array< complex< float >, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:147
CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< complex< double >, 1 > const &a, Array< complex< double >, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:260
CUTLASS_HOST_DEVICE void operator()(Array< float, 1 > &d, Array< half_t, 1 > const &a, Array< half_t, 1 > const &b, Array< float, 1 > const &c)
Definition: arch/mma_sm50.h:360
Templates exposing architecture support for multiply-add operations.
CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< complex< float >, 1 > const &a, Array< float, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:186
CUTLASS_HOST_DEVICE void operator()(Array< float, 1 > &d, Array< float, 1 > const &a, Array< float, 1 > const &b, Array< float, 1 > const &c)
Definition: arch/mma_sm50.h:58
CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< float, 1 > const &a, Array< complex< float >, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:223
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Definition: complex.h:92
Defines layout functions used by TensorRef and derived classes.
Matrix multiply-add operation.
Definition: arch/mma.h:92
CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< double, 1 > const &a, Array< complex< double >, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:332
CUTLASS_HOST_DEVICE void operator()(Array< double, 1 > &d, Array< double, 1 > const &a, Array< double, 1 > const &b, Array< double, 1 > const &c)
Definition: arch/mma_sm50.h:84
Generated by 1.8.11