[libc++] Enable segmented iterator optimizations for join_view::iterator

Reviewed By: ldionne, #libc

Spies: libcxx-commits

Differential Revision: https://reviews.llvm.org/D138413

NOKEYCHECK=True
GitOrigin-RevId: 21f4232dd963c449231f03a90836071202fd134a
diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt
index 74d1f96..44e129a 100644
--- a/include/CMakeLists.txt
+++ b/include/CMakeLists.txt
@@ -392,6 +392,7 @@
   __iterator/iter_swap.h
   __iterator/iterator.h
   __iterator/iterator_traits.h
+  __iterator/iterator_with_data.h
   __iterator/mergeable.h
   __iterator/move_iterator.h
   __iterator/move_sentinel.h
diff --git a/include/__iterator/iterator_with_data.h b/include/__iterator/iterator_with_data.h
new file mode 100644
index 0000000..06c2fa6
--- /dev/null
+++ b/include/__iterator/iterator_with_data.h
@@ -0,0 +1,100 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef _LIBCPP___ITERATOR_ITERATOR_WITH_DATA_H
+#define _LIBCPP___ITERATOR_ITERATOR_WITH_DATA_H
+
+#include <__compare/compare_three_way_result.h>
+#include <__compare/three_way_comparable.h>
+#include <__config>
+#include <__iterator/concepts.h>
+#include <__iterator/incrementable_traits.h>
+#include <__iterator/iter_move.h>
+#include <__iterator/iter_swap.h>
+#include <__iterator/iterator_traits.h>
+#include <__iterator/readable_traits.h>
+#include <__utility/move.h>
+
+#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
+#  pragma GCC system_header
+#endif
+
+#if _LIBCPP_STD_VER >= 20
+
+_LIBCPP_BEGIN_NAMESPACE_STD
+
+template <forward_iterator _Iterator, class _Data>
+class __iterator_with_data {
+  _Iterator __iter_{};
+  _Data __data_{};
+
+public:
+  using value_type      = iter_value_t<_Iterator>;
+  using difference_type = iter_difference_t<_Iterator>;
+
+  _LIBCPP_HIDE_FROM_ABI __iterator_with_data() = default;
+
+  constexpr _LIBCPP_HIDE_FROM_ABI __iterator_with_data(_Iterator __iter, _Data __data)
+      : __iter_(std::move(__iter)), __data_(std::move(__data)) {}
+
+  constexpr _LIBCPP_HIDE_FROM_ABI _Iterator __get_iter() const { return __iter_; }
+
+  constexpr _LIBCPP_HIDE_FROM_ABI _Data __get_data() && { return std::move(__data_); }
+
+  friend constexpr _LIBCPP_HIDE_FROM_ABI bool
+  operator==(const __iterator_with_data& __lhs, const __iterator_with_data& __rhs) {
+    return __lhs.__iter_ == __rhs.__iter_;
+  }
+
+  constexpr _LIBCPP_HIDE_FROM_ABI __iterator_with_data& operator++() {
+    ++__iter_;
+    return *this;
+  }
+
+  constexpr _LIBCPP_HIDE_FROM_ABI __iterator_with_data operator++(int) {
+    auto __tmp = *this;
+    __iter_++;
+    return __tmp;
+  }
+
+  constexpr _LIBCPP_HIDE_FROM_ABI __iterator_with_data& operator--()
+    requires bidirectional_iterator<_Iterator>
+  {
+    --__iter_;
+    return *this;
+  }
+
+  constexpr _LIBCPP_HIDE_FROM_ABI __iterator_with_data operator--(int)
+    requires bidirectional_iterator<_Iterator>
+  {
+    auto __tmp = *this;
+    --__iter_;
+    return __tmp;
+  }
+
+  constexpr _LIBCPP_HIDE_FROM_ABI iter_reference_t<_Iterator> operator*() const { return *__iter_; }
+
+  _LIBCPP_HIDE_FROM_ABI friend constexpr iter_rvalue_reference_t<_Iterator>
+  iter_move(const __iterator_with_data& __iter) noexcept(noexcept(ranges::iter_move(__iter.__iter_))) {
+    return ranges::iter_move(__iter.__iter_);
+  }
+
+  _LIBCPP_HIDE_FROM_ABI friend constexpr void
+  iter_swap(const __iterator_with_data& __lhs,
+            const __iterator_with_data& __rhs) noexcept(noexcept(ranges::iter_swap(__lhs.__iter_, __rhs.__iter_)))
+    requires indirectly_swappable<_Iterator>
+  {
+    return ranges::iter_swap(__lhs.__data_, __rhs.__iter_);
+  }
+};
+
+_LIBCPP_END_NAMESPACE_STD
+
+#endif // _LIBCPP_STD_VER >= 20
+
+#endif // _LIBCPP___ITERATOR_ITERATOR_WITH_DATA_H
diff --git a/include/__ranges/join_view.h b/include/__ranges/join_view.h
index 293926c..869540f 100644
--- a/include/__ranges/join_view.h
+++ b/include/__ranges/join_view.h
@@ -20,9 +20,12 @@
 #include <__iterator/iter_move.h>
 #include <__iterator/iter_swap.h>
 #include <__iterator/iterator_traits.h>
