Back to Cutlass

CUTLASS: mma_sm61.h Source File

docs/arch_2mma__sm61_8h_source.html

4.4.28.5 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

arch/mma_sm61.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30

31 #include "cutlass/layout/matrix.h"

32

34

35 namespace cutlass {

36 namespace arch {

37

39

41 template <typename LayoutA, typename LayoutB, typename LayoutC>

42 struct Mma<

43 gemm::GemmShape<1,1,4>,

44 1,

45 int8_t,

46 LayoutA,

47 int8_t,

48 LayoutB,

49 int,

50 LayoutC,

51 OpMultiplyAdd> {

52

53using Shape = gemm::GemmShape<1, 1, 4>;

54

55CUTLASS_HOST_DEVICE

56void operator()(

57 Array<int, 1> &d,

58 Array<int8_t, 4> const &a,

59 Array<int8_t, 4> const &b,

60 Array<int, 1> const &c

61 ) {

62

63 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))

64

65unsigned const &A = reinterpret_cast<unsigned const &>(a);

66unsigned const &B = reinterpret_cast<unsigned const &>(b);

67

68asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"

69 : "=r"(d[0])

70 : "r"(A), "r"(B), "r"(c[0]));

71

72 #else

73

74 d[0] = c[0];

75

76CUTLASS_PRAGMA_UNROLL

77for (int k = 0; k < 4; ++k) {

78 d[0] += a[k] * b[k];

79 }

80

81 #endif

82 }

83 };

84

86

88 template <typename LayoutC>

89 struct Mma<

90 gemm::GemmShape<1, 1, 2>,

91 1,

92 int16_t,

93layout::RowMajor,

94 int16_t,

95layout::ColumnMajor,

96 int,

97 LayoutC,

98 OpMultiplyAdd> {

99

100using Shape = gemm::GemmShape<1, 1, 2>;

101

102CUTLASS_HOST_DEVICE

103void operator()(

104 Array<int, 1> &d,

105 Array<int16_t, 2> const &a,

106 Array<int16_t, 2> const &b,

107 Array<int, 1> const &c

108 ) {

109

110 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610))

111

112unsigned const &A = reinterpret_cast<unsigned const &>(a);

113unsigned const &B = reinterpret_cast<unsigned const &>(b);

114

115asm volatile("dp2a.s32.s32 %0, %1, %2, %3;"

116 : "=r"(d[0])

117 : "r"(A), "r"(B), "r"(c[0]));

118 #else

119 d[0] = c[0];

120

121CUTLASS_PRAGMA_UNROLL

122for (int k = 0; k < 2; ++k) {

123 d[0] += a[k] * b[k];

124 }

125 #endif

126 }

127 };

128

130

131 }

132 }

133

cutlass

Definition: aligned_buffer.h:35

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 4 >, 1, int8_t, LayoutA, int8_t, LayoutB, int, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< int, 1 > &d, Array< int8_t, 4 > const &a, Array< int8_t, 4 > const &b, Array< int, 1 > const &c)

Definition: arch/mma_sm61.h:56

cutlass::layout::ColumnMajor

Mapping function for column-major matrices.

Definition: layout/matrix.h:142

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::layout::RowMajor

Mapping function for row-major matrices.

Definition: layout/matrix.h:50

matrix.h

Defines layout functions used by TensorRef and derived classes.

cutlass::arch::Mma

Matrix multiply-add operation.

Definition: arch/mma.h:92

cutlass::arch::Mma< gemm::GemmShape< 1, 1, 2 >, 1, int16_t, layout::RowMajor, int16_t, layout::ColumnMajor, int, LayoutC, OpMultiplyAdd >::operator()

CUTLASS_HOST_DEVICE void operator()(Array< int, 1 > &d, Array< int16_t, 2 > const &a, Array< int16_t, 2 > const &b, Array< int, 1 > const &c)

Definition: arch/mma_sm61.h:103


Generated by 1.8.11