docs/mma__simt_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_simt.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/array.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/matrix_shape.h"
35 #include "cutlass/gemm/gemm.h"
36 #include "cutlass/gemm/warp/mma.h"
37
38 #include "cutlass/gemm/thread/mma.h"
39
40 #include "[cutlass/gemm/warp/mma_simt_tile_iterator.h](mma simt tile__iterator_8h.html)"
41 #include "[cutlass/gemm/warp/mma_simt_policy.h](mma simt policy_8h.html)"
42
44
45 namespace cutlass {
46 namespace gemm {
47 namespace warp {
48
50
52 template <
54typename Shape_,
56typename ElementA_,
58typename LayoutA_,
60typename ElementB_,
62typename LayoutB_,
64typename ElementC_,
66typename LayoutC_,
68typename Policy_,
70int PartitionsK = 1,
72typename Enable = bool
73 >
75 public:
78
81
84
87
90
93
96
99
101using OperatorClass = arch::OpClassSimt;
102
103using ThreadLayoutA = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA >::value,
105typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value,
106layout::RowMajor,
107LayoutA>::type
108 >::type;
109
110using ThreadLayoutB = typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutB >::value,
111 layout::ColumnMajor,
112typename platform::conditional < platform::is_same< layout::RowMajorInterleaved<4>, LayoutB >::value,
113 layout::RowMajor,
114LayoutB>::type
115 >::type;
116
117static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value ||
118platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) &&
119platform::is_same< ElementA, int8_t >::value &&
120platform::is_same< ElementB, int8_t >::value;
121
122using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type;
123
125using ThreadMma = thread::Mma<
126GemmShape<
127 Shape::kM / Policy::WarpShape::kRow,
128 Shape::kN / Policy::WarpShape::kColumn,
129 Policy::LaneMmaShape::kK>,
130ElementA,
131ThreadLayoutA,
132ElementB,
133ThreadLayoutB,
134ElementC,
135LayoutC,
136 arch::OpMultiplyAdd,
137dp4a_type
138 >;
139
140 public:
141
143using IteratorA = MmaSimtTileIterator<
144MatrixShape<Shape::kM, Policy::LaneMmaShape::kK>,
145Operand::kA,
146ElementA,
147LayoutA,
148Policy,
149 PartitionsK,
150 Shape::kK
151 >;
152
154using FragmentA = typename IteratorA::Fragment;
155
157using IteratorB = MmaSimtTileIterator<
158MatrixShape<Policy::LaneMmaShape::kK, Shape::kN>,
159Operand::kB,
160ElementB,
161LayoutB,
162Policy,
163 PartitionsK,
164 Shape::kK
165 >;
166
168using FragmentB = typename IteratorB::Fragment;
169
171using IteratorC = MmaSimtTileIterator<
172MatrixShape<Shape::kM, Shape::kN>,
173Operand::kC,
174ElementC,
175LayoutC,
176 Policy
177 >;
178
180using FragmentC = typename ThreadMma::FragmentC;
181
182 public:
183
184//
185// Methods
186//
187
189 CUTLASS_DEVICE
191
193 CUTLASS_DEVICE
194void operator()(
195FragmentC &d,
196FragmentA const &a,
197FragmentB const &b,
198FragmentC const &c, int group_idx = 0) const {
199
200ThreadMma mma;
201
202 mma(d, a, b, c);
203 }
204 };
205
207
208 } // namespace warp
209 } // namespace gemm
210 } // namespace cutlass
[mma_simt_policy.h](mma simt policy_8h.html)
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
cutlass::gemm::warp::MmaSimt::ElementC
ElementC_ ElementC
Data type of accumulator matrix C.
Definition: mma_simt.h:92
Definition: aligned_buffer.h:35
#define constexpr
Definition: platform.h:137
[mma_simt_tile_iterator.h](mma simt tile__iterator_8h.html)
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
cutlass::platform::conditional::type
T type
Definition: platform.h:326
std::is_same (false specialization)
Definition: platform.h:394
cutlass::gemm::warp::MmaSimt::FragmentC
typename ThreadMma::FragmentC FragmentC
Storage for C tile.
Definition: mma_simt.h:180
cutlass::gemm::warp::MmaSimt::Shape
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_simt.h:77
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_simt.h:74
Defines common types used for all GEMM-like operators.
cutlass::gemm::warp::MmaSimt::operator()
CUTLASS_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c, int group_idx=0) const
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_simt.h:194
cutlass::gemm::warp::MmaSimt::use_dp4a
static constexpr bool use_dp4a
Definition: mma_simt.h:117
cutlass::gemm::warp::MmaSimt::LayoutC
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_simt.h:95
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
B multiplicand.
Templates exposing architecture support for warp-level multiply-add operations.
cutlass::gemm::warp::MmaSimtTileIterator
Definition: mma_simt_tile_iterator.h:69
Defines a Shape template for matrix tiles.
cutlass::gemm::warp::MmaSimt::OperatorClass
arch::OpClassSimt OperatorClass
Indicates class of matrix operator.
Definition: mma_simt.h:101
cutlass::gemm::warp::MmaSimt::ThreadLayoutB
typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved< 4 >, LayoutB >::value, layout::ColumnMajor, typename platform::conditional< platform::is_same< layout::RowMajorInterleaved< 4 >, LayoutB >::value, layout::RowMajor, LayoutB >::type >::type ThreadLayoutB
Definition: mma_simt.h:115
cutlass::gemm::warp::MmaSimt::LayoutA
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_simt.h:83
Templates exposing architecture support for warp-level multiply-add operations.
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
cutlass::platform::conditional
std::conditional (true specialization)
Definition: platform.h:325
cutlass::gemm::warp::MmaSimt::Policy
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_simt.h:98
cutlass::gemm::warp::MmaSimt::FragmentA
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_simt.h:154
cutlass::gemm::warp::MmaSimt::dp4a_type
typename platform::conditional< use_dp4a, int8_t, bool >::type dp4a_type
Definition: mma_simt.h:122
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Structure to compute the matrix product.
Definition: gemm/thread/mma.h:66
cutlass::gemm::warp::MmaSimt::ElementA
ElementA_ ElementA
Data type of multiplicand A.
Definition: mma_simt.h:80
A multiplicand.
cutlass::gemm::warp::MmaSimt::ThreadLayoutA
typename platform::conditional< platform::is_same< layout::ColumnMajorInterleaved< 4 >, LayoutA >::value, layout::ColumnMajor, typename platform::conditional< platform::is_same< layout::RowMajorInterleaved< 4 >, LayoutA >::value, layout::RowMajor, LayoutA >::type >::type ThreadLayoutA
Definition: mma_simt.h:108
cutlass::gemm::warp::MmaSimt::ElementB
ElementB_ ElementB
Data type of multiplicand B.
Definition: mma_simt.h:86
Basic include for CUTLASS.
cutlass::gemm::warp::MmaSimt::LayoutB
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_simt.h:89
cutlass::gemm::warp::MmaSimt::MmaSimt
CUTLASS_DEVICE MmaSimt()
Ctor.
Definition: mma_simt.h:190
cutlass::gemm::warp::MmaSimt::FragmentB
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_simt.h:168
Generated by 1.8.11