+#include <__iterator/iterator_with_data.h>
+#include <__iterator/segmented_iterator.h>
 #include <__ranges/access.h>
 #include <__ranges/all.h>
 #include <__ranges/concepts.h>
+#include <__ranges/empty.h>
 #include <__ranges/non_propagating_cache.h>
 #include <__ranges/range_adaptor.h>
 #include <__ranges/view_interface.h>
@@ -63,6 +66,14 @@
     >;
   };
 
+  template <input_range _View, bool _Const>
+    requires view<_View> && input_range<range_reference_t<_View>>
+  struct __join_view_iterator;
+
+  template <input_range _View, bool _Const>
+    requires view<_View> && input_range<range_reference_t<_View>>
+  struct __join_view_sentinel;
+
   template<input_range _View>
     requires view<_View> && input_range<range_reference_t<_View>>
   class join_view
@@ -70,8 +81,22 @@
   private:
     using _InnerRange = range_reference_t<_View>;
 
-    template<bool> struct __iterator;
-    template<bool> struct __sentinel;
+    template<bool _Const>
+    using __iterator = __join_view_iterator<_View, _Const>;
+
+    template<bool _Const>
+    using __sentinel = __join_view_sentinel<_View, _Const>;
+
+    template <input_range _View2, bool _Const2>
+      requires view<_View2> && input_range<range_reference_t<_View2>>
+    friend struct __join_view_iterator;
+
+    template <input_range _View2, bool _Const2>
+      requires view<_View2> && input_range<range_reference_t<_View2>>
+    friend struct __join_view_sentinel;
+
+    template <class>
+    friend struct std::__segmented_iterator_traits;
 
     static constexpr bool _UseCache = !is_reference_v<_InnerRange>;
     using _Cache = _If<_UseCache, __non_propagating_cache<remove_cvref_t<_InnerRange>>, __empty_cache>;
@@ -139,49 +164,57 @@
     }
   };
 
-  template<input_range _View>
+  template<input_range _View, bool _Const>
     requires view<_View> && input_range<range_reference_t<_View>>
