Allow mixed types for pow(), as long as the exponent is exactly representable in the base type.
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp
index 1d4f7b5..5a4c282 100644
--- a/test/array_cwise.cpp
+++ b/test/array_cwise.cpp
@@ -138,7 +138,7 @@
// base^e <= highest ==> base <= 2^(log2(highest)/e)
// For floating-point types, consider the bound for integer values that can be reproduced exactly = 2 ^ digits
double highest_bits = numext::mini(static_cast<double>(NumTraits<Scalar>::digits()),
- log2(NumTraits<Scalar>::highest()));
+ static_cast<double>(log2(NumTraits<Scalar>::highest())));
return static_cast<Scalar>(
numext::floor(exp2(highest_bits / static_cast<double>(exponent))));
}
@@ -146,49 +146,90 @@
template <typename Base, typename Exponent>
void test_exponent(Exponent exponent) {
+ const Base max_abs_bases = static_cast<Base>(10000);
+ // avoid integer overflow in Base type
+ Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent));
+ // avoid numbers that can't be verified with std::pow
+ double double_threshold = calc_overflow_threshold<double, Exponent>(numext::abs(exponent));
+ // use the lesser of these two thresholds
+ Base testing_threshold =
+ static_cast<double>(threshold) < double_threshold ? threshold : static_cast<Base>(double_threshold);
+ // test both vectorized and non-vectorized code paths
+ const Index array_size = 2 * internal::packet_traits<Base>::size + 1;
+
+ Base max_base = numext::mini(testing_threshold, max_abs_bases);
+ Base min_base = NumTraits<Base>::IsSigned ? -max_base : Base(0);
+
+ ArrayX<Base> x(array_size), y(array_size);
+ bool all_pass = true;
+ for (Base base = min_base; base <= max_base; base++) {
+ if (exponent < 0 && base == 0) continue;
+ x.setConstant(base);
+ y = x.pow(exponent);
EIGEN_USING_STD(pow);
-
- const Base max_abs_bases = 10000;
- // avoid integer overflow in Base type
- Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent));
- // avoid numbers that can't be verified with std::pow
- double double_threshold = calc_overflow_threshold<double, Exponent>(numext::abs(exponent));
- // use the lesser of these two thresholds
- Base testing_threshold = threshold < double_threshold ? threshold : static_cast<Base>(double_threshold);
- // test both vectorized and non-vectorized code paths
- const Index array_size = 2 * internal::packet_traits<Base>::size + 1;
-
- Base max_base = numext::mini(testing_threshold, max_abs_bases);
- Base min_base = NumTraits<Base>::IsSigned ? -max_base : 0;
-
- ArrayX<Base> x(array_size), y(array_size);
-
- bool all_pass = true;
-
- for (Base base = min_base; base <= max_base; base++) {
- if (exponent < 0 && base == 0) continue;
- x.setConstant(base);
- y = x.pow(exponent);
- Base e = pow(base, exponent);
- for (Base a : y) {
- bool pass = a == e;
- all_pass &= pass;
- if (!pass) {
- std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl;
- }
- }
+ Base e = pow(base, static_cast<Base>(exponent));
+ for (Base a : y) {
+ bool pass = (a == e);
+ if (!NumTraits<Base>::IsInteger) {
+ pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) ||
+ ((numext::isnan)(a) && (numext::isnan)(e)));
+ }
+ all_pass &= pass;
+ if (!pass) {
+ std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl;
+ }
}
-
- VERIFY(all_pass);
+ }
+ VERIFY(all_pass);
}
-template <typename Base, typename Exponent>
-void int_pow_test() {
- Exponent max_exponent = NumTraits<Base>::digits();
- Exponent min_exponent = NumTraits<Exponent>::IsSigned ? -max_exponent : 0;
- for (Exponent exponent = min_exponent; exponent < max_exponent; exponent++) {
- test_exponent<Base, Exponent>(exponent);
- }
+template <typename Base, typename Exponent>
+void unary_pow_test() {
+ Exponent max_exponent = static_cast<Exponent>(NumTraits<Base>::digits());
+ Exponent min_exponent = static_cast<Exponent>(NumTraits<Exponent>::IsSigned ? -max_exponent : 0);
+
+ for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) {
+ test_exponent<Base, Exponent>(exponent);
+ }
+};
+
+void mixed_pow_test() {
+ // The following cases will test promoting a smaller exponent type
+ // to a wider base type.
+ unary_pow_test<double, int>();
+ unary_pow_test<double, float>();
+ unary_pow_test<float, half>();
+ unary_pow_test<double, half>();
+ unary_pow_test<float, bfloat16>();
+ unary_pow_test<double, bfloat16>();
+
+ // Although in the following cases the exponent cannot be represented exactly
+ // in the base type, we do not perform a conversion, but implement
+ // the operation using repeated squaring.
+ unary_pow_test<float, int>();
+ unary_pow_test<double, long long>();
+
+ // The following cases will test promoting a wider exponent type
+ // to a narrower base type. This should compile but generate a
+ // deprecation warning:
+ unary_pow_test<float, double>();
+}
+
+void int_pow_test() {
+ unary_pow_test<int, int>();
+ unary_pow_test<unsigned int, unsigned int>();
+ unary_pow_test<long long, long long>();
+ unary_pow_test<unsigned long long, unsigned long long>();
+
+ // Although in the following cases the exponent cannot be represented exactly
+ // in the base type, we do not perform a conversion, but implement the
+ // operation using repeated squaring.
+ unary_pow_test<long long, int>();
+ unary_pow_test<int, unsigned int>();
+ unary_pow_test<unsigned int, int>();
+ unary_pow_test<long long, unsigned long long>();
+ unary_pow_test<unsigned long long, long long>();
+ unary_pow_test<long long, int>();
}
template<typename ArrayType> void array(const ArrayType& m)
@@ -207,7 +248,7 @@
// Here we cap the size of the values in m1 such that pow(3)/cube()
// doesn't overflow and result in undefined behavior. Notice that because
// pow(int, int) promotes its inputs and output to double (according to
- // the C++ standard), we hvae to make sure that the result fits in 53 bits
+ // the C++ standard), we have to make sure that the result fits in 53 bits
// for int64,
RealScalar max_val =
numext::mini(RealScalar(std::cbrt(NumTraits<RealScalar>::highest())),
@@ -565,14 +606,6 @@
VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse());
pow_test<Scalar>();
- typedef typename internal::make_integer<Scalar>::type SignedInt;
- typedef typename std::make_unsigned<SignedInt>::type UnsignedInt;
-
- int_pow_test<SignedInt, SignedInt>();
- int_pow_test<SignedInt, UnsignedInt>();
- int_pow_test<UnsignedInt, SignedInt>();
- int_pow_test<UnsignedInt, UnsignedInt>();
-
VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10)));
VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2)));
@@ -590,6 +623,7 @@
VERIFY_IS_APPROX(m3, m1);
}
+
template<typename ArrayType> void array_complex(const ArrayType& m)
{
typedef typename ArrayType::Scalar Scalar;
@@ -823,6 +857,11 @@
CALL_SUBTEST_4( array_complex(ArrayXXcf(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
}
+ for(int i = 0; i < g_repeat; i++) {
+ CALL_SUBTEST_6( int_pow_test() );
+ CALL_SUBTEST_7( mixed_pow_test() );
+ }
+
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<float>::type, float >::value));
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<Array2i>::type, ArrayBase<Array2i> >::value));