Back to Cutlass

CUTLASS: mma_sm50.h Source File

docs/arch_2mma__sm50_8h_source.html

4.4.225.0 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

arch/mma_sm50.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/arch/mma.h"

32 #include "cutlass/complex.h"

33

34 #include "cutlass/layout/matrix.h"

35 #include "cutlass/gemm/gemm.h"

36

38

39 namespace cutlass {

40 namespace arch {

41

43

45 template <

47typename LayoutA,

49typename LayoutB,

51typename LayoutC

52 >

53 struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {

54

55using Shape = gemm::GemmShape<1, 1, 1>;

56

57CUTLASS_HOST_DEVICE

58void operator()(

59 Array<float, 1> &d,

60 Array<float, 1> const &a,

61 Array<float, 1> const &b,

62 Array<float, 1> const &c

63 ) {

64 d[0] = a[0] * b[0] + c[0];

65 }

66 };

67

69

71 template <

73typename LayoutA,

75typename LayoutB,

77typename LayoutC

78 >

79 struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> {

80

81using Shape = gemm::GemmShape<1, 1, 1>;

82

83CUTLASS_HOST_DEVICE

84void operator()(

85 Array<double, 1> &d,

86 Array<double, 1> const &a,

87 Array<double, 1> const &b,

88 Array<double, 1> const &c

89 ) {

90

91 d[0] = a[0] * b[0] + c[0];

92 }

93 };

94

96

98 template <

100typename LayoutA,

102typename LayoutB,

104typename LayoutC

105 >

106 struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> {

107

108using Shape = gemm::GemmShape<1, 1, 1>;

109

110CUTLASS_HOST_DEVICE

111void operator()(

112 Array<int, 1> &d,

113 Array<int, 1> const &a,

114 Array<int, 1> const &b,

115 Array<int, 1> const &c

116 ) {

117

118 d[0] = a[0] * b[0] + c[0];

119 }

120 };

121

123

125 template <

127typename LayoutA,

129typename LayoutB,

131typename LayoutC

132 >

133 struct Mma<

134 gemm::GemmShape<1, 1, 1>,

135 1,

136complex<float>,

137 LayoutA,

138complex<float>,

139 LayoutB,

140complex<float>,

141 LayoutC,

142 OpMultiplyAdd> {

143

144using Shape = gemm::GemmShape<1, 1, 1>;

145

146CUTLASS_HOST_DEVICE

147void operator()(

148 Array<complex<float>, 1> &d,

149 Array<complex<float>, 1> const &a,

150 Array<complex<float>, 1> const &b,

151 Array<complex<float>, 1> const &c

152 ) {

153

154 d[0].real() = a[0].real() * b[0].real() + c[0].real();

155 d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();

156 d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();

157 d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();

158 }

159 };

160

162

164 template <

166typename LayoutA,

168typename LayoutB,

170typename LayoutC

171 >

172 struct Mma<

173 gemm::GemmShape<1, 1, 1>,

174 1,

175complex<float>,

176 LayoutA,

177 float,

178 LayoutB,

179complex<float>,

180 LayoutC,

181 OpMultiplyAdd> {

182

183using Shape = gemm::GemmShape<1, 1, 1>;

184

185CUTLASS_HOST_DEVICE

186void operator()(

187 Array<complex<float>, 1> &d,

188 Array<complex<float>, 1> const &a,

189 Array<float, 1> const &b,

190 Array<complex<float>, 1> const &c

191 ) {

192

193 d[0].real() = a[0].real() * b[0] + c[0].real();

194 d[0].imag() = a[0].imag() * b[0] + c[0].imag();

195 }

196 };

197

199

201 template <

203typename LayoutA,

205typename LayoutB,

207typename LayoutC

208 >

209 struct Mma<

210 gemm::GemmShape<1, 1, 1>,

211 1,

212 float,

213 LayoutA,

214complex<float>,

215 LayoutB,

216complex<float>,

217 LayoutC,

218 OpMultiplyAdd> {

219

220using Shape = gemm::GemmShape<1, 1, 1>;

221

222CUTLASS_HOST_DEVICE

223void operator()(

224 Array<complex<float>, 1> &d,

225 Array<float, 1> const &a,

226 Array<complex<float>, 1> const &b,

227 Array<complex<float>, 1> const &c

228 ) {

229

230 d[0].real() = a[0] * b[0].real() + c[0].real();

231 d[0].imag() = a[0] * b[0].imag() + d[0].imag();

232 }

233 };

234

236

238 template <

240typename LayoutA,

242typename LayoutB,

244typename LayoutC

245 >

246 struct Mma<

247 gemm::GemmShape<1, 1, 1>,

248 1,

249complex<double>,

250 LayoutA,

251complex<double>,

252 LayoutB,

253complex<double>,

254 LayoutC,

255 OpMultiplyAdd> {

256

257using Shape = gemm::GemmShape<1, 1, 1>;

258

259CUTLASS_HOST_DEVICE

260void operator()(

261 Array<complex<double>, 1> &d,

262 Array<complex<double>, 1> const &a,

263 Array<complex<double>, 1> const &b,

264 Array<complex<double>, 1> const &c

265 ) {

266

267 d[0].real() = a[0].real() * b[0].real() + c[0].real();

268 d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();

269 d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();

270 d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();

271 }

272 };

273

275 template <

277typename LayoutA,

279typename LayoutB,

281typename LayoutC

282 >

283 struct Mma<

284 gemm::GemmShape<1, 1, 1>,

285 1,

286complex<double>,

287 LayoutA,

288 double,

289 LayoutB,

290complex<double>,

291 LayoutC,

292 OpMultiplyAdd> {

293

294using Shape = gemm::GemmShape<1, 1, 1>;

295

296CUTLASS_HOST_DEVICE

297void operator()(

298 Array<complex<double>, 1> &d,

299 Array<complex<double>, 1> const &a,

300 Array<double, 1> const &b,

301 Array<complex<double>, 1> const &c

302 ) {

303

304 d[0].real() = a[0].real() * b[0] + c[0].real();

305 d[0].imag() = a[0].imag() * b[0] + c[0].imag();

306 }

307 };

308

310 template <

312typename LayoutA,

314typename LayoutB,

316typename LayoutC

317 >

318 struct Mma<

319 gemm::GemmShape<1, 1, 1>,

320 1,

321 double,

322 LayoutA,

323complex<double>,

324 LayoutB,

325complex<double>,

326 LayoutC,

327 OpMultiplyAdd> {

328

329using Shape = gemm::GemmShape<1, 1, 1>;

330

331CUTLASS_HOST_DEVICE

332void operator()(

333 Array<complex<double>, 1> &d,

334 Array<double, 1> const &a,

335 Array<complex<double>, 1> const &b,

336 Array<complex<double>, 1> const &c

337 ) {

338

339 d[0].real() = a[0] * b[0].real() + c[0].real();

340 d[0].imag() = a[0] * b[0].imag() + d[0].imag();

341 }

342 };

343

345

347 template <

349typename LayoutA,

351typename LayoutB,

353typename LayoutC

354 >

355 struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> {

356

357using Shape = gemm::GemmShape<1, 1, 1>;

358

359CUTLASS_HOST_DEVICE

360void operator()(

361 Array<float, 1> &d,

362 Array<half_t, 1> const &a,

363 Array<half_t, 1> const &b,

364 Array<float, 1> const &c

365 ) {

366 d[0] = float(a[0]) * float(b[0]) + c[0];

367 }

368 };

369

371

372 }

373 }

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< int, 1 > &d, Array< int, 1 > const &a, Array< int, 1 > const &b, Array< int, 1 > const &c)

