docs/gemm_2thread_2mma__sm50_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
gemm/thread/mma_sm50.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/tensor_ref.h"
33 #include "cutlass/layout/matrix.h"
34 #include "cutlass/arch/mma.h"
35 #include "cutlass/gemm/gemm.h"
36 #include "cutlass/gemm/thread/mma.h"
37
39
40 namespace cutlass {
41 namespace gemm {
42 namespace thread {
43
45
47 template <
49typename Shape_,
51typename ElementA_,
53typename LayoutA_,
55typename ElementB_,
57typename LayoutB_,
59typename ElementC_,
61typename LayoutC_,
63typename Operator_
64 >
65 struct MmaGeneric {
66
69
72
75
78
81
84
87
90
92using FragmentA = Array<ElementA, Shape::kMK>;
93
95using FragmentB = Array<ElementB, Shape::kKN>;
96
98using FragmentC = Array<ElementC, Shape::kMN>;
99
103 1,
108
109//
110// Methods
111//
112
115void operator()(
116FragmentC & D,
117FragmentA const & A,
118FragmentB const & B,
119FragmentC const & C) {
120
121TensorRef<ElementA const, LayoutA> a_ref(
122 reinterpret_cast<ElementA const *>(&A), LayoutA::packed({Shape::kM, Shape::kK}));
123
124TensorRef<ElementB const, LayoutB> b_ref(
125 reinterpret_cast<ElementB const *>(&B), LayoutB::packed({Shape::kK, Shape::kN}));
126
127TensorRef<ElementC, LayoutC> d_ref(
128 reinterpret_cast<ElementC *>(&D), LayoutC::packed({ Shape::kM, Shape::kN }));
129
130MmaOp mma_op;
131
132// Copy accumulators
133 D = C;
134
135// Compute matrix product
137for (int k = 0; k < Shape::kK; ++k) {
138
140for (int n = 0; n < Shape::kN; ++n) {
141
143for (int m = 0; m < Shape::kM; ++m) {
144
145int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m;
146
147MatrixCoord mn(m_serpentine, n);
148MatrixCoord mk(m_serpentine, k);
149MatrixCoord kn(k, n);
150
151 Array<ElementC, 1> d;
152 Array<ElementA, 1> a;
153 Array<ElementB, 1> b;
154
155 d[0] = d_ref.at(mn);
156 a[0] = a_ref.at(mk);
157 b[0] = b_ref.at(kn);
158
159 mma_op(d, a, b, d);
160
161 d_ref.at(mn) = d[0];
162 }
163 }
164 }
165 }
166 };
167
168
170
172 template <
174typename Shape_,
176typename ElementA_,
178typename LayoutA_,
180typename ElementB_,
182typename LayoutB_,
184typename ElementC_,
186typename LayoutC_
187 >
[188](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html) struct Mma<
189 Shape_,
190 ElementA_,
191 LayoutA_,
192 ElementB_,
193 LayoutB_,
194 ElementC_,
195 LayoutC_,
196 arch::OpMultiplyAdd,
197 bool> {
198
[200](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#aeef7c1c07c481fb13e3ab2025d22133a)using [Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#aeef7c1c07c481fb13e3ab2025d22133a) = Shape_;
201
[203](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a400d6fc8296c16b6277c3d7ad650e7c1)using [ElementA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a400d6fc8296c16b6277c3d7ad650e7c1) = ElementA_;
204
[206](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#af1c453f655d29855f026ab6dfc8f7ae9)using [LayoutA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#af1c453f655d29855f026ab6dfc8f7ae9) = LayoutA_;
207
[209](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#aa1e00de6ae05673351b0c7bba92827ab)using [ElementB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#aa1e00de6ae05673351b0c7bba92827ab) = ElementB_;
210
[212](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a0e71571693f24560bdba20fbd2ea1a77)using [LayoutB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a0e71571693f24560bdba20fbd2ea1a77) = LayoutB_;
213
[215](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a58b3c904716c54edb20b1ae1ae0bc715)using [ElementC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a58b3c904716c54edb20b1ae1ae0bc715) = ElementC_;
216
[218](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a08a2137eb47c1caa00adaf3572c706a0)using [LayoutC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a08a2137eb47c1caa00adaf3572c706a0) = LayoutC_;
219
[221](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a08207ff2d73d653194a061153edc27a9)using [Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a08207ff2d73d653194a061153edc27a9) = arch::OpMultiplyAdd;
222
[224](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a66abc782808b6b3e68518aff43a0b200)using [FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a66abc782808b6b3e68518aff43a0b200) = Array<ElementA, Shape::kMK>;
225
[227](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a2deaf8959c027ab4aca92630b85f5211)using [FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a2deaf8959c027ab4aca92630b85f5211) = Array<ElementB, Shape::kKN>;
228
[230](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a6109558276e8c66a4d3e9ad53fb046d8)using [FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a6109558276e8c66a4d3e9ad53fb046d8) = Array<ElementC, Shape::kMN>;
231
232//
233// Methods
234//
235
[238](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#ae98fd835ed4750d4f22d7e4e50b5e59f)void [operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape___00_01ElementA 00_01LayoutA 00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#ae98fd835ed4750d4f22d7e4e50b5e59f)(
239[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a6109558276e8c66a4d3e9ad53fb046d8) & D,
240[FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a66abc782808b6b3e68518aff43a0b200) const & A,
241[FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a2deaf8959c027ab4aca92630b85f5211) const & B,
242[FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a6109558276e8c66a4d3e9ad53fb046d8) const & C) {
243
244MmaGeneric<
245Shape,
246ElementA,
247LayoutA,
248ElementB,
249LayoutB,
250ElementC,
251LayoutC,
252[Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a08207ff2d73d653194a061153edc27a9)> mma;
253
254 mma(D, A, B, C);
255 }
256 };
257
259
260 } // namespace thread
261 } // namespace gemm
262 } // namespace cutlass
263
cutlass::gemm::thread::MmaGeneric::Operator
Operator_ Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm50.h:89
Definition: aligned_buffer.h:35
cutlass::gemm::thread::MmaGeneric::FragmentB
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm50.h:95
Defines a structure containing strides, bounds, and a pointer to tensor data.
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::ElementA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a400d6fc8296c16b6277c3d7ad650e7c1)
ElementA_ ElementA
Data type of operand A.
Definition: gemm/thread/mma_sm50.h:203
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::ElementB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#aa1e00de6ae05673351b0c7bba92827ab)
ElementB_ ElementB
Data type of operand B.
Definition: gemm/thread/mma_sm50.h:209
cutlass::gemm::thread::MmaGeneric::FragmentC
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm50.h:98
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::operator()](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#ae98fd835ed4750d4f22d7e4e50b5e59f)
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm50.h:238
cutlass::gemm::thread::MmaGeneric::LayoutA
LayoutA_ LayoutA
Layout of A matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:74
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::FragmentC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a6109558276e8c66a4d3e9ad53fb046d8)
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm50.h:230
cutlass::gemm::thread::MmaGeneric::LayoutC
LayoutC_ LayoutC
Layout of C matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:86
Defines common types used for all GEMM-like operators.
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::ElementC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a58b3c904716c54edb20b1ae1ae0bc715)
ElementC_ ElementC
Element type of operand C.
Definition: gemm/thread/mma_sm50.h:215
cutlass::gemm::thread::MmaGeneric::ElementA
ElementA_ ElementA
Data type of operand A.
Definition: gemm/thread/mma_sm50.h:71
cutlass::gemm::thread::MmaGeneric::LayoutB
LayoutB_ LayoutB
Layout of B matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:80
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for multiply-add operations.
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::LayoutB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a0e71571693f24560bdba20fbd2ea1a77)
LayoutB_ LayoutB
Layout of B matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:212
cutlass::gemm::thread::MmaGeneric
Gemplate that handles all packed matrix layouts.
Definition: gemm/thread/mma_sm50.h:65
cutlass::TensorRef< ElementA const, LayoutA >
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::FragmentA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a66abc782808b6b3e68518aff43a0b200)
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm50.h:224
cutlass::gemm::thread::MmaGeneric::FragmentA
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm50.h:92
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::Operator](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a08207ff2d73d653194a061153edc27a9)
arch::OpMultiplyAdd Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm50.h:221
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::gemm::thread::MmaGeneric::Shape
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm50.h:68
Templates exposing architecture support for warp-level multiply-add operations.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::FragmentB](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a2deaf8959c027ab4aca92630b85f5211)
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm50.h:227
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::Shape](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#aeef7c1c07c481fb13e3ab2025d22133a)
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm50.h:200
cutlass::gemm::thread::MmaGeneric::operator()
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm50.h:115
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
Structure to compute the matrix product.
Definition: gemm/thread/mma.h:66
Defines layout functions used by TensorRef and derived classes.
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::LayoutA](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#af1c453f655d29855f026ab6dfc8f7ae9)
LayoutA_ LayoutA
Layout of A matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:206
Matrix multiply-add operation.
Definition: arch/mma.h:92
[cutlass::gemm::thread::Mma< Shape_, ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, arch::OpMultiplyAdd, bool >::LayoutC](structcutlass_1_1gemm_1_1thread_1_1Mma_3_01Shape 00_01ElementA 00_01LayoutA___00_01ElementB_e41c1cd6078b6d1347fac239b0639d56.html#a08a2137eb47c1caa00adaf3572c706a0)
LayoutC_ LayoutC
Layout of C matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm50.h:218
Basic include for CUTLASS.
Definition: matrix_coord.h:39
cutlass::gemm::thread::MmaGeneric::ElementB
ElementB_ ElementB
Data type of operand B.
Definition: gemm/thread/mma_sm50.h:77
cutlass::gemm::thread::MmaGeneric::ElementC
ElementC_ ElementC
Element type of operand C.
Definition: gemm/thread/mma_sm50.h:83
Matrix multiply-add operation - specialized for 1x1x1x1 matrix multiply operation.
Definition: arch/mma.h:113
Generated by 1.8.11