Back to Cutlass

CUTLASS: batched_reduction_traits.h Source File

docs/batched__reduction__traits_8h_source.html

4.4.222.3 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

batched_reduction_traits.h

[Go to the documentation of this file.](batched reduction traits_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

29 #pragma once

30 #include "cutlass/cutlass.h"

31 #include "cutlass/shape.h"

32 #include "cutlass/reduction/threadblock_swizzle.h"

33 #include "cutlass/reduction/batched_reduction.h"

34 #include "cutlass/gemm/linear_scaling.h"

35

36 namespace cutlass {

37 namespace reduction {

38

39 /*

40 OutputTile defines the work load per thread block

41 Subtile defines the work load per thread block per iteration

42 OutputTile / Subtile = number of iterations within a kernel

43 ThreadShape defines the work load per thread

44 Subtile / ThreadShape = number of threads per thread block

45 */

46 template <

48typename ScalarA_,

50typename ScalarC_,

52typename ScalarD_,

54typename ScalarAlphaBeta_,

56typename ScalarAccum_,

58int ReductionSize_ = 1,

60typename OutputTile_ = Shape<1, 1, 128>,

62typename SubTile_ = Shape<1, 1, 64>,

64typename ThreadShape_ = Shape<1, 1, 2>,

66typename Index_ = int,

68typename BlockSwizzle_ = DefaultBlockSwizzle,

70int maxInReg_ = 160,

72int maxOutReg_ = 64,

74typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >

75 >

76 struct BatchedReductionTraits {

78typedef BatchedReductionTraits<ScalarA_,

79 ScalarC_,

80 ScalarD_,

81 ScalarAlphaBeta_,

82 ScalarAccum_,

83 ReductionSize_,

84 OutputTile_,

85 SubTile_,

86 ThreadShape_,

87 Index_,

88 BlockSwizzle_,

89 maxInReg_,

90 maxOutReg_,

91 Functor_> This_;

93typedef typename cutlass::reduction::BatchedReduction<This_> KernelClass;

95typedef OutputTile_ OutputTile;

97typedef SubTile_ SubTile;

99typedef ThreadShape_ ThreadShape;

101typedef ScalarA_ ScalarA;

103typedef ScalarC_ ScalarC;

105typedef ScalarD_ ScalarD;

107typedef ScalarAlphaBeta_ ScalarAlphaBeta;

109typedef ScalarAccum_ ScalarAccum;

111typedef Index_ Index;

113typedef BlockSwizzle_ BlockSwizzle;

115static const int ReductionSize = ReductionSize_;

117static const bool ThreadShapeMultiple2 = (ThreadShape::kW % 2 == 0);

119typedef Functor_ Functor;

122static int const kThreads = SubTile::kW / ThreadShape::kW;

123//

124static int const maxInReg = maxInReg_;

125//

126static int const maxOutReg = maxOutReg_;

127//

128static_assert(SubTile::kW % ThreadShape::kW == 0, "cannot evenly distribute work load among threads");

129//

130static_assert(kThreads % 32 == 0, "threads per threadblock is not multiple of 32");

131//

132static_assert(OutputTile::kW % SubTile::kW == 0, "cannot evenly distribute work load among iterations");

133//

134static_assert(ReductionSize * ThreadShape::kW <= maxInReg, "ReductionSize * ThreadShape::kW should not be bigger than maxInReg");

135//

136static_assert(ThreadShape::kW <= maxOutReg, "ThreadShape::kW should not be bigger than maxOutReg");

137

138struct Params {

140Coord<3> problem_size;

142 ScalarAlphaBeta alpha;

144 ScalarAlphaBeta beta;

146long long int reduction_stride;

147//

148 ScalarA const *d_a;

149//

150 Index lda;

151//

152 ScalarC const *d_c;

153//

154 Index ldc;

155//

156 ScalarD *d_d;

157//

158 Index ldd;

160typename Functor::Params functorParams;

162CUTLASS_HOST_DEVICE int initialize(Index m_,

163 Index n_,

164 ScalarAlphaBeta alpha_,

165 ScalarAlphaBeta beta_,

166long long int reduction_stride_,

167 ScalarA const *d_a_,

168 Index lda_,

169 ScalarC const *d_c_,

170 Index ldc_,

171 ScalarD *d_d_,

172 Index ldd_){

173 problem_size = make_Coord(1, n_, m_);

174 alpha = alpha_;

175 beta = beta_;

176 reduction_stride = reduction_stride_;

177 d_a = d_a_;

178 lda = lda_;

179 d_c = d_c_;

180 d_d = d_d_;

181 ldc = ldc_;

182 ldd = ldd_;

183

184 functorParams.initialize(alpha_, beta_);

185

186return 0;

187 }

188 };

189

190 };

191 } // namespace reduction

192 } // namespace cutlass

cutlass::reduction::BatchedReductionTraits::Params::problem_size

Coord< 3 > problem_size

The dimension of output tensor.

Definition: batched_reduction_traits.h:140

cutlass

Definition: aligned_buffer.h:35

cutlass::reduction::BatchedReductionTraits::Params

Definition: batched_reduction_traits.h:138

cutlass::reduction::BatchedReductionTraits::BlockSwizzle

BlockSwizzle_ BlockSwizzle

The thread block swizzle.

Definition: batched_reduction_traits.h:113

cutlass::reduction::BatchedReductionTraits::This_

BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ > This_

Definition: batched_reduction_traits.h:91

cutlass::reduction::BatchedReductionTraits::kThreads

static int const kThreads

Definition: batched_reduction_traits.h:122

cutlass::make_Coord

CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)