Definition: arch/mma_sm50.h:111

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< double >, LayoutA, double, LayoutB, complex< double >, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< complex< double >, 1 > const &a, Array< double, 1 > const &b, Array< complex< double >, 1 > const &c)

Definition: arch/mma_sm50.h:297

cutlass

Definition: aligned_buffer.h:35

complex.h

cutlass::half_t

IEEE half-precision floating-point type.

Definition: half.h:126

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< float >, LayoutA, complex< float >, LayoutB, complex< float >, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< complex< float >, 1 > const &a, Array< complex< float >, 1 > const &b, Array< complex< float >, 1 > const &c)

Definition: arch/mma_sm50.h:147

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< double >, LayoutA, complex< double >, LayoutB, complex< double >, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< complex< double >, 1 > const &a, Array< complex< double >, 1 > const &b, Array< complex< double >, 1 > const &c)

Definition: arch/mma_sm50.h:260

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< float, 1 > &d, Array< half_t, 1 > const &a, Array< half_t, 1 > const &b, Array< float, 1 > const &c)

Definition: arch/mma_sm50.h:360

mma.h

Templates exposing architecture support for multiply-add operations.

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< float >, LayoutA, float, LayoutB, complex< float >, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< complex< float >, 1 > const &a, Array< float, 1 > const &b, Array< complex< float >, 1 > const &c)

Definition: arch/mma_sm50.h:186

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< float, 1 > &d, Array< float, 1 > const &a, Array< float, 1 > const &b, Array< float, 1 > const &c)

Definition: arch/mma_sm50.h:58

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, float, LayoutA, complex< float >, LayoutB, complex< float >, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< float, 1 > const &a, Array< complex< float >, 1 > const &b, Array< complex< float >, 1 > const &c)

Definition: arch/mma_sm50.h:223

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::complex

Definition: complex.h:92

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::arch::Mma

Matrix multiply-add operation.

Definition: arch/mma.h:92

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, double, LayoutA, complex< double >, LayoutB, complex< double >, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< double, 1 > const &a, Array< complex< double >, 1 > const &b, Array< complex< double >, 1 > const &c)

Definition: arch/mma_sm50.h:332

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< double, 1 > &d, Array< double, 1 > const &a, Array< double, 1 > const &b, Array< double, 1 > const &c)

Definition: arch/mma_sm50.h:84


Generated by 1.8.11