Back to Cutlass

CUTLASS: gemv_batched_strided.h Source File

docs/gemv__batched__strided_8h_source.html

4.4.218.6 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

gemv_batched_strided.h

[Go to the documentation of this file.](gemv batched strided_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

25

26 #pragma once

27

28 #include "cutlass/cutlass.h"

29

30 #include "cutlass/aligned_buffer.h"

31 #include "cutlass/array.h"

32

33 #include "cutlass/numeric_types.h"

34 #include "cutlass/matrix_shape.h"

35

36 #include "cutlass/gemm/gemm.h"

37

39

40 namespace cutlass {

41 namespace gemm {

42 namespace kernel {

43

44 namespace detail

45 {

46template<typename ElementAlphaBeta, bool BetaIsZero>

47struct GemvBatchedStridedEpilogueScaling

48 {

49 ElementAlphaBeta const & alpha;

50 ElementAlphaBeta const & beta;

51

52 CUTLASS_DEVICE

53GemvBatchedStridedEpilogueScaling(ElementAlphaBeta& alpha_, ElementAlphaBeta& beta_) :

54 alpha(alpha_), beta(beta_)

55 { }

56

57template<typename FragmentCD, typename FragmentAccumulator>

58 CUTLASS_DEVICE

59void operator()(FragmentAccumulator& accumulators,

60 FragmentCD const& fragment_C,

61 FragmentCD& fragment_D) const

62 {

63using AccType = typename FragmentAccumulator::value_type;

64using CDType = typename FragmentCD::value_type;

65

66static_assert(FragmentCD::kElements == FragmentAccumulator::kElements,

67"Mistmatch in fragment sizes.");

68

69for (int i = 0; i < FragmentCD::kElements; ++i)

70 {

71if (BetaIsZero)

72 {

73 fragment_D[i] = CDType(accumulators[i] * AccType(alpha));

74 }

75else

76 {

77 fragment_D[i] = CDType(accumulators[i] * AccType(alpha)

78 + AccType(fragment_C[i]) * AccType(beta));

79 }

80 }

81 }

82 };

83 }

84

86

87 template <typename GemvKernel, typename ElementAlphaBeta, bool BetaIsZero=false>

88 CUTLASS_DEVICE void GemvBatchedStridedDevice(

89cutlass::gemm::BatchedGemmCoord problem_size,

90 ElementAlphaBeta alpha,

91 ElementAlphaBeta beta,

92typename GemvKernel::IteratorA::TensorRef ref_A,

93typename GemvKernel::IteratorA::TensorRef::LongIndex lda,

94typename GemvKernel::IteratorB::TensorRef ref_B,

95typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,

96typename GemvKernel::IteratorCD::TensorRef ref_C,

97typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,

98typename GemvKernel::IteratorCD::TensorRef ref_D,

99typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)

100 {

101using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv;

102using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle;

103using EpilogueScale = detail::GemvBatchedStridedEpilogueScaling<ElementAlphaBeta, BetaIsZero>;

104

105 ThreadBlockSwizzle swizzler;

106

107// Compute initial location in logical coordinates

108BatchedGemmCoord tb_offset = swizzler.get_tile_offset();

109int const batch_idx = swizzler.get_batch_idx();

110

111// Offset to the batch

112 ref_A.add_pointer_offset(batch_idx*lda);

113 ref_B.add_pointer_offset(batch_idx*ldb);

114

115// Construct iterators to A and B operands

116typename GemvKernel::IteratorA::Params params_A(ref_A.layout());

117typename GemvKernel::IteratorA iterator_A(

118 params_A,

119 ref_A.data(),

120 { 1, problem_size.k() },

121 0,

122 { 0, 0 });

123

124typename GemvKernel::IteratorB::Params params_B(ref_B.layout());

125typename GemvKernel::IteratorB iterator_B(

126 params_B,

127 ref_B.data(),

128 { problem_size.k(), problem_size.n() },

129 threadIdx.x,

130 { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });

131

132//

133// Main loop

134//

135

136// Construct thread-scoped matrix multiply

137 ThreadBlockGemv mma;

138

139typename ThreadBlockGemv::FragmentC accumulators;

140 accumulators.clear();

141

142// Compute threadblock-scoped gemv

143 mma(problem_size.mnk(), accumulators, iterator_A, iterator_B, accumulators);

144

145//

146// Epilogue (TODO: Epiloge as template argument)

147//

148typename GemvKernel::FragmentCD fragment_CD;

149

150// Load C (skip if beta is zero)

151if (!BetaIsZero)

152 {

153 tb_offset = swizzler.get_tile_offset();

154 ref_C.add_pointer_offset(batch_idx*ldc);

155typename GemvKernel::IteratorCD::Params params_C(ref_C.layout());

156typename GemvKernel::IteratorCD iterator_C(

157 params_C,

158 ref_C.data(),

159 { 1, problem_size.n() },

160 threadIdx.x,

161 { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });

162 iterator_C.load(fragment_CD);

163 }

164

165// Apply alpha/beta scaling

166 EpilogueScale epilogue_scale(alpha, beta);

167 epilogue_scale(accumulators, fragment_CD, fragment_CD);

168

169// Store D

170 tb_offset = swizzler.get_tile_offset();

171 ref_D.add_pointer_offset(batch_idx*ldd);

172typename GemvKernel::IteratorCD::Params params_D(ref_D.layout());

173typename GemvKernel::IteratorCD iterator_D(

174 params_D,

175 ref_D.data(),

176 { 1, problem_size.n() },

177 threadIdx.x,

178 { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });

179 iterator_D.store(fragment_CD);

180 }

181

182 template <typename GemvKernel, typename ElementAlphaBeta, bool BetaIsZero>

183 __global__ void GemvBatchedStrided(

184cutlass::gemm::BatchedGemmCoord problem_size,

185 ElementAlphaBeta alpha,

186 ElementAlphaBeta beta,

187typename GemvKernel::IteratorA::TensorRef ref_A,

188typename GemvKernel::IteratorA::TensorRef::LongIndex lda,

189typename GemvKernel::IteratorB::TensorRef ref_B,

190typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,

191typename GemvKernel::IteratorCD::TensorRef ref_C,

192typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,

193typename GemvKernel::IteratorCD::TensorRef ref_D,

194typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)

195 {

196 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, BetaIsZero>(

197 problem_size, alpha, beta, ref_A, lda, ref_B, ldb, ref_C, ldc, ref_D, ldd

198 );

199 }

200

201 template <typename GemvKernel, typename ElementAlphaBeta>

202 __global__ void GemvBatchedStrided(

203cutlass::gemm::BatchedGemmCoord problem_size,

204 ElementAlphaBeta alpha,

205typename GemvKernel::IteratorA::TensorRef ref_A,

206typename GemvKernel::IteratorA::TensorRef::LongIndex lda,

207typename GemvKernel::IteratorB::TensorRef ref_B,

208typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,

209typename GemvKernel::IteratorCD::TensorRef ref_D,

210typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)

211 {

212 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(

213 problem_size, alpha, ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd

214 );

215 }

216

217 template <typename GemvKernel>

218 __global__ void GemvBatchedStrided(

219cutlass::gemm::BatchedGemmCoord problem_size,

220typename GemvKernel::IteratorA::TensorRef ref_A,

221typename GemvKernel::IteratorA::TensorRef::LongIndex lda,

222typename GemvKernel::IteratorB::TensorRef ref_B,

223typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,

224typename GemvKernel::IteratorCD::TensorRef ref_D,

225typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)

226 {

227using ElementAlphaBeta = typename GemvKernel::IteratorCD::Element;

228 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(

229 problem_size, ElementAlphaBeta(1), ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd

230 );

231 }

232

233

235

236 } // namespace kernel

237 } // namespace gemm

238 } // namespace cutlass

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::BatchedGemmCoord::mnk

CUTLASS_HOST_DEVICE GemmCoord mnk() const

Obtains a GemmCoord from BatchedGemmCoord.

Definition: include/cutlass/gemm/gemm.h:330

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::kernel::GemvBatchedStridedDevice

CUTLASS_DEVICE void GemvBatchedStridedDevice(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)

Definition: gemv_batched_strided.h:88

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

cutlass::gemm::kernel::GemvBatchedStrided

__global__ void GemvBatchedStrided(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)

Definition: gemv_batched_strided.h:183

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::BatchedGemmCoord

Definition: include/cutlass/gemm/gemm.h:260

aligned_buffer.h

AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...

cutlass::gemm::BatchedGemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:314

numeric_types.h

Top-level include for all CUTLASS numeric types.

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling::beta

ElementAlphaBeta const & beta

Definition: gemv_batched_strided.h:50

cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling::operator()

CUTLASS_DEVICE void operator()(FragmentAccumulator &accumulators, FragmentCD const &fragment_C, FragmentCD &fragment_D) const

Definition: gemv_batched_strided.h:59

cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling::alpha

ElementAlphaBeta const & alpha

Definition: gemv_batched_strided.h:49

cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling

Definition: gemv_batched_strided.h:47

cutlass::gemm::BatchedGemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:306

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling::GemvBatchedStridedEpilogueScaling

CUTLASS_DEVICE GemvBatchedStridedEpilogueScaling(ElementAlphaBeta &alpha_, ElementAlphaBeta &beta_)

Definition: gemv_batched_strided.h:53


Generated by 1.8.11