Back to Cutlass

CUTLASS: gemm.h Source File

docs/include_2cutlass_2gemm_2gemm_8h_source.html

4.4.239.5 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

include/cutlass/gemm/gemm.h

Go to the documentation of this file.

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

28 #pragma once

29

30 #include "cutlass/cutlass.h"

31 #include "cutlass/coord.h"

32

33 namespace cutlass {

34 namespace gemm {

35

37

39 enum class Operand {

40kA,

41kB,

42kC,

43kD

44 };

45

47

49 template <

51int M = 1,

53int N = 1,

55int K = 1

56 >

57 struct GemmShape {

58static int const kM = M;

59static int const kN = N;

60static int const kK = K;

61

62static int const kMN = M * N;

63static int const kMK = M * K;

64static int const kKN = N * K;

65static int const kMNK = M * N * K;

66

67static int const kCount = kMNK;

68

69

70//

71// Static member functions

72//

73

75CUTLASS_HOST_DEVICE

76static Coord<3> toCoord() {

77return make_Coord(kM, kN, kK);

78 }

79 };

80

82

84 template <

86typename Shape

87 >

88 using GemmShapeTranspose = GemmShape<Shape::kN, Shape::kM, Shape::kK>;

89

91

94 struct GemmCoord : public Coord<3, int> {

95

97typedef int Index;

98

100typedef Coord<3, Index> Base;

101

103static int const kM = 0;

104

106static int const kN = 1;

107

109static int const kK = 2;

110

111//

112// Methods

113//

114

116CUTLASS_HOST_DEVICE

117GemmCoord() { }

118

120CUTLASS_HOST_DEVICE

121GemmCoord(Coord<3, Index> const &coord): Base(make_Coord(coord[0], coord[1], coord[2])) { }

122

124CUTLASS_HOST_DEVICE

125GemmCoord(Index m, Index n, Index k): Base(make_Coord(m, n, k)) { }

126

128CUTLASS_HOST_DEVICE

129 Index const & m() const { return this->at(kM); }

130

132CUTLASS_HOST_DEVICE

133 Index & m() { return this->at(kM); }

134

136CUTLASS_HOST_DEVICE

137 Index const & n() const { return this->at(kN); }

138

140CUTLASS_HOST_DEVICE

141 Index & n() { return this->at(kN); }

142

144CUTLASS_HOST_DEVICE

145 Index const & k() const { return this->at(kK); }

146

148CUTLASS_HOST_DEVICE

149 Index & k() { return this->at(kK); }

150

152CUTLASS_HOST_DEVICE

153Coord<3> mnk() const {

154return make_Coord(m(), n(), k());

155 }

156

158CUTLASS_HOST_DEVICE

159Coord<3> knm() const {

160return make_Coord(k(), n(), m());

161 }

162

164CUTLASS_HOST_DEVICE

165Coord<2> nm() const {

166return make_Coord(n(), m());

167 }

168

170CUTLASS_HOST_DEVICE

171Coord<2> mn() const {

172return make_Coord(m(), n());

173 }

174

176CUTLASS_HOST_DEVICE

177Coord<2> mk() const {

178return make_Coord(m(), k());

179 }

180

182CUTLASS_HOST_DEVICE

183Coord<2> km() const {

184return make_Coord(k(), m());

185 }

186

188CUTLASS_HOST_DEVICE

189Coord<2> nk() const {

190return make_Coord(n(), k());

191 }

192

194CUTLASS_HOST_DEVICE

195Coord<2> kn() const {

196return make_Coord(k(), n());

197 }

198

199//

200// Coord operators

201//

202

204CUTLASS_HOST_DEVICE

205GemmCoord operator+(Base const& b) const {

206return GemmCoord(Base::operator+(b));

207 }

208

210CUTLASS_HOST_DEVICE

211GemmCoord operator-(Base const& b) const {

212return GemmCoord(Base::operator-(b));

213 }

214

216CUTLASS_HOST_DEVICE

217GemmCoord operator*(Base const& b) const {

218return GemmCoord(Base::operator*(b));

219 }

220

222CUTLASS_HOST_DEVICE

223GemmCoord operator/(Base const& b) const {

224return GemmCoord(Base::operator/(b));

225 }

226

228CUTLASS_HOST_DEVICE

229GemmCoord& operator+=(Base const& b) {

230Base::operator+=(b);

231return *this;

232 }

233

235CUTLASS_HOST_DEVICE

236GemmCoord& operator-=(Base const& b) {

237Base::operator-=(b);

238return *this;

239 }

240

242CUTLASS_HOST_DEVICE

243GemmCoord& operator*=(Base const& b) {

244Base::operator*=(b);

245return *this;

246 }

247

249CUTLASS_HOST_DEVICE

250GemmCoord& operator/=(Base const& b) {

251Base::operator/=(b);

252return *this;

253 }

254 };

255

257

260 struct BatchedGemmCoord : public Coord<4, int> {

261

263typedef int Index;

264

266typedef Coord<4, Index> Base;

267

269static int const kM = 0;

270

272static int const kN = 1;

273

275static int const kK = 2;

276

278static int const kBatch = 3;

279

280//

281// Methods

282//

283

285CUTLASS_HOST_DEVICE

286BatchedGemmCoord() { }

287

289CUTLASS_HOST_DEVICE

290BatchedGemmCoord(Base const &coord): Base(coord) { }

291

293CUTLASS_HOST_DEVICE

294BatchedGemmCoord(Index m, Index n, Index k, Index b): Base(make_Coord(m, n, k, b)) { }

295

297CUTLASS_HOST_DEVICE

298 Index const & m() const { return this->at(kM); }

299

301CUTLASS_HOST_DEVICE

302 Index & m() { return this->at(kM); }

303

305CUTLASS_HOST_DEVICE

306 Index const & n() const { return this->at(kN); }

307

309CUTLASS_HOST_DEVICE

310 Index & n() { return this->at(kN); }

311

313CUTLASS_HOST_DEVICE

314 Index const & k() const { return this->at(kK); }

315

317CUTLASS_HOST_DEVICE

318 Index & k() { return this->at(kK); }

319

321CUTLASS_HOST_DEVICE

322 Index const & batch() const { return this->at(kBatch); }

323

325CUTLASS_HOST_DEVICE

326 Index & batch() { return this->at(kBatch); }

327

329CUTLASS_HOST_DEVICE

330GemmCoord mnk() const {

331return GemmCoord(m(), n(), k());

332 }

333

335CUTLASS_HOST_DEVICE

336Coord<4> mnkb() const {

337return make_Coord(m(), n(), k(), batch());

338 }

339

340//

341// Coord operators

342//

343

345CUTLASS_HOST_DEVICE

346BatchedGemmCoord operator+(Base const& b) const {

347return BatchedGemmCoord(Base::operator+(b));

348 }

349

351CUTLASS_HOST_DEVICE

352BatchedGemmCoord operator-(Base const& b) const {

353return BatchedGemmCoord(Base::operator-(b));

354 }

355

357CUTLASS_HOST_DEVICE

358BatchedGemmCoord operator*(Base const& b) const {

359return BatchedGemmCoord(Base::operator*(b));

360 }

361

363CUTLASS_HOST_DEVICE

364BatchedGemmCoord operator/(Base const& b) const {

365return BatchedGemmCoord(Base::operator/(b));

366 }

367

369CUTLASS_HOST_DEVICE

370BatchedGemmCoord& operator+=(Base const& b) {

371Base::operator+=(b);

372return *this;

373 }

374

376CUTLASS_HOST_DEVICE

377BatchedGemmCoord& operator-=(Base const& b) {

378Base::operator-=(b);

379return *this;

380 }

381

383CUTLASS_HOST_DEVICE

384BatchedGemmCoord& operator*=(Base const& b) {

385Base::operator*=(b);

386return *this;

387 }

388

390CUTLASS_HOST_DEVICE

391BatchedGemmCoord& operator/=(Base const& b) {

392Base::operator/=(b);

393return *this;

394 }

395 };

396

398

399 } // namespace gemm

400 } // namespace cutlass

