Back to Cutlass

CUTLASS: default_thread_map_simt.h Source File

docs/default__thread__map__simt_8h_source.html

4.4.213.3 KB
Original Source

| | CUTLASS

CUDA Templates for Linear Algebra Subroutines and Solvers |

default_thread_map_simt.h

[Go to the documentation of this file.](default thread map__simt_8h.html)

1 /***************************************************************************************************

2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.

3 *

4 * Redistribution and use in source and binary forms, with or without modification, are permitted

5 * provided that the following conditions are met:

6 * * Redistributions of source code must retain the above copyright notice, this list of

7 * conditions and the following disclaimer.

8 * * Redistributions in binary form must reproduce the above copyright notice, this list of

9 * conditions and the following disclaimer in the documentation and/or other materials

10 * provided with the distribution.

11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used

12 * to endorse or promote products derived from this software without specific prior written

13 * permission.

14 *

15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR

16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND

17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE

18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,

19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;

20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,

21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE

22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

23 *

24 **************************************************************************************************/

30 #pragma once

31

32 #include "[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)"

33 #include "cutlass/gemm/gemm.h"

34

36

37 namespace cutlass {

38 namespace epilogue {

39 namespace threadblock {

40

42

44 template <

45typename ThreadblockShape_,

46typename WarpShape_,

47typename MmaSimtPolicy_,

48int PartitionsK,

49typename Element_,

50int ElementsPerAccess

51 >

52 struct DefaultThreadMapSimt {

53

54using ThreadblockShape = ThreadblockShape_;

55using WarpShape = WarpShape_;

56using MmaSimtPolicy = MmaSimtPolicy_;

57static int const kPartitionsK = PartitionsK;

58using Element = Element_;

59static int const kElementsPerAccess = ElementsPerAccess;

60

61//

62// Definitions

63//

64

65struct Detail {

66

67static int const kWarpSize = 32;

68

69static_assert(

70 !(ThreadblockShape::kM % WarpShape::kM) &&

71 !(ThreadblockShape::kM % WarpShape::kM), "Divisibility");

72

74using WarpCount = gemm::GemmShape<

75 ThreadblockShape::kM / WarpShape::kM,

76 ThreadblockShape::kN / WarpShape::kN,

77 kPartitionsK

78 >;

79

81static int const kGroupCount =

82 WarpShape::kM / (MmaSimtPolicy::WarpShape::kRow * MmaSimtPolicy::LaneMmaShape::kM);

83

85static int const kThreads = WarpCount::kCount * kWarpSize;

86

88static int const kIterations = MmaSimtPolicy::LaneMmaShape::kM * kGroupCount;

89 };

90

91//

92// ThreadMap

93//

94

96using Type = OutputTileOptimalThreadMap<

97OutputTileShape< // Shape

98 ThreadblockShape::kN,

99 1,

100 MmaSimtPolicy::WarpShape::kRow,

101Detail::WarpCount::kM,

102 1>,

103OutputTileShape< // Count

104 1,

105 MmaSimtPolicy::LaneMmaShape::kM,

106Detail::kGroupCount,

107 1,

108Detail::kIterations>,

109Detail::kThreads,

110kElementsPerAccess,

111sizeof_bits<Element>::value

112 >;

113 };

114

116

117 } // namespace threadblock

118 } // namespace epilogue

119 } // namespace cutlass

120

cutlass::gemm::GemmShape::kM

static int const kM

Definition: include/cutlass/gemm/gemm.h:58

cutlass::epilogue::threadblock::OutputTileOptimalThreadMap

Definition: output_tile_thread_map.h:228

cutlass

Definition: aligned_buffer.h:35

cutlass::epilogue::threadblock::DefaultThreadMapSimt::ThreadblockShape

ThreadblockShape_ ThreadblockShape

Definition: default_thread_map_simt.h:54

cutlass::epilogue::threadblock::DefaultThreadMapSimt::MmaSimtPolicy

MmaSimtPolicy_ MmaSimtPolicy

Definition: default_thread_map_simt.h:56

cutlass::epilogue::threadblock::OutputTileShape

Tuple defining point in output tile.

Definition: output_tile_thread_map.h:57

cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail::kThreads

static int const kThreads

Number of participating threads.

Definition: default_thread_map_simt.h:85

[predicated_tile_iterator.h](epilogue_2threadblock_2predicated tile iterator_8h.html)

Epilogue for threadblock scoped GEMMs using Tensor Ops.

gemm.h

Defines common types used for all GEMM-like operators.

cutlass::gemm::GemmShape::kCount

static int const kCount

Definition: include/cutlass/gemm/gemm.h:67

cutlass::epilogue::threadblock::DefaultThreadMapSimt

Defines the optimal thread map for SIMT accumulator layouts.

Definition: default_thread_map_simt.h:52

cutlass::sizeof_bits

Defines the size of an element in bits.

Definition: numeric_types.h:42

cutlass::epilogue::threadblock::DefaultThreadMapSimt::Element

Element_ Element

Definition: default_thread_map_simt.h:58

cutlass::epilogue::threadblock::DefaultThreadMapSimt::kElementsPerAccess

static int const kElementsPerAccess

Definition: default_thread_map_simt.h:59

cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail::kIterations

static int const kIterations

Number of iterations.

Definition: default_thread_map_simt.h:88

cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail::kWarpSize

static int const kWarpSize

Definition: default_thread_map_simt.h:67

cutlass::epilogue::threadblock::DefaultThreadMapSimt::kPartitionsK

static int const kPartitionsK

Definition: default_thread_map_simt.h:57

cutlass::gemm::GemmShape

Shape of a matrix multiply-add operation.

Definition: include/cutlass/gemm/gemm.h:57

static_assert

#define static_assert(__e, __m)

Definition: platform.h:153

cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail

Definition: default_thread_map_simt.h:65

cutlass::epilogue::threadblock::DefaultThreadMapSimt::WarpShape

WarpShape_ WarpShape

Definition: default_thread_map_simt.h:55

cutlass::epilogue::threadblock::DefaultThreadMapSimt::Detail::kGroupCount

static int const kGroupCount

Computes number of thread-level matrix multiplies are needed to span a warp.

Definition: default_thread_map_simt.h:81


Generated by 1.8.11