Back to Cutlass

CUTLASS: epilogue_workspace.h Source File

docs/epilogue__workspace_8h_source.html

4.4.217.0 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

epilogue_workspace.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

46 #pragma once

47

48 #include "cutlass/cutlass.h"

49 #include "cutlass/numeric_types.h"

50 #include "cutlass/array.h"

51

53

54 namespace cutlass {

55 namespace epilogue {

56

58

59 template <

60typename Shape_,

61int WarpCount,

62typename FragmentC_

63 >

64 class EpilogueWorkspace {

65 public:

66

67using Shape = Shape_;

68using FragmentC = FragmentC_;

69using ElementC = typename FragmentC::value_type;

70

71static int const kWarpCount = WarpCount;

72

74static int const kAccessSizeInBits = 128;

75

77static int const kWarpSize = 32;

78

80static int const kElementsPerAccess =

81 kAccessSizeInBits / sizeof_bits<ElementC>::value;

82

84static int const kIterations = FragmentC::kElements / kElementsPerAccess;

85

86static_assert(

87 !(FragmentC::kElements % kElementsPerAccess),

88"The number of accumulators must be divisible by the access size.");

89

91static int const kWarpAccesses = kIterations * kWarpSize;

92

94static int const kThreadblockAccesses = kWarpAccesses * kWarpCount;

95

97struct Params {

98

100ElementC *ptr_C;

101

103int stride_n;

104

106int stride_k;

107

108//

109// Methods

110//

111

112CUTLASS_HOST_DEVICE

113Params(

114ElementC *ptr_C,

115int stride_n_,

116int stride_k_

117 ):

118 ptr_C(ptr_C), stride_n(stride_n_ / kElementsPerAccess), stride_k(stride_k_ / kElementsPerAccess) {

119

120 }

121 };

122

124struct SharedStorage {

125// Intentionally empty

126 };

127

128 private:

129

130struct alignas((kAccessSizeInBits / 8)) AccessType {

131 Array<ElementC, kElementsPerAccess> storage;

132 };

133

135 AccessType *pointer_;

136

138int stride_n_;

139

141int stride_k_;

142

143 public:

144

146 CUTLASS_DEVICE

147EpilogueWorkspace(

148Params const &params,

149SharedStorage &,

150int warp_idx,

151int lane_idx

152

153 ):

154 pointer_(reinterpret_cast<AccessType *>(params.ptr_C)),

155 stride_n_(params.stride_n),

156 stride_k_(params.stride_k) {

157

158// Add per-thread offset

159 pointer_ += lane_idx + warp_idx * kWarpAccesses;

160 }

161

163 CUTLASS_DEVICE

164void operator()(

165cutlass::gemm::GemmCoord problem_size,

166cutlass::gemm::GemmCoord tb_tile_coord,

167FragmentC const &accum) {

168

169// Compute offset for entire threadblock (note, per-thread offset has been folded in already)

170 AccessType *pointer = pointer_ +

171 tb_tile_coord.m() * kThreadblockAccesses +

172 tb_tile_coord.n() * stride_n_ +

173 tb_tile_coord.k() * stride_k_;

174

175// Cast to vectorized view of accumulator fragments

176 AccessType const * src_pointer = reinterpret_cast<AccessType const *>(&accum);

177

178// Write out accumulators at full speed

179CUTLASS_PRAGMA_UNROLL

180for (int i = 0; i < kIterations; ++i) {

181 pointer[i * kWarpSize] = src_pointer[i];

182 }

183 }

184 };

185

187

188 } // namespace epilogue

189 } // namespace cutlass

190

cutlass

Definition: aligned_buffer.h:35

cutlass::epilogue::EpilogueWorkspace::SharedStorage

Shared storage allocation needed by the epilogue.

Definition: epilogue_workspace.h:124

cutlass::epilogue::EpilogueWorkspace::kAccessSizeInBits

static int const kAccessSizeInBits

Optimize for 128b accesses.

Definition: epilogue_workspace.h:74

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::epilogue::EpilogueWorkspace::kWarpAccesses

static int const kWarpAccesses

Total number of vectorized accesses in warp (in units of vector)

Definition: epilogue_workspace.h:91

cutlass::epilogue::EpilogueWorkspace::Params::Params

CUTLASS_HOST_DEVICE Params(ElementC *ptr_C, int stride_n_, int stride_k_)

Definition: epilogue_workspace.h:113

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::epilogue::EpilogueWorkspace::kIterations

static int const kIterations

Number of stores per thread.

Definition: epilogue_workspace.h:84

cutlass::epilogue::EpilogueWorkspace::operator()

CUTLASS_DEVICE void operator()(cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tb_tile_coord, FragmentC const &accum)

Streams the result to global memory.

Definition: epilogue_workspace.h:164

cutlass::epilogue::EpilogueWorkspace::Shape

Shape_ Shape

Definition: epilogue_workspace.h:67

cutlass::gemm::GemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:145

cutlass::epilogue::EpilogueWorkspace::ElementC

typename FragmentC::value_type ElementC

Definition: epilogue_workspace.h:69

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

cutlass::sizeof_bits

Defines the size of an element in bits.

Definition: numeric_types.h:42

cutlass::epilogue::EpilogueWorkspace::Params::ptr_C

ElementC * ptr_C

Pointer to C matrix.

Definition: epilogue_workspace.h:100

cutlass::epilogue::EpilogueWorkspace::FragmentC

FragmentC_ FragmentC

Definition: epilogue_workspace.h:68

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::epilogue::EpilogueWorkspace::EpilogueWorkspace

CUTLASS_DEVICE EpilogueWorkspace(Params const &params, SharedStorage &, int warp_idx, int lane_idx)

Constructor.

Definition: epilogue_workspace.h:147

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::epilogue::EpilogueWorkspace::kWarpCount

static int const kWarpCount

Definition: epilogue_workspace.h:71

cutlass::epilogue::EpilogueWorkspace::Params::stride_n

int stride_n

Stride between tiles along the GEMM N dimension (in units of vectors)

Definition: epilogue_workspace.h:103

cutlass::epilogue::EpilogueWorkspace::kElementsPerAccess

static int const kElementsPerAccess

Vector length of accesses.

Definition: epilogue_workspace.h:80

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

cutlass::epilogue::EpilogueWorkspace::Params::stride_k

int stride_k

Stride between tiles along the GEMM K dimension (in units of vectors)

Definition: epilogue_workspace.h:106

cutlass::epilogue::EpilogueWorkspace::Params

Parameters structure.

Definition: epilogue_workspace.h:97

cutlass::epilogue::EpilogueWorkspace

Definition: epilogue_workspace.h:64

cutlass.h

Basic include for CUTLASS.

cutlass::epilogue::EpilogueWorkspace::kWarpSize

static int const kWarpSize

Warp size from the perspective of memory operations.

Definition: epilogue_workspace.h:77

cutlass::epilogue::EpilogueWorkspace::kThreadblockAccesses

static int const kThreadblockAccesses

Total number of vectorized accesses in threadblock tile (in units of vector)

Definition: epilogue_workspace.h:94


Generated by 1.8.11