docs/wmma__sm75_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
wmma_sm75.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include <assert.h>
32 #include "cutlass/layout/matrix.h"
33
35 namespace cutlass {
36 namespace arch {
37
39 //
40 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
41 // wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4).
42 //
44 template <
45 typename Shape_,
46 typename LayoutA_,
47 typename LayoutB_,
48 typename LayoutC_>
[49](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01cutlass_1_1int4b t_00_01LayoutA___00_01cutlass_16fd808a90b3cf9d7cfc99f30888ca3fe.html) struct Wmma<
50 Shape_,
52 LayoutA_,
54 LayoutB_,
55 int32_t,
56 LayoutC_,
57 cutlass::arch::OpMultiplyAdd
58 > {
59 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
60using Shape = Shape_;
61using ElementA = cutlass::int4b_t;
62using LayoutA = LayoutA_;
63using ElementB = cutlass::int4b_t;
64using LayoutB = LayoutB_;
65using ElementC = int32_t;
66using LayoutC = LayoutC_;
67using Operator = cutlass::arch::OpMultiplyAdd;
68
69// check supported wmma shape for the given multiplicand data types
71platform::is_same<cutlass::gemm::GemmShape<8, 8, 32>, Shape>::value,
72"Supported list of wmma operator shape for s8 multiplicands is: 8x8x32");
73
74
75// Wmma Fragment
76using FragmentA = nvcuda::wmma::fragment<
77 nvcuda::wmma::matrix_a,
78 Shape::kM,
79 Shape::kN,
80 Shape::kK,
81typename CutlassToWmmaDataType<ElementA>::Type,
82typename CutlassToWmmaLayout<LayoutA>::Layout>;
83
84using FragmentB = nvcuda::wmma::fragment<
85 nvcuda::wmma::matrix_b,
86 Shape::kM,
87 Shape::kN,
88 Shape::kK,
89typename CutlassToWmmaDataType<ElementB>::Type,
90typename CutlassToWmmaLayout<LayoutB>::Layout>;
91
92using FragmentC = nvcuda::wmma::fragment<
93 nvcuda::wmma::accumulator,
94 Shape::kM,
95 Shape::kN,
96 Shape::kK,
97typename CutlassToWmmaDataType<ElementC>::Type>;
98
100 CUTLASS_DEVICE
101void operator()(
102 FragmentC &D,
103 FragmentA const &A,
104 FragmentB const &B,
105 FragmentC const &C) const {
106 nvcuda::wmma::mma_sync(D, A, B, C);
107 }
108
109 #else
110static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
111 #endif
112
113 };
114
116 //
117 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
118 // wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1)
119 // (nvcuda::wmma targeting SASS instruction BMMA)
120 //
122 template <
123 typename Shape_,
124 typename LayoutA_,
125 typename LayoutB_,
126 typename LayoutC_>
[127](structcutlass_1_1arch_1_1Wmma_3_01Shape _00_01cutlass_1_1uint1b t_00_01LayoutA___00_01cutlass_c80a7ea4d219cd9b13b560b493338028.html) struct Wmma<
128 Shape_,
130 LayoutA_,
132 LayoutB_,
133 int32_t,
134 LayoutC_,
135 cutlass::arch::OpXorPopc
136 > {
137 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
138using Shape = Shape_;
139using ElementA = cutlass::uint1b_t;
140using LayoutA = LayoutA_;
141using ElementB = cutlass::uint1b_t;
142using LayoutB = LayoutB_;
143using ElementC = int32_t;
144using LayoutC = LayoutC_;
145using Operator = cutlass::arch::OpXorPopc;
146
147// check supported wmma shape for the given multiplicand data types
148static_assert(
149platform::is_same<cutlass::gemm::GemmShape<8, 8, 128>, Shape>::value,
150"Supported list of wmma operator shape for b1 multiplicands is: 8x8x128");
151
152
153// Wmma Fragment
154using FragmentA = nvcuda::wmma::fragment<
155 nvcuda::wmma::matrix_a,
156 Shape::kM,
157 Shape::kN,
158 Shape::kK,
159typename CutlassToWmmaDataType<ElementA>::Type,
160typename CutlassToWmmaLayout<LayoutA>::Layout>;
161
162using FragmentB = nvcuda::wmma::fragment<
163 nvcuda::wmma::matrix_b,
164 Shape::kM,
165 Shape::kN,
166 Shape::kK,
167typename CutlassToWmmaDataType<ElementB>::Type,
168typename CutlassToWmmaLayout<LayoutB>::Layout>;
169
170using FragmentC = nvcuda::wmma::fragment<
171 nvcuda::wmma::accumulator,
172 Shape::kM,
173 Shape::kN,
174 Shape::kK,
175typename CutlassToWmmaDataType<ElementC>::Type>;
176
178 CUTLASS_DEVICE
179void operator()(
180 FragmentC &D,
181 FragmentA const &A,
182 FragmentB const &B,
183 FragmentC const &C) const {
184
185 nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
186 nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
187 }
188
189 #else
190static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
191 #endif
192
193 };
194
195 } // namespace arch
196 } // namespace cutlass
Definition: aligned_buffer.h:35
std::is_same (false specialization)
Definition: platform.h:394
integer_subbyte< 1, false > uint1b_t
1-bit Unsigned integer type
Definition: integer_subbyte.h:152
4-bit signed integer type
Definition: integer_subbyte.h:42
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
Defines layout functions used by TensorRef and derived classes.
integer_subbyte< 4, true > int4b_t
4-bit Integer type
Definition: integer_subbyte.h:155
Generated by 1.8.11