Back to Cutlass

CUTLASS: fast_math.h Source File

docs/fast__math_8h_source.html

4.4.214.7 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

fast_math.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

25

26 #pragma once

27

28 #include <cstdint>

29 #include "cutlass/cutlass.h"

30

36 namespace cutlass {

37

38 /******************************************************************************

39 * Static math utilities

40 ******************************************************************************/

41

45 template <int N>

46 struct is_pow2 {

47static bool const value = ((N & (N - 1)) == 0);

48 };

49

53 template <int N, int CurrentVal = N, int Count = 0>

54 struct log2_down {

56enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };

57 };

58

59 // Base case

60 template <int N, int Count>

61 struct log2_down<N, 1, Count> {

62enum { value = Count };

63 };

64

68 template <int N, int CurrentVal = N, int Count = 0>

69 struct log2_up {

71enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };

72 };

73

74 // Base case

75 template <int N, int Count>

76 struct log2_up<N, 1, Count> {

77enum { value = ((1 << Count) < N) ? Count + 1 : Count };

78 };

79

83 template <int N>

84 struct sqrt_est {

85enum { value = 1 << (log2_up<N>::value / 2) };

86 };

87

92 template <int Dividend, int Divisor>

93 struct divide_assert {

94enum { value = Dividend / Divisor };

95

96static_assert((Dividend % Divisor == 0), "Not an even multiple");

97 };

98

99 /******************************************************************************

100 * Rounding

101 ******************************************************************************/

102

106 template <typename dividend_t, typename divisor_t>

107 CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {

108return ((dividend + divisor - 1) / divisor) * divisor;

109 }

110

114 template <typename value_t>

115 CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {

116for (;;) {

117if (a == 0) return b;

118 b %= a;

119if (b == 0) return a;

120 a %= b;

121 }

122 }

123

127 template <typename value_t>

128 CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {

129 value_t temp = gcd(a, b);

130

131return temp ? (a / temp * b) : 0;

132 }

133

139 template <typename value_t>

140 CUTLASS_HOST_DEVICE value_t clz(value_t x) {

141for (int i = 31; i >= 0; --i) {

142if ((1 << i) & x) return 31 - i;

143 }

144return 32;

145 }

146

147 template <typename value_t>

148 CUTLASS_HOST_DEVICE value_t find_log2(value_t x) {

149int a = int(31 - clz(x));

150 a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2.

151return a;

152 }

153

154

158 CUTLASS_HOST_DEVICE

159 void find_divisor(unsigned int& mul, unsigned int& shr, unsigned int denom) {

160if (denom == 1) {

161 mul = 0;

162 shr = 0;

163 } else {

164unsigned int p = 31 + find_log2(denom);

165unsigned m = unsigned(((1ull << p) + unsigned(denom) - 1) / unsigned(denom));

166

167 mul = m;

168 shr = p - 32;

169 }

170 }

171

175 CUTLASS_HOST_DEVICE

176 void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigned int shr) {

177

178 #if defined(__CUDA_ARCH__)

179// Use IMUL.HI if div != 1, else simply copy the source.

180 quo = (div != 1) ? __umulhi(src, mul) >> shr : src;

181 #else

182 quo = int((div != 1) ? int(src * mul) >> shr : src);

183 #endif

184

185// The remainder.

186 rem = src - (quo * div);

187

188 }

189

190 // For long int input

191 CUTLASS_HOST_DEVICE

192 void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, unsigned int shr) {

193

194 #if defined(__CUDA_ARCH__)

195// Use IMUL.HI if div != 1, else simply copy the source.

196 quo = (div != 1) ? __umulhi(src, mul) >> shr : src;

197 #else

198 quo = int((div != 1) ? (src * mul) >> shr : src);

199 #endif

200// The remainder.

201 rem = src - (quo * div);

202 }

203

204 /******************************************************************************

205 * Min/Max

206 ******************************************************************************/

207

208 template <int A, int B>

209 struct Min {

210static int const kValue = (A < B) ? A : B;

211 };

212

213 template <int A, int B>

214 struct Max {

215static int const kValue = (A > B) ? A : B;

216 };

217

218 CUTLASS_HOST_DEVICE

219 constexpr int const_min(int a, int b) {

220return (b < a ? b : a);

221 }

222

223 CUTLASS_HOST_DEVICE

224 constexpr int const_max(int a, int b) {

225return (b > a ? b : a);

226 }

227

228 } // namespace cutlass

cutlass

Definition: aligned_buffer.h:35

constexpr

#define constexpr

Definition: platform.h:137

cutlass::fast_divmod

CUTLASS_HOST_DEVICE void fast_divmod(int &quo, int &rem, int src, int div, unsigned int mul, unsigned int shr)

Definition: fast_math.h:176

cutlass::find_log2

CUTLASS_HOST_DEVICE value_t find_log2(value_t x)

Definition: fast_math.h:148

cutlass::log2_down

Definition: fast_math.h:54

cutlass::Min

Definition: fast_math.h:209

cutlass::const_max

CUTLASS_HOST_DEVICE constexpr int const_max(int a, int b)

Definition: fast_math.h:224

cutlass::is_pow2::value

static bool const value

Definition: fast_math.h:47

cutlass::lcm

CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b)

Definition: fast_math.h:128

cutlass::round_nearest

CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor)

Definition: fast_math.h:107

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::Max

Definition: fast_math.h:214

cutlass::find_divisor

CUTLASS_HOST_DEVICE void find_divisor(unsigned int &mul, unsigned int &shr, unsigned int denom)

Definition: fast_math.h:159

cutlass::gcd

CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b)

Definition: fast_math.h:115

cutlass::divide_assert

Definition: fast_math.h:93

cutlass::log2_up

Definition: fast_math.h:69

cutlass::clz

CUTLASS_HOST_DEVICE value_t clz(value_t x)

Definition: fast_math.h:140

cutlass::is_pow2

Definition: fast_math.h:46

cutlass::const_min

CUTLASS_HOST_DEVICE constexpr int const_min(int a, int b)

Definition: fast_math.h:219

cutlass.h

Basic include for CUTLASS.

cutlass::sqrt_est

Definition: fast_math.h:84


Generated by 1.8.11