Back to Cutlass

CUTLASS: mma_complex_tensor_op.h Source File

docs/mma__complex__tensor__op_8h_source.html

4.4.228.6 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

mma_complex_tensor_op.h

[Go to the documentation of this file.](mma complex tensor__op_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

30 #pragma once

31

32 #include "cutlass/cutlass.h"

33

34 #include "cutlass/array.h"

35 #include "cutlass/complex.h"

36 #include "cutlass/numeric_types.h"

37 #include "cutlass/matrix_shape.h"

38

39 #include "cutlass/arch/memory_sm75.h"

40 #include "cutlass/arch/mma_sm75.h"

41 #include "cutlass/gemm/gemm.h"

42 #include "cutlass/gemm/warp/mma.h"

43

44 #include "[cutlass/gemm/warp/mma_tensor_op_policy.h](mma tensor op__policy_8h.html)"

45 #include "[cutlass/gemm/warp/mma_tensor_op.h](mma tensor op_8h.html)"

46

47 #include "[cutlass/gemm/warp/mma_tensor_op_tile_iterator.h](mma tensor op tile iterator_8h.html)"

49

50 namespace cutlass {

51 namespace gemm {

52 namespace warp {

53

55

56 template <

58typename Shape_,

60typename RealElementA,

62typename LayoutA_,

64typename RealElementB,

66typename LayoutB_,

68typename RealElementC,

70typename LayoutC_,

72typename Policy_,

74ComplexTransform TransformA = ComplexTransform::kNone,

76ComplexTransform TransformB = ComplexTransform::kNone,

78typename Enable = bool

79 >

80 class MmaComplexTensorOp;

81

83

85 template <

87typename Shape_,

89typename RealElementA,

91typename LayoutA_,

93typename RealElementB,

95typename LayoutB_,

97typename RealElementC,

99typename LayoutC_,

101typename Policy_,

103ComplexTransform TransformA,

105ComplexTransform TransformB,

107typename Enable

108 >

109 class MmaComplexTensorOp<

110 Shape_,

111complex<RealElementA>,

112 LayoutA_,

113complex<RealElementB>,

114 LayoutB_,

115complex<RealElementC>,

116 LayoutC_,

117 Policy_,

118 TransformA,

119 TransformB,

120 Enable> {

121 public:

123using Shape = Shape_;

124

126using ElementA = complex<RealElementA>;

127

129using LayoutA = LayoutA_;

130

132using ElementB = complex<RealElementB>;

133

135using LayoutB = LayoutB_;

136

138using ElementC = complex<RealElementC>;

139

141using LayoutC = LayoutC_;

142

144using Policy = Policy_;

145

147static ComplexTransform const kTransformA = TransformA;

148

150static ComplexTransform const kTransformB = TransformB;

151

153using OperatorClass = arch::OpClassTensorOp;

154

156static int const kThreadCount = 32;

157

158 public:

159

161using IteratorA = MmaTensorOpMultiplicandTileIterator<

162MatrixShape<Shape::kM, Shape::kK>,

163Operand::kA,

164ElementA,

165LayoutA,

166MatrixShape<Policy::Operator::Shape::kM, Policy::Operator::Shape::kK>,

167 Policy::OpDelta::kRow,

168 32,

169 1

170 >;

171

173using FragmentA = typename IteratorA::Fragment;

174

176using IteratorB = MmaTensorOpMultiplicandTileIterator<

177MatrixShape<Shape::kK, Shape::kN>,

178Operand::kB,

179ElementB,

180LayoutB,

181MatrixShape<Policy::Operator::Shape::kK, Policy::Operator::Shape::kN>,

182 Policy::OpDelta::kColumn,

183 32,

184 1

185 >;

186

188using FragmentB = typename IteratorB::Fragment;

189

190

191static_assert(

192 !(Shape::kM % Policy::Operator::Shape::kM) &&

193 !(Shape::kN % Policy::Operator::Shape::kN),

194"Shape of warp-level Mma must be divisible by operator shape.");

195

197using MmaIterations = MatrixShape<

198 Shape::kM / Policy::Operator::Shape::kM,

199 Shape::kN / Policy::Operator::Shape::kN

200 >;

201

203using IteratorC = MmaTensorOpAccumulatorTileIterator<

204MatrixShape<Shape::kM, Shape::kN>,

205ElementC,

206LayoutC,

207typename Policy::Operator::Shape,

208typename Policy::OpDelta>;

209

214using FragmentC = typename IteratorC::Fragment;

215

216static_assert(

217 FragmentC::kElements == 2 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements,

218"Unexpected planar complex fragment length.");

219

220 private:

221

222//

223// Data members

224//

225

227typename Policy::Operator mma;

228

229 public:

230

231//

232// Methods

233//

234

236 CUTLASS_DEVICE

237MmaComplexTensorOp() {}

238

240 CUTLASS_DEVICE

241void operator()(

242FragmentC &D,

243FragmentA const &A,

244FragmentB const &B,

245FragmentC const &C) const {

246

247// Alias types for underlying real-valued matrix multiply operator

248using MmaOperandA = typename Policy::Operator::FragmentA;

249using MmaOperandB = typename Policy::Operator::FragmentB;

250using MmaOperandC = typename Policy::Operator::FragmentC;

251

252static_assert(MmaOperandA::kElements == 1,

253"This implementation only supports math instructions in which exactly one element is needed for the A operand."

254"We can geneneralize later.");

255

256static_assert(MmaOperandB::kElements == 1,

257"This implementation only supports math instructions in which exactly one element is needed for the A operand."

258"We can geneneralize later.");

259

260 D = C;

261

262CUTLASS_PRAGMA_UNROLL

263for (int m = 0; m < MmaIterations::kRow; ++m) {

264

265// mma(accum.real(), a.real(), b.real(), accum.real());

266CUTLASS_PRAGMA_UNROLL

267for (int n = 0; n < MmaIterations::kColumn; ++n) {

268

269// Pack operands together. This may result in actual MOVs

270 MmaOperandA operand_A;

271 MmaOperandB operand_B;

272

273 operand_A[0] = A[m].real();

274 operand_B[0] = B[n].real();

275

276// Real-valued accumulator part

277 MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +

278 (m + n * MmaIterations::kRow);

279

280 mma(*accum, operand_A, operand_B, *accum);

281 }

282

283// mma(accum.imag(), a.real(), b.imag(), accum.imag());

284CUTLASS_PRAGMA_UNROLL

285for (int n = MmaIterations::kColumn - 1; n >= 0; --n) {

286

287// Pack operands together. This may result in actual MOVs

288 MmaOperandA operand_A;

289 MmaOperandB operand_B;

290

291 operand_A[0] = A[m].real();

292 operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag());

293

294// Complex-valued accumulator part

295 MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +

296 (m + n * MmaIterations::kRow) + MmaIterations::kCount;

297

298 mma(*accum, operand_A, operand_B, *accum);

299 }

300

301// mma(accum.real(), -a.imag(), b.imag(), accum.real())

302CUTLASS_PRAGMA_UNROLL

303for (int n = 0; n < MmaIterations::kColumn; ++n) {

304

305// Pack operands together. This may result in actual MOVs

306 MmaOperandA operand_A;

307 MmaOperandB operand_B;

308

309// A imaginary part is intentionally negated

310 operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? A[m].imag() : -A[m].imag());

311 operand_B[0] = (kTransformB == ComplexTransform::kConjugate ? -B[n].imag() : B[n].imag());

312

313// Complex-valued accumulator part

314 MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +

315 (m + n * MmaIterations::kRow);

316

317 mma(*accum, operand_A, operand_B, *accum);

318 }

319

320// mma(accum.imag(), a.imag(), b.real(), accum.imag())

321CUTLASS_PRAGMA_UNROLL

322for (int n = MmaIterations::kColumn - 1; n >= 0; --n) {

323

324// Pack operands together. This may result in actual MOVs

325 MmaOperandA operand_A;

326 MmaOperandB operand_B;

327

328 operand_A[0] = (kTransformA == ComplexTransform::kConjugate ? -A[m].imag() : A[m].imag());

329 operand_B[0] = B[n].real();

330

331// Real-valued accumulator part

332 MmaOperandC *accum = reinterpret_cast<MmaOperandC *>(&D) +

333 (m + n * MmaIterations::kRow) + MmaIterations::kCount;

334

335 mma(*accum, operand_A, operand_B, *accum);

336 }

337 }

338 }

339 };

340

342

343 // TODO - partial specializations of real*complex and complex*real

344

346

347 } // namespace warp

348 } // namespace gemm

