Back to Cutlass

CUTLASS: wmma_sm70.h Source File

docs/wmma__sm70_8h_source.html

4.4.26.9 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

wmma_sm70.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include <assert.h>

32 #include "cutlass/layout/matrix.h"

33

35 namespace cutlass {

36 namespace arch {

37

38

40 //

41 // WMMA template structure defines nvcuda::wmma::fragments and static assert for

42 // wmma native instruction sizes supported for half

43 //

45 template <

46 typename Shape_,

47 typename LayoutA_,

48 typename LayoutB_,

49 typename ElementC_,

50 typename LayoutC_>

[51](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01cutlass_1_1half t_00_01LayoutA___00_01cutlass_1_84e30c8cc93eeb7ca02f651bd16d4c38.html) struct Wmma<

52 Shape_,

53cutlass::half_t,

54 LayoutA_,

55cutlass::half_t,

56 LayoutB_,

57 ElementC_,

58 LayoutC_,

59 cutlass::arch::OpMultiplyAdd

60 > {

61

62 #if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED)

63using Shape = Shape_;

64using ElementA = cutlass::half_t;

65using LayoutA = LayoutA_;

66using ElementB = cutlass::half_t;

67using LayoutB = LayoutB_;

68using ElementC = ElementC_;

69using LayoutC = LayoutC_;

70using Operator = cutlass::arch::OpMultiplyAdd;

71

72// check supported wmma shape for the given multiplicand data types

73static_assert(

74platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||

75platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||

76platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,

77"Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");

78

79// check supported wmma output data type for the given multiplicand data types

80static_assert(

81platform::is_same<cutlass::half_t, ElementC>::value || platform::is_same<float, ElementC>::value,

82"Supported of wmma output data type for f16 multiplicands are: f16 and f32");

83

84// Wmma Fragment

85using FragmentA = nvcuda::wmma::fragment<

86 nvcuda::wmma::matrix_a,

87 Shape::kM,

88 Shape::kN,

89 Shape::kK,

90typename CutlassToWmmaDataType<ElementA>::Type,

91typename CutlassToWmmaLayout<LayoutA>::Layout>;

92

93using FragmentB = nvcuda::wmma::fragment<

94 nvcuda::wmma::matrix_b,

95 Shape::kM,

96 Shape::kN,

97 Shape::kK,

98typename CutlassToWmmaDataType<ElementB>::Type,

99typename CutlassToWmmaLayout<LayoutB>::Layout>;

100

101using FragmentC = nvcuda::wmma::fragment<

102 nvcuda::wmma::accumulator,

103 Shape::kM,

104 Shape::kN,

105 Shape::kK,

106typename CutlassToWmmaDataType<ElementC>::Type>;

107

109 CUTLASS_DEVICE

110void operator()(

111 FragmentC &D,

112 FragmentA const &A,

113 FragmentB const &B,

114 FragmentC const &C) const {

115

116 nvcuda::wmma::mma_sync(D, A, B, C);

117 }

118 #else

119static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond");

120 #endif

121

122 };

123

124 } // namespace arch

125 } // namespace cutlass

cutlass

Definition: aligned_buffer.h:35

cutlass::platform::is_same

std::is_same (false specialization)

Definition: platform.h:394

cutlass::half_t

IEEE half-precision floating-point type.

Definition: half.h:126

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

matrix.h

Defines layout functions used by TensorRef and derived classes.


Generated by 1.8.11