docs/gemv__batched__strided_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
gemv_batched_strided.h
[Go to the documentation of this file.](gemv batched strided_8h.html)
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25
26 #pragma once
27
28 #include "cutlass/cutlass.h"
29
30 #include "cutlass/aligned_buffer.h"
31 #include "cutlass/array.h"
32
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/matrix_shape.h"
35
36 #include "cutlass/gemm/gemm.h"
37
39
40 namespace cutlass {
41 namespace gemm {
42 namespace kernel {
43
44 namespace detail
45 {
46template<typename ElementAlphaBeta, bool BetaIsZero>
47struct GemvBatchedStridedEpilogueScaling
48 {
49 ElementAlphaBeta const & alpha;
50 ElementAlphaBeta const & beta;
51
52 CUTLASS_DEVICE
53GemvBatchedStridedEpilogueScaling(ElementAlphaBeta& alpha_, ElementAlphaBeta& beta_) :
54 alpha(alpha_), beta(beta_)
55 { }
56
57template<typename FragmentCD, typename FragmentAccumulator>
58 CUTLASS_DEVICE
59void operator()(FragmentAccumulator& accumulators,
60 FragmentCD const& fragment_C,
61 FragmentCD& fragment_D) const
62 {
63using AccType = typename FragmentAccumulator::value_type;
64using CDType = typename FragmentCD::value_type;
65
66static_assert(FragmentCD::kElements == FragmentAccumulator::kElements,
67"Mistmatch in fragment sizes.");
68
69for (int i = 0; i < FragmentCD::kElements; ++i)
70 {
71if (BetaIsZero)
72 {
73 fragment_D[i] = CDType(accumulators[i] * AccType(alpha));
74 }
75else
76 {
77 fragment_D[i] = CDType(accumulators[i] * AccType(alpha)
78 + AccType(fragment_C[i]) * AccType(beta));
79 }
80 }
81 }
82 };
83 }
84
86
87 template <typename GemvKernel, typename ElementAlphaBeta, bool BetaIsZero=false>
88 CUTLASS_DEVICE void GemvBatchedStridedDevice(
89cutlass::gemm::BatchedGemmCoord problem_size,
90 ElementAlphaBeta alpha,
91 ElementAlphaBeta beta,
92typename GemvKernel::IteratorA::TensorRef ref_A,
93typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
94typename GemvKernel::IteratorB::TensorRef ref_B,
95typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
96typename GemvKernel::IteratorCD::TensorRef ref_C,
97typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,
98typename GemvKernel::IteratorCD::TensorRef ref_D,
99typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
100 {
101using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv;
102using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle;
103using EpilogueScale = detail::GemvBatchedStridedEpilogueScaling<ElementAlphaBeta, BetaIsZero>;
104
105 ThreadBlockSwizzle swizzler;
106
107// Compute initial location in logical coordinates
108BatchedGemmCoord tb_offset = swizzler.get_tile_offset();
109int const batch_idx = swizzler.get_batch_idx();
110
111// Offset to the batch
112 ref_A.add_pointer_offset(batch_idx*lda);
113 ref_B.add_pointer_offset(batch_idx*ldb);
114
115// Construct iterators to A and B operands
116typename GemvKernel::IteratorA::Params params_A(ref_A.layout());
117typename GemvKernel::IteratorA iterator_A(
118 params_A,
119 ref_A.data(),
120 { 1, problem_size.k() },
121 0,
122 { 0, 0 });
123
124typename GemvKernel::IteratorB::Params params_B(ref_B.layout());
125typename GemvKernel::IteratorB iterator_B(
126 params_B,
127 ref_B.data(),
128 { problem_size.k(), problem_size.n() },
129 threadIdx.x,
130 { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });
131
132//
133// Main loop
134//
135
136// Construct thread-scoped matrix multiply
137 ThreadBlockGemv mma;
138
139typename ThreadBlockGemv::FragmentC accumulators;
140 accumulators.clear();
141
142// Compute threadblock-scoped gemv
143 mma(problem_size.mnk(), accumulators, iterator_A, iterator_B, accumulators);
144
145//
146// Epilogue (TODO: Epiloge as template argument)
147//
148typename GemvKernel::FragmentCD fragment_CD;
149
150// Load C (skip if beta is zero)
151if (!BetaIsZero)
152 {
153 tb_offset = swizzler.get_tile_offset();
154 ref_C.add_pointer_offset(batch_idx*ldc);
155typename GemvKernel::IteratorCD::Params params_C(ref_C.layout());
156typename GemvKernel::IteratorCD iterator_C(
157 params_C,
158 ref_C.data(),
159 { 1, problem_size.n() },
160 threadIdx.x,
161 { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });
162 iterator_C.load(fragment_CD);
163 }
164
165// Apply alpha/beta scaling
166 EpilogueScale epilogue_scale(alpha, beta);
167 epilogue_scale(accumulators, fragment_CD, fragment_CD);
168
169// Store D
170 tb_offset = swizzler.get_tile_offset();
171 ref_D.add_pointer_offset(batch_idx*ldd);
172typename GemvKernel::IteratorCD::Params params_D(ref_D.layout());
173typename GemvKernel::IteratorCD iterator_D(
174 params_D,
175 ref_D.data(),
176 { 1, problem_size.n() },
177 threadIdx.x,
178 { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });
179 iterator_D.store(fragment_CD);
180 }
181
182 template <typename GemvKernel, typename ElementAlphaBeta, bool BetaIsZero>
183 __global__ void GemvBatchedStrided(
184cutlass::gemm::BatchedGemmCoord problem_size,
185 ElementAlphaBeta alpha,
186 ElementAlphaBeta beta,
187typename GemvKernel::IteratorA::TensorRef ref_A,
188typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
189typename GemvKernel::IteratorB::TensorRef ref_B,
190typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
191typename GemvKernel::IteratorCD::TensorRef ref_C,
192typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,
193typename GemvKernel::IteratorCD::TensorRef ref_D,
194typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
195 {
196 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, BetaIsZero>(
197 problem_size, alpha, beta, ref_A, lda, ref_B, ldb, ref_C, ldc, ref_D, ldd
198 );
199 }
200
201 template <typename GemvKernel, typename ElementAlphaBeta>
202 __global__ void GemvBatchedStrided(
203cutlass::gemm::BatchedGemmCoord problem_size,
204 ElementAlphaBeta alpha,
205typename GemvKernel::IteratorA::TensorRef ref_A,
206typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
207typename GemvKernel::IteratorB::TensorRef ref_B,
208typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
209typename GemvKernel::IteratorCD::TensorRef ref_D,
210typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
211 {
212 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(
213 problem_size, alpha, ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd
214 );
215 }
216
217 template <typename GemvKernel>
218 __global__ void GemvBatchedStrided(
219cutlass::gemm::BatchedGemmCoord problem_size,
220typename GemvKernel::IteratorA::TensorRef ref_A,
221typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
222typename GemvKernel::IteratorB::TensorRef ref_B,
223typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
224typename GemvKernel::IteratorCD::TensorRef ref_D,
225typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
226 {
227using ElementAlphaBeta = typename GemvKernel::IteratorCD::Element;
228 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(
229 problem_size, ElementAlphaBeta(1), ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd
230 );
231 }
232
233
235
236 } // namespace kernel
237 } // namespace gemm
238 } // namespace cutlass
Definition: aligned_buffer.h:35
cutlass::gemm::BatchedGemmCoord::mnk
CUTLASS_HOST_DEVICE GemmCoord mnk() const
Obtains a GemmCoord from BatchedGemmCoord.
Definition: include/cutlass/gemm/gemm.h:330
Defines common types used for all GEMM-like operators.
cutlass::gemm::kernel::GemvBatchedStridedDevice
CUTLASS_DEVICE void GemvBatchedStridedDevice(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
Definition: gemv_batched_strided.h:88
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
cutlass::gemm::kernel::GemvBatchedStrided
__global__ void GemvBatchedStrided(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
Definition: gemv_batched_strided.h:183
Defines a Shape template for matrix tiles.
cutlass::gemm::BatchedGemmCoord
Definition: include/cutlass/gemm/gemm.h:260
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
cutlass::gemm::BatchedGemmCoord::k
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:314
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling::beta
ElementAlphaBeta const & beta
Definition: gemv_batched_strided.h:50
cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling::operator()
CUTLASS_DEVICE void operator()(FragmentAccumulator &accumulators, FragmentCD const &fragment_C, FragmentCD &fragment_D) const
Definition: gemv_batched_strided.h:59
cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling::alpha
ElementAlphaBeta const & alpha
Definition: gemv_batched_strided.h:49
cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling
Definition: gemv_batched_strided.h:47
cutlass::gemm::BatchedGemmCoord::n
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:306
Basic include for CUTLASS.
cutlass::gemm::kernel::detail::GemvBatchedStridedEpilogueScaling::GemvBatchedStridedEpilogueScaling
CUTLASS_DEVICE GemvBatchedStridedEpilogueScaling(ElementAlphaBeta &alpha_, ElementAlphaBeta &beta_)
Definition: gemv_batched_strided.h:53
Generated by 1.8.11