docs/gemm__pipelined_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
gemm_pipelined.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30
31 #include "cutlass/cutlass.h"
32
33 #include "cutlass/aligned_buffer.h"
34 #include "cutlass/array.h"
35
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/matrix_shape.h"
38
39 #include "cutlass/gemm/gemm.h"
40
42
43 namespace cutlass {
44 namespace gemm {
45 namespace kernel {
46
48
49 template <typename Mma, typename Epilogue, typename ThreadblockSwizzle>
50 __global__ void GemmPipelined(
51cutlass::gemm::GemmCoord problem_size,
52cutlass::gemm::GemmCoord grid_tiled_shape,
53typename Mma::IteratorA::Params params_A,
54typename Mma::IteratorA::TensorRef ref_A,
55typename Mma::IteratorB::Params params_B,
56typename Mma::IteratorB::TensorRef ref_B,
57typename Epilogue::Params params_epilogue
58 ) {
59
60// Shared storage needed by threadblock-scoped matrix multiply-accumulate
61 __shared__ union {
62typename Mma::SharedStorage main_loop;
63typename Epilogue::SharedStorage epilogue;
64 } shared_storage;
65
66// Compute threadblock location
67 ThreadblockSwizzle threadblock_swizzle;
68
69cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset();
70
71if (grid_tiled_shape.m() <= tb_tile_offset.m() ||
72 grid_tiled_shape.n() <= tb_tile_offset.n()) {
73
74return;
75 }
76
77// Compute initial location in logical coordinates
78cutlass::MatrixCoord tb_offset_A{
79 tb_tile_offset.m() * Mma::Shape::kM,
80 tb_tile_offset.k()
81 };
82
83cutlass::MatrixCoord tb_offset_B{
84 tb_tile_offset.k(),
85 tb_tile_offset.n() * Mma::Shape::kN
86 };
87
88// Compute position within threadblock
89int tb_thread_id = threadIdx.x;
90
91// Construct iterators to A and B operands
92typename Mma::IteratorA iterator_A(
93 params_A,
94 ref_A.data(),
95 {problem_size.m(), problem_size.k()},
96 tb_thread_id,
97 tb_offset_A);
98
99typename Mma::IteratorB iterator_B(
100 params_B,
101 ref_B.data(),
102 {problem_size.k(), problem_size.n()},
103 tb_thread_id,
104 tb_offset_B);
105
106int warp_id = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
107int lane_id = threadIdx.x % 32;
108
109//
110// Main loop
111//
112
113// Construct thread-scoped matrix multiply
114 Mma mma(shared_storage.main_loop, tb_thread_id, warp_id, lane_id);
115
116typename Mma::FragmentC accumulators;
117
118 accumulators.clear();
119
120// Compute threadblock-scoped matrix multiply-add
121 mma(problem_size, accumulators, iterator_A, iterator_B, accumulators);
122
123//
124// Epilogue
125//
126
127 Epilogue epilogue(
128 params_epilogue,
129 shared_storage.epilogue,
130 tb_thread_id,
131 warp_id,
132 lane_id);
133
134 tb_tile_offset = threadblock_swizzle.get_tile_offset();
135
136//assume identity swizzle
137MatrixCoord threadblock_offset(
138 tb_tile_offset.m() * Mma::Shape::kM,
139 tb_tile_offset.n() * Mma::Shape::kN
140 );
141
142// run efficient epilogue
143 epilogue({problem_size.m(), problem_size.n()}, accumulators, threadblock_offset);
144 }
145
147
148 } // namespace kernel
149 } // namespace gemm
150 } // namespace cutlass
Definition: aligned_buffer.h:35
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::kernel::GemmPipelined
__global__ void GemmPipelined(cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord grid_tiled_shape, typename Mma::IteratorA::Params params_A, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::Params params_B, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::Params params_epilogue)
Definition: gemm_pipelined.h:50
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Defines a Shape template for matrix tiles.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Basic include for CUTLASS.
Definition: matrix_coord.h:39
Generated by 1.8.11