blob: cf1ca173dd66f3abc49974df24f09843fedcd2db [file] [log] [blame]
Gael Guennebaud498aa952017-06-09 11:53:49 +02001// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#include "main.h"
11
Antonio Sanchezbde67412021-01-16 10:22:07 -080012template<typename T, typename U>
13bool check_if_equal_or_nans(const T& actual, const U& expected) {
14 return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected)));
15}
16
17template<typename T, typename U>
18bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) {
19 return check_if_equal_or_nans(numext::real(actual), numext::real(expected))
20 && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected));
21}
22
23template<typename T, typename U>
24bool test_is_equal_or_nans(const T& actual, const U& expected)
25{
26 if (check_if_equal_or_nans(actual, expected)) {
27 return true;
28 }
29
30 // false:
31 std::cerr
32 << "\n actual = " << actual
33 << "\n expected = " << expected << "\n\n";
34 return false;
35}
36
37#define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b))
38
Gael Guennebaud498aa952017-06-09 11:53:49 +020039template<typename T>
40void check_abs() {
41 typedef typename NumTraits<T>::Real Real;
Gael Guennebaud32d72322019-01-15 11:18:48 +010042 Real zero(0);
Gael Guennebaud498aa952017-06-09 11:53:49 +020043
44 if(NumTraits<T>::IsSigned)
45 VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1));
46 VERIFY_IS_EQUAL(numext::abs(T(0)), T(0));
47 VERIFY_IS_EQUAL(numext::abs(T(1)), T(1));
48
Antonio Sanchezbde67412021-01-16 10:22:07 -080049 for(int k=0; k<100; ++k)
Gael Guennebaud498aa952017-06-09 11:53:49 +020050 {
51 T x = internal::random<T>();
52 if(!internal::is_same<T,bool>::value)
53 x = x/Real(2);
54 if(NumTraits<T>::IsSigned)
55 {
56 VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x));
Gael Guennebaud32d72322019-01-15 11:18:48 +010057 VERIFY( numext::abs(-x) >= zero );
Gael Guennebaud498aa952017-06-09 11:53:49 +020058 }
Gael Guennebaud32d72322019-01-15 11:18:48 +010059 VERIFY( numext::abs(x) >= zero );
Gael Guennebaud498aa952017-06-09 11:53:49 +020060 VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) );
61 }
62}
63
Antonio Sanchezbde67412021-01-16 10:22:07 -080064template<typename T>
65struct check_sqrt_impl {
66 static void run() {
67 for (int i=0; i<1000; ++i) {
68 const T x = numext::abs(internal::random<T>());
69 const T sqrtx = numext::sqrt(x);
70 VERIFY_IS_APPROX(sqrtx*sqrtx, x);
71 }
Gael Guennebaud498aa952017-06-09 11:53:49 +020072
Antonio Sanchezbde67412021-01-16 10:22:07 -080073 // Corner cases.
74 const T zero = T(0);
75 const T one = T(1);
76 const T inf = std::numeric_limits<T>::infinity();
77 const T nan = std::numeric_limits<T>::quiet_NaN();
78 VERIFY_IS_EQUAL(numext::sqrt(zero), zero);
79 VERIFY_IS_EQUAL(numext::sqrt(inf), inf);
80 VERIFY((numext::isnan)(numext::sqrt(nan)));
81 VERIFY((numext::isnan)(numext::sqrt(-one)));
82 }
83};
84
85template<typename T>
86struct check_sqrt_impl<std::complex<T> > {
87 static void run() {
88 typedef typename std::complex<T> ComplexT;
89
90 for (int i=0; i<1000; ++i) {
91 const ComplexT x = internal::random<ComplexT>();
92 const ComplexT sqrtx = numext::sqrt(x);
93 VERIFY_IS_APPROX(sqrtx*sqrtx, x);
94 }
95
96 // Corner cases.
97 const T zero = T(0);
98 const T one = T(1);
99 const T inf = std::numeric_limits<T>::infinity();
100 const T nan = std::numeric_limits<T>::quiet_NaN();
101
102 // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt
103 const int kNumCorners = 20;
104 const ComplexT corners[kNumCorners][2] = {
105 {ComplexT(zero, zero), ComplexT(zero, zero)},
106 {ComplexT(-zero, zero), ComplexT(zero, zero)},
107 {ComplexT(zero, -zero), ComplexT(zero, zero)},
108 {ComplexT(-zero, -zero), ComplexT(zero, zero)},
109 {ComplexT(one, inf), ComplexT(inf, inf)},
110 {ComplexT(nan, inf), ComplexT(inf, inf)},
111 {ComplexT(one, -inf), ComplexT(inf, -inf)},
112 {ComplexT(nan, -inf), ComplexT(inf, -inf)},
113 {ComplexT(-inf, one), ComplexT(zero, inf)},
114 {ComplexT(inf, one), ComplexT(inf, zero)},
115 {ComplexT(-inf, -one), ComplexT(zero, -inf)},
116 {ComplexT(inf, -one), ComplexT(inf, -zero)},
117 {ComplexT(-inf, nan), ComplexT(nan, inf)},
118 {ComplexT(inf, nan), ComplexT(inf, nan)},
119 {ComplexT(zero, nan), ComplexT(nan, nan)},
120 {ComplexT(one, nan), ComplexT(nan, nan)},
121 {ComplexT(nan, zero), ComplexT(nan, nan)},
122 {ComplexT(nan, one), ComplexT(nan, nan)},
123 {ComplexT(nan, -one), ComplexT(nan, nan)},
124 {ComplexT(nan, nan), ComplexT(nan, nan)},
125 };
126
127 for (int i=0; i<kNumCorners; ++i) {
128 const ComplexT& x = corners[i][0];
129 const ComplexT sqrtx = corners[i][1];
130 VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx);
131 }
132 }
133};
134
135template<typename T>
136void check_sqrt() {
137 check_sqrt_impl<T>::run();
138}
139
140template<typename T>
141struct check_rsqrt_impl {
142 static void run() {
143 const T zero = T(0);
144 const T one = T(1);
145 const T inf = std::numeric_limits<T>::infinity();
146 const T nan = std::numeric_limits<T>::quiet_NaN();
147
148 for (int i=0; i<1000; ++i) {
149 const T x = numext::abs(internal::random<T>());
150 const T rsqrtx = numext::rsqrt(x);
151 const T invx = one / x;
152 VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
153 }
154
155 // Corner cases.
156 VERIFY_IS_EQUAL(numext::rsqrt(zero), inf);
157 VERIFY_IS_EQUAL(numext::rsqrt(inf), zero);
158 VERIFY((numext::isnan)(numext::rsqrt(nan)));
159 VERIFY((numext::isnan)(numext::rsqrt(-one)));
160 }
161};
162
163template<typename T>
164struct check_rsqrt_impl<std::complex<T> > {
165 static void run() {
166 typedef typename std::complex<T> ComplexT;
167 const T zero = T(0);
168 const T one = T(1);
169 const T inf = std::numeric_limits<T>::infinity();
170 const T nan = std::numeric_limits<T>::quiet_NaN();
171
172 for (int i=0; i<1000; ++i) {
173 const ComplexT x = internal::random<ComplexT>();
174 const ComplexT invx = ComplexT(one, zero) / x;
175 const ComplexT rsqrtx = numext::rsqrt(x);
176 VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx);
177 }
178
179 // GCC and MSVC differ in their treatment of 1/(0 + 0i)
180 // GCC/clang = (inf, nan)
181 // MSVC = (nan, nan)
182 // and 1 / (x + inf i)
183 // GCC/clang = (0, 0)
184 // MSVC = (nan, nan)
185 #if (EIGEN_COMP_GNUC)
186 {
187 const int kNumCorners = 20;
188 const ComplexT corners[kNumCorners][2] = {
189 // Only consistent across GCC, clang
190 {ComplexT(zero, zero), ComplexT(zero, zero)},
191 {ComplexT(-zero, zero), ComplexT(zero, zero)},
192 {ComplexT(zero, -zero), ComplexT(zero, zero)},
193 {ComplexT(-zero, -zero), ComplexT(zero, zero)},
194 {ComplexT(one, inf), ComplexT(inf, inf)},
195 {ComplexT(nan, inf), ComplexT(inf, inf)},
196 {ComplexT(one, -inf), ComplexT(inf, -inf)},
197 {ComplexT(nan, -inf), ComplexT(inf, -inf)},
198 // Consistent across GCC, clang, MSVC
199 {ComplexT(-inf, one), ComplexT(zero, inf)},
200 {ComplexT(inf, one), ComplexT(inf, zero)},
201 {ComplexT(-inf, -one), ComplexT(zero, -inf)},
202 {ComplexT(inf, -one), ComplexT(inf, -zero)},
203 {ComplexT(-inf, nan), ComplexT(nan, inf)},
204 {ComplexT(inf, nan), ComplexT(inf, nan)},
205 {ComplexT(zero, nan), ComplexT(nan, nan)},
206 {ComplexT(one, nan), ComplexT(nan, nan)},
207 {ComplexT(nan, zero), ComplexT(nan, nan)},
208 {ComplexT(nan, one), ComplexT(nan, nan)},
209 {ComplexT(nan, -one), ComplexT(nan, nan)},
210 {ComplexT(nan, nan), ComplexT(nan, nan)},
211 };
212
213 for (int i=0; i<kNumCorners; ++i) {
214 const ComplexT& x = corners[i][0];
215 const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1];
216 VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx);
217 }
218 }
219 #endif
220 }
221};
222
223template<typename T>
224void check_rsqrt() {
225 check_rsqrt_impl<T>::run();
226}
227
228EIGEN_DECLARE_TEST(numext) {
229 for(int k=0; k<g_repeat; ++k)
230 {
231 CALL_SUBTEST( check_abs<bool>() );
232 CALL_SUBTEST( check_abs<signed char>() );
233 CALL_SUBTEST( check_abs<unsigned char>() );
234 CALL_SUBTEST( check_abs<short>() );
235 CALL_SUBTEST( check_abs<unsigned short>() );
236 CALL_SUBTEST( check_abs<int>() );
237 CALL_SUBTEST( check_abs<unsigned int>() );
238 CALL_SUBTEST( check_abs<long>() );
239 CALL_SUBTEST( check_abs<unsigned long>() );
240 CALL_SUBTEST( check_abs<half>() );
241 CALL_SUBTEST( check_abs<bfloat16>() );
242 CALL_SUBTEST( check_abs<float>() );
243 CALL_SUBTEST( check_abs<double>() );
244 CALL_SUBTEST( check_abs<long double>() );
245
246 CALL_SUBTEST( check_abs<std::complex<float> >() );
247 CALL_SUBTEST( check_abs<std::complex<double> >() );
248
249 CALL_SUBTEST( check_sqrt<float>() );
250 CALL_SUBTEST( check_sqrt<double>() );
251 CALL_SUBTEST( check_sqrt<std::complex<float> >() );
252 CALL_SUBTEST( check_sqrt<std::complex<double> >() );
253
254 CALL_SUBTEST( check_rsqrt<float>() );
255 CALL_SUBTEST( check_rsqrt<double>() );
256 CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
257 CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
258 }
Gael Guennebaud498aa952017-06-09 11:53:49 +0200259}