Vectorize pow for integer base / exponent types
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp
index 3e96e4c..64c4c2a 100644
--- a/test/array_cwise.cpp
+++ b/test/array_cwise.cpp
@@ -126,6 +126,72 @@
VERIFY(all_pass);
}
+template <typename Scalar, typename ScalarExponent>
+Scalar calc_overflow_threshold(const ScalarExponent exponent) {
+ EIGEN_USING_STD(exp2);
+ EIGEN_STATIC_ASSERT((NumTraits<Scalar>::digits() < 2 * NumTraits<double>::digits()), BASE_TYPE_IS_TOO_BIG);
+
+ if (exponent < 2)
+ return NumTraits<Scalar>::highest();
+ else {
+ const double max_exponent = static_cast<double>(NumTraits<Scalar>::digits());
+ const double clamped_exponent = exponent < max_exponent ? static_cast<double>(exponent) : max_exponent;
+ const double threshold = exp2(max_exponent / clamped_exponent);
+ return static_cast<Scalar>(threshold);
+ }
+}
+
+template <typename Base, typename Exponent>
+void test_exponent(Exponent exponent) {
+ EIGEN_USING_STD(pow);
+
+ const Base max_abs_bases = 10000;
+ // avoid integer overflow in Base type
+ Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent));
+ // avoid numbers that can't be verified with std::pow
+ double double_threshold = calc_overflow_threshold<double, Exponent>(numext::abs(exponent));
+ // use the lesser of these two thresholds
+ Base testing_threshold = threshold < double_threshold ? threshold : static_cast<Base>(double_threshold);
+ // avoid divide by zero
+ Base min_abs_base = exponent < 0 ? 1 : 0;
+ // avoid excessively long test
+ Base max_abs_base = numext::mini(testing_threshold, max_abs_bases);
+ // test both vectorized and non-vectorized code paths
+ const Index array_size = 2 * internal::packet_traits<Base>::size + 1;
+
+ Base max_base = numext::mini(testing_threshold, max_abs_bases);
+ Base min_base = NumTraits<Base>::IsSigned ? -max_base : 0;
+
+ ArrayX<Base> x(array_size), y(array_size);
+
+ bool all_pass = true;
+
+ for (Base base = min_base; base <= max_base; base++) {
+ if (exponent < 0 && base == 0) continue;
+ x.setConstant(base);
+ y = x.pow(exponent);
+ Base e = pow(base, exponent);
+ for (Base a : y) {
+ bool pass = a == e;
+ all_pass &= pass;
+ if (!pass) {
+ std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl;
+ }
+ }
+ }
+
+ VERIFY(all_pass);
+}
+template <typename Base, typename Exponent>
+void int_pow_test() {
+ Exponent max_exponent = NumTraits<Base>::digits();
+ Exponent min_exponent = NumTraits<Exponent>::IsSigned ? -max_exponent : 0;
+
+ for (Exponent exponent = min_exponent; exponent < max_exponent; exponent++) {
+ test_exponent<Base, Exponent>(exponent);
+ }
+}
+
template<typename ArrayType> void array(const ArrayType& m)
{
typedef typename ArrayType::Scalar Scalar;
@@ -500,6 +566,14 @@
VERIFY_IS_APPROX(m3.pow(RealScalar(-2)), m3.square().inverse());
pow_test<Scalar>();
+ typedef typename internal::make_integer<Scalar>::type SignedInt;
+ typedef typename std::make_unsigned<SignedInt>::type UnsignedInt;
+
+ int_pow_test<SignedInt, SignedInt>();
+ int_pow_test<SignedInt, UnsignedInt>();
+ int_pow_test<UnsignedInt, SignedInt>();
+ int_pow_test<UnsignedInt, UnsignedInt>();
+
VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10)));
VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2)));