This repository has been archived by the owner on Dec 29, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathsimd_neon.h
161 lines (131 loc) · 4.45 KB
/
simd_neon.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
/*
* Copyright 2017 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DIMSUM_SIMD_NEON_H_
#define DIMSUM_SIMD_NEON_H_
#include <arm_neon.h>
#include "operations.h"
namespace dimsum {
namespace detail {
// llvm.experimental.vector.reduce.add.* already lowers to single instructions
// like addv, but I'm not sure if there are intrinsics for them.
} // namespace detail
template <>
inline Simd128<int8> abs(Simd128<int8> simd) {
return vabsq_s8(to_raw(simd));
}
template <>
inline Simd128<int16> abs(Simd128<int16> simd) {
return vabsq_s16(to_raw(simd));
}
template <>
inline Simd128<int32> abs(Simd128<int32> simd) {
return vabsq_s32(to_raw(simd));
}
template <>
inline Simd128<int64> abs(Simd128<int64> simd) {
return vabsq_s64(to_raw(simd));
}
template <>
inline Simd128<float> abs(Simd128<float> simd) {
return vabsq_f32(to_raw(simd));
}
template <>
inline Simd128<double> abs(Simd128<double> simd) {
return vabsq_f64(to_raw(simd));
}
template <>
inline Simd128<float> reciprocal_estimate(Simd128<float> simd) {
return vrecpeq_f32(to_raw(simd));
}
template <>
inline Simd128<float> sqrt(Simd128<float> simd) {
return vsqrtq_f32(to_raw(simd));
}
template <>
inline Simd128<double> sqrt(Simd128<double> simd) {
return vsqrtq_f64(to_raw(simd));
}
template <>
inline Simd128<float> reciprocal_sqrt_estimate(Simd128<float> simd) {
return vrsqrteq_f32(to_raw(simd));
}
template <>
inline Simd128<int8> add_saturated(Simd128<int8> lhs, Simd128<int8> rhs) {
return vqaddq_s8(to_raw(lhs), to_raw(rhs));
}
template <>
inline Simd128<uint8> add_saturated(Simd128<uint8> lhs, Simd128<uint8> rhs) {
return vqaddq_u8(to_raw(lhs), to_raw(rhs));
}
template <>
inline Simd128<int16> add_saturated(Simd128<int16> lhs, Simd128<int16> rhs) {
return vqaddq_s16(to_raw(lhs), to_raw(rhs));
}
template <>
inline Simd128<uint16> add_saturated(Simd128<uint16> lhs, Simd128<uint16> rhs) {
return vqaddq_u16(to_raw(lhs), to_raw(rhs));
}
template <>
inline Simd128<int8> sub_saturated(Simd128<int8> lhs, Simd128<int8> rhs) {
return vqsubq_s8(to_raw(lhs), to_raw(rhs));
}
template <>
inline Simd128<uint8> sub_saturated(Simd128<uint8> lhs, Simd128<uint8> rhs) {
return vqsubq_u8(to_raw(lhs), to_raw(rhs));
}
template <>
inline Simd128<int16> sub_saturated(Simd128<int16> lhs, Simd128<int16> rhs) {
return vqsubq_s16(to_raw(lhs), to_raw(rhs));
}
template <>
inline Simd128<uint16> sub_saturated(Simd128<uint16> lhs, Simd128<uint16> rhs) {
return vqsubq_u16(to_raw(lhs), to_raw(rhs));
}
// An alternative implementation of the SSE intrinsic function _mm_madd_epi16
// on ARM. It breaks a Simd object into the low and high parts. Then values in
// each part are multiplied and summed pairwisely before being concatenated.
template <>
inline Simd128<int32> mul_sum(Simd128<int16> lhs, Simd128<int16> rhs,
Simd128<int32> acc) {
int16x8_t lhs_raw = to_raw(lhs);
int16x8_t rhs_raw = to_raw(rhs);
int32x4_t mullo = vmull_s16(vget_low_s16(lhs_raw), vget_low_s16(rhs_raw));
int32x4_t mulhi = vmull_s16(vget_high_s16(lhs_raw), vget_high_s16(rhs_raw));
int32x2_t addlo = vpadd_s32(vget_low_s32(mullo), vget_high_s32(mullo));
int32x2_t addhi = vpadd_s32(vget_low_s32(mulhi), vget_high_s32(mulhi));
return vaddq_s32(to_raw(acc), vcombine_s32(addlo, addhi));
}
// vrndnq_f{32,64} translate to VRINTN.F{16,32}, which round floating points
// using the round-to-even rule (Round to Nearest rounding mode in ARM
// parlance).
template <>
inline Simd128<float> round(Simd128<float> simd) {
return vrndnq_f32(to_raw(simd));
}
template <>
inline Simd128<double> round(Simd128<double> simd) {
return vrndnq_f64(to_raw(simd));
}
template <>
inline Simd128<int32> round_to_integer(Simd128<float> simd) {
return vcvtnq_s32_f32(to_raw(simd));
}
template <typename T>
Simd128<ScaleBy<T, 2>> mul_widened(Simd64<T> lhs, Simd64<T> rhs) {
return simd_cast<ScaleBy<T, 2>>(lhs) * simd_cast<ScaleBy<T, 2>>(rhs);
}
} // namespace dimsum
#endif // DIMSUM_SIMD_NEON_H_