docs/batched__reduction_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
batched_reduction.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #if !defined(__CUDACC_RTC__)
32 #include <cuda.h>
33 #endif
34
35 #include "cutlass/coord.h"
36 #include "cutlass/util/platform.h"
37 #include "cutlass/fragment.h"
38
39 namespace cutlass {
40 namespace reduction {
41
43
44 template <typename batched_reduction_>
45 __global__ __launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_::Params params) {
46// Construct the batched_reduction object
47 batched_reduction_ batched_reduction(params);
48 batched_reduction.run();
49 }
50
51 template <typename BatchedReductionTraits_>
52 struct BatchedReduction {
54typedef BatchedReduction<BatchedReductionTraits_> This_;
56typedef BatchedReductionTraits_ Traits;
58typedef typename Traits::Params Params;
60typedef typename Traits::Functor Functor;
61
63 CUTLASS_DEVICE BatchedReduction(Params const ¶ms_)
64 : params(params_), functor(params_.functorParams) {}
65
68 CUTLASS_DEVICE void run() {
69 #if (__CUDA_ARCH__ >= 600)
70// Swizzle the IDs of the block
71typename Traits::BlockSwizzle block_swizzle;
72Coord<3> threadblock_offset =
73 block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::SubTile>());
74
75int subTileSize = gridDim.x * Traits::SubTile::kW;
76int tileSize = params.problem_size[1] * params.problem_size[2];
77int subTileOffset = threadblock_offset[2] + threadIdx.x * Traits::ThreadShape::kW;
78
79int subTileBase = 0;
80
81typename Traits::ScalarA inRegs[Traits::maxInReg];
82typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg];
83 #pragma unroll
84for (int subTile = 0; subTile < tileSize; subTile += subTileSize) {
85int tileOffset = subTileBase + subTileOffset;
86// Init AccumRegs
87 #pragma unroll
88for (int i = 0; i < Traits::ThreadShape::kW; i++)
89 AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(0.0f);
90// Fetch c0
91typename Traits::ScalarAccum c0[Traits::ThreadShape::kW];
92 #pragma unroll
93for (int i = 0; i< Traits::ThreadShape::kW; i++)
94 c0[i] = static_cast<typename Traits::ScalarAccum>(params.d_c[tileOffset + i]);
95
96// Fetch partial sums from A
97 #pragma unroll
98for (int s = 0; s < Traits::ReductionSize; s++) {
99int inRegOffset = s * Traits::ThreadShape::kW;
100int dOffset = (s * tileSize) + tileOffset;
101 #pragma unroll
102for (int i = 0; i< Traits::ThreadShape::kW; i++) {
103 inRegs[inRegOffset + i] = params.d_a[dOffset + i];
104 }
105 }
106
107// Accumulate
108 #pragma unroll
109for (int s = 0; s < Traits::ReductionSize; s++) {
110int inRegOffset = s * Traits::ThreadShape::kW;
111 #pragma unroll
112for (int i = 0; i < Traits::ThreadShape::kW; i++) {
113//AccumRegs[i] = cuFma(params.alpha, inRegs[inRegOffset + i], AccumRegs[i]);
114//AccumRegs[i] = params.alpha * inRegs[inRegOffset + i] + AccumRegs[i];
115 AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(inRegs[inRegOffset + i]) + AccumRegs[i];
116 }
117 }
118// calling functor
119 functor_caller<Traits::ThreadShapeMultiple2>(AccumRegs, c0, AccumRegs);
120
121// Store AccumRegs to D
122 #pragma unroll
123for (int i = 0; i < Traits::ThreadShape::kW; i++) {
124params.d_d[tileOffset + i] = static_cast<typename Traits::ScalarD>(AccumRegs[i]);
125 }
126
127// Advance sub-tile pointer
128 subTileBase += subTileSize;
129 } // end for loop
130 #endif //#if (__CUDA_ARCH__ >= 600)
131 }
132
133template<bool ThreadShapeMultiple2>
134 CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output) {
135if (ThreadShapeMultiple2 == true) {
136 #pragma unroll
137for (int i = 0; i < Traits::ThreadShape::kW / 2; i++) {
138functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 2>(&accum[2 * i], &old[2 * i], &output[2 * i]);
139 }
140 }
141else {
142 #pragma unroll
143for (int i = 0; i < Traits::ThreadShape::kW; i++) {
144functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 1>(&accum[i], &old[i], &output[i]);
145 }
146 }
147 }
148
149//
150// Static function members
151//
152 #if !defined(__CUDACC_RTC__)
153static __host__ cudaError_t launch(Params const& params,
155 cudaStream_t stream = cudaStreamDefault) {
156// Setup the grid.
157typename Traits::BlockSwizzle block_swizzle;
158 dim3 grid = block_swizzle.get_grid_layout(params.problem_size,
159 make_Coord_from_shape<typename Traits::OutputTile>());
160
161 dim3 block;
162 block.x = Traits::kThreads;
163 batched_reduction_kernel<This_><<<grid, block, 0, stream>>>(params);
164return cudaGetLastError();
165 }
166 #endif
167
168//
169// Data members
170//
171
174// The functor.
176 };
177
178 } // namespace reduction
179 } // namespace cutlass
Definition: aligned_buffer.h:35
cutlass::reduction::BatchedReduction::params
Params const & params
The params.
Definition: batched_reduction.h:173
cutlass::reduction::__launch_bounds__
__global__ __launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_
Definition: batched_reduction.h:45
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
cutlass::reduction::BatchedReduction::run
CUTLASS_DEVICE void run()
Definition: batched_reduction.h:68
cutlass::reduction::BatchedReduction::This_
BatchedReduction< BatchedReductionTraits_ > This_
This class.
Definition: batched_reduction.h:54
cutlass::reduction::BatchedReduction::functor
Functor functor
Definition: batched_reduction.h:175
cutlass::reduction::BatchedReduction
Definition: batched_reduction.h:52
cutlass::reduction::BatchedReduction::BatchedReduction
CUTLASS_DEVICE BatchedReduction(Params const ¶ms_)
ctor
Definition: batched_reduction.h:63
cutlass::reduction::BatchedReduction::Params
Traits::Params Params
Params.
Definition: batched_reduction.h:58
cutlass::reduction::BatchedReduction::functor_caller
CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output)
Definition: batched_reduction.h:134
cutlass::reduction::BatchedReduction::Functor
Traits::Functor Functor
functor
Definition: batched_reduction.h:60
cutlass::reduction::BatchedReduction::Traits
BatchedReductionTraits_ Traits
The traits.
Definition: batched_reduction.h:56
cutlass::reduction::BatchedReduction::launch
static __host__ cudaError_t launch(Params const ¶ms, cudaStream_t stream=cudaStreamDefault)
Launch the kernel.
Definition: batched_reduction.h:154
Generated by 1.8.11