cutlass::gemm::BatchedGemmCoord::mnkb

CUTLASS_HOST_DEVICE Coord< 4 > mnkb() const

Obtains a Coord<4> from BatchedGemmCoord.

Definition: include/cutlass/gemm/gemm.h:336

cutlass::gemm::BatchedGemmCoord::Base

Coord< 4, Index > Base

Base type is a Coord of rank=4.

Definition: include/cutlass/gemm/gemm.h:266

cutlass::gemm::BatchedGemmCoord::m

CUTLASS_HOST_DEVICE Index & m()

Returns reference to the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:302

cutlass

Definition: aligned_buffer.h:35

cutlass::gemm::BatchedGemmCoord::operator/

CUTLASS_HOST_DEVICE BatchedGemmCoord operator/(Base const &b) const

Element-wise division.

Definition: include/cutlass/gemm/gemm.h:364

cutlass::gemm::GemmCoord::operator/=

CUTLASS_HOST_DEVICE GemmCoord & operator/=(Base const &b)

In-place division.

Definition: include/cutlass/gemm/gemm.h:250

cutlass::gemm::GemmCoord::Index

int Index

Integer-valued index.

Definition: include/cutlass/gemm/gemm.h:97

cutlass::gemm::GemmCoord::GemmCoord

CUTLASS_HOST_DEVICE GemmCoord(Coord< 3, Index > const &coord)

