docs/default__thread__map__simt_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
default_thread_map_simt.h
[Go to the documentation of this file.](default thread map__simt_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
30 #pragma once
31
32 #include "[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)"
33 #include "cutlass/gemm/gemm.h"
34
36
37 namespace cutlass {
38 namespace epilogue {
39 namespace threadblock {
40
42
44 template <
45typename ThreadblockShape_,
46typename WarpShape_,
47typename MmaSimtPolicy_,
48int PartitionsK,
49typename Element_,
50int ElementsPerAccess
51 >
52 struct DefaultThreadMapSimt {
53
54using ThreadblockShape = ThreadblockShape_;
55using WarpShape = WarpShape_;
56using MmaSimtPolicy = MmaSimtPolicy_;
57static int const kPartitionsK = PartitionsK;
59static int const kElementsPerAccess = ElementsPerAccess;
60
61//
62// Definitions
63//
64
66
67static int const kWarpSize = 32;
68
70 !(ThreadblockShape::kM % WarpShape::kM) &&
71 !(ThreadblockShape::kM % WarpShape::kM), "Divisibility");
72
74using WarpCount = gemm::GemmShape<
75 ThreadblockShape::kM / WarpShape::kM,
76 ThreadblockShape::kN / WarpShape::kN,
77 kPartitionsK
78 >;
79
81static int const kGroupCount =
82 WarpShape::kM / (MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM);
83
85static int const kThreads = WarpCount::kCount * kWarpSize;
86
88static int const kIterations = MmaSimtPolicy::LaneMmaShape::kM * kGroupCount;
89 };
90
91//
92// ThreadMap
93//
94
96using Type = OutputTileOptimalThreadMap<
97OutputTileShape< // Shape
98 ThreadblockShape::kN,
99 1,
100 MmaSimtPolicy::WarpShape::kRow,
102 1>,
103OutputTileShape< // Count
104 1,
105 MmaSimtPolicy::LaneMmaShape::kM,
107 1,
108Detail::kIterations>,
109Detail::kThreads,
111sizeof_bits<Element>::value
112 >;
113 };
114
116
117 } // namespace threadblock
118 } // namespace epilogue
119 } // namespace cutlass
120
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
cutlass::epilogue::threadblock::OutputTileOptimalThreadMap
Definition: output_tile_thread_map.h:228
Definition: aligned_buffer.h:35
cutlass::epilogue::threadblock::DefaultThreadMapSimt::ThreadblockShape
ThreadblockShape_ ThreadblockShape
Definition: default_thread_map_simt.h:54
cutlass::epilogue::threadblock::DefaultThreadMapSimt::MmaSimtPolicy
MmaSimtPolicy_ MmaSimtPolicy
Definition: default_thread_map_simt.h:56
cutlass::epilogue::threadblock::OutputTileShape
Tuple defining point in output tile.
Definition: output_tile_thread_map.h:57
cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail::kThreads
static int const kThreads
Number of participating threads.
Definition: default_thread_map_simt.h:85
[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Defines common types used for all GEMM-like operators.
cutlass::gemm::GemmShape::kCount
static int const kCount
Definition: include/cutlass/gemm/gemm.h:67
cutlass::epilogue::threadblock::DefaultThreadMapSimt
Defines the optimal thread map for SIMT accumulator layouts.
Definition: default_thread_map_simt.h:52
Defines the size of an element in bits.
Definition: numeric_types.h:42
cutlass::epilogue::threadblock::DefaultThreadMapSimt::Element
Element_ Element
Definition: default_thread_map_simt.h:58
cutlass::epilogue::threadblock::DefaultThreadMapSimt::kElementsPerAccess
static int const kElementsPerAccess
Definition: default_thread_map_simt.h:59
cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail::kIterations
static int const kIterations
Number of iterations.
Definition: default_thread_map_simt.h:88
cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail::kWarpSize
static int const kWarpSize
Definition: default_thread_map_simt.h:67
cutlass::epilogue::threadblock::DefaultThreadMapSimt::kPartitionsK
static int const kPartitionsK
Definition: default_thread_map_simt.h:57
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail
Definition: default_thread_map_simt.h:65
cutlass::epilogue::threadblock::DefaultThreadMapSimt::WarpShape
WarpShape_ WarpShape
Definition: default_thread_map_simt.h:55
cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail::kGroupCount
static int const kGroupCount
Computes number of thread-level matrix multiplies are needed to span a warp.
Definition: default_thread_map_simt.h:81
Generated by 1.8.11