docs/default__gemm__configuration_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
default_gemm_configuration.h
[Go to the documentation of this file.](default gemm configuration_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/arch/mma.h"
35 #include "cutlass/arch/wmma.h"
36
37 #include "cutlass/gemm/gemm.h"
38 #include "cutlass/epilogue/thread/linear_combination.h"
39 #include "[cutlass/epilogue/thread/linear_combination_clamp.h](linear combination clamp_8h.html)"
40
42
43 namespace cutlass {
44 namespace gemm {
45 namespace device {
46
48
49 template <
50typename OperatorClass,
51typename ArchTag,
52typename ElementA,
53typename ElementB,
54typename ElementC,
55typename ElementAccumulator
56 >
57 struct DefaultGemmConfiguration;
58
60
61 template <
62typename ArchTag,
63typename ElementA,
64typename ElementB,
65typename ElementC,
66typename ElementAccumulator>
67 struct DefaultGemmConfiguration<
68 arch::OpClassSimt,
69 ArchTag,
70 ElementA,
71 ElementB,
72 ElementC,
73 ElementAccumulator> {
74
75static int const kAlignmentA = 1;
76static int const kAlignmentB = 1;
77using ThreadblockShape = GemmShape<128, 128, 8>;
78using WarpShape = GemmShape<32, 64, 8>;
79using InstructionShape = GemmShape<1, 1, 1>;
80static int const kStages = 2;
81
82using EpilogueOutputOp = epilogue::thread::LinearCombination<
83 ElementC,
84 1,
85 ElementAccumulator,
86 ElementAccumulator
87 >;
88
89using Operator = arch::OpMultiplyAdd;
90 };
91
93
94 template <
95typename ArchTag,
96typename ElementC>
97 struct DefaultGemmConfiguration<arch::OpClassSimt, ArchTag, int8_t, int8_t, ElementC, int32_t> {
98
99static int const kAlignmentA = 4;
100static int const kAlignmentB = 4;
101using ThreadblockShape = GemmShape<128, 128, 32>;
102using WarpShape = GemmShape<32, 64, 32>;
103using InstructionShape = GemmShape<1, 1, 4>;
104static int const kStages = 2;
105
106using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
107 ElementC,
108 1,
109 int32_t,
110float
111 >;
112
113using Operator = arch::OpMultiplyAdd;
114 };
115
117
118 template <
119typename ArchTag,
120typename ElementA,
121typename ElementB,
122typename ElementC,
123typename ElementAccumulator>
124 struct DefaultGemmConfiguration<
125 arch::OpClassWmmaTensorOp,
126 ArchTag,
127 ElementA,
128 ElementB,
129 ElementC,
130 ElementAccumulator> {
131
132static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;
133static int const kAlignmentB = 128 / sizeof_bits<ElementB>::value;
134
135static int const kStages = 2;
136
137using EpilogueOutputOp = epilogue::thread::LinearCombination<
138 ElementC,
139 128 / sizeof_bits<ElementC>::value,
140 ElementAccumulator,
141 ElementAccumulator
142 >;
143
144using Operator = arch::OpMultiplyAdd;
145 };
146
148
149 template <
150typename ElementA,
151typename ElementB,
152typename ElementC,
153typename ElementAccumulator>
154 struct DefaultGemmConfiguration<
155 arch::OpClassTensorOp,
156arch::Sm70,
157 ElementA,
158 ElementB,
159 ElementC,
160 ElementAccumulator> {
161
162static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;
163static int const kAlignmentB = 128 / sizeof_bits<ElementB>::value;
164
165using ThreadblockShape = GemmShape<128, 256, 32>;
166using WarpShape = GemmShape<64, 64, 32>;
167using InstructionShape = GemmShape<16, 16, 4>;
168static int const kStages = 2;
169
170using EpilogueOutputOp = epilogue::thread::LinearCombination<
171 ElementC,
172 128 / sizeof_bits<ElementC>::value,
173 ElementAccumulator,
174 ElementAccumulator
175 >;
176
177using Operator = arch::OpMultiplyAdd;
178 };
179
181
182 template <
183typename ElementA,
184typename ElementB,
185typename ElementC,
186typename ElementAccumulator>
187 struct DefaultGemmConfiguration<
188 arch::OpClassTensorOp,
189arch::Sm75,
190 ElementA,
191 ElementB,
192 ElementC,
193 ElementAccumulator> {
194
195static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;
196static int const kAlignmentB = 128 / sizeof_bits<ElementA>::value;
197using ThreadblockShape = GemmShape<128, 256, 32>;
198using WarpShape = GemmShape<64, 64, 32>;
199using InstructionShape = GemmShape<16, 8, 8>;
200static int const kStages = 2;
201
202using EpilogueOutputOp = epilogue::thread::LinearCombination<
203 ElementC,
204 128 / sizeof_bits<ElementC>::value,
205 ElementAccumulator,
206 ElementAccumulator
207 >;
208
209using Operator = typename platform::conditional<
210 (platform::is_same<ElementA, int8_t>::value ||
211platform::is_same<ElementA, int4b_t>::value ||
212platform::is_same<ElementA, uint8_t>::value ||
213platform::is_same<ElementA, uint4b_t>::value),
214 arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type;
215 };
216
218
219 template <
220typename ElementC>
221 struct DefaultGemmConfiguration<
222 arch::OpClassTensorOp,
223arch::Sm75,
224 int8_t,
225 int8_t,
226 ElementC,
227 int32_t> {
228
229static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;
230static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;
231
232using ThreadblockShape = GemmShape<128, 256, 64>;
233using WarpShape = GemmShape<64, 64, 64>;
234using InstructionShape = GemmShape<8, 8, 16>;
235static int const kStages = 2;
236
237using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
238 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
239
240using Operator = arch::OpMultiplyAddSaturate;
241 };
242
244
245 template <
246typename ElementC>
247 struct DefaultGemmConfiguration<
248 arch::OpClassTensorOp,
249arch::Sm75,
250 int8_t,
251 uint8_t,
252 ElementC,
253 int32_t> {
254
255static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;
256static int const kAlignmentB = 128 / sizeof_bits<uint8_t>::value;
257
258using ThreadblockShape = GemmShape<128, 256, 64>;
259using WarpShape = GemmShape<64, 64, 64>;
260using InstructionShape = GemmShape<8, 8, 16>;
261static int const kStages = 2;
262
263using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
264 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
265
266using Operator = arch::OpMultiplyAddSaturate;
267 };
268
270
271 template <
272typename ElementC>
273 struct DefaultGemmConfiguration<
274 arch::OpClassTensorOp,
275arch::Sm75,
276 uint8_t,
277 int8_t,
278 ElementC,
279 int32_t> {
280
281static int const kAlignmentA = 128 / sizeof_bits<uint8_t>::value;
282static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;
283
284using ThreadblockShape = GemmShape<128, 256, 64>;
285using WarpShape = GemmShape<64, 64, 64>;
286using InstructionShape = GemmShape<8, 8, 16>;
287static int const kStages = 2;
288
289using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
290 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
291
292using Operator = arch::OpMultiplyAddSaturate;
293 };
294
296
297 template <
298typename ElementC>
299 struct DefaultGemmConfiguration<
300 arch::OpClassTensorOp,
301arch::Sm75,
302 uint8_t,
303 uint8_t,
304 ElementC,
305 int32_t> {
306
307static int const kAlignmentA = 128 / sizeof_bits<uint8_t>::value;
308static int const kAlignmentB = 128 / sizeof_bits<uint8_t>::value;
309
310using ThreadblockShape = GemmShape<128, 256, 64>;
311using WarpShape = GemmShape<64, 64, 64>;
312using InstructionShape = GemmShape<8, 8, 16>;
313static int const kStages = 2;
314
315using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
316 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
317
318using Operator = arch::OpMultiplyAddSaturate;
319 };
320
322
323 template <
324typename ElementC>
325 struct DefaultGemmConfiguration<
326 arch::OpClassTensorOp,
327arch::Sm75,
328int4b_t,
329int4b_t,
330 ElementC,
331 int32_t> {
332
333static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;
334static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;
335
336using ThreadblockShape = GemmShape<128, 256, 128>;
337using WarpShape = GemmShape<64, 64, 128>;
338using InstructionShape = GemmShape<8, 8, 32>;
339static int const kStages = 2;
340
341using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
342 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
343
344using Operator = arch::OpMultiplyAddSaturate;
345 };
346
348
349 template <
350typename ElementC>
351 struct DefaultGemmConfiguration<
352 arch::OpClassTensorOp,
353arch::Sm75,
354int4b_t,
355uint4b_t,
356 ElementC,
357 int32_t> {
358
359static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;
360static int const kAlignmentB = 128 / sizeof_bits<uint4b_t>::value;
361
362using ThreadblockShape = GemmShape<128, 256, 128>;
363using WarpShape = GemmShape<64, 64, 128>;
364using InstructionShape = GemmShape<8, 8, 32>;
365static int const kStages = 2;
366
367using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
368 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
369
370using Operator = arch::OpMultiplyAddSaturate;
371 };
372
374
375 template <
376typename ElementC>
377 struct DefaultGemmConfiguration<
378 arch::OpClassTensorOp,
379arch::Sm75,
380uint4b_t,
381int4b_t,
382 ElementC,
383 int32_t> {
384
385static int const kAlignmentA = 128 / sizeof_bits<uint4b_t>::value;
386static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;
387
388using ThreadblockShape = GemmShape<128, 256, 128>;
389using WarpShape = GemmShape<64, 64, 128>;
390using InstructionShape = GemmShape<8, 8, 32>;
391static int const kStages = 2;
392
393using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
394 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
395
396using Operator = arch::OpMultiplyAddSaturate;
397 };
398
400
401 template <
402typename ElementC>
403 struct DefaultGemmConfiguration<
404 arch::OpClassTensorOp,
405arch::Sm75,
406uint4b_t,
407uint4b_t,
408 ElementC,
409 int32_t> {
410
411static int const kAlignmentA = 128 / sizeof_bits<uint4b_t>::value;
412static int const kAlignmentB = 128 / sizeof_bits<uint4b_t>::value;
413
414using ThreadblockShape = GemmShape<128, 256, 128>;
415using WarpShape = GemmShape<64, 64, 128>;
416using InstructionShape = GemmShape<8, 8, 32>;
417static int const kStages = 2;
418
419using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
420 ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
421
422using Operator = arch::OpMultiplyAddSaturate;
423 };
424
426 } // namespace device
427 } // namespace gemm
428 } // namespace cutlass
429
Definition: aligned_buffer.h:35
arch::OpMultiplyAddSaturate Operator
Definition: default_gemm_configuration.h:240
cutlass::epilogue::thread::LinearCombination
Definition: linear_combination.h:56
std::is_same (false specialization)
Definition: platform.h:394
arch::OpMultiplyAddSaturate Operator
Definition: default_gemm_configuration.h:422
arch::OpMultiplyAddSaturate Operator
Definition: default_gemm_configuration.h:292
cutlass::epilogue::thread::LinearCombinationClamp
Definition: linear_combination_clamp.h:58
4-bit signed integer type
Definition: integer_subbyte.h:42
[linear_combination_clamp.h](linear combination clamp_8h.html)
Functor performing linear scaling operations used by epilogues. Values are clamped before converting ...
Definition: arch.h:46
Defines common types used for all GEMM-like operators.
typename platform::conditional< (platform::is_same< ElementA, int8_t >::value||platform::is_same< ElementA, int4b_t >::value||platform::is_same< ElementA, uint8_t >::value||platform::is_same< ElementA, uint4b_t >::value), arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd >::type Operator
Definition: default_gemm_configuration.h:214
arch::OpMultiplyAddSaturate Operator
Definition: default_gemm_configuration.h:396
Templates exposing architecture support for multiply-add operations.
Definition: arch.h:52
arch::OpMultiplyAdd Operator
Definition: default_gemm_configuration.h:144
Functor performing linear combination operations used by epilogues.
Defines the size of an element in bits.
Definition: numeric_types.h:42
arch::OpMultiplyAdd Operator
Definition: default_gemm_configuration.h:113
arch::OpMultiplyAddSaturate Operator
Definition: default_gemm_configuration.h:266
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
cutlass::platform::conditional
std::conditional (true specialization)
Definition: platform.h:325
arch::OpMultiplyAdd Operator
Definition: default_gemm_configuration.h:177
cutlass::gemm::device::DefaultGemmConfiguration
Definition: default_gemm_configuration.h:57
Defines tags for architecture-specific configurations.
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
arch::OpMultiplyAddSaturate Operator
Definition: default_gemm_configuration.h:318
Basic include for CUTLASS.
arch::OpMultiplyAddSaturate Operator
Definition: default_gemm_configuration.h:344
arch::OpMultiplyAddSaturate Operator
Definition: default_gemm_configuration.h:370
arch::OpMultiplyAdd Operator
Definition: default_gemm_configuration.h:89
Generated by 1.8.11