349 } // namespace cutlass

350

cutlass::MatrixShape

Describes the size of a matrix tile.

Definition: matrix_shape.h:42

cutlass

Definition: aligned_buffer.h:35

cutlass::ComplexTransform

ComplexTransform

Enumeraed type describing a transformation on a complex value.

Definition: complex.h:43

complex.h

memory_sm75.h

Architecture-specific operators on memory added for SM75.

[mma_tensor_op_tile_iterator.h](mma tensor op tile iterator_8h.html)

Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::FragmentA

typename IteratorA::Fragment FragmentA

Storage for A tile.

Definition: mma_complex_tensor_op.h:173

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::ComplexTransform::kNone

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::LayoutB

LayoutB_ LayoutB

Layout of multiplicand B.

Definition: mma_complex_tensor_op.h:135

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::FragmentB

typename IteratorB::Fragment FragmentB

Storage for B tile.

Definition: mma_complex_tensor_op.h:188

cutlass::gemm::warp::MmaComplexTensorOp

Definition: mma_complex_tensor_op.h:80

array.h

Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...

CUTLASS_PRAGMA_UNROLL

#define CUTLASS_PRAGMA_UNROLL

Definition: cutlass.h:110

mma.h

Templates exposing architecture support for warp-level multiply-add operations.

cutlass::gemm::Operand::kA

