docs/epilogue__base_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
epilogue_base.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
33 #pragma once
34
35 #include <assert.h>
36
37 #include "cutlass/cutlass.h"
38 #include "cutlass/matrix_shape.h"
39 #include "cutlass/numeric_types.h"
40 #include "cutlass/array.h"
41 #include "cutlass/layout/vector.h"
42 #include "cutlass/layout/tensor.h"
43 #include "cutlass/tensor_coord.h"
44 #include "cutlass/aligned_buffer.h"
45
46 #include "cutlass/gemm/gemm.h"
47
48 #include "[cutlass/transform/pitch_linear_thread_map.h](pitch linear thread__map_8h.html)"
49
51
52 namespace cutlass {
53 namespace epilogue {
54 namespace threadblock {
55
57
59 template <
60typename Shape_,
61typename WarpMmaOperator_,
62int PartitionsK,
63typename AccumulatorFragmentIterator_,
64typename WarpTileIterator_,
65typename Padding_
66 >
67 class EpilogueBase {
68 public:
69
71using WarpMmaOperator = WarpMmaOperator_;
72static int const kPartitionsK = PartitionsK;
73using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
74using WarpTileIterator = WarpTileIterator_;
76
78using Layout = layout::RowMajor;
79
81using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
82
84using ElementAccumulator = typename AccumulatorTile::Element;
85
86
88using WarpCount = gemm::GemmShape<
89 Shape::kM / WarpMmaOperator::Shape::kM,
90 Shape::kN / WarpMmaOperator::Shape::kN,
91 kPartitionsK
92 >;
93
94 public:
95
97struct SharedStorage {
98
99//
100// Type definitions
101//
102
104using Element = typename WarpTileIterator::Element;
105
107using TensorRef = typename WarpTileIterator::TensorRef;
108
110using Layout = typename WarpTileIterator::Layout;
111
113using Shape = MatrixShape<
114WarpCount::kM * WarpTileIterator::Shape::kRow * WarpCount::kK,
115WarpCount::kN * WarpTileIterator::Shape::kColumn
116 >;
117
119using StorageShape = MatrixShape<
120Shape::kRow + Padding::kRow,
121Shape::kColumn + Padding::kColumn
122 >;
123
124//
125// Data members
126//
127
128AlignedBuffer<Element, StorageShape::kCount> storage;
129
130//
131// Methods
132//
133
135 CUTLASS_DEVICE
137return storage.data();
138 }
139
141 CUTLASS_DEVICE
143return TensorRef(
144 storage.data(),
145Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
146 }
147
148 CUTLASS_DEVICE
149void debug_print() {
150if (threadIdx.x == 0) {
151
152 #pragma unroll 1
153for (int r = 0; r < Shape::kRow; ++r) {
154
155 #pragma unroll 1
156for (int c = 0; c < Shape::kColumn; ++c) {
157
158 printf("%d ", int(storage.data()[r * StorageShape::kColumn + c]));
159 }
160 printf("\n");
161 }
162 }
163 __syncthreads();
164 }
165 };
166
167 protected:
168
169//
170// Data members
171//
172
173SharedStorage &shared_storage_;
174
176WarpTileIterator warp_tile_iterator_;
177
178 public:
179
181 CUTLASS_DEVICE
183SharedStorage &shared_storage,
184int thread_idx,
185int warp_idx,
186int lane_idx
187 ):
188 shared_storage_(shared_storage),
189 warp_tile_iterator_(shared_storage.reference(), lane_idx) {
190
191// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
192//
193// _m: the warp's position within the threadblock along the M dimension
194// _n: the warp's position within the threadblock along the N dimension
195// _k: the warp's position within the threadblock along the K dimension
196
197int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
198int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
199int warp_m = warp_mn % WarpCount::kM;
200int warp_n = warp_mn / WarpCount::kM;
201
202MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};
203
204 warp_tile_iterator_.add_tile_offset(warp_offset);
205 }
206 };
207
209
210 } // namespace threadblock
211 } // namespace epilogue
212 } // namespace cutlass
213
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::debug_print
CUTLASS_DEVICE void debug_print()
Definition: epilogue_base.h:149
static int const kColumn
columns of a matrix
Definition: matrix_shape.h:44
cutlass::epilogue::threadblock::EpilogueBase::warp_tile_iterator_
WarpTileIterator warp_tile_iterator_
Stores a warp's fragment of accumulators to SMEM.
Definition: epilogue_base.h:176
cutlass::epilogue::threadblock::EpilogueBase::shared_storage_
SharedStorage & shared_storage_
Definition: epilogue_base.h:173
[pitch_linear_thread_map.h](pitch linear thread__map_8h.html)
Templates implementing how threads are mapped to a given tile.
cutlass::epilogue::threadblock::EpilogueBase::WarpMmaOperator
WarpMmaOperator_ WarpMmaOperator
Definition: epilogue_base.h:71
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage
Shared storage allocation needed by the epilogue.
Definition: epilogue_base.h:97
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::data
CUTLASS_DEVICE Element * data()
Returns a pointer to the shared memory buffer.
Definition: epilogue_base.h:136
Defines common types used for all GEMM-like operators.
cutlass::epilogue::threadblock::EpilogueBase::ElementAccumulator
typename AccumulatorTile::Element ElementAccumulator
Accumulator element.
Definition: epilogue_base.h:84
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::TensorRef
typename WarpTileIterator::TensorRef TensorRef
Tensor reference to shared memory allocation.
Definition: epilogue_base.h:107
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
static int const kK
Definition: include/cutlass/gemm/gemm.h:60
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
Defines a Shape template for matrix tiles.
cutlass::epilogue::threadblock::EpilogueBase::kPartitionsK
static int const kPartitionsK
Definition: epilogue_base.h:72
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::Element
typename WarpTileIterator::Element Element
Element type of shared memory.
Definition: epilogue_base.h:104
Defines a canonical coordinate for rank=4 tensors offering named indices.
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::storage
AlignedBuffer< Element, StorageShape::kCount > storage
Definition: epilogue_base.h:128
static int const kRow
rows of a matrix
Definition: matrix_shape.h:43
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
Top-level include for all CUTLASS numeric types.
Modifies semantics of cutlass::Array<> to provide guaranteed alignment.
Definition: aligned_buffer.h:45
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_HOST_DEVICE pointer data()
Definition: aligned_buffer.h:84
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::Layout
typename WarpTileIterator::Layout Layout
Layout of shared memory allocation.
Definition: epilogue_base.h:110
cutlass::epilogue::threadblock::EpilogueBase::AccumulatorFragmentIterator
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: epilogue_base.h:73
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used for rank=1 vectors.
cutlass::epilogue::threadblock::EpilogueBase::Shape
Shape_ Shape
Definition: epilogue_base.h:70
cutlass::epilogue::threadblock::EpilogueBase
Base class for epilogues defining warp-level.
Definition: epilogue_base.h:67
cutlass::layout::RowMajor::packed
static CUTLASS_HOST_DEVICE RowMajor packed(MatrixCoord const &extent)
Helper returns a layout to a tightly packed tensor.
Definition: layout/matrix.h:93
cutlass::epilogue::threadblock::EpilogueBase::Padding
Padding_ Padding
Definition: epilogue_base.h:75
cutlass::epilogue::threadblock::EpilogueBase::EpilogueBase
CUTLASS_DEVICE EpilogueBase(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: epilogue_base.h:182
cutlass::epilogue::threadblock::EpilogueBase::WarpTileIterator
WarpTileIterator_ WarpTileIterator
Definition: epilogue_base.h:74
cutlass::epilogue::threadblock::EpilogueBase::AccumulatorTile
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: epilogue_base.h:81
Basic include for CUTLASS.
Definition: matrix_coord.h:39
cutlass::epilogue::threadblock::EpilogueBase::SharedStorage::reference
CUTLASS_DEVICE TensorRef reference()
Returns a tensor reference to the shared memory buffer.
Definition: epilogue_base.h:142
static int const kN
Definition: include/cutlass/gemm/gemm.h:59
Generated by 1.8.11