docs/wmma__sm70_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
wmma_sm70.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include <assert.h>
32 #include "cutlass/layout/matrix.h"
33
35 namespace cutlass {
36 namespace arch {
37
38
40 //
41 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
42 // wmma native instruction sizes supported for half
43 //
45 template <
46 typename Shape_,
47 typename LayoutA_,
48 typename LayoutB_,
49 typename ElementC_,
50 typename LayoutC_>
[51](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01cutlass_1_1half t_00_01LayoutA___00_01cutlass_1_84e30c8cc93eeb7ca02f651bd16d4c38.html) struct Wmma<
52 Shape_,
54 LayoutA_,
56 LayoutB_,
57 ElementC_,
58 LayoutC_,
59 cutlass::arch::OpMultiplyAdd
60 > {
61
62 #if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED)
63using Shape = Shape_;
64using ElementA = cutlass::half_t;
65using LayoutA = LayoutA_;
66using ElementB = cutlass::half_t;
67using LayoutB = LayoutB_;
68using ElementC = ElementC_;
69using LayoutC = LayoutC_;
70using Operator = cutlass::arch::OpMultiplyAdd;
71
72// check supported wmma shape for the given multiplicand data types
74platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
75platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
76platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
77"Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
78
79// check supported wmma output data type for the given multiplicand data types
81platform::is_same<cutlass::half_t, ElementC>::value || platform::is_same<float, ElementC>::value,
82"Supported of wmma output data type for f16 multiplicands are: f16 and f32");
83
84// Wmma Fragment
85using FragmentA = nvcuda::wmma::fragment<
86 nvcuda::wmma::matrix_a,
87 Shape::kM,
88 Shape::kN,
89 Shape::kK,
90typename CutlassToWmmaDataType<ElementA>::Type,
91typename CutlassToWmmaLayout<LayoutA>::Layout>;
92
93using FragmentB = nvcuda::wmma::fragment<
94 nvcuda::wmma::matrix_b,
95 Shape::kM,
96 Shape::kN,
97 Shape::kK,
98typename CutlassToWmmaDataType<ElementB>::Type,
99typename CutlassToWmmaLayout<LayoutB>::Layout>;
100
101using FragmentC = nvcuda::wmma::fragment<
102 nvcuda::wmma::accumulator,
103 Shape::kM,
104 Shape::kN,
105 Shape::kK,
106typename CutlassToWmmaDataType<ElementC>::Type>;
107
109 CUTLASS_DEVICE
110void operator()(
111 FragmentC &D,
112 FragmentA const &A,
113 FragmentB const &B,
114 FragmentC const &C) const {
115
116 nvcuda::wmma::mma_sync(D, A, B, C);
117 }
118 #else
119static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond");
120 #endif
121
122 };
123
124 } // namespace arch
125 } // namespace cutlass
Definition: aligned_buffer.h:35
std::is_same (false specialization)
Definition: platform.h:394
IEEE half-precision floating-point type.
Definition: half.h:126
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
Defines layout functions used by TensorRef and derived classes.
Generated by 1.8.11