docs/mma__base_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
mma_base.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/aligned_buffer.h"
32 #include "cutlass/arch/memory.h"
33 #include "cutlass/array.h"
34 #include "cutlass/cutlass.h"
35 #include "cutlass/gemm/gemm.h"
36 #include "cutlass/matrix_shape.h"
37 #include "cutlass/numeric_types.h"
39
40 namespace cutlass {
41 namespace gemm {
42 namespace threadblock {
43
45
47 template <
49typename Operator_,
51typename SmemPaddingA_,
53typename SmemPaddingB_,
55int PartitionsK = 1>
59
61using SmemPaddingA = SmemPaddingA_;
62
64using SmemPaddingB = SmemPaddingB_;
65
67static int const kPartitionsK = PartitionsK;
68 };
69
71
74 template <
76typename Shape_,
78typename Policy_,
80int Stages,
82typename Enable = bool>
84public:
86using Shape = Shape_;
87
90
91//
92// Dependent types
93//
94
96using Operator = typename Policy::Operator;
97
100using WarpGemm = typename Policy::Operator::Shape;
101
103using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
104 Shape::kN / WarpGemm::kN,
105 Shape::kK / WarpGemm::kK>;
106
108static int const kWarpGemmIterations =
109 (WarpGemm::kK / Operator::Policy::MmaShape::kK);
110
112static int const kStages = Stages;
113
115using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
116
118using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
119
120//
121// Nested structs
122//
123
125class SharedStorage {
126public:
127//
128// Type definitions
129//
130
132using ShapeA = MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
133 Shape::kK * kStages +
134 Policy::SmemPaddingA::kColumn>;
135
137using ShapeB =
138MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
139 Shape::kN + Policy::SmemPaddingB::kColumn>;
140
141public:
142//
143// Data members
144//
145
147AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
148
150AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
151
152public:
153
154//
155// Methods
156//
157
159 CUTLASS_DEVICE
160static typename Operator::LayoutA LayoutA() {
161return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
162 }
163
166static typename Operator::LayoutB LayoutB() {
167return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
168 }
169
172TensorRefA operand_A_ref() {
173return TensorRefA{operand_A.data(), LayoutA()};
174 }
175
178TensorRefB operand_B_ref() {
179return TensorRefB{operand_B.data(), LayoutB()};
180 }
181 };
182
183protected:
184
185//
186// Data members
187//
188
190typename Operator::IteratorA warp_tile_iterator_A_;
191
193typename Operator::IteratorB warp_tile_iterator_B_;
194
195 public:
196
198 CUTLASS_DEVICE
201 SharedStorage &shared_storage,
203int thread_idx,
205int warp_idx,
207int lane_idx
208 ):
209 warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
210 warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {
211
212 }
213 };
214
216
217 } // namespace threadblock
218 } // namespace gemm
219 } // namespace cutlass
220
cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 1 >::Policy
Policy_ Policy
Definition: mma_base.h:89
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Architecture-specific operators on memory.
cutlass::gemm::threadblock::MmaBase::SharedStorage::operand_B
AlignedBuffer< typename Operator::ElementB, ShapeB::kCount > operand_B
Buffer for B operand.
Definition: mma_base.h:150
cutlass::gemm::threadblock::MmaBase::warp_tile_iterator_B_
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory.
Definition: mma_base.h:193
cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 1 >::WarpGemm
typename Policy::Operator::Shape WarpGemm
Definition: mma_base.h:100
Defines common types used for all GEMM-like operators.
cutlass::gemm::threadblock::MmaBase::SharedStorage
Shared storage object needed by threadblock-scoped GEMM.
Definition: mma_base.h:125
cutlass::gemm::threadblock::MmaBase::Shape
Shape_ Shape
Policy describing tuning details.
Definition: mma_base.h:88
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
cutlass::gemm::threadblock::MmaPolicy::Operator
Operator_ Operator
Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) ...
Definition: mma_base.h:58
cutlass::gemm::threadblock::MmaPolicy::SmemPaddingA
SmemPaddingA_ SmemPaddingA
Padding used for A operand in shared memory.
Definition: mma_base.h:61
Defines a Shape template for matrix tiles.
cutlass::gemm::threadblock::MmaBase::SharedStorage::LayoutB
static CUTLASS_HOST_DEVICE Operator::LayoutB LayoutB()
Returns a layout object for the B matrix.
Definition: mma_base.h:166
cutlass::gemm::threadblock::MmaPolicy
Policy object describing MmaTensorOp.
Definition: mma_base.h:56
Definition: tensor_ref.h:146
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Modifies semantics of cutlass::Array<> to provide guaranteed alignment.
Definition: aligned_buffer.h:45
cutlass::gemm::threadblock::MmaBase::SharedStorage::operand_A_ref
CUTLASS_HOST_DEVICE TensorRefA operand_A_ref()
Returns a TensorRef to the A operand.
Definition: mma_base.h:172
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_HOST_DEVICE pointer data()
Definition: aligned_buffer.h:84
cutlass::gemm::threadblock::MmaBase::MmaBase
CUTLASS_DEVICE MmaBase(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references.
Definition: mma_base.h:199
cutlass::gemm::threadblock::MmaBase
Definition: mma_base.h:83
cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, 1 >::Operator
typename Policy::Operator Operator
Warp-level Mma.
Definition: mma_base.h:96
cutlass::gemm::threadblock::MmaBase::warp_tile_iterator_A_
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory.
Definition: mma_base.h:190
cutlass::gemm::threadblock::MmaBase::SharedStorage::operand_A
AlignedBuffer< typename Operator::ElementA, ShapeA::kCount > operand_A
Buffer for A operand.
Definition: mma_base.h:147
cutlass::gemm::threadblock::MmaBase::SharedStorage::LayoutA
static CUTLASS_DEVICE Operator::LayoutA LayoutA()
Returns a layout object for the A matrix.
Definition: mma_base.h:160
cutlass::gemm::threadblock::MmaPolicy::SmemPaddingB
SmemPaddingB_ SmemPaddingB
Padding used for B operand in shared memory.
Definition: mma_base.h:64
cutlass::gemm::threadblock::MmaPolicy::kPartitionsK
static int const kPartitionsK
Number of partitions of K dimension.
Definition: mma_base.h:67
cutlass::gemm::threadblock::MmaBase::SharedStorage::operand_B_ref
CUTLASS_HOST_DEVICE TensorRefB operand_B_ref()
Returns a TensorRef to the B operand.
Definition: mma_base.h:178
Basic include for CUTLASS.
Generated by 1.8.11