docs/device_2kernel_2tensor__elementwise_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
device/kernel/tensor_elementwise.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
25
26 #pragma once
27
28 #include <curand_kernel.h>
29
30 #include "cutlass/cutlass.h"
31
32 namespace cutlass {
33 namespace reference {
34 namespace device {
35 namespace kernel {
36
38
40 template <typename T>
41 __global__ void TensorInitializeUniform(
42Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
43 __shared__ curandState_t rng_state[1024];
44
45 uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
46
47 curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
48
49int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
50int s_idx = blockIdx.y * blockDim.x;
51
52 tensor += s_idx * ldm + c_idx;
53
54for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
55if (s_idx < dim_strided && c_idx < dim_contiguous) {
56double range = dist.uniform.max - dist.uniform.min;
57
58double rnd = curand_uniform(&rng_state[threadIdx.x]);
59
60 rnd = dist.uniform.min + range * rnd;
61
62// Random values are cast to integer after scaling by a power of two to facilitate error
63// testing
64if (dist.int_scale >= 0) {
65 rnd = double(int(rnd * double(1 << dist.int_scale)));
66 *tensor = T(rnd / double(1 << dist.int_scale));
67 } else {
68 *tensor = T(rnd);
69 }
70
71 tensor += ldm;
72 }
73 }
74 }
75
77
79 template <typename T>
80 __global__ void TensorInitializeGaussian(
81Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
82 __shared__ curandState_t rng_state[1024];
83
84 uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
85
86 curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
87
88int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
89int s_idx = blockIdx.y * blockDim.x;
90
91 tensor += s_idx * ldm + c_idx;
92
93for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
94if (s_idx < dim_strided && c_idx < dim_contiguous) {
95// Random values are cast to integer after scaling by a power of two to facilitate error
96// testing
97
98double rnd = curand_normal(&rng_state[threadIdx.x]);
99
100 rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd;
101
102if (dist.int_scale >= 0) {
103 rnd = double(int(rnd * double(1 << dist.int_scale)));
104 *tensor = T(rnd / double(1 << dist.int_scale));
105 } else {
106 *tensor = T(rnd);
107 }
108 }
109 }
110 }
111
113 template <typename T>
114 __global__ void TensorInitializeLinear(
115Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
116 __shared__ curandState_t rng_state[1024];
117
118 uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
119
120 curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
121
122int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
123int s_idx = blockIdx.y * blockDim.x;
124
125 tensor += s_idx * ldm + c_idx;
126
127for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
128if (s_idx < dim_strided && c_idx < dim_contiguous) {
129 *tensor =
130 dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx;
131 }
132 }
133 }
134
136 template <typename T>
137 __global__ void TensorInitializeIdentity(
138Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) {
139 __shared__ curandState_t rng_state[1024];
140
141 uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x;
142
143 curand_init(seed, gtid, 0, &rng_state[threadIdx.x]);
144
145int c_idx = blockIdx.x * blockDim.x + threadIdx.x;
146int s_idx = blockIdx.y * blockDim.x;
147
148 tensor += s_idx * ldm + c_idx;
149
150for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) {
151if (s_idx < dim_strided && c_idx < dim_contiguous) {
152 *tensor = (c_idx == s_idx ? T(1) : T(0));
153 }
154 }
155 }
156
158
159 } // namespace kernel
160 } // namespace device
161 } // namespace reference
162 } // namespace cutlass
cutlass::reference::device::kernel::TensorInitializeUniform
__global__ void TensorInitializeUniform(Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm)
Kernel to initialize tensor to uniform random distribution.
Definition: device/kernel/tensor_elementwise.h:41
Definition: aligned_buffer.h:35
cutlass::reference::device::kernel::TensorInitializeGaussian
__global__ void TensorInitializeGaussian(Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm)
Kernel to initialize tensor to uniform distribution.
Definition: device/kernel/tensor_elementwise.h:80
cutlass::Distribution::uniform
struct cutlass::Distribution::@18::@20 uniform
Uniform distribution.
cutlass::reference::device::kernel::TensorInitializeLinear
__global__ void TensorInitializeLinear(Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm)
Kernel to initialize tensor to an identity matrix.
Definition: device/kernel/tensor_elementwise.h:114
cutlass::Distribution::gaussian
struct cutlass::Distribution::@18::@21 gaussian
Gaussian distribution.
cutlass::reference::device::kernel::TensorInitializeIdentity
__global__ void TensorInitializeIdentity(Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm)
Kernel to initialize tensor to an identity matrix.
Definition: device/kernel/tensor_elementwise.h:137
Distribution type.
Definition: distribution.h:38
cutlass::Distribution::int_scale
int int_scale
Random values are cast to integer after scaling by this power of two.
Definition: distribution.h:67
Basic include for CUTLASS.
Generated by 1.8.11