Constructs from Coord<3> and a batch.

Definition: include/cutlass/gemm/gemm.h:121

cutlass::gemm::BatchedGemmCoord::mnk

CUTLASS_HOST_DEVICE GemmCoord mnk() const

Obtains a GemmCoord from BatchedGemmCoord.

Definition: include/cutlass/gemm/gemm.h:330

cutlass::gemm::GemmCoord::operator+

CUTLASS_HOST_DEVICE GemmCoord operator+(Base const &b) const

Element-wise addition.

Definition: include/cutlass/gemm/gemm.h:205

coord.h

A Coord is a coordinate of arbitrary rank into a tensor or matrix.

cutlass::make_Coord

CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)

Helper to make a 2-element coordinate.

Definition: coord.h:387

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index & m()

Returns reference to the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:133

cutlass::gemm::Operand

Operand

GEMM operand enumeration: D = A * B + C.

Definition: include/cutlass/gemm/gemm.h:39

cutlass::gemm::GemmCoord

Definition: include/cutlass/gemm/gemm.h:94

cutlass::gemm::GemmCoord::mn

CUTLASS_HOST_DEVICE Coord< 2 > mn() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:171

cutlass::operator/=

CUTLASS_HOST_DEVICE half_t & operator/=(half_t &lhs, half_t const &rhs)

Definition: half.h:684

cutlass::gemm::BatchedGemmCoord::operator-

CUTLASS_HOST_DEVICE BatchedGemmCoord operator-(Base const &b) const

Element-wise subtraction.

Definition: include/cutlass/gemm/gemm.h:352

cutlass::gemm::BatchedGemmCoord::batch

CUTLASS_HOST_DEVICE Index & batch()

Returns reference to the GEMM batch coordinate.

Definition: include/cutlass/gemm/gemm.h:326

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:137

cutlass::gemm::GemmCoord::nm

CUTLASS_HOST_DEVICE Coord< 2 > nm() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:165

cutlass::gemm::BatchedGemmCoord::operator+

CUTLASS_HOST_DEVICE BatchedGemmCoord operator+(Base const &b) const

