Back to Cutlass

CUTLASS: default_gemv_core.h Source File

docs/default__gemv__core_8h_source.html

4.4.221.1 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

default_gemv_core.h

[Go to the documentation of this file.](default gemv core_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

32 #pragma once

33

34 #include "cutlass/cutlass.h"

35 #include "cutlass/array.h"

36 #include "cutlass/numeric_types.h"

37 #include "cutlass/matrix_shape.h"

38

39 #include "cutlass/layout/matrix.h"

40

41 #include "cutlass/platform/platform.h"

42

43 #include "cutlass/gemm/gemm.h"

44 #include "cutlass/gemm/thread/mma.h"

45

46 #include "[cutlass/transform/threadblock/predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)"

47 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"

48

49 #include "cutlass/gemm/threadblock/gemv.h"

50

52 namespace cutlass {

53 namespace gemm {

54 namespace threadblock {

55

58 template <

59typename Shape_,

60typename ThreadShape_,

61typename ElementA_,

62typename LayoutA_,

63typename ElementB_,

64typename LayoutB_,

65typename ElementC_,

66typename LayoutC_

67 >

68 struct DefaultGemvCore {

69

70using Shape = Shape_;

71using ThreadShape = ThreadShape_;

72

73using LayoutA = LayoutA_;

74using LayoutB = LayoutB_;

75using LayoutC = LayoutC_;

76

77using ElementA = ElementA_;

78using ElementB = ElementB_;

79using ElementC = ElementC_;

80

81static int const kThreadsPerN = Shape::kN / ThreadShape::kN;

82

83using IteratorPolicyA = typename platform::conditional<

84platform::is_same<LayoutA, layout::RowMajor>::value,

85cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<

86layout::PitchLinearShape<Shape::kK, Shape::kM>, 1, ThreadShape::kK>,

87cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided<

88layout::PitchLinearShape<Shape::kM, Shape::kK>, 1, ThreadShape::kM>>::type;

89

90using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<

91cutlass::MatrixShape<Shape::kM, Shape::kK>, ElementA, LayoutA, 1, IteratorPolicyA>;

92

93using IteratorPolicyB = typename platform::conditional<

94platform::is_same<LayoutB, layout::RowMajor>::value,

95cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<

96layout::PitchLinearShape<Shape::kN, Shape::kK>, kThreadsPerN, ThreadShape::kN>,

97cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided<

98layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreadsPerN, ThreadShape::kK>>::type;

99

100using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<

101cutlass::MatrixShape<Shape::kK, Shape::kN>, ElementB, LayoutB, 0, IteratorPolicyB>;

102

103using IteratorPolicyC = typename platform::conditional<

104platform::is_same<LayoutC, layout::RowMajor>::value,

105cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous<

106layout::PitchLinearShape<Shape::kN, Shape::kM>, kThreadsPerN, ThreadShape::kN>,

107cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided<

108layout::PitchLinearShape<Shape::kM, Shape::kN>, kThreadsPerN, ThreadShape::kM>>::type;

109

110using IteratorC = cutlass::transform::threadblock::PredicatedTileIterator<

111cutlass::MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC, 0, IteratorPolicyC>;

112

113using MmaSimtOp = typename cutlass::gemm::thread::Mma<

114cutlass::gemm::GemmShape<ThreadShape::kM, ThreadShape::kN, Shape::kK>,

115ElementA,

116LayoutA,

117ElementB,

118LayoutB,

119ElementC,

120 LayoutC>;

121

122using Operator = MmaSimtOp;

123

124// Assertions for correctness

125static_assert((Shape::kM == 1), "M=1 is required for GEMV");

126

127static_assert((ThreadShape::kM == 1), "M=1 is required for GEMV");

128

129static_assert(Shape::kK % ThreadShape::kK == 0, "Shape::K must be a multiple of ThreadShape::K");

130

131static_assert(((ThreadShape::kK == 1) ||

132 (ThreadShape::kK == 2) ||

133 (ThreadShape::kK == 4) ||

134 (ThreadShape::kK == 8) ||

135 (ThreadShape::kK == 16) ||

136 (ThreadShape::kK == 32)

137 ),

138"ThreadShape::K must be a 1, 2, 4, 8, 16 or 32");

139 };

140

142

143 } // namespace threadblock

144 } // namespace gemm

145 } // namespace cutlass

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::threadblock::DefaultGemvCore::Shape

Shape_ Shape

Definition: default_gemv_core.h:70

cutlass::platform::is_same

std::is_same (false specialization)

Definition: platform.h:394

[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)

Templates implementing how threads are mapped to a given tile.

cutlass::gemm::threadblock::DefaultGemvCore::ThreadShape

ThreadShape_ ThreadShape

Definition: default_gemv_core.h:71

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::threadblock::DefaultGemvCore::ElementA

ElementA_ ElementA

Definition: default_gemv_core.h:77

platform.h

C++ features that may be otherwise unimplemented for CUDA device functions.

cutlass::gemm::threadblock::DefaultGemvCore::IteratorPolicyC

typename platform::conditional< platform::is_same< LayoutC, layout::RowMajor >::value, cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< layout::PitchLinearShape< Shape::kN, Shape::kM >, kThreadsPerN, ThreadShape::kN >, cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< layout::PitchLinearShape< Shape::kM, Shape::kN >, kThreadsPerN, ThreadShape::kM >>::type IteratorPolicyC

Definition: default_gemv_core.h:108

cutlass::layout::PitchLinearShape

Template defining a shape used by pitch-linear operators.

Definition: pitch_linear.h:43

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::threadblock::DefaultGemvCore::IteratorPolicyA

typename platform::conditional< platform::is_same< LayoutA, layout::RowMajor >::value, cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< layout::PitchLinearShape< Shape::kK, Shape::kM >, 1, ThreadShape::kK >, cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< layout::PitchLinearShape< Shape::kM, Shape::kK >, 1, ThreadShape::kM >>::type IteratorPolicyA

Definition: default_gemv_core.h:88

cutlass::gemm::threadblock::DefaultGemvCore::IteratorPolicyB

typename platform::conditional< platform::is_same< LayoutB, layout::RowMajor >::value, cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous< layout::PitchLinearShape< Shape::kN, Shape::kK >, kThreadsPerN, ThreadShape::kN >, cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided< layout::PitchLinearShape< Shape::kK, Shape::kN >, kThreadsPerN, ThreadShape::kK >>::type IteratorPolicyB

Definition: default_gemv_core.h:98

cutlass::gemm::threadblock::DefaultGemvCore::ElementC

ElementC_ ElementC

Definition: default_gemv_core.h:79

cutlass::gemm::threadblock::DefaultGemvCore::Operator

MmaSimtOp Operator

Definition: default_gemv_core.h:122

cutlass::gemm::threadblock::DefaultGemvCore::MmaSimtOp

typename cutlass::gemm::thread::Mma< cutlass::gemm::GemmShape< ThreadShape::kM, ThreadShape::kN, Shape::kK >, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC > MmaSimtOp

Definition: default_gemv_core.h:120

mma.h

Templates exposing architecture support for warp-level multiply-add operations.

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::gemm::threadblock::DefaultGemvCore

Definition: default_gemv_core.h:68

cutlass::platform::conditional

std::conditional (true specialization)

Definition: platform.h:325

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::gemm::threadblock::DefaultGemvCore::LayoutB

LayoutB_ LayoutB

Definition: default_gemv_core.h:74

cutlass::transform::threadblock::PredicatedTileIterator

Definition: transform/threadblock/predicated_tile_iterator.h:133

cutlass::gemm::threadblock::DefaultGemvCore::kThreadsPerN

static int const kThreadsPerN

Definition: default_gemv_core.h:81

cutlass::gemm::thread::Mma

Structure to compute the matrix product.

Definition: gemm/thread/mma.h:66

matrix.h

Defines layout functions used by TensorRef and derived classes.

gemv.h

Template for a threadblock-scoped GEMV kernel.

[predicated_tile_iterator.h](transform_2threadblock_2predicated tile iterator_8h.html)

Templates implementing loading of tiles from pitch-linear rank=2 tensors.

cutlass::gemm::threadblock::DefaultGemvCore::ElementB

ElementB_ ElementB

Definition: default_gemv_core.h:78

cutlass::transform::PitchLinearTilePolicyStripminedThreadStrided

Definition: pitch_linear_thread_map.h:168

cutlass::gemm::threadblock::DefaultGemvCore::LayoutC

LayoutC_ LayoutC

Definition: default_gemv_core.h:75

cutlass::transform::PitchLinearTilePolicyStripminedThreadContiguous

Definition: pitch_linear_thread_map.h:140

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::threadblock::DefaultGemvCore::LayoutA

LayoutA_ LayoutA

Definition: default_gemv_core.h:73


Generated by 1.8.11