docs/wmma_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
wmma.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 // CUTLASS WMMA does not support clang at present.
32 #if !defined(__clang__)
33
34 #if (__CUDACC_VER_MAJOR__ >= 9)
35 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))
36 #define CUTLASS_ARCH_WMMA_ENABLED
37 #define CUTLASS_ARCH_WMMA_SM70_ENABLED
38 #endif
39 #endif
40
41 #if (__CUDACC_VER_MAJOR__ >= 10)
42 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720))
43 #define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED
44 #define CUTLASS_ARCH_WMMA_SM72_ENABLED
45 #endif
46 #endif
47
48 #if (__CUDACC_VER_MAJOR__ >= 10)
49 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750))
50 #define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED
51 #define CUTLASS_ARCH_WMMA_SM75_ENABLED
52 #endif
53 #endif
54
55 #endif //__clang__
56
57 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
58
59 #include <mma.h>
60 #include "cutlass/arch/mma.h"
61 #include "cutlass/array.h"
62 #include "cutlass/numeric_types.h"
63 #include "cutlass/gemm/gemm.h"
64
65
67
68 namespace cutlass {
69 namespace arch {
70
74 enum class MemoryKind {
75 kShared, // Data resides in shared memory
76 kGlobal // Data resides in global memory
77 };
78
79
83 struct WarpParams {
84static int const kThreadsPerWarp = 32;
85static int const kQuadsPerWarp = 8;
86static int const kThreadsPerQuad = 4;
87 };
88
92 template <typename Type_>
93 struct CutlassToWmmaDataType{
94using Type = Type_;
95 };
96
98 template<>
99 struct CutlassToWmmaDataType<cutlass::half_t> {
100using Type = __half;
101 };
102
103
105 template<>
106 struct CutlassToWmmaDataType<int8_t> {
107using Type = signed char;
108 };
109
111 template<>
112 struct CutlassToWmmaDataType<uint8_t> {
113using Type = unsigned char;
114 };
115
117 template<>
118 struct CutlassToWmmaDataType<int32_t> {
119using Type = int;
120 };
121
122 #if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED)
123 template<>
125 struct CutlassToWmmaDataType<cutlass::int4b_t> {
126using Type = nvcuda::wmma::experimental::precision::s4;
127 };
128
130 template<>
131 struct CutlassToWmmaDataType<cutlass::uint4b_t> {
132using Type = nvcuda::wmma::experimental::precision::u4;
133 };
134
136 template<>
137 struct CutlassToWmmaDataType<cutlass::uint1b_t> {
138using Type = nvcuda::wmma::experimental::precision::b1;
139 };
140 #endif
141
145 template <typename Layout_>
146 struct CutlassToWmmaLayout {
147 };
148
150 template <>
151 struct CutlassToWmmaLayout<cutlass::layout::RowMajor> {
152using Layout = nvcuda::wmma::row_major;
153static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major;
154 };
155
159 template <>
160 struct CutlassToWmmaLayout<cutlass::layout::ColumnMajor> {
161using Layout = nvcuda::wmma::col_major;
162static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major;
163 };
165
169 template <typename Type_>
170 struct WmmaToCutlassDataType{
171using Type = Type_;
172 };
173
175 template<>
176 struct WmmaToCutlassDataType<__half> {
177using Type = cutlass::half_t;
178 };
180
182 // WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks
183 // for a specific template paramterized data type (Element[A|B|C]), layout (Layout[A|B|C]),
184 // and native wmma size (Shape)
186 template <
187typename Shape_,
188typename ElementA_,
189typename LayoutA_,
190typename ElementB_,
191typename LayoutB_,
192typename ElementC_,
193typename LayoutC_,
194typename Operator_ = cutlass::arch::OpMultiplyAdd
195 >
196 struct Wmma;
198
199
200 } // namespace arch
201 } // namespace cutlass
202
204
205 //
206 // Specializations for each compute capability
207 //
208 #ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED
209 #include "cutlass/arch/wmma_sm70.h"
210 #endif
211
212 #ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED
213 #include "cutlass/arch/wmma_sm72.h"
214 #endif
215
216 #ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED
217 #include "cutlass/arch/wmma_sm75.h"
218 #endif
219
221
222 #endif //CUTLASS_ARCH_WMMA_ENABLED
integer_subbyte< 4, false > uint4b_t
4-bit Unsigned integer type
Definition: integer_subbyte.h:158
Definition: aligned_buffer.h:35
Matrix multiply.
integer_subbyte< 1, false > uint1b_t
1-bit Unsigned integer type
Definition: integer_subbyte.h:152
Matrix multiply.
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
Matrix multiply.
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Templates exposing architecture support for multiply-add operations.
Top-level include for all CUTLASS numeric types.
integer_subbyte< 4, true > int4b_t
4-bit Integer type
Definition: integer_subbyte.h:155
Generated by 1.8.11