Element-wise addition.

Definition: include/cutlass/gemm/gemm.h:346

cutlass::gemm::GemmCoord::k

CUTLASS_HOST_DEVICE Index & k()

Returns reference to the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:149

cutlass::gemm::BatchedGemmCoord::n

CUTLASS_HOST_DEVICE Index & n()

Returns reference to the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:310

cutlass::gemm::GemmCoord::operator/

CUTLASS_HOST_DEVICE GemmCoord operator/(Base const &b) const

Element-wise division.

Definition: include/cutlass/gemm/gemm.h:223

cutlass::gemm::GemmCoord::GemmCoord

CUTLASS_HOST_DEVICE GemmCoord(Index m, Index n, Index k)

Helper to construct from a K, N, M, batch variables.

Definition: include/cutlass/gemm/gemm.h:125

cutlass::gemm::BatchedGemmCoord::Index

int Index

Integer-valued index.

Definition: include/cutlass/gemm/gemm.h:263

cutlass::operator+=

CUTLASS_HOST_DEVICE half_t & operator+=(half_t &lhs, half_t const &rhs)

Definition: half.h:654

cutlass::gemm::GemmCoord::nk

CUTLASS_HOST_DEVICE Coord< 2 > nk() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:189

cutlass::gemm::BatchedGemmCoord::BatchedGemmCoord

CUTLASS_HOST_DEVICE BatchedGemmCoord(Base const &coord)

Constructs from Coord<4>

Definition: include/cutlass/gemm/gemm.h:290

cutlass::gemm::GemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:145

cutlass::gemm::Operand::kC

B multiplicand.

cutlass::operator-=

CUTLASS_HOST_DEVICE half_t & operator-=(half_t &lhs, half_t const &rhs)

Definition: half.h:664

cutlass::gemm::GemmCoord::Base

Coord< 3, Index > Base

Base type is a Coord of rank=4.

Definition: include/cutlass/gemm/gemm.h:100

cutlass::gemm::Operand::kA

cutlass::gemm::GemmShape::toCoord

static CUTLASS_HOST_DEVICE Coord< 3 > toCoord()

Returns a Coord object.

Definition: include/cutlass/gemm/gemm.h:76

cutlass::gemm::GemmCoord::km

CUTLASS_HOST_DEVICE Coord< 2 > km() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:183

cutlass::gemm::BatchedGemmCoord

Definition: include/cutlass/gemm/gemm.h:260

cutlass::gemm::GemmCoord::mnk

CUTLASS_HOST_DEVICE Coord< 3 > mnk() const

Obtains a Coord<3> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:153

cutlass::gemm::GemmCoord::operator-

CUTLASS_HOST_DEVICE GemmCoord operator-(Base const &b) const

Element-wise subtraction.

Definition: include/cutlass/gemm/gemm.h:211

cutlass::gemm::BatchedGemmCoord::batch

CUTLASS_HOST_DEVICE Index const & batch() const

Returns the GEMM batch coordinate.

Definition: include/cutlass/gemm/gemm.h:322

cutlass::gemm::BatchedGemmCoord::operator*=

CUTLASS_HOST_DEVICE BatchedGemmCoord & operator*=(Base const &b)

In-place multiplication.

Definition: include/cutlass/gemm/gemm.h:384

cutlass::gemm::BatchedGemmCoord::operator*

CUTLASS_HOST_DEVICE BatchedGemmCoord operator*(Base const &b) const

Element-wise multiplication.

Definition: include/cutlass/gemm/gemm.h:358

CUTLASS_HOST_DEVICE

#define CUTLASS_HOST_DEVICE

Definition: cutlass.h:89

cutlass::gemm::BatchedGemmCoord::k

CUTLASS_HOST_DEVICE Index const & k() const

Returns the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:314

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

cutlass::gemm::GemmCoord::mk

