docs/inner__product_8h_source.html
| | CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers |
inner_product.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
28 #pragma once
29
30 #include "cutlass/cutlass.h"
31 #include "cutlass/array.h"
32
33 namespace cutlass {
34 namespace reference {
35 namespace detail {
36
38
40 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a
41// host-only type
42 template <typename Atype, typename Btype, typename Ctype>
44 Ctype inner_product(Atype a, Btype b, Ctype c) {
45return Ctype(a) * Ctype(b) + c;
46 }
47
49 template <>
51 int inner_product<Array<bin1_t, 32>, Array<bin1_t, 32>, int>(
52 Array<bin1_t, 32> a,
53 Array<bin1_t, 32> b,
54int c) {
55
56int accum = 0;
57for (int bit = 0; bit < 32; bit++) {
58 accum += a[bit] ^ b[bit];
59 }
60return accum + c;
61 }
62
63 /*
65 template <>
66 CUTLASS_HOST_DEVICE
67 int inner_product<Array<int4b_t, 8>, Array<int4b_t, 8>, int>(
68 Array<int4b_t, 8> a,
69 Array<int4b_t, 8> b,
70 int c) {
71
72 int accum = 0;
73 for (int k = 0; k < 8; k++) {
74 accum += a[k] * b[k];
75 }
76 return accum + c;
77 }
78
80 template <>
81 CUTLASS_HOST_DEVICE
82 int inner_product<Array<uint4b_t, 8>, Array<uint4b_t, 8>, int>(
83 Array<uint4b_t, 8> a,
84 Array<uint4b_t, 8> b,
85 int c) {
86
87 int accum = 0;
88 for (int k = 0; k < 8; k++) {
89 accum += a[k] * b[k];
90 }
91 return accum + c;
92 }
93 */
94
96
97 template <typename SrcType, typename DstType>
99// Default behavior: convert to the destination type
100 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
101// host-only type
103static DstType apply(SrcType src) { return static_cast<DstType>(src); };
104 };
105
106 template <>
107 struct Cast<float, int8_t> {
109static int8_t apply(float src) {
110// Clamp to the range of signed 8-bit integers.
111return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
112 };
113 };
114
115 template <>
116 struct Cast<float, uint8_t> {
118static uint8_t apply(float src) {
119// Clamp to the range of signed 8-bit integers.
120return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
121 };
122 };
123
125
126 } // namespace detail
127 } // namespace reference
128 } // namespace cutlass
129
Definition: aligned_buffer.h:35
cutlass::reference::detail::Cast::apply
static CUTLASS_HOST_DEVICE DstType apply(SrcType src)
Definition: inner_product.h:103
cutlass::reference::detail::inner_product
CUTLASS_HOST_DEVICE Ctype inner_product(Atype a, Btype b, Ctype c)
Template function to compute an inner product.
Definition: inner_product.h:44
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
cutlass::reference::detail::Cast< float, uint8_t >::apply
static CUTLASS_HOST_DEVICE uint8_t apply(float src)
Definition: inner_product.h:118
cutlass::reference::detail::Cast
Definition: inner_product.h:98
Basic include for CUTLASS.
cutlass::reference::detail::Cast< float, int8_t >::apply
static CUTLASS_HOST_DEVICE int8_t apply(float src)
Definition: inner_product.h:109
Generated by 1.8.11