docs/mma__complex__tensor__op_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_complex_tensor_op.h
[Go to the documentation of this file.](mma complex tensor__op_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
30 #pragma once
31
32 #include "cutlass/cutlass.h"
33
34 #include "cutlass/array.h"
35 #include "cutlass/complex.h"
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/matrix_shape.h"
38
39 #include "cutlass/arch/memory_sm75.h"
40 #include "cutlass/arch/mma_sm75.h"
41 #include "cutlass/gemm/gemm.h"
42 #include "cutlass/gemm/warp/mma.h"
43
44 #include "[cutlass/gemm/warp/mma_tensor_op_policy.h](mma tensor op__policy_8h.html)"
45 #include "[cutlass/gemm/warp/mma_tensor_op.h](mma tensor op_8h.html)"
46
47 #include "[cutlass/gemm/warp/mma_tensor_op_tile_iterator.h](mma tensor op tile iterator_8h.html)"
49
50 namespace cutlass {
51 namespace gemm {
52 namespace warp {
53
55
56 template <
58typename Shape_,
60typename RealElementA,
62typename LayoutA_,
64typename RealElementB,
66typename LayoutB_,
68typename RealElementC,
70typename LayoutC_,
72typename Policy_,
74ComplexTransform TransformA = ComplexTransform::kNone,
76ComplexTransform TransformB = ComplexTransform::kNone,
78typename Enable = bool
79 >
80 class MmaComplexTensorOp;
81
83
85 template <
87typename Shape_,
89typename RealElementA,
91typename LayoutA_,
93typename RealElementB,
95typename LayoutB_,
97typename RealElementC,
99typename LayoutC_,
101typename Policy_,
103ComplexTransform TransformA,
105ComplexTransform TransformB,
107typename Enable
108 >
109 class MmaComplexTensorOp<
110 Shape_,
111complex<RealElementA>,
112 LayoutA_,
113complex<RealElementB>,
114 LayoutB_,
115complex<RealElementC>,
116 LayoutC_,
117 Policy_,
118 TransformA,
119 TransformB,
120 Enable> {
121 public:
124
126using ElementA = complex<RealElementA>;
127
130
132using ElementB = complex<RealElementB>;
133
136
138using ElementC = complex<RealElementC>;
139
142
145
147static ComplexTransform const kTransformA = TransformA;
148
150static ComplexTransform const kTransformB = TransformB;
151
153using OperatorClass = arch::OpClassTensorOp;
154
156static int const kThreadCount = 32;
157
158 public:
159
161using IteratorA = MmaTensorOpMultiplicandTileIterator<
162MatrixShape<Shape::kM, Shape::kK>,
163Operand::kA,
164ElementA,
165LayoutA,
166MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,
167 Policy::OpDelta::kRow,
168 32,
169 1
170 >;
171
173using FragmentA = typename IteratorA::Fragment;
174
176using IteratorB = MmaTensorOpMultiplicandTileIterator<
177MatrixShape<Shape::kK, Shape::kN>,
178Operand::kB,
179ElementB,
180LayoutB,
181MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,
182 Policy::OpDelta::kColumn,
183 32,
184 1
185 >;
186
188using FragmentB = typename IteratorB::Fragment;
189
190
191static_assert(
192 !(Shape::kM % Policy::Operator::Shape::kM) &&
193 !(Shape::kN % Policy::Operator::Shape::kN),
194"Shape of warp-level Mma must be divisible by operator shape.");
195
197using MmaIterations = MatrixShape<
198 Shape::kM / Policy::Operator::Shape::kM,
199 Shape::kN / Policy::Operator::Shape::kN
200 >;
201
203using IteratorC = MmaTensorOpAccumulatorTileIterator<
204MatrixShape<Shape::kM, Shape::kN>,
205ElementC,
206LayoutC,
207typename Policy::Operator::Shape,
208typename Policy::OpDelta>;
209
214using FragmentC = typename IteratorC::Fragment;
215
216static_assert(
217 FragmentC::kElements == 2 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements,
218"Unexpected planar complex fragment length.");
219
220 private:
221
222//
223// Data members
224//
225
227typename Policy::Operator mma;
228
229 public:
230
231//
232// Methods
233//
234
236 CUTLASS_DEVICE
238
240 CUTLASS_DEVICE
241void operator()(
242FragmentC &D,
243FragmentA const &A,
244FragmentB const &B,
245FragmentC const &C) const {
246
247// Alias types for underlying real-valued matrix multiply operator
248using MmaOperandA = typename Policy::Operator::FragmentA;
249using MmaOperandB = typename Policy::Operator::FragmentB;
250using MmaOperandC = typename Policy::Operator::FragmentC;
251
252static_assert(MmaOperandA::kElements == 1,
253"This implementation only supports math instructions in which exactly one element is needed for the A operand."
254"We can geneneralize later.");
255
256static_assert(MmaOperandB::kElements == 1,
257"This implementation only supports math instructions in which exactly one element is needed for the A operand."
258"We can geneneralize later.");
259
260 D = C;
261
263for (int m = 0; m < MmaIterations::kRow; ++m) {
264
265// mma(accum.real(), a.real(), b.real(), accum.real());
267for (int n = 0; n < MmaIterations::kColumn; ++n) {
268
269// Pack operands together. This may result in actual MOVs
270 MmaOperandA operand_A;
271 MmaOperandB operand_B;
272
273 operand_A[0] = A[m].real();
274 operand_B[0] = B[n].real();
275
276// Real-valued accumulator part
277 MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
278 (m + n * MmaIterations::kRow);
279
280 mma(*accum, operand_A, operand_B, *accum);
281 }
282
283// mma(accum.imag(), a.real(), b.imag(), accum.imag());
285for (int n = MmaIterations::kColumn - 1; n >= 0; --n) {
286
287// Pack operands together. This may result in actual MOVs
288 MmaOperandA operand_A;
289 MmaOperandB operand_B;
290
291 operand_A[0] = A[m].real();
292 operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag());
293
294// Complex-valued accumulator part
295 MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
296 (m + n * MmaIterations::kRow) + MmaIterations::kCount;
297
298 mma(*accum, operand_A, operand_B, *accum);
299 }
300
301// mma(accum.real(), -a.imag(), b.imag(), accum.real())
303for (int n = 0; n < MmaIterations::kColumn; ++n) {
304
305// Pack operands together. This may result in actual MOVs
306 MmaOperandA operand_A;
307 MmaOperandB operand_B;
308
309// A imaginary part is intentionally negated
310 operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag());
311 operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag());
312
313// Complex-valued accumulator part
314 MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
315 (m + n * MmaIterations::kRow);
316
317 mma(*accum, operand_A, operand_B, *accum);
318 }
319
320// mma(accum.imag(), a.imag(), b.real(), accum.imag())
322for (int n = MmaIterations::kColumn - 1; n >= 0; --n) {
323
324// Pack operands together. This may result in actual MOVs
325 MmaOperandA operand_A;
326 MmaOperandB operand_B;
327
328 operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag());
329 operand_B[0] = B[n].real();
330
331// Real-valued accumulator part
332 MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +
333 (m + n * MmaIterations::kRow) + MmaIterations::kCount;
334
335 mma(*accum, operand_A, operand_B, *accum);
336 }
337 }
338 }
339 };
340
342
343 // TODO - partial specializations of real*complex and complex*real
344
346
347 } // namespace warp
348 } // namespace gemm
349 } // namespace cutlass
350
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: complex.h:43
Architecture-specific operators on memory added for SM75.
[mma_tensor_op_tile_iterator.h](mma tensor op tile iterator_8h.html)
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_complex_tensor_op.h:173
Defines common types used for all GEMM-like operators.
cutlass::ComplexTransform::kNone
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_complex_tensor_op.h:135
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_complex_tensor_op.h:188
cutlass::gemm::warp::MmaComplexTensorOp
Definition: mma_complex_tensor_op.h:80
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Defines a Shape template for matrix tiles.
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator
Definition: mma_tensor_op_tile_iterator.h:1794
typename IteratorC::Fragment FragmentC
Definition: mma_complex_tensor_op.h:214
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_complex_tensor_op.h:123
CUTLASS_DEVICE MmaComplexTensorOp()
Ctor.
Definition: mma_complex_tensor_op.h:237
Top-level include for all CUTLASS numeric types.
cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator
Definition: mma_tensor_op_tile_iterator.h:75
#define static_assert(__e, __m)
Definition: platform.h:153
arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator.
Definition: mma_complex_tensor_op.h:153
Definition: complex.h:92
cutlass::ComplexTransform::kConjugate
Matrix multiply for SM75.
A multiplicand.
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_complex_tensor_op.h:144
[mma_tensor_op.h](mma tensor op_8h.html)
Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores...
Basic include for CUTLASS.
CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C) const
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_complex_tensor_op.h:241
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_complex_tensor_op.h:141
[mma_tensor_op_policy.h](mma tensor op__policy_8h.html)
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_complex_tensor_op.h:129
Generated by 1.8.11