CUTLASS_HOST_DEVICE Coord< 2 > mk() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:177

cutlass::gemm::BatchedGemmCoord::BatchedGemmCoord

CUTLASS_HOST_DEVICE BatchedGemmCoord()

Default ctor.

Definition: include/cutlass/gemm/gemm.h:286

cutlass::gemm::BatchedGemmCoord::k

CUTLASS_HOST_DEVICE Index & k()

Returns reference to the GEMM K coordinate.

Definition: include/cutlass/gemm/gemm.h:318

cutlass::gemm::Operand::kD

Source accumulator.

cutlass::gemm::GemmCoord::knm

CUTLASS_HOST_DEVICE Coord< 3 > knm() const

Obtains a Coord<3> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:159

cutlass::gemm::GemmCoord::operator+=

CUTLASS_HOST_DEVICE GemmCoord & operator+=(Base const &b)

In-place addition.

Definition: include/cutlass/gemm/gemm.h:229

cutlass::operator*=

CUTLASS_HOST_DEVICE half_t & operator*=(half_t &lhs, half_t const &rhs)

Definition: half.h:674

cutlass::Coord< 3 >

cutlass::gemm::GemmCoord::GemmCoord

CUTLASS_HOST_DEVICE GemmCoord()

Default ctor.

Definition: include/cutlass/gemm/gemm.h:117

cutlass::gemm::BatchedGemmCoord::operator/=

CUTLASS_HOST_DEVICE BatchedGemmCoord & operator/=(Base const &b)

In-place division.

Definition: include/cutlass/gemm/gemm.h:391

cutlass::gemm::GemmCoord::operator-=

CUTLASS_HOST_DEVICE GemmCoord & operator-=(Base const &b)

In-place subtraction.

Definition: include/cutlass/gemm/gemm.h:236

cutlass::gemm::BatchedGemmCoord::operator+=

CUTLASS_HOST_DEVICE BatchedGemmCoord & operator+=(Base const &b)

In-place addition.

Definition: include/cutlass/gemm/gemm.h:370

cutlass::gemm::GemmCoord::kn

CUTLASS_HOST_DEVICE Coord< 2 > kn() const

Obtains a Coord<2> from GemmCoord.

Definition: include/cutlass/gemm/gemm.h:195

cutlass::gemm::BatchedGemmCoord::n

CUTLASS_HOST_DEVICE Index const & n() const

Returns the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:306

cutlass::gemm::BatchedGemmCoord::BatchedGemmCoord

CUTLASS_HOST_DEVICE BatchedGemmCoord(Index m, Index n, Index k, Index b)

Helper to construct from a K, N, M, and batch variables.

Definition: include/cutlass/gemm/gemm.h:294

cutlass::gemm::GemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:129

cutlass::gemm::GemmCoord::operator*=

CUTLASS_HOST_DEVICE GemmCoord & operator*=(Base const &b)

In-place multiplication.

Definition: include/cutlass/gemm/gemm.h:243

cutlass::gemm::Operand::kB

A multiplicand.

cutlass::gemm::BatchedGemmCoord::operator-=

CUTLASS_HOST_DEVICE BatchedGemmCoord & operator-=(Base const &b)

In-place subtraction.

Definition: include/cutlass/gemm/gemm.h:377

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::GemmCoord::n

CUTLASS_HOST_DEVICE Index & n()

Returns reference to the GEMM N coordinate.

Definition: include/cutlass/gemm/gemm.h:141

cutlass::gemm::GemmCoord::operator*

CUTLASS_HOST_DEVICE GemmCoord operator*(Base const &b) const

Element-wise multiplication.

Definition: include/cutlass/gemm/gemm.h:217

cutlass::gemm::BatchedGemmCoord::m

CUTLASS_HOST_DEVICE Index const & m() const

Returns the GEMM M coordinate.

Definition: include/cutlass/gemm/gemm.h:298


Generated by 1.8.11