[libc++] [P0879] constexpr std::nth_element, and rewrite its tests.
This patch is more than just adding the `constexpr` keyword, because
the old code relied on `goto`, and `goto` is not constexpr-friendly.
Refactor to eliminate `goto`, and then mark it as constexpr in C++20.
I freely admit that the name `__nth_element_partloop` is bad;
I couldn't find any better name because I don't really know
what this loop is doing, conceptually. Vice versa, I think
`__nth_element_find_guard` has a decent name.
Now the only one we're still missing from P0879 is `sort`.
Differential Revision: https://reviews.llvm.org/D93557
GitOrigin-RevId: 207d4be4d9d39fbb9aca30e5d5d11245db9bccc1
diff --git a/include/algorithm b/include/algorithm
index 61fc94b..9bc31f8 100644
--- a/include/algorithm
+++ b/include/algorithm
@@ -385,11 +385,11 @@
RandomAccessIterator result_first, RandomAccessIterator result_last, Compare comp);
template <class RandomAccessIterator>
- void
+ constexpr void // constexpr in C++20
nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last);
template <class RandomAccessIterator, class Compare>
- void
+ constexpr void // constexpr in C++20
nth_element(RandomAccessIterator first, RandomAccessIterator nth, RandomAccessIterator last, Compare comp);
template <class ForwardIterator, class T>
@@ -3833,7 +3833,7 @@
// stable, 2-3 compares, 0-2 swaps
template <class _Compare, class _ForwardIterator>
-unsigned
+_LIBCPP_CONSTEXPR_AFTER_CXX11 unsigned
__sort3(_ForwardIterator __x, _ForwardIterator __y, _ForwardIterator __z, _Compare __c)
{
unsigned __r = 0;
@@ -3927,7 +3927,7 @@
// Assumes size > 0
template <class _Compare, class _BidirectionalIterator>
-void
+_LIBCPP_CONSTEXPR_AFTER_CXX11 void
__selection_sort(_BidirectionalIterator __first, _BidirectionalIterator __last, _Compare __comp)
{
_BidirectionalIterator __lm1 = __last;
@@ -5279,8 +5279,70 @@
// nth_element
+template<class _Compare, class _RandomAccessIterator>
+_LIBCPP_CONSTEXPR_AFTER_CXX11 bool
+__nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j,
+ _RandomAccessIterator& __m, _Compare __comp)
+{
+ // manually guard downward moving __j against __i
+ while (--__j != __i)
+ {
+ if (__comp(*__j, *__m))
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
+template<class _Compare, class _RandomAccessIterator>
+_LIBCPP_CONSTEXPR_AFTER_CXX11 bool
+__nth_element_partloop(_RandomAccessIterator __first, _RandomAccessIterator __last,
+ _RandomAccessIterator& __i, _RandomAccessIterator& __j,
+ unsigned& __n_swaps, _Compare __comp)
+{
+ // *__first == *__m, *__m <= all other elements
+ // Partition instead into [__first, __i) == *__first and *__first < [__i, __last)
+ ++__i; // __first + 1
+ __j = __last;
+ if (!__comp(*__first, *--__j)) // we need a guard if *__first == *(__last-1)
+ {
+ while (true)
+ {
+ if (__i == __j)
+ return true; // [__first, __last) all equivalent elements
+ if (__comp(*__first, *__i))
+ {
+ swap(*__i, *__j);
+ ++__n_swaps;
+ ++__i;
+ break;
+ }
+ ++__i;
+ }
+ }
+ // [__first, __i) == *__first and *__first < [__j, __last) and __j == __last - 1
+ if (__i == __j)
+ return true;
+ while (true)
+ {
+ while (!__comp(*__first, *__i))
+ ++__i;
+ while (__comp(*__first, *--__j))
+ ;
+ if (__i >= __j)
+ break;
+ swap(*__i, *__j);
+ ++__n_swaps;
+ ++__i;
+ }
+ // [__first, __i) == *__first and *__first < [__i, __last)
+ // The first part is sorted.
+ return false;
+}
+
template <class _Compare, class _RandomAccessIterator>
-void
+_LIBCPP_CONSTEXPR_AFTER_CXX11 void
__nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
{
// _Compare is known to be a reference type
@@ -5288,7 +5350,7 @@
const difference_type __limit = 7;
while (true)
{
- __restart:
+ // __restart: -- this is the target of a "continue" below
if (__nth == __last)
return;
difference_type __len = __last - __first;
@@ -5328,61 +5390,19 @@
if (!__comp(*__i, *__m)) // if *__first == *__m
{
// *__first == *__m, *__first doesn't go in first part
- // manually guard downward moving __j against __i
- while (true)
- {
- if (__i == --__j)
- {
- // *__first == *__m, *__m <= all other elements
- // Partition instead into [__first, __i) == *__first and *__first < [__i, __last)
- ++__i; // __first + 1
- __j = __last;
- if (!__comp(*__first, *--__j)) // we need a guard if *__first == *(__last-1)
- {
- while (true)
- {
- if (__i == __j)
- return; // [__first, __last) all equivalent elements
- if (__comp(*__first, *__i))
- {
- swap(*__i, *__j);
- ++__n_swaps;
- ++__i;
- break;
- }
- ++__i;
- }
- }
- // [__first, __i) == *__first and *__first < [__j, __last) and __j == __last - 1
- if (__i == __j)
- return;
- while (true)
- {
- while (!__comp(*__first, *__i))
- ++__i;
- while (__comp(*__first, *--__j))
- ;
- if (__i >= __j)
- break;
- swap(*__i, *__j);
- ++__n_swaps;
- ++__i;
- }
- // [__first, __i) == *__first and *__first < [__i, __last)
- // The first part is sorted,
- if (__nth < __i)
- return;
- // __nth_element the second part
- // _VSTD::__nth_element<_Compare>(__i, __nth, __last, __comp);
- __first = __i;
- goto __restart;
- }
- if (__comp(*__j, *__m))
- {
- swap(*__i, *__j);
- ++__n_swaps;
- break; // found guard for downward moving __j, now use unguarded partition
- }
+ if (_VSTD::__nth_element_find_guard<_Compare>(__i, __j, __m, __comp)) {
+ swap(*__i, *__j);
+ ++__n_swaps;
+ // found guard for downward moving __j, now use unguarded partition
+ } else if (_VSTD::__nth_element_partloop<_Compare>(__first, __last, __i, __j, __n_swaps, __comp)) {
+ return;
+ } else if (__nth < __i) {
+ return;
+ } else {
+ // __nth_element the second part
+ // _VSTD::__nth_element<_Compare>(__i, __nth, __last, __comp);
+ __first = __i;
+ continue; // i.e., goto __restart
}
}
++__i;
@@ -5426,15 +5446,16 @@
{
// Check for [__first, __i) already sorted
__j = __m = __first;
- while (++__j != __i)
+ while (true)
{
+ if (++__j == __i)
+ // [__first, __i) sorted
+ return;
if (__comp(*__j, *__m))
// not yet sorted, so sort
- goto not_sorted;
+ break;
__m = __j;
}
- // [__first, __i) sorted
- return;
}
else
{
@@ -5442,16 +5463,16 @@
__j = __m = __i;
while (++__j != __last)
{
+ if (++__j == __last)
+ // [__i, __last) sorted
+ return;
if (__comp(*__j, *__m))
// not yet sorted, so sort
- goto not_sorted;
+ break;
__m = __j;
}
- // [__i, __last) sorted
- return;
}
}
-not_sorted:
// __nth_element on range containing __nth
if (__nth < __i)
{
@@ -5467,7 +5488,7 @@
}
template <class _RandomAccessIterator, class _Compare>
-inline _LIBCPP_INLINE_VISIBILITY
+inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX17
void
nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
{
@@ -5476,7 +5497,7 @@
}
template <class _RandomAccessIterator>
-inline _LIBCPP_INLINE_VISIBILITY
+inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_AFTER_CXX17
void
nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last)
{