docs/include_2cutlass_2gemm_2gemm_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
include/cutlass/gemm/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
28 #pragma once
29
30 #include "cutlass/cutlass.h"
31 #include "cutlass/coord.h"
32
33 namespace cutlass {
34 namespace gemm {
35
37
40kA,
41kB,
42kC,
43kD
44 };
45
47
49 template <
51int M = 1,
53int N = 1,
55int K = 1
56 >
58static int const kM = M;
59static int const kN = N;
60static int const kK = K;
61
62static int const kMN = M * N;
63static int const kMK = M * K;
64static int const kKN = N * K;
65static int const kMNK = M * N * K;
66
67static int const kCount = kMNK;
68
69
70//
71// Static member functions
72//
73
77return make_Coord(kM, kN, kK);
78 }
79 };
80
82
84 template <
86typename Shape
87 >
88 using GemmShapeTranspose = GemmShape<Shape::kN, Shape::kM, Shape::kK>;
89
91
94 struct GemmCoord : public Coord<3, int> {
95
98
100typedef Coord<3, Index> Base;
101
103static int const kM = 0;
104
106static int const kN = 1;
107
109static int const kK = 2;
110
111//
112// Methods
113//
114
118
121GemmCoord(Coord<3, Index> const &coord): Base(make_Coord(coord[0], coord[1], coord[2])) { }
122
125GemmCoord(Index m, Index n, Index k): Base(make_Coord(m, n, k)) { }
126
129 Index const & m() const { return this->at(kM); }
130
133 Index & m() { return this->at(kM); }
134
137 Index const & n() const { return this->at(kN); }
138
141 Index & n() { return this->at(kN); }
142
145 Index const & k() const { return this->at(kK); }
146
149 Index & k() { return this->at(kK); }
150
154return make_Coord(m(), n(), k());
155 }
156
160return make_Coord(k(), n(), m());
161 }
162
166return make_Coord(n(), m());
167 }
168
172return make_Coord(m(), n());
173 }
174
178return make_Coord(m(), k());
179 }
180
184return make_Coord(k(), m());
185 }
186
190return make_Coord(n(), k());
191 }
192
196return make_Coord(k(), n());
197 }
198
199//
200// Coord operators
201//
202
205GemmCoord operator+(Base const& b) const {
206return GemmCoord(Base::operator+(b));
207 }
208
211GemmCoord operator-(Base const& b) const {
212return GemmCoord(Base::operator-(b));
213 }
214
217GemmCoord operator*(Base const& b) const {
218return GemmCoord(Base::operator*(b));
219 }
220
223GemmCoord operator/(Base const& b) const {
224return GemmCoord(Base::operator/(b));
225 }
226
229GemmCoord& operator+=(Base const& b) {
230Base::operator+=(b);
231return *this;
232 }
233
236GemmCoord& operator-=(Base const& b) {
237Base::operator-=(b);
238return *this;
239 }
240
243GemmCoord& operator*=(Base const& b) {
244Base::operator*=(b);
245return *this;
246 }
247
250GemmCoord& operator/=(Base const& b) {
251Base::operator/=(b);
252return *this;
253 }
254 };
255
257
260 struct BatchedGemmCoord : public Coord<4, int> {
261
264
266typedef Coord<4, Index> Base;
267
269static int const kM = 0;
270
272static int const kN = 1;
273
275static int const kK = 2;
276
278static int const kBatch = 3;
279
280//
281// Methods
282//
283
286BatchedGemmCoord() { }
287
290BatchedGemmCoord(Base const &coord): Base(coord) { }
291
294BatchedGemmCoord(Index m, Index n, Index k, Index b): Base(make_Coord(m, n, k, b)) { }
295
298 Index const & m() const { return this->at(kM); }
299
302 Index & m() { return this->at(kM); }
303
306 Index const & n() const { return this->at(kN); }
307
310 Index & n() { return this->at(kN); }
311
314 Index const & k() const { return this->at(kK); }
315
318 Index & k() { return this->at(kK); }
319
322 Index const & batch() const { return this->at(kBatch); }
323
326 Index & batch() { return this->at(kBatch); }
327
331return GemmCoord(m(), n(), k());
332 }
333
337return make_Coord(m(), n(), k(), batch());
338 }
339
340//
341// Coord operators
342//
343
346BatchedGemmCoord operator+(Base const& b) const {
347return BatchedGemmCoord(Base::operator+(b));
348 }
349
352BatchedGemmCoord operator-(Base const& b) const {
353return BatchedGemmCoord(Base::operator-(b));
354 }
355
358BatchedGemmCoord operator*(Base const& b) const {
359return BatchedGemmCoord(Base::operator*(b));
360 }
361
364BatchedGemmCoord operator/(Base const& b) const {
365return BatchedGemmCoord(Base::operator/(b));
366 }
367
370BatchedGemmCoord& operator+=(Base const& b) {
371Base::operator+=(b);
372return *this;
373 }
374
377BatchedGemmCoord& operator-=(Base const& b) {
378Base::operator-=(b);
379return *this;
380 }
381
384BatchedGemmCoord& operator*=(Base const& b) {
385Base::operator*=(b);
386return *this;
387 }
388
391BatchedGemmCoord& operator/=(Base const& b) {
392Base::operator/=(b);
393return *this;
394 }
395 };
396
398
399 } // namespace gemm
400 } // namespace cutlass
cutlass::gemm::BatchedGemmCoord::mnkb
CUTLASS_HOST_DEVICE Coord< 4 > mnkb() const
Obtains a Coord<4> from BatchedGemmCoord.
Definition: include/cutlass/gemm/gemm.h:336
cutlass::gemm::BatchedGemmCoord::Base
Coord< 4, Index > Base
Base type is a Coord of rank=4.
Definition: include/cutlass/gemm/gemm.h:266
cutlass::gemm::BatchedGemmCoord::m
CUTLASS_HOST_DEVICE Index & m()
Returns reference to the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:302
Definition: aligned_buffer.h:35
cutlass::gemm::BatchedGemmCoord::operator/
CUTLASS_HOST_DEVICE BatchedGemmCoord operator/(Base const &b) const
Element-wise division.
Definition: include/cutlass/gemm/gemm.h:364
cutlass::gemm::GemmCoord::operator/=
CUTLASS_HOST_DEVICE GemmCoord & operator/=(Base const &b)
In-place division.
Definition: include/cutlass/gemm/gemm.h:250
cutlass::gemm::GemmCoord::Index
int Index
Integer-valued index.
Definition: include/cutlass/gemm/gemm.h:97
cutlass::gemm::GemmCoord::GemmCoord
CUTLASS_HOST_DEVICE GemmCoord(Coord< 3, Index > const &coord)
Constructs from Coord<3> and a batch.
Definition: include/cutlass/gemm/gemm.h:121
cutlass::gemm::BatchedGemmCoord::mnk
CUTLASS_HOST_DEVICE GemmCoord mnk() const
Obtains a GemmCoord from BatchedGemmCoord.
Definition: include/cutlass/gemm/gemm.h:330
cutlass::gemm::GemmCoord::operator+
CUTLASS_HOST_DEVICE GemmCoord operator+(Base const &b) const
Element-wise addition.
Definition: include/cutlass/gemm/gemm.h:205
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
CUTLASS_HOST_DEVICE Index & m()
Returns reference to the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:133
Operand
GEMM operand enumeration: D = A * B + C.
Definition: include/cutlass/gemm/gemm.h:39
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
CUTLASS_HOST_DEVICE half_t & operator/=(half_t &lhs, half_t const &rhs)
Definition: half.h:684
cutlass::gemm::BatchedGemmCoord::operator-
CUTLASS_HOST_DEVICE BatchedGemmCoord operator-(Base const &b) const
Element-wise subtraction.
Definition: include/cutlass/gemm/gemm.h:352
cutlass::gemm::BatchedGemmCoord::batch
CUTLASS_HOST_DEVICE Index & batch()
Returns reference to the GEMM batch coordinate.
Definition: include/cutlass/gemm/gemm.h:326
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
CUTLASS_HOST_DEVICE Coord< 2 > nm() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:165
cutlass::gemm::BatchedGemmCoord::operator+
CUTLASS_HOST_DEVICE BatchedGemmCoord operator+(Base const &b) const
Element-wise addition.
Definition: include/cutlass/gemm/gemm.h:346
CUTLASS_HOST_DEVICE Index & k()
Returns reference to the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:149
cutlass::gemm::BatchedGemmCoord::n
CUTLASS_HOST_DEVICE Index & n()
Returns reference to the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:310
cutlass::gemm::GemmCoord::operator/
CUTLASS_HOST_DEVICE GemmCoord operator/(Base const &b) const
Element-wise division.
Definition: include/cutlass/gemm/gemm.h:223
cutlass::gemm::GemmCoord::GemmCoord
CUTLASS_HOST_DEVICE GemmCoord(Index m, Index n, Index k)
Helper to construct from a K, N, M, batch variables.
Definition: include/cutlass/gemm/gemm.h:125
cutlass::gemm::BatchedGemmCoord::Index
int Index
Integer-valued index.
Definition: include/cutlass/gemm/gemm.h:263
CUTLASS_HOST_DEVICE half_t & operator+=(half_t &lhs, half_t const &rhs)
Definition: half.h:654
CUTLASS_HOST_DEVICE Coord< 2 > nk() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:189
cutlass::gemm::BatchedGemmCoord::BatchedGemmCoord
CUTLASS_HOST_DEVICE BatchedGemmCoord(Base const &coord)
Constructs from Coord<4>
Definition: include/cutlass/gemm/gemm.h:290
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
B multiplicand.
CUTLASS_HOST_DEVICE half_t & operator-=(half_t &lhs, half_t const &rhs)
Definition: half.h:664
cutlass::gemm::GemmCoord::Base
Coord< 3, Index > Base
Base type is a Coord of rank=4.
Definition: include/cutlass/gemm/gemm.h:100
cutlass::gemm::GemmShape::toCoord
static CUTLASS_HOST_DEVICE Coord< 3 > toCoord()
Returns a Coord object.
Definition: include/cutlass/gemm/gemm.h:76
CUTLASS_HOST_DEVICE Coord< 2 > km() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:183
cutlass::gemm::BatchedGemmCoord
Definition: include/cutlass/gemm/gemm.h:260
CUTLASS_HOST_DEVICE Coord< 3 > mnk() const
Obtains a Coord<3> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:153
cutlass::gemm::GemmCoord::operator-
CUTLASS_HOST_DEVICE GemmCoord operator-(Base const &b) const
Element-wise subtraction.
Definition: include/cutlass/gemm/gemm.h:211
cutlass::gemm::BatchedGemmCoord::batch
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: include/cutlass/gemm/gemm.h:322
cutlass::gemm::BatchedGemmCoord::operator*=
CUTLASS_HOST_DEVICE BatchedGemmCoord & operator*=(Base const &b)
In-place multiplication.
Definition: include/cutlass/gemm/gemm.h:384
cutlass::gemm::BatchedGemmCoord::operator*
CUTLASS_HOST_DEVICE BatchedGemmCoord operator*(Base const &b) const
Element-wise multiplication.
Definition: include/cutlass/gemm/gemm.h:358
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::gemm::BatchedGemmCoord::k
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:314
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_HOST_DEVICE Coord< 2 > mk() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:177
cutlass::gemm::BatchedGemmCoord::BatchedGemmCoord
CUTLASS_HOST_DEVICE BatchedGemmCoord()
Default ctor.
Definition: include/cutlass/gemm/gemm.h:286
cutlass::gemm::BatchedGemmCoord::k
CUTLASS_HOST_DEVICE Index & k()
Returns reference to the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:318
Source accumulator.
CUTLASS_HOST_DEVICE Coord< 3 > knm() const
Obtains a Coord<3> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:159
cutlass::gemm::GemmCoord::operator+=
CUTLASS_HOST_DEVICE GemmCoord & operator+=(Base const &b)
In-place addition.
Definition: include/cutlass/gemm/gemm.h:229
CUTLASS_HOST_DEVICE half_t & operator*=(half_t &lhs, half_t const &rhs)
Definition: half.h:674
cutlass::gemm::GemmCoord::GemmCoord
CUTLASS_HOST_DEVICE GemmCoord()
Default ctor.
Definition: include/cutlass/gemm/gemm.h:117
cutlass::gemm::BatchedGemmCoord::operator/=
CUTLASS_HOST_DEVICE BatchedGemmCoord & operator/=(Base const &b)
In-place division.
Definition: include/cutlass/gemm/gemm.h:391
cutlass::gemm::GemmCoord::operator-=
CUTLASS_HOST_DEVICE GemmCoord & operator-=(Base const &b)
In-place subtraction.
Definition: include/cutlass/gemm/gemm.h:236
cutlass::gemm::BatchedGemmCoord::operator+=
CUTLASS_HOST_DEVICE BatchedGemmCoord & operator+=(Base const &b)
In-place addition.
Definition: include/cutlass/gemm/gemm.h:370
CUTLASS_HOST_DEVICE Coord< 2 > kn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:195
cutlass::gemm::BatchedGemmCoord::n
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:306
cutlass::gemm::BatchedGemmCoord::BatchedGemmCoord
CUTLASS_HOST_DEVICE BatchedGemmCoord(Index m, Index n, Index k, Index b)
Helper to construct from a K, N, M, and batch variables.
Definition: include/cutlass/gemm/gemm.h:294
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
cutlass::gemm::GemmCoord::operator*=
CUTLASS_HOST_DEVICE GemmCoord & operator*=(Base const &b)
In-place multiplication.
Definition: include/cutlass/gemm/gemm.h:243
A multiplicand.
cutlass::gemm::BatchedGemmCoord::operator-=
CUTLASS_HOST_DEVICE BatchedGemmCoord & operator-=(Base const &b)
In-place subtraction.
Definition: include/cutlass/gemm/gemm.h:377
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE Index & n()
Returns reference to the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:141
cutlass::gemm::GemmCoord::operator*
CUTLASS_HOST_DEVICE GemmCoord operator*(Base const &b) const
Element-wise multiplication.
Definition: include/cutlass/gemm/gemm.h:217
cutlass::gemm::BatchedGemmCoord::m
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:298
Generated by 1.8.11