Helper to make a 2-element coordinate.

Definition: coord.h:387

cutlass::reduction::BatchedReductionTraits::ScalarAccum

ScalarAccum_ ScalarAccum

The type for accumulation.

Definition: batched_reduction_traits.h:109

cutlass::reduction::BatchedReductionTraits::Params::lda

Index lda

Definition: batched_reduction_traits.h:150

cutlass::reduction::BatchedReductionTraits::Params::beta

ScalarAlphaBeta beta

The beta.

Definition: batched_reduction_traits.h:144

threadblock_swizzle.h

Defies functors for mapping blockIdx to partitions of the batched reduction computation.

cutlass::reduction::BatchedReductionTraits::ThreadShape

ThreadShape_ ThreadShape

Definition: batched_reduction_traits.h:99

cutlass::reduction::BatchedReductionTraits::Params::ldd

Index ldd

Definition: batched_reduction_traits.h:158

cutlass::reduction::BatchedReductionTraits::ScalarD

ScalarD_ ScalarD

The output pointer type.

Definition: batched_reduction_traits.h:105

cutlass::reduction::BatchedReductionTraits::Params::reduction_stride

long long int reduction_stride

stride between two element that will be sumed

Definition: batched_reduction_traits.h:146

cutlass::reduction::BatchedReductionTraits::SubTile

SubTile_ SubTile

Definition: batched_reduction_traits.h:97

cutlass::reduction::BatchedReductionTraits::Params::d_c

ScalarC const * d_c

Definition: batched_reduction_traits.h:152

cutlass::reduction::BatchedReductionTraits::OutputTile

OutputTile_ OutputTile

Definition: batched_reduction_traits.h:95

cutlass::reduction::BatchedReductionTraits::Params::ldc

Index ldc

Definition: batched_reduction_traits.h:154

cutlass::reduction::BatchedReductionTraits::ScalarAlphaBeta

ScalarAlphaBeta_ ScalarAlphaBeta

The alpha beta type.

Definition: batched_reduction_traits.h:107

cutlass::reduction::BatchedReductionTraits::ScalarC

ScalarC_ ScalarC

Definition: batched_reduction_traits.h:103

cutlass::reduction::BatchedReductionTraits::Params::d_a

ScalarA const * d_a

Definition: batched_reduction_traits.h:148

cutlass::reduction::BatchedReduction

Definition: batched_reduction.h:52

cutlass::reduction::BatchedReductionTraits::ReductionSize

static const int ReductionSize

Definition: batched_reduction_traits.h:115

cutlass::reduction::BatchedReductionTraits::ScalarA

ScalarA_ ScalarA

The input pointer type.

Definition: batched_reduction_traits.h:101

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::Coord< 3 >

cutlass::reduction::BatchedReductionTraits::Params::alpha

ScalarAlphaBeta alpha

The alpha.

Definition: batched_reduction_traits.h:142

cutlass::reduction::BatchedReductionTraits::Functor

Functor_ Functor

Definition: batched_reduction_traits.h:119

cutlass::reduction::BatchedReductionTraits::maxInReg

static int const maxInReg

Definition: batched_reduction_traits.h:124

cutlass::reduction::BatchedReductionTraits::Params::d_d

ScalarD * d_d

Definition: batched_reduction_traits.h:156

cutlass::reduction::BatchedReductionTraits::Params::functorParams

Functor::Params functorParams

The functor params.

Definition: batched_reduction_traits.h:160

cutlass::reduction::BatchedReductionTraits::ThreadShapeMultiple2

static const bool ThreadShapeMultiple2

check if threadShape is multiple of 2.

Definition: batched_reduction_traits.h:117

cutlass::reduction::BatchedReductionTraits::Index

Index_ Index

The index.

Definition: batched_reduction_traits.h:111

cutlass::reduction::BatchedReductionTraits::maxOutReg

static int const maxOutReg

Definition: batched_reduction_traits.h:126

batched_reduction.h

Implements a software-pipelined efficient batched reduction. D = alpha * Reduction(A) + beta * C...

cutlass.h

Basic include for CUTLASS.

cutlass::reduction::BatchedReductionTraits

Definition: batched_reduction_traits.h:76

cutlass::reduction::BatchedReductionTraits::Params::initialize

CUTLASS_HOST_DEVICE int initialize(Index m_, Index n_, ScalarAlphaBeta alpha_, ScalarAlphaBeta beta_, long long int reduction_stride_, ScalarA const *d_a_, Index lda_, ScalarC const *d_c_, Index ldc_, ScalarD *d_d_, Index ldd_)

Initialize the parameters for 2D output tensor.

Definition: batched_reduction_traits.h:162

cutlass::reduction::BatchedReductionTraits::KernelClass

cutlass::reduction::BatchedReduction< This_ > KernelClass

The struct that consumes this Traits.

Definition: batched_reduction_traits.h:93


Generated by 1.8.11