Add signbit function
diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h
index b67c4ed..af773dd 100644
--- a/Eigen/src/Core/GenericPacketMath.h
+++ b/Eigen/src/Core/GenericPacketMath.h
@@ -563,13 +563,13 @@
parg(const Packet& a) { using numext::arg; return arg(a); }
-/** \internal \returns \a a logically shifted by N bits to the right */
+/** \internal \returns \a a arithmetically shifted by N bits to the right */
template<int N> EIGEN_DEVICE_FUNC inline int
parithmetic_shift_right(const int& a) { return a >> N; }
template<int N> EIGEN_DEVICE_FUNC inline long int
parithmetic_shift_right(const long int& a) { return a >> N; }
-/** \internal \returns \a a arithmetically shifted by N bits to the right */
+/** \internal \returns \a a logically shifted by N bits to the right */
template<int N> EIGEN_DEVICE_FUNC inline int
plogical_shift_right(const int& a) { return static_cast<int>(static_cast<unsigned int>(a) >> N); }
template<int N> EIGEN_DEVICE_FUNC inline long int
@@ -1191,6 +1191,34 @@
return preciprocal<Packet>(psqrt(a));
}
+template <typename Packet, bool IsScalar = is_scalar<Packet>::value,
+ bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
+ struct psignbit_impl;
+template <typename Packet, bool IsInteger>
+struct psignbit_impl<Packet, true, IsInteger> {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return numext::signbit(a); }
+};
+template <typename Packet>
+struct psignbit_impl<Packet, false, false> {
+ // generic implementation if not specialized in PacketMath.h
+ // slower than arithmetic shift
+ typedef typename unpacket_traits<Packet>::type Scalar;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static Packet run(const Packet& a) {
+ const Packet cst_pos_one = pset1<Packet>(Scalar(1));
+ const Packet cst_neg_one = pset1<Packet>(Scalar(-1));
+ return pcmp_eq(por(pand(a, cst_neg_one), cst_pos_one), cst_neg_one);
+ }
+};
+template <typename Packet>
+struct psignbit_impl<Packet, false, true> {
+ // generic implementation for integer packets
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return pcmp_lt(a, pzero(a)); }
+};
+/** \internal \returns the sign bit of \a a as a bitmask*/
+template <typename Packet>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE constexpr Packet
+psignbit(const Packet& a) { return psignbit_impl<Packet>::run(a); }
+
} // end namespace internal
} // end namespace Eigen
diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h
index 0eee333..b194353 100644
--- a/Eigen/src/Core/MathFunctions.h
+++ b/Eigen/src/Core/MathFunctions.h
@@ -1531,6 +1531,37 @@
}
#endif
+template <typename Scalar, bool IsInteger = NumTraits<Scalar>::IsInteger, bool IsSigned = NumTraits<Scalar>::IsSigned>
+struct signbit_impl;
+template <typename Scalar>
+struct signbit_impl<Scalar, false, true> {
+ static constexpr size_t Size = sizeof(Scalar);
+ static constexpr size_t Shift = (CHAR_BIT * Size) - 1;
+ using intSize_t = typename get_integer_by_size<Size>::signed_type;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static Scalar run(const Scalar& x) {
+ intSize_t a = bit_cast<intSize_t, Scalar>(x);
+ a = a >> Shift;
+ Scalar result = bit_cast<Scalar, intSize_t>(a);
+ return result;
+ }
+};
+template <typename Scalar>
+struct signbit_impl<Scalar, true, true> {
+ static constexpr size_t Size = sizeof(Scalar);
+ static constexpr size_t Shift = (CHAR_BIT * Size) - 1;
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar run(const Scalar& x) { return x >> Shift; }
+};
+template <typename Scalar>
+struct signbit_impl<Scalar, true, false> {
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar run(const Scalar& ) {
+ return Scalar(0);
+ }
+};
+template <typename Scalar>
+EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Scalar signbit(const Scalar& x) {
+ return signbit_impl<Scalar>::run(x);
+}
+
template<typename T>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
T exp(const T &x) {
diff --git a/Eigen/src/Core/NumTraits.h b/Eigen/src/Core/NumTraits.h
index 4f1f992..53362ef 100644
--- a/Eigen/src/Core/NumTraits.h
+++ b/Eigen/src/Core/NumTraits.h
@@ -95,7 +95,7 @@
// Load src into registers first. This allows the memcpy to be elided by CUDA.
const Src staged = src;
EIGEN_USING_STD(memcpy)
- memcpy(&tgt, &staged, sizeof(Tgt));
+ memcpy(static_cast<void*>(&tgt),static_cast<const void*>(&staged), sizeof(Tgt));
return tgt;
}
} // namespace numext
diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h
index ecbb73c..33a4dee 100644
--- a/Eigen/src/Core/arch/AVX/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX/PacketMath.h
@@ -229,10 +229,7 @@
Vectorizable = 1,
AlignedOnScalar = 1,
HasCmp = 1,
- size=4,
-
- // requires AVX512
- HasShift = 0,
+ size=4
};
};
#endif
@@ -360,6 +357,35 @@
EIGEN_STRONG_INLINE Packet4l plogical_shift_left(Packet4l a) {
return _mm256_slli_epi64(a, N);
}
+#ifdef EIGEN_VECTORIZE_AVX512FP16
+template <int N>
+EIGEN_STRONG_INLINE Packet4l parithmetic_shift_right(Packet4l a) { return _mm256_srai_epi64(a, N); }
+#else
+template <int N>
+EIGEN_STRONG_INLINE std::enable_if_t< (N == 0), Packet4l> parithmetic_shift_right(Packet4l a) {
+ return a;
+}
+template <int N>
+EIGEN_STRONG_INLINE std::enable_if_t< (N > 0) && (N < 32), Packet4l> parithmetic_shift_right(Packet4l a) {
+ __m256i hi_word = _mm256_srai_epi32(a, N);
+ __m256i lo_word = _mm256_srli_epi64(a, N);
+ return _mm256_blend_epi32(hi_word, lo_word, 0b01010101);
+}
+template <int N>
+EIGEN_STRONG_INLINE std::enable_if_t< (N >= 32) && (N < 63), Packet4l> parithmetic_shift_right(Packet4l a) {
+ __m256i hi_word = _mm256_srai_epi32(a, 31);
+ __m256i lo_word = _mm256_shuffle_epi32(_mm256_srai_epi32(a, N - 32), (shuffle_mask<1, 1, 3, 3>::mask));
+ return _mm256_blend_epi32(hi_word, lo_word, 0b01010101);
+}
+template <int N>
+EIGEN_STRONG_INLINE std::enable_if_t< (N == 63), Packet4l> parithmetic_shift_right(Packet4l a) {
+ return _mm256_shuffle_epi32(_mm256_srai_epi32(a, 31), (shuffle_mask<1, 1, 3, 3>::mask));
+}
+template <int N>
+EIGEN_STRONG_INLINE std::enable_if_t< (N < 0) || (N > 63), Packet4l> parithmetic_shift_right(Packet4l a) {
+ return parithmetic_shift_right<int(N&63)>(a);
+}
+#endif
template <>
EIGEN_STRONG_INLINE Packet4l pload<Packet4l>(const int64_t* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
@@ -1103,6 +1129,11 @@
#endif
}
+template<> EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { return _mm_srai_epi16(a, 15); }
+template<> EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) { return _mm_srai_epi16(a, 15); }
+template<> EIGEN_STRONG_INLINE Packet8f psignbit(const Packet8f& a) { return _mm256_castsi256_ps(parithmetic_shift_right<31>((Packet8i)_mm256_castps_si256(a))); }
+template<> EIGEN_STRONG_INLINE Packet4d psignbit(const Packet4d& a) { return _mm256_castsi256_pd(parithmetic_shift_right<63>((Packet4l)_mm256_castpd_si256(a))); }
+
template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) {
return pfrexp_generic(a,exponent);
}
diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h
index 5f37740..c210f2f 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMath.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMath.h
@@ -1127,6 +1127,11 @@
return _mm512_abs_epi32(a);
}
+template<> EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) { return _mm256_srai_epi16(a, 15); }
+template<> EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) { return _mm256_srai_epi16(a, 15); }
+template<> EIGEN_STRONG_INLINE Packet16f psignbit(const Packet16f& a) { return _mm512_castsi512_ps(_mm512_srai_epi32(_mm512_castps_si512(a), 31)); }
+template<> EIGEN_STRONG_INLINE Packet8d psignbit(const Packet8d& a) { return _mm512_castsi512_pd(_mm512_srai_epi64(_mm512_castpd_si512(a), 63)); }
+
template<>
EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){
return pfrexp_generic(a, exponent);
diff --git a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
index 58621d9..13f285e 100644
--- a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
+++ b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h
@@ -196,6 +196,13 @@
return _mm512_abs_ph(a);
}
+// psignbit
+
+template <>
+EIGEN_STRONG_INLINE Packet32h psignbit<Packet32h>(const Packet32h& a) {
+ return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15));
+}
+
// pmin
template <>
diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h
index d9ddb5e..d30ead4 100644
--- a/Eigen/src/Core/arch/AltiVec/PacketMath.h
+++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h
@@ -1575,6 +1575,9 @@
return pand<Packet8us>(p8us_abs_mask, a);
}
+template<> EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) { return vec_sra(a.m_val, vec_splat_u16(15)); }
+template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { return (Packet4f)vec_sra((Packet4i)a, vec_splats(uint32_t(31))); }
+
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a)
{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a)
@@ -2928,7 +2931,7 @@
return vec_sld(a, a, 8);
}
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); }
-
+template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { return (Packet2d)vec_sra((Packet2l)a, vec_splats(uint64_t(63))); }
// VSX support varies between different compilers and even different
// versions of the same compiler. For gcc version >= 4.9.3, we can use
// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 5cbf4ac..067b725 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -2372,6 +2372,12 @@
}
template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; }
+template<> EIGEN_STRONG_INLINE Packet4h psignbit(const Packet4h& a) { vreinterpret_f16_s16( vshr_n_s16( vreinterpret_s16_f16(a), 15)); }
+template<> EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { vreinterpretq_f16_s16(vshrq_n_s16(vreinterpretq_s16_f16(a), 15)); }
+template<> EIGEN_STRONG_INLINE Packet2f psignbit(const Packet2f& a) { vreinterpret_f32_s32( vshr_n_s32( vreinterpret_s32_f32(a), 31)); }
+template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { vreinterpretq_f32_s32(vshrq_n_s32(vreinterpretq_s32_f32(a), 31)); }
+template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { vreinterpretq_f64_s64(vshrq_n_s64(vreinterpretq_s64_f64(a), 63)); }
+
template<> EIGEN_STRONG_INLINE Packet2f pfrexp<Packet2f>(const Packet2f& a, Packet2f& exponent)
{ return pfrexp_generic(a,exponent); }
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent)
diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h
index 847ff07..a0ff359 100644
--- a/Eigen/src/Core/arch/SSE/PacketMath.h
+++ b/Eigen/src/Core/arch/SSE/PacketMath.h
@@ -649,6 +649,17 @@
#endif
}
+template<> EIGEN_STRONG_INLINE Packet4f psignbit(const Packet4f& a) { return _mm_castsi128_ps(_mm_srai_epi32(_mm_castps_si128(a), 31)); }
+template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a)
+{
+ Packet4f tmp = psignbit<Packet4f>(_mm_castpd_ps(a));
+#ifdef EIGEN_VECTORIZE_AVX
+ return _mm_castps_pd(_mm_permute_ps(tmp, (shuffle_mask<1, 1, 3, 3>::mask)));
+#else
+ return _mm_castps_pd(_mm_shuffle_ps(tmp, tmp, (shuffle_mask<1, 1, 3, 3>::mask)));
+#endif // EIGEN_VECTORIZE_AVX
+}
+
#ifdef EIGEN_VECTORIZE_SSE4_1
template<> EIGEN_STRONG_INLINE Packet4f pround<Packet4f>(const Packet4f& a)
{
diff --git a/Eigen/src/Core/util/Meta.h b/Eigen/src/Core/util/Meta.h
index 32152ac..6c6fb71 100644
--- a/Eigen/src/Core/util/Meta.h
+++ b/Eigen/src/Core/util/Meta.h
@@ -43,6 +43,32 @@
typedef std::int32_t int32_t;
typedef std::uint64_t uint64_t;
typedef std::int64_t int64_t;
+
+template <size_t Size>
+struct get_integer_by_size {
+ typedef void signed_type;
+ typedef void unsigned_type;
+};
+template <>
+struct get_integer_by_size<1> {
+ typedef int8_t signed_type;
+ typedef uint8_t unsigned_type;
+};
+template <>
+struct get_integer_by_size<2> {
+ typedef int16_t signed_type;
+ typedef uint16_t unsigned_type;
+};
+template <>
+struct get_integer_by_size<4> {
+ typedef int32_t signed_type;
+ typedef uint32_t unsigned_type;
+};
+template <>
+struct get_integer_by_size<8> {
+ typedef int64_t signed_type;
+ typedef uint64_t unsigned_type;
+};
}
}
diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp
index a7e0ff4..94c9451 100644
--- a/test/array_cwise.cpp
+++ b/test/array_cwise.cpp
@@ -219,7 +219,7 @@
for (Exponent exponent = min_exponent; exponent < max_exponent; ++exponent) {
test_exponent<Base, Exponent>(exponent);
}
-};
+}
void mixed_pow_test() {
// The following cases will test promoting a smaller exponent type
@@ -260,6 +260,81 @@
unary_pow_test<long long, int>();
}
+namespace Eigen {
+namespace internal {
+template <typename Scalar>
+struct test_signbit_op {
+ Scalar constexpr operator()(const Scalar& a) const { return numext::signbit(a); }
+ template <typename Packet>
+ inline Packet packetOp(const Packet& a) const {
+ return psignbit(a);
+ }
+};
+template <typename Scalar>
+struct functor_traits<test_signbit_op<Scalar>> {
+ enum { Cost = 1, PacketAccess = true }; //todo: define HasSignbit flag
+};
+} // namespace internal
+} // namespace Eigen
+
+template <typename T, bool IsInteger = NumTraits<T>::IsInteger>
+struct ref_signbit_func_impl {
+ static bool run(const T& x) { return std::signbit(x); }
+};
+template <typename T>
+struct ref_signbit_func_impl<T, true> {
+ // MSVC (perhaps others) does not have a std::signbit overload for integers
+ static bool run(const T& x) { return x < T(0); }
+};
+template <typename T>
+bool ref_signbit_func(const T& x) {
+ return ref_signbit_func_impl<T>::run(x);
+}
+
+template <typename Scalar>
+void signbit_test() {
+ Scalar true_mask;
+ std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(Scalar));
+ Scalar false_mask;
+ std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(Scalar));
+
+ const size_t size = 100 * internal::packet_traits<Scalar>::size;
+ ArrayX<Scalar> x(size), y(size);
+ x.setRandom();
+ std::vector<Scalar> special_vals = special_values<Scalar>();
+ for (size_t i = 0; i < special_vals.size(); i++) {
+ x(2 * i + 0) = special_vals[i];
+ x(2 * i + 1) = -special_vals[i];
+ }
+ y = x.unaryExpr(internal::test_signbit_op<Scalar>());
+
+ bool all_pass = true;
+ for (size_t i = 0; i < size; i++) {
+ const Scalar ref_val = ref_signbit_func(x(i)) ? true_mask : false_mask;
+ bool not_same = internal::predux_any(internal::bitwise_helper<Scalar>::bitwise_xor(ref_val, y(i)));
+ if (not_same) std::cout << "signbit(" << x(i) << ") != " << y(i) << "\n";
+ all_pass = all_pass && !not_same;
+ }
+
+ VERIFY(all_pass);
+}
+void signbit_tests() {
+ signbit_test<float>();
+ signbit_test<double>();
+ signbit_test<Eigen::half>();
+ signbit_test<Eigen::bfloat16>();
+
+ signbit_test<uint8_t>();
+ signbit_test<uint16_t>();
+ signbit_test<uint32_t>();
+ signbit_test<uint64_t>();
+
+ signbit_test<int8_t>();
+ signbit_test<int16_t>();
+ signbit_test<int32_t>();
+ signbit_test<int64_t>();
+}
+
template<typename ArrayType> void array(const ArrayType& m)
{
typedef typename ArrayType::Scalar Scalar;
@@ -855,6 +930,35 @@
VERIFY( (m2 == m1.unaryExpr(arithmetic_shift_right<9>())).all() );
}
+template <typename ArrayType>
+struct signed_shift_test_impl {
+ typedef typename ArrayType::Scalar Scalar;
+ static constexpr size_t Size = sizeof(Scalar);
+ static constexpr size_t MaxShift = (CHAR_BIT * Size) - 1;
+
+ template <size_t N = 0>
+ static inline std::enable_if_t<(N > MaxShift), void> run(const ArrayType& ) {}
+ template <size_t N = 0>
+ static inline std::enable_if_t<(N <= MaxShift), void> run(const ArrayType& m) {
+ const Index rows = m.rows();
+ const Index cols = m.cols();
+
+ ArrayType m1 = ArrayType::Random(rows, cols), m2(rows, cols);
+
+ m2 = m1.unaryExpr([](const Scalar& x) { return x >> N; });
+ VERIFY((m2 == m1.unaryExpr(internal::scalar_shift_right_op<Scalar, N>())).all());
+
+ m2 = m1.unaryExpr([](const Scalar& x) { return x << N; });
+ VERIFY((m2 == m1.unaryExpr( internal::scalar_shift_left_op<Scalar, N>())).all());
+
+ run<N + 1>(m);
+ }
+};
+template <typename ArrayType>
+void signed_shift_test(const ArrayType& m) {
+ signed_shift_test_impl<ArrayType>::run(m);
+}
+
EIGEN_DECLARE_TEST(array_cwise)
{
for(int i = 0; i < g_repeat; i++) {
@@ -867,6 +971,9 @@
CALL_SUBTEST_6( array(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_6( array_integer(ArrayXXi(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_6( array_integer(Array<Index,Dynamic,Dynamic>(internal::random<int>(1,EIGEN_TEST_MAX_SIZE), internal::random<int>(1,EIGEN_TEST_MAX_SIZE))) );
+ CALL_SUBTEST_7( signed_shift_test(ArrayXXi(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
+ CALL_SUBTEST_7( signed_shift_test(Array<Index, Dynamic, Dynamic>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE), internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));
+
}
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1( comparisons(Array<float, 1, 1>()) );
@@ -897,6 +1004,7 @@
for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_6( int_pow_test() );
CALL_SUBTEST_7( mixed_pow_test() );
+ CALL_SUBTEST_8( signbit_tests() );
}
VERIFY((internal::is_same< internal::global_math_functions_filtering_base<int>::type, int >::value));
diff --git a/test/numext.cpp b/test/numext.cpp
index ee879c9..5483e5c 100644
--- a/test/numext.cpp
+++ b/test/numext.cpp
@@ -239,6 +239,58 @@
check_rsqrt_impl<T>::run();
}
+template <typename T, bool IsInteger = NumTraits<T>::IsInteger>
+struct ref_signbit_func_impl {
+ static bool run(const T& x) { return std::signbit(x); }
+};
+template <typename T>
+struct ref_signbit_func_impl<T, true> {
+ // MSVC (perhaps others) does not have a std::signbit overload for integers
+ static bool run(const T& x) { return x < T(0); }
+};
+template <typename T>
+bool ref_signbit_func(const T& x) {
+ return ref_signbit_func_impl<T>::run(x);
+}
+
+template <typename T>
+struct check_signbit_impl {
+ static void run() {
+ T true_mask;
+ std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(T));
+ T false_mask;
+ std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(T));
+
+ // has sign bit
+ const T neg_zero = static_cast<T>(-0.0);
+ const T neg_one = static_cast<T>(-1.0);
+ const T neg_inf = -std::numeric_limits<T>::infinity();
+ const T neg_nan = -std::numeric_limits<T>::quiet_NaN();
+ // does not have sign bit
+ const T pos_zero = static_cast<T>(0.0);
+ const T pos_one = static_cast<T>(1.0);
+ const T pos_inf = std::numeric_limits<T>::infinity();
+ const T pos_nan = std::numeric_limits<T>::quiet_NaN();
+
+ std::vector<T> values = {neg_zero, neg_one, neg_inf, neg_nan, pos_zero, pos_one, pos_inf, pos_nan};
+
+ bool all_pass = true;
+
+ for (T val : values) {
+ const T numext_val = numext::signbit(val);
+ const T ref_val = ref_signbit_func(val) ? true_mask : false_mask;
+ bool not_same = internal::predux_any(internal::bitwise_helper<T>::bitwise_xor(ref_val, numext_val));
+ all_pass = all_pass && !not_same;
+ if (not_same) std::cout << "signbit(" << val << ") != " << numext_val << "\n";
+ }
+ VERIFY(all_pass);
+ }
+};
+template <typename T>
+void check_signbit() {
+ check_signbit_impl<T>::run();
+}
+
EIGEN_DECLARE_TEST(numext) {
for(int k=0; k<g_repeat; ++k)
{
@@ -271,5 +323,20 @@
CALL_SUBTEST( check_rsqrt<double>() );
CALL_SUBTEST( check_rsqrt<std::complex<float> >() );
CALL_SUBTEST( check_rsqrt<std::complex<double> >() );
+
+ CALL_SUBTEST( check_signbit<half>());
+ CALL_SUBTEST( check_signbit<bfloat16>());
+ CALL_SUBTEST( check_signbit<float>());
+ CALL_SUBTEST( check_signbit<double>());
+
+ CALL_SUBTEST( check_signbit<uint8_t>());
+ CALL_SUBTEST( check_signbit<uint16_t>());
+ CALL_SUBTEST( check_signbit<uint32_t>());
+ CALL_SUBTEST( check_signbit<uint64_t>());
+
+ CALL_SUBTEST( check_signbit<int8_t>());
+ CALL_SUBTEST( check_signbit<int16_t>());
+ CALL_SUBTEST( check_signbit<int32_t>());
+ CALL_SUBTEST( check_signbit<int64_t>());
}
}