matrix_shape.h

Defines a Shape template for matrix tiles.

cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator

Definition: mma_tensor_op_tile_iterator.h:1794

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::FragmentC

typename IteratorC::Fragment FragmentC

Definition: mma_complex_tensor_op.h:214

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::Shape

Shape_ Shape

Shape of warp-level matrix operation (concept: GemmShape)

Definition: mma_complex_tensor_op.h:123

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::MmaComplexTensorOp

CUTLASS_DEVICE MmaComplexTensorOp()

Ctor.

Definition: mma_complex_tensor_op.h:237

numeric_types.h

Top-level include for all CUTLASS numeric types.

cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator

Definition: mma_tensor_op_tile_iterator.h:75

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::OperatorClass

arch::OpClassTensorOp OperatorClass

Indicates class of matrix operator.

Definition: mma_complex_tensor_op.h:153

cutlass::complex

Definition: complex.h:92

cutlass::ComplexTransform::kConjugate

mma_sm75.h

Matrix multiply for SM75.

cutlass::gemm::Operand::kB

A multiplicand.

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::Policy

Policy_ Policy

Shape of the warp in units of thread (concept: MmaLanePolicySimt)

Definition: mma_complex_tensor_op.h:144

[mma_tensor_op.h](mma tensor op_8h.html)

Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores...

cutlass.h

Basic include for CUTLASS.

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::operator()

CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C) const

Performs a warp-level matrix multiply-accumulate operation.

Definition: mma_complex_tensor_op.h:241

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::LayoutC

LayoutC_ LayoutC

Layout of accumulator matrix C.

Definition: mma_complex_tensor_op.h:141

[mma_tensor_op_policy.h](mma tensor op__policy_8h.html)

Policy describing implementation details of warp-level GEMM targeting Tensor Cores.

cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::LayoutA

LayoutA_ LayoutA

Layout of multiplicand A.

Definition: mma_complex_tensor_op.h:129


Generated by 1.8.11