Back to Cutlass

CUTLASS: wmma_sm75.h Source File

docs/wmma__sm75_8h_source.html

4.4.29.4 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

wmma_sm75.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include <assert.h>

32 #include "cutlass/layout/matrix.h"

33

35 namespace cutlass {

36 namespace arch {

37

39 //

40 // WMMA template structure defines nvcuda::wmma::fragments and static assert for

41 // wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4).

42 //

44 template <

45 typename Shape_,

46 typename LayoutA_,

47 typename LayoutB_,

48 typename LayoutC_>

[49](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01cutlass_1_1int4b t_00_01LayoutA___00_01cutlass_16fd808a90b3cf9d7cfc99f30888ca3fe.html) struct Wmma<

50 Shape_,

51cutlass::int4b_t,

52 LayoutA_,

53cutlass::int4b_t,

54 LayoutB_,

55 int32_t,

56 LayoutC_,

57 cutlass::arch::OpMultiplyAdd

58 > {

59 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)

60using Shape = Shape_;

61using ElementA = cutlass::int4b_t;

62using LayoutA = LayoutA_;

63using ElementB = cutlass::int4b_t;

64using LayoutB = LayoutB_;

65using ElementC = int32_t;

66using LayoutC = LayoutC_;

67using Operator = cutlass::arch::OpMultiplyAdd;

68

69// check supported wmma shape for the given multiplicand data types

70static_assert(

71platform::is_same<cutlass::gemm::GemmShape<8, 8, 32>, Shape>::value,

72"Supported list of wmma operator shape for s8 multiplicands is: 8x8x32");

73

74

75// Wmma Fragment

76using FragmentA = nvcuda::wmma::fragment<

77 nvcuda::wmma::matrix_a,

78 Shape::kM,

79 Shape::kN,

80 Shape::kK,

81typename CutlassToWmmaDataType<ElementA>::Type,

82typename CutlassToWmmaLayout<LayoutA>::Layout>;

83

84using FragmentB = nvcuda::wmma::fragment<

85 nvcuda::wmma::matrix_b,

86 Shape::kM,

87 Shape::kN,

88 Shape::kK,

89typename CutlassToWmmaDataType<ElementB>::Type,

90typename CutlassToWmmaLayout<LayoutB>::Layout>;

91

92using FragmentC = nvcuda::wmma::fragment<

93 nvcuda::wmma::accumulator,

94 Shape::kM,

95 Shape::kN,

96 Shape::kK,

97typename CutlassToWmmaDataType<ElementC>::Type>;

98

100 CUTLASS_DEVICE

101void operator()(

102 FragmentC &D,

103 FragmentA const &A,

104 FragmentB const &B,

105 FragmentC const &C) const {

106 nvcuda::wmma::mma_sync(D, A, B, C);

107 }

108

109 #else

110static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");

111 #endif

112

113 };

114

116 //

117 // WMMA template structure defines nvcuda::wmma::fragments and static assert for

118 // wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1)

119 // (nvcuda::wmma targeting SASS instruction BMMA)

120 //

122 template <

123 typename Shape_,

124 typename LayoutA_,

125 typename LayoutB_,

126 typename LayoutC_>

[127](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01cutlass_1_1uint1b t_00_01LayoutA___00_01cutlass_c80a7ea4d219cd9b13b560b493338028.html) struct Wmma<

128 Shape_,

129cutlass::uint1b_t,

130 LayoutA_,

131cutlass::uint1b_t,

132 LayoutB_,

133 int32_t,

134 LayoutC_,

135 cutlass::arch::OpXorPopc

136 > {

137 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)

138using Shape = Shape_;

139using ElementA = cutlass::uint1b_t;

140using LayoutA = LayoutA_;

141using ElementB = cutlass::uint1b_t;

142using LayoutB = LayoutB_;

143using ElementC = int32_t;

144using LayoutC = LayoutC_;

145using Operator = cutlass::arch::OpXorPopc;

146

147// check supported wmma shape for the given multiplicand data types

148static_assert(

149platform::is_same<cutlass::gemm::GemmShape<8, 8, 128>, Shape>::value,

150"Supported list of wmma operator shape for b1 multiplicands is: 8x8x128");

151

152

153// Wmma Fragment

154using FragmentA = nvcuda::wmma::fragment<

155 nvcuda::wmma::matrix_a,

156 Shape::kM,

157 Shape::kN,

158 Shape::kK,

159typename CutlassToWmmaDataType<ElementA>::Type,

160typename CutlassToWmmaLayout<LayoutA>::Layout>;

161

162using FragmentB = nvcuda::wmma::fragment<

163 nvcuda::wmma::matrix_b,

164 Shape::kM,

165 Shape::kN,

166 Shape::kK,

167typename CutlassToWmmaDataType<ElementB>::Type,

168typename CutlassToWmmaLayout<LayoutB>::Layout>;

169

170using FragmentC = nvcuda::wmma::fragment<

171 nvcuda::wmma::accumulator,

172 Shape::kM,

173 Shape::kN,

174 Shape::kK,

175typename CutlassToWmmaDataType<ElementC>::Type>;

176

178 CUTLASS_DEVICE

179void operator()(

180 FragmentC &D,

181 FragmentA const &A,

182 FragmentB const &B,

183 FragmentC const &C) const {

184

185 nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,

186 nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);

187 }

188

189 #else

190static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");

191 #endif

192

193 };

194

195 } // namespace arch

196 } // namespace cutlass

cutlass

Definition: aligned_buffer.h:35

cutlass::platform::is_same

std::is_same (false specialization)

Definition: platform.h:394

cutlass::uint1b_t

integer_subbyte< 1, false > uint1b_t

1-bit Unsigned integer type

Definition: integer_subbyte.h:152

cutlass::integer_subbyte

4-bit signed integer type

Definition: integer_subbyte.h:42

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::int4b_t

integer_subbyte< 4, true > int4b_t

4-bit Integer type

Definition: integer_subbyte.h:155


Generated by 1.8.11