Back to Cutlass

CUTLASS: mma_sm60.h Source File

docs/arch_2mma__sm60_8h_source.html

4.4.215.3 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

arch/mma_sm60.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include <cuda_fp16.h>

32

33 #include "cutlass/arch/mma.h"

34

35 #include "cutlass/layout/matrix.h"

36

38

39 namespace cutlass {

40 namespace arch {

41

43

45 template <typename LayoutA, typename LayoutB, typename LayoutC>

46 struct Mma<

47 gemm::GemmShape<2,1,1>,

48 1,

49half_t,

50 LayoutA,

51half_t,

52 LayoutB,

53half_t,

54 LayoutC,

55 OpMultiplyAdd> {

56

57using Shape = gemm::GemmShape<2, 1, 1>;

58

59CUTLASS_HOST_DEVICE

60void operator()(

61 Array<half_t, 2> &d,

62 Array<half_t, 2> const &a,

63 Array<half_t, 1> const &b,

64 Array<half_t, 2> const &c

65 ) {

66

67 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))

68

69 __half2 const & A = reinterpret_cast<__half2 const &>(a);

70 __half2 B = __half2half2(reinterpret_cast<__half const &>(b));

71 __half2 const & C = reinterpret_cast<__half2 const &>(c);

72

73 __half2 D = __hfma2(A, B, C);

74

75 d = reinterpret_cast<Array<half_t, 2> &>(D);

76

77 #else

78CUTLASS_PRAGMA_UNROLL

79for (int i = 0; i < 2; ++i) {

80 d[i] = a[i] * b[0] + c[i];

81 }

82 #endif

83 }

84 };

85

87

89 template <typename LayoutA, typename LayoutB>

90 struct Mma<

91 gemm::GemmShape<1,2,1>,

92 1,

93half_t,

94 LayoutA,

95half_t,

96 LayoutB,

97half_t,

98layout::RowMajor,

99 OpMultiplyAdd> {

100

101using Shape = gemm::GemmShape<1, 2, 1>;

102

103CUTLASS_HOST_DEVICE

104void operator()(

105 Array<half_t, 2> &d,

106 Array<half_t, 1> const &a,

107 Array<half_t, 2> const &b,

108 Array<half_t, 2> const &c

109 ) {

110

111 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))

112

113 __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a));

114 __half2 B = reinterpret_cast<__half2 const &>(b);

115 __half2 const & C = reinterpret_cast<__half2 const &>(c);

116

117 __half2 D = __hfma2(A, B, C);

118

119 d = reinterpret_cast<Array<half_t, 2> &>(D);

120

121 #else

122CUTLASS_PRAGMA_UNROLL

123for (int i = 0; i < 2; ++i) {

124 d[i] = a[0] * b[i] + c[i];

125 }

126 #endif

127 }

128 };

129

131

133 template <>

134 struct Mma <

135 gemm::GemmShape<2, 2, 1>,

136 1,

137half_t,

138layout::ColumnMajor,

139half_t,

140layout::RowMajor,

141half_t,

142layout::ColumnMajor,

143 OpMultiplyAdd> {

144

145using Shape = gemm::GemmShape<2, 2, 1>;

146

147CUTLASS_HOST_DEVICE

148void operator()(

149 Array<half_t, 4> &d,

150 Array<half_t, 2> const &a,

151 Array<half_t, 2> const &b,

152 Array<half_t, 4> const &c

153 ) {

154

155 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))

156

157 __half2 const & A = reinterpret_cast<__half2 const &>(a);

158 __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b));

159 __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b));

160

161 __half2 const *C = reinterpret_cast<__half2 const *>(&c);

162

163 __half2 Dlo = __hfma2(A, Blo, C[0]);

164 __half2 Dhi = __hfma2(A, Bhi, C[1]);

165

166 Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);

167

168 D[0] = reinterpret_cast<Array<half_t, 2> const &>(Dlo);

169 D[1] = reinterpret_cast<Array<half_t, 2> const &>(Dhi);

170

171 #else

172CUTLASS_PRAGMA_UNROLL

173for (int j = 0; j < 2; ++j) {

174CUTLASS_PRAGMA_UNROLL

175for (int i = 0; i < 2; ++i) {

176 d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j];

177 }

178 }

179 #endif

180 }

181 };

182

184

186 template <>

187 struct Mma<

188 gemm::GemmShape<2, 2, 1>,

189 1,

190half_t,

191layout::ColumnMajor,

192half_t,

193layout::RowMajor,

194half_t,

195layout::RowMajor,

196 OpMultiplyAdd> {

197

198using Shape = gemm::GemmShape<2, 2, 1>;

199

200CUTLASS_HOST_DEVICE

201void operator()(

202 Array<half_t, 4> &d,

203 Array<half_t, 2> const &a,

204 Array<half_t, 2> const &b,

205 Array<half_t, 4> const &c

206 ) {

207

208 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))

209

210 __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));

211 __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));

212 __half2 const & B = reinterpret_cast<__half2 const &>(b);

213

214 __half2 const *C = reinterpret_cast<__half2 const *>(&c);

215

216 __half2 Dlo = __hfma2(Alo, B, C[0]);

217 __half2 Dhi = __hfma2(Ahi, B, C[0]);

218

219 Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);

220

221 D[0] = reinterpret_cast<Array<half_t, 2> &>(Dlo);

222 D[1] = reinterpret_cast<Array<half_t, 2> &>(Dhi);

223 #else

224CUTLASS_PRAGMA_UNROLL

225for (int i = 0; i < 2; ++i) {

226CUTLASS_PRAGMA_UNROLL

227for (int j = 0; j < 2; ++j) {

228 d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j];

229 }

230 }

231 #endif

232 }

233 };

234

236

237 }

238 }

239

cutlass::arch::Mma< gemm::GemmShape< 1, 2, 1 >, 1, half_t, LayoutA, half_t, LayoutB, half_t, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 1 > const &a, Array< half_t, 2 > const &b, Array< half_t, 2 > const &c)

Definition: arch/mma_sm60.h:104

cutlass

Definition: aligned_buffer.h:35

cutlass::half_t

IEEE half-precision floating-point type.

Definition: half.h:126

cutlass::arch::Mma< gemm::GemmShape< 2, 1, 1 >, 1, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 2 > const &a, Array< half_t, 1 > const &b, Array< half_t, 2 > const &c)

Definition: arch/mma_sm60.h:60

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

mma.h

Templates exposing architecture support for multiply-add operations.

cutlass::arch::Mma< gemm::GemmShape< 2, 2, 1 >, 1, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::ColumnMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)

Definition: arch/mma_sm60.h:148

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::arch::Mma

Matrix multiply-add operation.

Definition: arch/mma.h:92

cutlass::arch::Mma< gemm::GemmShape< 2, 2, 1 >, 1, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)

Definition: arch/mma_sm60.h:201


Generated by 1.8.11