Back to Cutlass

CUTLASS: wmma_sm72.h Source File

docs/wmma__sm72_8h_source.html

4.4.29.0 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

wmma_sm72.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include <assert.h>

32 #include "cutlass/layout/matrix.h"

33

35 namespace cutlass {

36 namespace arch {

37

39 //

40 // WMMA template structure defines nvcuda::wmma::fragments and static assert for

41 // wmma native instruction sizes supported for int8_t

42 //

44 template <

45 typename Shape_,

46 typename LayoutA_,

47 typename LayoutB_,

48 typename LayoutC_>

[49](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01int8 t_00_01LayoutA _00_01int8 t_00_01LayoutB_505c57bb6818a941dc16f00cf35a9ec0.html) struct Wmma<

50 Shape_,

51 int8_t,

52 LayoutA_,

53 int8_t,

54 LayoutB_,

55 int32_t,

56 LayoutC_,

57cutlass::arch::OpMultiplyAdd

58 > {

59 #if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)

60using Shape = Shape_;

61using ElementA = int8_t;

62using LayoutA = LayoutA_;

63using ElementB = int8_t;

64using LayoutB = LayoutB_;

65using ElementC = int32_t;

66using LayoutC = LayoutC_;

67using Operator = cutlass::arch::OpMultiplyAdd;

68

69// check supported wmma shape for the given multiplicand data types

70static_assert(

71platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||

72platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||

73platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,

74"Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");

75

76

77// Wmma Fragment

78using FragmentA = nvcuda::wmma::fragment<

79 nvcuda::wmma::matrix_a,

80 Shape::kM,

81 Shape::kN,

82 Shape::kK,

83typename CutlassToWmmaDataType<ElementA>::Type,

84typename CutlassToWmmaLayout<LayoutA>::Layout>;

85

86using FragmentB = nvcuda::wmma::fragment<

87 nvcuda::wmma::matrix_b,

88 Shape::kM,

89 Shape::kN,

90 Shape::kK,

91typename CutlassToWmmaDataType<ElementB>::Type,

92typename CutlassToWmmaLayout<LayoutB>::Layout>;

93

94using FragmentC = nvcuda::wmma::fragment<

95 nvcuda::wmma::accumulator,

96 Shape::kM,

97 Shape::kN,

98 Shape::kK,

99typename CutlassToWmmaDataType<ElementC>::Type>;

100

102 CUTLASS_DEVICE

103void operator()(

104 FragmentC &D,

105 FragmentA const &A,

106 FragmentB const &B,

107 FragmentC const &C) const {

108

109 nvcuda::wmma::mma_sync(D, A, B, C);

110 }

111

112 #else

113static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond");

114 #endif

115

116 };

117

119 //

120 // WMMA template structure defines nvcuda::wmma::fragments and static assert for

121 // wmma native instruction sizes supported for uint8_t

122 //

124 template <

125 typename Shape_,

126 typename LayoutA_,

127 typename LayoutB_,

128 typename LayoutC_>

[129](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01uint8 t_00_01LayoutA _00_01uint8 t_00_01Layout219a464a1248ebfc37aa29bcb10cb1b0.html) struct Wmma<

130 Shape_,

131 uint8_t,

132 LayoutA_,

133 uint8_t,

134 LayoutB_,

135 int32_t,

136 LayoutC_,

137cutlass::arch::OpMultiplyAdd

138 > {

139 #if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)

140using Shape = Shape_;

141using ElementA = uint8_t;

142using LayoutA = LayoutA_;

143using ElementB = uint8_t;

144using LayoutB = LayoutB_;

145using ElementC = int32_t;

146using LayoutC = LayoutC_;

147using Operator = cutlass::arch::OpMultiplyAdd;

148

149// check supported wmma shape for the given multiplicand data types

150static_assert(

151platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||

152platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||

153platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,

154"Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");

155

156// Wmma Fragment

157using FragmentA = nvcuda::wmma::fragment<

158 nvcuda::wmma::matrix_a,

159 Shape::kM,

160 Shape::kN,

161 Shape::kK,

162typename CutlassToWmmaDataType<ElementA>::Type,

163typename CutlassToWmmaLayout<LayoutA>::Layout>;

164

165using FragmentB = nvcuda::wmma::fragment<

166 nvcuda::wmma::matrix_b,

167 Shape::kM,

168 Shape::kN,

169 Shape::kK,

170typename CutlassToWmmaDataType<ElementB>::Type,

171typename CutlassToWmmaLayout<LayoutB>::Layout>;

172

173using FragmentC = nvcuda::wmma::fragment<

174 nvcuda::wmma::accumulator,

175 Shape::kM,

176 Shape::kN,

177 Shape::kK,

178typename CutlassToWmmaDataType<ElementC>::Type>;

179

181 CUTLASS_DEVICE

182void operator()(

183 FragmentC &D,

184 FragmentA const &A,

185 FragmentB const &B,

186 FragmentC const &C) const {

187

188 nvcuda::wmma::mma_sync(D, A, B, C);

189 }

190

191 #else

192static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond");

193 #endif

194

195 };

196

197 } // namespace arch

198 } // namespace cutlass

cutlass

Definition: aligned_buffer.h:35

cutlass::platform::is_same

std::is_same (false specialization)

Definition: platform.h:394

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

matrix.h

Defines layout functions used by TensorRef and derived classes.


Generated by 1.8.11