blob: 5483e5c4aad73bb479898ebd00bfd66eb8d1105f [file] [log] [blame]
Gael Guennebaud498aa952017-06-09 11:53:49 +02001// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#include "main.h"
11
Antonio Sanchezbde67412021-01-16 10:22:07 -080012template<typename T, typename U>
13bool check_if_equal_or_nans(const T& actual, const U& expected) {
Erik Schultheisd271a7d2022-01-26 18:16:19 +000014 return (numext::equal_strict(actual, expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
Antonio Sanchezbde67412021-01-16 10:22:07 -080015}
16
17template<typename T, typename U>
18bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
19 return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
20 && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
21}
22
23template<typename T, typename U>
24bool test_is_equal_or_nans(const T& actual, const U& expected)
25{
26 if (check_if_equal_or_nans(actual, expected)) {
27 return true;
28 }
29
30 // false:
31 std::cerr
32 << "\n actual = " << actual
33 << "\n expected = " << expected << "\n\n";
34 return false;
35}
36
37#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
38
Gael Guennebaud498aa952017-06-09 11:53:49 +020039template<typename T>
40void check_abs() {
41 typedef typename NumTraits<T>::Real Real;
Gael Guennebaud32d72322019-01-15 11:18:48 +010042 Real zero(0);
Gael Guennebaud498aa952017-06-09 11:53:49 +020043
44 if(NumTraits<T>::IsSigned)
45 VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
46 VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
47 VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
48
Antonio Sanchezbde67412021-01-16 10:22:07 -080049 for(int k=0; k<100; ++k)
Gael Guennebaud498aa952017-06-09 11:53:49 +020050 {
51 T x = internal::random<T>();
52 if(!internal::is_same<T,bool>::value)
53 x = x/Real(2);
54 if(NumTraits<T>::IsSigned)
55 {
56 VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
Gael Guennebaud32d72322019-01-15 11:18:48 +010057 VERIFY( numext::abs(-x) >= zero );
Gael Guennebaud498aa952017-06-09 11:53:49 +020058 }
Gael Guennebaud32d72322019-01-15 11:18:48 +010059 VERIFY( numext::abs(x) >= zero );
Gael Guennebaud498aa952017-06-09 11:53:49 +020060 VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
61 }
62}
63
Antonio Sanchezbde67412021-01-16 10:22:07 -080064template<typename T>
Antonio Sanchez90e9a332021-05-07 08:24:32 -070065void check_arg() {
66 typedef typename NumTraits<T>::Real Real;
67 VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
68 VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
69
70 for(int k=0; k<100; ++k)
71 {
72 T x = internal::random<T>();
73 Real y = numext::arg(x);
74 VERIFY_IS_APPROX( y, std::arg(x) );
75 }
76}
77
78template<typename T>
Antonio Sanchezbde67412021-01-16 10:22:07 -080079struct check_sqrt_impl {
80 static void run() {
81 for (int i=0; i<1000; ++i) {
82 const T x = numext::abs(internal::random<T>());
83 const T sqrtx = numext::sqrt(x);
84 VERIFY_IS_APPROX(sqrtx*sqrtx, x);
85 }
Gael Guennebaud498aa952017-06-09 11:53:49 +020086
Antonio Sanchezbde67412021-01-16 10:22:07 -080087 // Corner cases.
88 const T zero = T(0);
89 const T one = T(1);
90 const T inf = std::numeric_limits<T>::infinity();
91 const T nan = std::numeric_limits<T>::quiet_NaN();
92 VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
93 VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
94 VERIFY((numext::isnan)(numext::sqrt(nan)));
95 VERIFY((numext::isnan)(numext::sqrt(-one)));
96 }
97};
98
99template<typename T>
100struct check_sqrt_impl<std::complex<T> > {
101 static void run() {
102 typedef typename std::complex<T> ComplexT;
103
104 for (int i=0; i<1000; ++i) {
105 const ComplexT x = internal::random<ComplexT>();
106 const ComplexT sqrtx = numext::sqrt(x);
107 VERIFY_IS_APPROX(sqrtx*sqrtx, x);
108 }
109
110 // Corner cases.
111 const T zero = T(0);
112 const T one = T(1);
113 const T inf = std::numeric_limits<T>::infinity();
114 const T nan = std::numeric_limits<T>::quiet_NaN();
115
116 // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
117 const int kNumCorners = 20;
118 const ComplexT corners[kNumCorners][2] = {
119 {ComplexT(zero, zero), ComplexT(zero, zero)},
120 {ComplexT(-zero, zero), ComplexT(zero, zero)},
121 {ComplexT(zero, -zero), ComplexT(zero, zero)},
122 {ComplexT(-zero, -zero), ComplexT(zero, zero)},
123 {ComplexT(one, inf), ComplexT(inf, inf)},
124 {ComplexT(nan, inf), ComplexT(inf, inf)},
125 {ComplexT(one, -inf), ComplexT(inf, -inf)},
126 {ComplexT(nan, -inf), ComplexT(inf, -inf)},
127 {ComplexT(-inf, one), ComplexT(zero, inf)},
128 {ComplexT(inf, one), ComplexT(inf, zero)},
129 {ComplexT(-inf, -one), ComplexT(zero, -inf)},
130 {ComplexT(inf, -one), ComplexT(inf, -zero)},
131 {ComplexT(-inf, nan), ComplexT(nan, inf)},
132 {ComplexT(inf, nan), ComplexT(inf, nan)},
133 {ComplexT(zero, nan), ComplexT(nan, nan)},
134 {ComplexT(one, nan), ComplexT(nan, nan)},
135 {ComplexT(nan, zero), ComplexT(nan, nan)},
136 {ComplexT(nan, one), ComplexT(nan, nan)},
137 {ComplexT(nan, -one), ComplexT(nan, nan)},
138 {ComplexT(nan, nan), ComplexT(nan, nan)},
139 };
140
141 for (int i=0; i<kNumCorners; ++i) {
142 const ComplexT& x = corners[i][0];
143 const ComplexT sqrtx = corners[i][1];
144 VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
145 }
146 }
147};
148
149template<typename T>
150void check_sqrt() {
151 check_sqrt_impl<T>::run();
152}
153
154template<typename T>
155struct check_rsqrt_impl {
156 static void run() {
157 const T zero = T(0);
158 const T one = T(1);
159 const T inf = std::numeric_limits<T>::infinity();
160 const T nan = std::numeric_limits<T>::quiet_NaN();
161
162 for (int i=0; i<1000; ++i) {
163 const T x = numext::abs(internal::random<T>());
164 const T rsqrtx = numext::rsqrt(x);
165 const T invx = one / x;
166 VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
167 }
168
169 // Corner cases.
170 VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
171 VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
172 VERIFY((numext::isnan)(numext::rsqrt(nan)));
173 VERIFY((numext::isnan)(numext::rsqrt(-one)));
174 }
175};
176
177template<typename T>
178struct check_rsqrt_impl<std::complex<T> > {
179 static void run() {
180 typedef typename std::complex<T> ComplexT;
181 const T zero = T(0);
182 const T one = T(1);
183 const T inf = std::numeric_limits<T>::infinity();
184 const T nan = std::numeric_limits<T>::quiet_NaN();
185
186 for (int i=0; i<1000; ++i) {
187 const ComplexT x = internal::random<ComplexT>();
188 const ComplexT invx = ComplexT(one, zero) / x;
189 const ComplexT rsqrtx = numext::rsqrt(x);
190 VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
191 }
192
193 // GCC and MSVC differ in their treatment of 1/(0 + 0i)
194 // GCC/clang = (inf, nan)
195 // MSVC = (nan, nan)
196 // and 1 / (x + inf i)
197 // GCC/clang = (0, 0)
198 // MSVC = (nan, nan)
199 #if (EIGEN_COMP_GNUC)
200 {
201 const int kNumCorners = 20;
202 const ComplexT corners[kNumCorners][2] = {
203 // Only consistent across GCC, clang
204 {ComplexT(zero, zero), ComplexT(zero, zero)},
205 {ComplexT(-zero, zero), ComplexT(zero, zero)},
206 {ComplexT(zero, -zero), ComplexT(zero, zero)},
207 {ComplexT(-zero, -zero), ComplexT(zero, zero)},
208 {ComplexT(one, inf), ComplexT(inf, inf)},
209 {ComplexT(nan, inf), ComplexT(inf, inf)},
210 {ComplexT(one, -inf), ComplexT(inf, -inf)},
211 {ComplexT(nan, -inf), ComplexT(inf, -inf)},
212 // Consistent across GCC, clang, MSVC
213 {ComplexT(-inf, one), ComplexT(zero, inf)},
214 {ComplexT(inf, one), ComplexT(inf, zero)},
215 {ComplexT(-inf, -one), ComplexT(zero, -inf)},
216 {ComplexT(inf, -one), ComplexT(inf, -zero)},
217 {ComplexT(-inf, nan), ComplexT(nan, inf)},
218 {ComplexT(inf, nan), ComplexT(inf, nan)},
219 {ComplexT(zero, nan), ComplexT(nan, nan)},
220 {ComplexT(one, nan), ComplexT(nan, nan)},
221 {ComplexT(nan, zero), ComplexT(nan, nan)},
222 {ComplexT(nan, one), ComplexT(nan, nan)},
223 {ComplexT(nan, -one), ComplexT(nan, nan)},
224 {ComplexT(nan, nan), ComplexT(nan, nan)},
225 };
226
227 for (int i=0; i<kNumCorners; ++i) {
228 const ComplexT& x = corners[i][0];
229 const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
230 VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
231 }
232 }
233 #endif
234 }
235};
236
237template<typename T>
238void check_rsqrt() {
239 check_rsqrt_impl<T>::run();
240}
241
Charles Schlosser82b152d2022-11-04 00:31:20 +0000242template <typename T, bool IsInteger = NumTraits<T>::IsInteger>
243struct ref_signbit_func_impl {
244 static bool run(const T& x) { return std::signbit(x); }
245};
246template <typename T>
247struct ref_signbit_func_impl<T, true> {
248 // MSVC (perhaps others) does not have a std::signbit overload for integers
249 static bool run(const T& x) { return x < T(0); }
250};
251template <typename T>
252bool ref_signbit_func(const T& x) {
253 return ref_signbit_func_impl<T>::run(x);
254}
255
256template <typename T>
257struct check_signbit_impl {
258 static void run() {
259 T true_mask;
260 std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(T));
261 T false_mask;
262 std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(T));
263
264 // has sign bit
265 const T neg_zero = static_cast<T>(-0.0);
266 const T neg_one = static_cast<T>(-1.0);
267 const T neg_inf = -std::numeric_limits<T>::infinity();
268 const T neg_nan = -std::numeric_limits<T>::quiet_NaN();
269 // does not have sign bit
270 const T pos_zero = static_cast<T>(0.0);
271 const T pos_one = static_cast<T>(1.0);
272 const T pos_inf = std::numeric_limits<T>::infinity();
273 const T pos_nan = std::numeric_limits<T>::quiet_NaN();
274
275 std::vector<T> values = {neg_zero, neg_one, neg_inf, neg_nan, pos_zero, pos_one, pos_inf, pos_nan};
276
277 bool all_pass = true;
278
279 for (T val : values) {
280 const T numext_val = numext::signbit(val);
281 const T ref_val = ref_signbit_func(val) ? true_mask : false_mask;
282 bool not_same = internal::predux_any(internal::bitwise_helper<T>::bitwise_xor(ref_val, numext_val));
283 all_pass = all_pass && !not_same;
284 if (not_same) std::cout << "signbit(" << val << ") != " << numext_val << "\n";
285 }
286 VERIFY(all_pass);
287 }
288};
289template <typename T>
290void check_signbit() {
291 check_signbit_impl<T>::run();
292}
293
Antonio Sanchezbde67412021-01-16 10:22:07 -0800294EIGEN_DECLARE_TEST(numext) {
295 for(int k=0; k<g_repeat; ++k)
296 {
297 CALL_SUBTEST( check_abs<bool>() );
298 CALL_SUBTEST( check_abs<signed char>() );
299 CALL_SUBTEST( check_abs<unsigned char>() );
300 CALL_SUBTEST( check_abs<short>() );
301 CALL_SUBTEST( check_abs<unsigned short>() );
302 CALL_SUBTEST( check_abs<int>() );
303 CALL_SUBTEST( check_abs<unsigned int>() );
304 CALL_SUBTEST( check_abs<long>() );
305 CALL_SUBTEST( check_abs<unsigned long>() );
306 CALL_SUBTEST( check_abs<half>() );
307 CALL_SUBTEST( check_abs<bfloat16>() );
308 CALL_SUBTEST( check_abs<float>() );
309 CALL_SUBTEST( check_abs<double>() );
310 CALL_SUBTEST( check_abs<long double>() );
Antonio Sanchezbde67412021-01-16 10:22:07 -0800311 CALL_SUBTEST( check_abs<std::complex<float> >() );
312 CALL_SUBTEST( check_abs<std::complex<double> >() );
313
Antonio Sanchez90e9a332021-05-07 08:24:32 -0700314 CALL_SUBTEST( check_arg<std::complex<float> >() );
315 CALL_SUBTEST( check_arg<std::complex<double> >() );
316
Antonio Sanchezbde67412021-01-16 10:22:07 -0800317 CALL_SUBTEST( check_sqrt<float>() );
318 CALL_SUBTEST( check_sqrt<double>() );
319 CALL_SUBTEST( check_sqrt<std::complex<float> >() );
320 CALL_SUBTEST( check_sqrt<std::complex<double> >() );
321
322 CALL_SUBTEST( check_rsqrt<float>() );
323 CALL_SUBTEST( check_rsqrt<double>() );
324 CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
325 CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
Charles Schlosser82b152d2022-11-04 00:31:20 +0000326
327 CALL_SUBTEST( check_signbit<half>());
328 CALL_SUBTEST( check_signbit<bfloat16>());
329 CALL_SUBTEST( check_signbit<float>());
330 CALL_SUBTEST( check_signbit<double>());
331
332 CALL_SUBTEST( check_signbit<uint8_t>());
333 CALL_SUBTEST( check_signbit<uint16_t>());
334 CALL_SUBTEST( check_signbit<uint32_t>());
335 CALL_SUBTEST( check_signbit<uint64_t>());
336
337 CALL_SUBTEST( check_signbit<int8_t>());
338 CALL_SUBTEST( check_signbit<int16_t>());
339 CALL_SUBTEST( check_signbit<int32_t>());
340 CALL_SUBTEST( check_signbit<int64_t>());
Antonio Sanchezbde67412021-01-16 10:22:07 -0800341 }
Gael Guennebaud498aa952017-06-09 11:53:49 +0200342}