docs/reduce__split__k_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
reduce_split_k.h
[Go to the documentation of this file.](reduce split k_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32 #include "cutlass/tensor_ref.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/array.h"
35 #include "cutlass/functional.h"
36 #include "cutlass/matrix_shape.h"
37 #include "cutlass/numeric_conversion.h"
38
39 #include "cutlass/layout/matrix.h"
40
42
43 namespace cutlass {
44 namespace reduction {
45 namespace kernel {
46
48
49 template <
50typename Shape_,
51typename OutputOp_ ,
52typename ReductionOp_,
53int PartitionsPerStage = 4
54 >
55 class ReduceSplitK {
56 public:
57
59using ReductionOp = ReductionOp_;
61static int const kElementsPerAccess = OutputOp::kCount;
62static int const kPartitionsPerStage = PartitionsPerStage;
63
64using ElementWorkspace = typename ReductionOp::Element;
65using ElementAccumulator = typename ReductionOp::ElementAccumulator;
66using ElementOutput = typename OutputOp::ElementOutput;
67
68using WorkspaceTensorRef = TensorRef<ElementWorkspace, layout::RowMajor>;
69using OutputTensorRef = TensorRef<ElementOutput, layout::RowMajor>;
70
71using FragmentWorkspace = AlignedArray<ElementWorkspace, kElementsPerAccess>;
72using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
73using FragmentOutput = AlignedArray<ElementOutput, kElementsPerAccess>;
74
75//
76// Types
77//
78
81
83int partitions;
84size_t partition_stride;
85WorkspaceTensorRef workspace;
86OutputTensorRef destination;
88typename OutputOp::Params output;
89typename ReductionOp::Params reduction;
90
91//
92// Methods
93//
94
97
100MatrixCoord problem_size_,
101int partitions_,
102size_t partition_stride_,
103WorkspaceTensorRef workspace_,
104OutputTensorRef destination_,
105OutputTensorRef source_,
106typename OutputOp::Params output_ = typename OutputOp::Params(),
107typename ReductionOp::Params reduction_ = typename ReductionOp::Params()
108 ):
109 problem_size(problem_size_),
110 partitions(partitions_),
111 partition_stride(sizeof(FragmentWorkspace) * partition_stride_ / kElementsPerAccess),
112 workspace(workspace_),
113 destination(destination_),
114 source(source_),
115 output(output_),
116 reduction(reduction_) {
117
118 }
119 };
120
121struct SharedStorage { };
122
123
124 public:
125
128static dim3 grid_shape(
129cutlass::MatrixCoord problem_size) {
130
131return dim3(
132 (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn,
133 (problem_size.row() + Shape::kRow -1) / Shape::kRow);
134 }
135
138static dim3 block_shape() {
139return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow);
140 }
141
143 CUTLASS_DEVICE
144void operator()(Params const ¶ms, SharedStorage &storage) {
145
146// Determine CTA position
147MatrixCoord thread_offset(
148int(blockIdx.y) * Shape::kRow + threadIdx.y,
149int(blockIdx.x) * Shape::kColumn + threadIdx.x * kElementsPerAccess
150 );
151
152// One guard conditional
153if (!(thread_offset.row() < params.problem_size.row() &&
154 thread_offset.column() < params.problem_size.column())) {
155
156return;
157 }
158
159
160ReductionOp reduction_op(params.reduction);
161
162FragmentAccumulator accumulator;
163
164 accumulator.clear();
165
166//
167// Load the first slice
168//
169
170char const *workspace_ptr =
171reinterpret_cast<char const *>(
172 params.workspace.data() + params.workspace.offset(thread_offset));
173
174FragmentWorkspace workspace_frag[kPartitionsPerStage];
175
176//
177// Construct the output operator
178//
179
180OutputOp output_op(params.output);
181
182//
183// Load and accumulate with a simple batched loading sequence.
184//
185
187for (int k = 0; k < params.partitions; k += kPartitionsPerStage) {
188
190for (int i = 0; i < kPartitionsPerStage; ++i) {
191if (k + i < params.partitions) {
192 workspace_frag[i] = *reinterpret_cast<FragmentWorkspace const *>(workspace_ptr);
193 workspace_ptr += params.partition_stride;
194 }
195 }
196
198for (int i = 0; i < kPartitionsPerStage; ++i) {
199if (k + i < params.partitions) {
200 accumulator = reduction_op(accumulator, workspace_frag[i]);
201 }
202 }
203 }
204
205//
206// Conditionally load the source
207//
208
209FragmentOutput source_frag;
210
211 source_frag.clear();
212
213FragmentOutput const *source_ptr = reinterpret_cast<FragmentOutput const *>(
214 params.source.data() + params.source.offset(thread_offset));
215
216if (output_op.is_source_needed()) {
217reinterpret_cast<FragmentOutput &>(source_frag) = *source_ptr;
218 }
219
220//
221// Compute the output
222//
223
224typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag);
225
226//
227// Store
228//
229
230FragmentOutput *dest_ptr = reinterpret_cast<FragmentOutput *>(
231 params.destination.data() + params.destination.offset(thread_offset));
232
233 *dest_ptr = reinterpret_cast<FragmentOutput const &>(output_frag);
234 }
235 };
236
238
239 } // namespace kernel
240 } // namespace reduction
241 } // namespace cutlass
cutlass::reduction::kernel::ReduceSplitK::ElementWorkspace
typename ReductionOp::Element ElementWorkspace
Definition: reduce_split_k.h:64
cutlass::reduction::kernel::ReduceSplitK::ElementAccumulator
typename ReductionOp::ElementAccumulator ElementAccumulator
Definition: reduce_split_k.h:65
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
cutlass::reduction::kernel::ReduceSplitK::Params::source
OutputTensorRef source
Definition: reduce_split_k.h:87
Definition: aligned_buffer.h:35
cutlass::reduction::kernel::ReduceSplitK::Params::destination
OutputTensorRef destination
Definition: reduce_split_k.h:86
Defines a structure containing strides, bounds, and a pointer to tensor data.
cutlass::reduction::kernel::ReduceSplitK::Params::partition_stride
size_t partition_stride
Definition: reduce_split_k.h:84
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
cutlass::reduction::kernel::ReduceSplitK::block_shape
static CUTLASS_HOST_DEVICE dim3 block_shape()
Determines the threadblock shape.
Definition: reduce_split_k.h:138
Aligned array type.
Definition: array.h:511
cutlass::reduction::kernel::ReduceSplitK::ElementOutput
typename OutputOp::ElementOutput ElementOutput
Definition: reduce_split_k.h:66
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
cutlass::reduction::kernel::ReduceSplitK::Params::partitions
int partitions
Definition: reduce_split_k.h:83
cutlass::reduction::kernel::ReduceSplitK::Params::Params
CUTLASS_HOST_DEVICE Params(MatrixCoord problem_size_, int partitions_, size_t partition_stride_, WorkspaceTensorRef workspace_, OutputTensorRef destination_, OutputTensorRef source_, typename OutputOp::Params output_=typename OutputOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Definition: reduce_split_k.h:99
cutlass::reduction::kernel::ReduceSplitK::Params::Params
CUTLASS_HOST_DEVICE Params()
Definition: reduce_split_k.h:96
cutlass::reduction::kernel::ReduceSplitK::Params
Params structure.
Definition: reduce_split_k.h:80
cutlass::reduction::kernel::ReduceSplitK::ReductionOp
ReductionOp_ ReductionOp
Definition: reduce_split_k.h:59
cutlass::reduction::kernel::ReduceSplitK::operator()
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &storage)
Perform a reduction.
Definition: reduce_split_k.h:144
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
cutlass::reduction::kernel::ReduceSplitK::Shape
Shape_ Shape
Definition: reduce_split_k.h:58
Boost-like numeric conversion operator for CUTLASS numeric types.
Defines a Shape template for matrix tiles.
cutlass::reduction::kernel::ReduceSplitK::Params::workspace
WorkspaceTensorRef workspace
Definition: reduce_split_k.h:85
cutlass::TensorRef< ElementWorkspace, layout::RowMajor >
cutlass::reduction::kernel::ReduceSplitK::Params::reduction
ReductionOp::Params reduction
Definition: reduce_split_k.h:89
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
cutlass::reduction::kernel::ReduceSplitK
Definition: reduce_split_k.h:55
#define CUTLASS_PRAGMA_NO_UNROLL
Definition: cutlass.h:111
cutlass::reduction::kernel::ReduceSplitK::kPartitionsPerStage
static int const kPartitionsPerStage
Definition: reduce_split_k.h:62
cutlass::reduction::kernel::ReduceSplitK::grid_shape
static CUTLASS_HOST_DEVICE dim3 grid_shape(cutlass::MatrixCoord problem_size)
Computes the grid size given a chosen threadblock shape.
Definition: reduce_split_k.h:128
cutlass::reduction::kernel::ReduceSplitK::kElementsPerAccess
static int const kElementsPerAccess
Definition: reduce_split_k.h:61
Defines layout functions used by TensorRef and derived classes.
cutlass::reduction::kernel::ReduceSplitK::OutputOp
OutputOp_ OutputOp
Definition: reduce_split_k.h:60
cutlass::reduction::kernel::ReduceSplitK::Params::problem_size
MatrixCoord problem_size
Definition: reduce_split_k.h:82
cutlass::reduction::kernel::ReduceSplitK::FragmentAccumulator
Array< ElementAccumulator, kElementsPerAccess > FragmentAccumulator
Definition: reduce_split_k.h:72
Basic include for CUTLASS.
Definition: matrix_coord.h:39
cutlass::reduction::kernel::ReduceSplitK::SharedStorage
Definition: reduce_split_k.h:121
cutlass::reduction::kernel::ReduceSplitK::Params::output
OutputOp::Params output
Definition: reduce_split_k.h:88
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Generated by 1.8.11