docs/wmma__sm72_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
wmma_sm72.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include <assert.h>
32 #include "cutlass/layout/matrix.h"
33
35 namespace cutlass {
36 namespace arch {
37
39 //
40 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
41 // wmma native instruction sizes supported for int8_t
42 //
44 template <
45 typename Shape_,
46 typename LayoutA_,
47 typename LayoutB_,
48 typename LayoutC_>
[49](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01int8 t_00_01LayoutA _00_01int8 t_00_01LayoutB_505c57bb6818a941dc16f00cf35a9ec0.html) struct Wmma<
50 Shape_,
51 int8_t,
52 LayoutA_,
53 int8_t,
54 LayoutB_,
55 int32_t,
56 LayoutC_,
57cutlass::arch::OpMultiplyAdd
58 > {
59 #if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)
60using Shape = Shape_;
61using ElementA = int8_t;
62using LayoutA = LayoutA_;
63using ElementB = int8_t;
64using LayoutB = LayoutB_;
65using ElementC = int32_t;
66using LayoutC = LayoutC_;
67using Operator = cutlass::arch::OpMultiplyAdd;
68
69// check supported wmma shape for the given multiplicand data types
71platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
72platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
73platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
74"Supported list of wmma operator shape for s8 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
75
76
77// Wmma Fragment
78using FragmentA = nvcuda::wmma::fragment<
79 nvcuda::wmma::matrix_a,
80 Shape::kM,
81 Shape::kN,
82 Shape::kK,
83typename CutlassToWmmaDataType<ElementA>::Type,
84typename CutlassToWmmaLayout<LayoutA>::Layout>;
85
86using FragmentB = nvcuda::wmma::fragment<
87 nvcuda::wmma::matrix_b,
88 Shape::kM,
89 Shape::kN,
90 Shape::kK,
91typename CutlassToWmmaDataType<ElementB>::Type,
92typename CutlassToWmmaLayout<LayoutB>::Layout>;
93
94using FragmentC = nvcuda::wmma::fragment<
95 nvcuda::wmma::accumulator,
96 Shape::kM,
97 Shape::kN,
98 Shape::kK,
99typename CutlassToWmmaDataType<ElementC>::Type>;
100
102 CUTLASS_DEVICE
103void operator()(
104 FragmentC &D,
105 FragmentA const &A,
106 FragmentB const &B,
107 FragmentC const &C) const {
108
109 nvcuda::wmma::mma_sync(D, A, B, C);
110 }
111
112 #else
113static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond");
114 #endif
115
116 };
117
119 //
120 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
121 // wmma native instruction sizes supported for uint8_t
122 //
124 template <
125 typename Shape_,
126 typename LayoutA_,
127 typename LayoutB_,
128 typename LayoutC_>
[129](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01uint8 t_00_01LayoutA _00_01uint8 t_00_01Layout219a464a1248ebfc37aa29bcb10cb1b0.html) struct Wmma<
130 Shape_,
131 uint8_t,
132 LayoutA_,
133 uint8_t,
134 LayoutB_,
135 int32_t,
136 LayoutC_,
137cutlass::arch::OpMultiplyAdd
138 > {
139 #if defined(CUTLASS_ARCH_WMMA_SM72_ENABLED)
140using Shape = Shape_;
141using ElementA = uint8_t;
142using LayoutA = LayoutA_;
143using ElementB = uint8_t;
144using LayoutB = LayoutB_;
145using ElementC = int32_t;
146using LayoutC = LayoutC_;
147using Operator = cutlass::arch::OpMultiplyAdd;
148
149// check supported wmma shape for the given multiplicand data types
150static_assert(
151platform::is_same<cutlass::gemm::GemmShape<16, 16, 16>, Shape>::value ||
152platform::is_same<cutlass::gemm::GemmShape< 8, 32, 16>, Shape>::value ||
153platform::is_same<cutlass::gemm::GemmShape<32, 8, 16>, Shape>::value,
154"Supported list of wmma operator shape for u8 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
155
156// Wmma Fragment
157using FragmentA = nvcuda::wmma::fragment<
158 nvcuda::wmma::matrix_a,
159 Shape::kM,
160 Shape::kN,
161 Shape::kK,
162typename CutlassToWmmaDataType<ElementA>::Type,
163typename CutlassToWmmaLayout<LayoutA>::Layout>;
164
165using FragmentB = nvcuda::wmma::fragment<
166 nvcuda::wmma::matrix_b,
167 Shape::kM,
168 Shape::kN,
169 Shape::kK,
170typename CutlassToWmmaDataType<ElementB>::Type,
171typename CutlassToWmmaLayout<LayoutB>::Layout>;
172
173using FragmentC = nvcuda::wmma::fragment<
174 nvcuda::wmma::accumulator,
175 Shape::kM,
176 Shape::kN,
177 Shape::kK,
178typename CutlassToWmmaDataType<ElementC>::Type>;
179
181 CUTLASS_DEVICE
182void operator()(
183 FragmentC &D,
184 FragmentA const &A,
185 FragmentB const &B,
186 FragmentC const &C) const {
187
188 nvcuda::wmma::mma_sync(D, A, B, C);
189 }
190
191 #else
192static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond");
193 #endif
194
195 };
196
197 } // namespace arch
198 } // namespace cutlass
Definition: aligned_buffer.h:35
std::is_same (false specialization)
Definition: platform.h:394
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
Defines layout functions used by TensorRef and derived classes.
Generated by 1.8.11