-  template<bool _Const> struct join_view<_View>::__sentinel {
-    template<bool> friend struct __sentinel;
+  struct __join_view_sentinel {
+    template<input_range _View2, bool>
+      requires view<_View2> && input_range<range_reference_t<_View2>>
+    friend struct __join_view_sentinel;
 
   private:
-    using _Parent = __maybe_const<_Const, join_view>;
+    using _Parent = __maybe_const<_Const, join_view<_View>>;
     using _Base = __maybe_const<_Const, _View>;
     sentinel_t<_Base> __end_ = sentinel_t<_Base>();
 
   public:
     _LIBCPP_HIDE_FROM_ABI
-    __sentinel() = default;
+    __join_view_sentinel() = default;
 
     _LIBCPP_HIDE_FROM_ABI
-    constexpr explicit __sentinel(_Parent& __parent)
+    constexpr explicit __join_view_sentinel(_Parent& __parent)
       : __end_(ranges::end(__parent.__base_)) {}
 
     _LIBCPP_HIDE_FROM_ABI
-    constexpr __sentinel(__sentinel<!_Const> __s)
+    constexpr __join_view_sentinel(__join_view_sentinel<_View, !_Const> __s)
       requires _Const && convertible_to<sentinel_t<_View>, sentinel_t<_Base>>
       : __end_(std::move(__s.__end_)) {}
 
     template<bool _OtherConst>
       requires sentinel_for<sentinel_t<_Base>, iterator_t<__maybe_const<_OtherConst, _View>>>
     _LIBCPP_HIDE_FROM_ABI
-    friend constexpr bool operator==(const __iterator<_OtherConst>& __x, const __sentinel& __y) {
+    friend constexpr bool operator==(const __join_view_iterator<_View, _OtherConst>& __x, const __join_view_sentinel& __y) {
       return __x.__outer_ == __y.__end_;
     }
   };
 
-  template<input_range _View>
+  template<input_range _View, bool _Const>
     requires view<_View> && input_range<range_reference_t<_View>>
-  template<bool _Const> struct join_view<_View>::__iterator
+  struct __join_view_iterator
     : public __join_view_iterator_category<__maybe_const<_Const, _View>> {
 
-    template<bool> friend struct __iterator;
+    template<input_range _View2, bool>
+      requires view<_View2> && input_range<range_reference_t<_View2>>
+    friend struct __join_view_iterator;
+
+    template <class>
+    friend struct std::__segmented_iterator_traits;
 
   private:
-    using _Parent = __maybe_const<_Const, join_view>;
+    using _Parent = __maybe_const<_Const, join_view<_View>>;
     using _Base = __maybe_const<_Const, _View>;
     using _Outer = iterator_t<_Base>;
     using _Inner = iterator_t<range_reference_t<_Base>>;
+    using _InnerRange = range_reference_t<_View>;
 
     static constexpr bool __ref_is_glvalue = is_reference_v<range_reference_t<_Base>>;
 
@@ -210,6 +243,9 @@
         __inner_.reset();
     }
 
+    _LIBCPP_HIDE_FROM_ABI constexpr __join_view_iterator(_Parent* __parent, _Outer __outer, _Inner __inner)
+      : __outer_(std::move(__outer)), __inner_(std::move(__inner)), __parent_(__parent) {}
+
   public:
     using iterator_concept = _If<
       __ref_is_glvalue && bidirectional_range<_Base> && bidirectional_range<range_reference_t<_Base>> &&
@@ -228,17 +264,17 @@
       range_difference_t<_Base>, range_difference_t<range_reference_t<_Base>>>;
 
     _LIBCPP_HIDE_FROM_ABI
-    __iterator() requires default_initializable<_Outer> = default;
+    __join_view_iterator() requires default_initializable<_Outer> = default;
 
     _LIBCPP_HIDE_FROM_ABI
-    constexpr __iterator(_Parent& __parent, _Outer __outer)
+    constexpr __join_view_iterator(_Parent& __parent, _Outer __outer)
       : __outer_(std::move(__outer))
       , __parent_(std::addressof(__parent)) {
       __satisfy();
     }
 
     _LIBCPP_HIDE_FROM_ABI
-    constexpr __iterator(__iterator<!_Const> __i)
+    constexpr __join_view_iterator(__join_view_iterator<_View, !_Const> __i)
       requires _Const &&
                convertible_to<iterator_t<_View>, _Outer> &&
                convertible_to<iterator_t<_InnerRange>, _Inner>
@@ -259,7 +295,7 @@
     }
 
     _LIBCPP_HIDE_FROM_ABI
-    constexpr __iterator& operator++() {
+    constexpr __join_view_iterator& operator++() {
       auto&& __inner = [&]() -> auto&& {
         if constexpr (__ref_is_glvalue)
           return *__outer_;
@@ -279,7 +315,7 @@
     }
 
     _LIBCPP_HIDE_FROM_ABI
-    constexpr __iterator operator++(int)
+    constexpr __join_view_iterator operator++(int)
       requires __ref_is_glvalue &&
                forward_range<_Base> &&
                forward_range<range_reference_t<_Base>>
@@ -290,7 +326,7 @@
     }
 
     _LIBCPP_HIDE_FROM_ABI
-    constexpr __iterator& operator--()
+    constexpr __join_view_iterator& operator--()
       requires __ref_is_glvalue &&
                bidirectional_range<_Base> &&
                bidirectional_range<range_reference_t<_Base>> &&
@@ -309,7 +345,7 @@
     }
 
     _LIBCPP_HIDE_FROM_ABI
-    constexpr __iterator operator--(int)
+    constexpr __join_view_iterator operator--(int)
       requires __ref_is_glvalue &&
                bidirectional_range<_Base> &&
                bidirectional_range<range_reference_t<_Base>> &&
@@ -321,7 +357,7 @@
     }
 
     _LIBCPP_HIDE_FROM_ABI
