docs/direct__epilogue__tensor__op_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
direct_epilogue_tensor_op.h
[Go to the documentation of this file.](direct epilogue tensor__op_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
30 #pragma once
31
32 #include "cutlass/cutlass.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/array.h"
35
36 #include "cutlass/gemm/gemm.h"
37
39
40 namespace cutlass {
41 namespace epilogue {
42 namespace threadblock {
43
45
47 template <
48typename Shape_,
49typename Operator_,
50int PartitionsK,
51typename Element_,
52typename OutputOp_,
53typename ConvertOp_
54 >
55 class DirectEpilogueTensorOp {
56 public:
57
60
62using WarpCount = gemm::GemmShape<
63 Shape::kM / Operator::Shape::kM,
64 Shape::kN / Operator::Shape::kN,
65 PartitionsK,
66 >;
67
68static_assert(PartitionsK == 1,
69"Direct epilogue cannot be used with when the threadblock tile is partitioned along the K dimension.");
70
72using FragmentC = typename Operator::FragmentC;
73
76
78using Layout = layout::RowMajor;
79
82
84using ConvertOp = ConvertOp_;
85
87using TensorRef = TensorRef<Element, Layout::kRank, Layout>;
88
89 public:
90
93
94//
95// Data members
96//
97
100
101typename OutputOp::Params output_op;
102typename ConvertOp::Params convert_op;
103
104//
105// Methods
106//
107
111TensorRef destination_ref_,
112TensorRef source_ref_,
113typename OutputOp::Params output_op_,
114typename ConvertOp::Params convert_op_
115 ):
116 destination_ref(destination_ref_),
117 source_ref(source_ref_),
118 output_op(output_op_),
119 convert_op(convert_op_) {
120
121 }
122
126TensorRef destination_ref_,
127TensorRef source_ref_,
128typename OutputOp::Params output_op_
129 ):
130Params(
131 destination_ref,
132 source_ref,
133 output_op,
135 ) { }
136 };
137
139struct SharedStorage { };
140
141 private:
142
144ConvertOp convert_op;
145
146TensorRef destination_ref_;
147TensorRef source_ref_;
148
149MatrixCoord warp_origin_;
150
151 public:
152
154 CUTLASS_DEVICE
156Params const ¶ms,
157SharedStorage &shared_storage,
158int thread_idx,
159int warp_idx,
160int lane_idx
161 ):
162 output_op(params.output_op),
163 convert_op(params.convert_op),
164 destination_ref_(params.destination_ref),
165 source_ref_(params.source_ref) {
166
167
168// Compute warp location within threadblock tile by mapping the warp_id to three coordinates:
169//
170// _m: the warp's position within the threadblock along the M dimension
171// _n: the warp's position within the threadblock along the N dimension
172// _k: the warp's position within the threadblock along the K dimension
173
174int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
175int warp_m = warp_mn % WarpCount::kM;
176int warp_n = warp_mn / WarpCount::kM;
177
178 warp_origin_ = MatrixCoord{
179 warp_m * Operator::Shape::kM,
180 warp_n * Operator::Shape::kN
181 };
182
183 destination_ref_.add_coord_offset(warp_origin_);
184 source_ref_.add_coord_offset(warp_origin_);
185 }
186
188 CUTLASS_DEVICE
189void operator()(
190gemm::GemmCoord problem_size,
191gemm::GemmCoord tb_tile_coord,
192FragmentC const &accumulators) {
193
194MatrixCoord thread_origin =
195MatrixCoord{tb_tile_coord.m() * Shape::kM, tb_tile_coord.n() * Shape::kN} + warp_origin_;
196
198using MmaIterations = MatrixShape<
199 Operator::Shape::kM / Operator::Policy::Operator::Shape::kM,
200 Operator::Shape::kN / Operator::Policy::Operator::Shape::kN
201 >;
202
203// Assume accumulator tile is an arrangement of 8-by-8 tiles replicated over the entire
204// shape, with each quad mapped to one row and each thread mapped to 1/4 of the elements
205// of that row. The accumulators within one row are assumed to be consecutive.
206int const kElementsPerAccess = Operator::Policy::Operator::Shape::kN / 4;
207int const kRowsPerTile = 8;
208int const kAccumulatorRows = Operator::Policy::Operator::Shape::kM / kRowsPerTile;
209
211for (int mma_n = 0; mma_n < MmaIterations::kN; ++mma_n) {
213for (int mma_m = 0; mma_m < MmaIterations::kM; ++mma_m) {
214
215int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
216 (mma_m * MmaIterations::kN + mma_n);
217
219for (int row = 0; row < kAccumulatorRows; ++row) {
221for (int col = 0; col < kElementsPerAccess; ++col) {
222
223int accum_m = mma_m * Operator::Policy::Operator::Shape::kM + row * kRowsPerTile;
224int accum_n = mma_n * Operator::Policy::Operator::Shape::kN + col;
225int idx = mma_accum_start + row * kElementsPerAccess + col;
226
227MatrixCoord accum_coord = MatrixCoord{accum_m, accum_n};
228
229MatrixCoord thread_coord = thread_origin + accum_coord;
230
231if (thread_coord < MatrixCoord{problem_size.m(), problem_size.n()}) {
232
233typename ConvertOp::result_type converted_accum = convert_op(accumulators[idx]);
234
235typename OutputOp::result_type output = output_op(converted_accum, source_ref_.at(accum_coord));
236
237 destination_ref_.at(accum_coord) = output;
238 }
239 }
240 }
241 }
242 }
243 }
244 };
245
247
248 } // namespace threadblock
249 } // namespace epilogue
250 } // namespace cutlass
251
cutlass::epilogue::threadblock::DirectEpilogueTensorOp
Epilogue operator.
Definition: direct_epilogue_tensor_op.h:55
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::Params
CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_)
Constructs a Params object.
Definition: direct_epilogue_tensor_op.h:125
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params
Parameters structure for host-constructible state.
Definition: direct_epilogue_tensor_op.h:92
Definition: aligned_buffer.h:35
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::operator()
CUTLASS_DEVICE void operator()(gemm::GemmCoord problem_size, gemm::GemmCoord tb_tile_coord, FragmentC const &accumulators)
Streams the result to global memory.
Definition: direct_epilogue_tensor_op.h:189
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::destination_ref
TensorRef destination_ref
Definition: direct_epilogue_tensor_op.h:98
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::TensorRef
TensorRef< Element, Layout::kRank, Layout > TensorRef
Reference to source and destination tensors.
Definition: direct_epilogue_tensor_op.h:87
cutlass::TensorRef::add_coord_offset
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::source_ref
TensorRef source_ref
Definition: direct_epilogue_tensor_op.h:99
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::DirectEpilogueTensorOp
CUTLASS_DEVICE DirectEpilogueTensorOp(Params const ¶ms, SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: direct_epilogue_tensor_op.h:155
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::output_op
OutputOp::Params output_op
Definition: direct_epilogue_tensor_op.h:101
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::convert_op
ConvertOp::Params convert_op
Definition: direct_epilogue_tensor_op.h:102
cutlass::TensorRef< Element, Layout::kRank, Layout >
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Params::Params
CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_, typename ConvertOp::Params convert_op_)
Constructs a Params object.
Definition: direct_epilogue_tensor_op.h:110
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::SharedStorage
Shared storage allocation needed by the epilogue.
Definition: direct_epilogue_tensor_op.h:139
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::FragmentC
typename Operator::FragmentC FragmentC
Accumulator tile is really the warp-scoped tile.
Definition: direct_epilogue_tensor_op.h:72
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Operator
Operator_ Operator
Definition: direct_epilogue_tensor_op.h:59
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::OutputOp
OutputOp_ OutputOp
Function operator computing final output.
Definition: direct_epilogue_tensor_op.h:81
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::ConvertOp
ConvertOp_ ConvertOp
Conversion operator to shared memory.
Definition: direct_epilogue_tensor_op.h:84
Basic include for CUTLASS.
Definition: matrix_coord.h:39
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Element
Element_ Element
Data type of output tensor.
Definition: direct_epilogue_tensor_op.h:75
cutlass::epilogue::threadblock::DirectEpilogueTensorOp::Shape
Shape_ Shape
Definition: direct_epilogue_tensor_op.h:58
static int const kN
Definition: include/cutlass/gemm/gemm.h:59
Generated by 1.8.11