-    friend constexpr bool operator==(const __iterator& __x, const __iterator& __y)
+    friend constexpr bool operator==(const __join_view_iterator& __x, const __join_view_iterator& __y)
       requires __ref_is_glvalue &&
                equality_comparable<iterator_t<_Base>> &&
                equality_comparable<iterator_t<range_reference_t<_Base>>>
@@ -330,14 +366,14 @@
     }
 
     _LIBCPP_HIDE_FROM_ABI
-    friend constexpr decltype(auto) iter_move(const __iterator& __i)
+    friend constexpr decltype(auto) iter_move(const __join_view_iterator& __i)
       noexcept(noexcept(ranges::iter_move(*__i.__inner_)))
     {
       return ranges::iter_move(*__i.__inner_);
     }
 
     _LIBCPP_HIDE_FROM_ABI
-    friend constexpr void iter_swap(const __iterator& __x, const __iterator& __y)
+    friend constexpr void iter_swap(const __join_view_iterator& __x, const __join_view_iterator& __y)
       noexcept(noexcept(ranges::iter_swap(*__x.__inner_, *__y.__inner_)))
       requires indirectly_swappable<_Inner>
     {
@@ -365,6 +401,50 @@
 } // namespace views
 } // namespace ranges
 
+template <class _View, bool _Const>
+  requires(ranges::common_range<typename ranges::__join_view_iterator<_View, _Const>::_Parent> &&
+           __is_cpp17_random_access_iterator<typename ranges::__join_view_iterator<_View, _Const>::_Outer>::value &&
+           __is_cpp17_random_access_iterator<typename ranges::__join_view_iterator<_View, _Const>::_Inner>::value)
+struct __segmented_iterator_traits<ranges::__join_view_iterator<_View, _Const>> {
+  using _JoinViewIterator = ranges::__join_view_iterator<_View, _Const>;
+
+  using __segment_iterator =
+      _LIBCPP_NODEBUG __iterator_with_data<typename _JoinViewIterator::_Outer, typename _JoinViewIterator::_Parent*>;
+  using __local_iterator = typename _JoinViewIterator::_Inner;
+
+  // TODO: Would it make sense to enable the optimization for other iterator types?
+
+  static constexpr _LIBCPP_HIDE_FROM_ABI __segment_iterator __segment(_JoinViewIterator __iter) {
+      if (ranges::empty(__iter.__parent_->__base_))
+        return {};
+      if (!__iter.__inner_.has_value())
+        return __segment_iterator(--__iter.__outer_, __iter.__parent_);
+      return __segment_iterator(__iter.__outer_, __iter.__parent_);
+  }
+
+  static constexpr _LIBCPP_HIDE_FROM_ABI __local_iterator __local(_JoinViewIterator __iter) {
+      if (ranges::empty(__iter.__parent_->__base_))
+        return {};
+      if (!__iter.__inner_.has_value())
+        return ranges::end(*--__iter.__outer_);
+      return *__iter.__inner_;
+  }
+
+  static constexpr _LIBCPP_HIDE_FROM_ABI __local_iterator __begin(__segment_iterator __iter) {
+      return ranges::begin(*__iter.__get_iter());
+  }
+
+  static constexpr _LIBCPP_HIDE_FROM_ABI __local_iterator __end(__segment_iterator __iter) {
+      return ranges::end(*__iter.__get_iter());
+  }
+
+  static constexpr _LIBCPP_HIDE_FROM_ABI _JoinViewIterator
+  __compose(__segment_iterator __seg_iter, __local_iterator __local_iter) {
+      return _JoinViewIterator(
+          std::move(__seg_iter).__get_data(), std::move(__seg_iter).__get_iter(), std::move(__local_iter));
+  }
+};
+
 #endif // _LIBCPP_STD_VER > 17
 
 _LIBCPP_END_NAMESPACE_STD
diff --git a/include/module.modulemap.in b/include/module.modulemap.in
index ffc8e27..ab9a213 100644
--- a/include/module.modulemap.in
+++ b/include/module.modulemap.in
@@ -996,6 +996,7 @@
       module iter_swap             { private header "__iterator/iter_swap.h" }
       module iterator              { private header "__iterator/iterator.h" }
       module iterator_traits       { private header "__iterator/iterator_traits.h" }
+      module iterator_with_data    { private header "__iterator/iterator_with_data.h" }
       module mergeable {
         private header "__iterator/mergeable.h"
         export functional.__functional.ranges_operations