10#ifndef XTENSOR_ITERATOR_HPP
11#define XTENSOR_ITERATOR_HPP
19#include <xtl/xcompare.hpp>
20#include <xtl/xiterator_base.hpp>
21#include <xtl/xmeta_utils.hpp>
22#include <xtl/xsequence.hpp>
24#include "../core/xlayout.hpp"
25#include "../core/xshape.hpp"
26#include "../utils/xexception.hpp"
27#include "../utils/xutils.hpp"
39 template <
bool is_const,
class CT>
45 struct get_stepper_iterator_impl
47 using type =
typename C::container_iterator;
51 struct get_stepper_iterator_impl<const C>
53 using type =
typename C::const_container_iterator;
57 struct get_stepper_iterator_impl<xscalar<CT>>
59 using type =
typename xscalar<CT>::dummy_iterator;
63 struct get_stepper_iterator_impl<const xscalar<CT>>
65 using type =
typename xscalar<CT>::const_dummy_iterator;
70 using get_stepper_iterator =
typename detail::get_stepper_iterator_impl<C>::type;
79 struct index_type_impl
81 using type = dynamic_shape<typename ST::value_type>;
84 template <
class V, std::
size_t L>
85 struct index_type_impl<std::array<V, L>>
87 using type = std::array<V, L>;
90 template <std::size_t... I>
91 struct index_type_impl<fixed_shape<I...>>
93 using type = std::array<std::size_t,
sizeof...(I)>;
98 using xindex_type_t =
typename detail::index_type_impl<C>::type;
109 using storage_type = C;
110 using subiterator_type = get_stepper_iterator<C>;
111 using subiterator_traits = std::iterator_traits<subiterator_type>;
112 using value_type =
typename subiterator_traits::value_type;
113 using reference =
typename subiterator_traits::reference;
114 using pointer =
typename subiterator_traits::pointer;
115 using difference_type =
typename subiterator_traits::difference_type;
116 using size_type =
typename storage_type::size_type;
117 using shape_type =
typename storage_type::shape_type;
118 using simd_value_type = xt_simd::simd_type<value_type>;
120 template <
class requested_type>
121 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
123 xstepper() =
default;
124 xstepper(storage_type* c, subiterator_type it, size_type offset)
noexcept;
126 reference operator*()
const;
128 void step(size_type dim, size_type n = 1);
129 void step_back(size_type dim, size_type n = 1);
130 void reset(size_type dim);
131 void reset_back(size_type dim);
137 simd_return_type<T> step_simd();
142 void store_simd(
const R& vec);
147 subiterator_type m_it;
151 template <layout_type L>
158 template <
class S,
class IT,
class ST>
159 static void increment_stepper(S& stepper, IT& index,
const ST& shape);
161 template <
class S,
class IT,
class ST>
162 static void decrement_stepper(S& stepper, IT& index,
const ST& shape);
164 template <
class S,
class IT,
class ST>
165 static void increment_stepper(S& stepper, IT& index,
const ST& shape,
typename S::size_type n);
167 template <
class S,
class IT,
class ST>
168 static void decrement_stepper(S& stepper, IT& index,
const ST& shape,
typename S::size_type n);
175 template <
class E,
bool is_const>
176 class xindexed_stepper
180 using self_type = xindexed_stepper<E, is_const>;
181 using xexpression_type = std::conditional_t<is_const, const E, E>;
183 using value_type =
typename xexpression_type::value_type;
184 using reference = std::
185 conditional_t<is_const, typename xexpression_type::const_reference, typename xexpression_type::reference>;
186 using pointer = std::
187 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
188 using size_type =
typename xexpression_type::size_type;
189 using difference_type =
typename xexpression_type::difference_type;
191 using shape_type =
typename xexpression_type::shape_type;
192 using index_type = xindex_type_t<shape_type>;
194 xindexed_stepper() =
default;
195 xindexed_stepper(xexpression_type* e, size_type offset,
bool end =
false)
noexcept;
197 reference operator*()
const;
199 void step(size_type dim, size_type n = 1);
200 void step_back(size_type dim, size_type n = 1);
201 void reset(size_type dim);
202 void reset_back(size_type dim);
209 xexpression_type* p_e;
217 static const bool value =
false;
220 template <
class T,
bool B>
223 static const bool value =
true;
226 template <
class T,
class R = T>
231 template <
class T,
class R = T>
234 template <
class T,
class R = T>
239 template <
class T,
class R = T>
253 using shape_type = S;
254 using param_type =
const S&;
256 shape_storage() =
default;
257 shape_storage(param_type shape);
258 const S& shape()
const;
266 class shape_storage<S*>
270 using shape_type = S;
271 using param_type =
const S*;
273 shape_storage(param_type shape = 0);
274 const S& shape()
const;
281 template <layout_type L>
282 struct LAYOUT_FORBIDEN_FOR_XITERATOR;
285 template <
class St,
class S, layout_type L>
286 class xiterator :
public xtl::xrandom_access_iterator_base<
288 typename St::value_type,
289 typename St::difference_type,
290 typename St::pointer,
291 typename St::reference>,
292 private detail::shape_storage<S>
296 using self_type = xiterator<St, S, L>;
298 using stepper_type = St;
299 using value_type =
typename stepper_type::value_type;
300 using reference =
typename stepper_type::reference;
301 using pointer =
typename stepper_type::pointer;
302 using difference_type =
typename stepper_type::difference_type;
303 using size_type =
typename stepper_type::size_type;
304 using iterator_category = std::random_access_iterator_tag;
306 using private_base = detail::shape_storage<S>;
307 using shape_type =
typename private_base::shape_type;
308 using shape_param_type =
typename private_base::param_type;
309 using index_type = xindex_type_t<shape_type>;
311 xiterator() =
default;
314 xiterator(St st, shape_param_type shape,
bool end_index);
316 self_type& operator++();
317 self_type& operator--();
319 self_type& operator+=(difference_type n);
320 self_type& operator-=(difference_type n);
322 difference_type operator-(
const self_type& rhs)
const;
324 reference operator*()
const;
325 pointer operator->()
const;
327 bool equal(
const xiterator& rhs)
const;
328 bool less_than(
const xiterator& rhs)
const;
334 difference_type m_linear_index;
336 using checking_type =
typename detail::LAYOUT_FORBIDEN_FOR_XITERATOR<L>::type;
339 template <
class St,
class S, layout_type L>
342 template <
class St,
class S, layout_type L>
345 template <
class St,
class S, layout_type L>
354 template <
class It,
class BIt>
355 class xbounded_iterator :
public xtl::xrandom_access_iterator_base<
356 xbounded_iterator<It, BIt>,
357 typename std::iterator_traits<It>::value_type,
358 typename std::iterator_traits<It>::difference_type,
359 typename std::iterator_traits<It>::pointer,
360 typename std::iterator_traits<It>::reference>
364 using self_type = xbounded_iterator<It, BIt>;
366 using subiterator_type = It;
367 using bound_iterator_type = BIt;
368 using value_type =
typename std::iterator_traits<It>::value_type;
369 using reference =
typename std::iterator_traits<It>::reference;
370 using pointer =
typename std::iterator_traits<It>::pointer;
371 using difference_type =
typename std::iterator_traits<It>::difference_type;
372 using iterator_category = std::random_access_iterator_tag;
374 xbounded_iterator() =
default;
375 xbounded_iterator(It it, BIt bound_it);
377 self_type& operator++();
378 self_type& operator--();
380 self_type& operator+=(difference_type n);
381 self_type& operator-=(difference_type n);
383 difference_type operator-(
const self_type& rhs)
const;
385 value_type operator*()
const;
387 bool equal(
const self_type& rhs)
const;
388 bool less_than(
const self_type& rhs)
const;
392 subiterator_type m_it;
393 bound_iterator_type m_bound_it;
396 template <
class It,
class BIt>
399 template <
class It,
class BIt>
408 template <
class C,
class =
void_t<>>
409 struct has_linear_iterator : std::false_type
414 struct has_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
420 constexpr auto linear_begin(C& c)
noexcept
422 if constexpr (detail::has_linear_iterator<C>::value)
424 return c.linear_begin();
433 constexpr auto linear_end(C& c)
noexcept
435 if constexpr (detail::has_linear_iterator<C>::value)
437 return c.linear_end();
446 constexpr auto linear_begin(
const C& c)
noexcept
448 if constexpr (detail::has_linear_iterator<C>::value)
450 return c.linear_cbegin();
459 constexpr auto linear_end(
const C& c)
noexcept
461 if constexpr (detail::has_linear_iterator<C>::value)
463 return c.linear_cend();
476 inline xstepper<C>::xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept
484 inline auto xstepper<C>::operator*() const -> reference
490 inline void xstepper<C>::step(size_type dim, size_type n)
494 using strides_value_type =
typename std::decay_t<
decltype(p_c->strides())>::value_type;
495 m_it += difference_type(
static_cast<strides_value_type
>(n) * p_c->strides()[dim - m_offset]);
500 inline void xstepper<C>::step_back(size_type dim, size_type n)
504 using strides_value_type =
typename std::decay_t<
decltype(p_c->strides())>::value_type;
505 m_it -= difference_type(
static_cast<strides_value_type
>(n) * p_c->strides()[dim - m_offset]);
510 inline void xstepper<C>::reset(size_type dim)
514 m_it -= difference_type(p_c->backstrides()[dim - m_offset]);
519 inline void xstepper<C>::reset_back(size_type dim)
523 m_it += difference_type(p_c->backstrides()[dim - m_offset]);
528 inline void xstepper<C>::to_begin()
530 m_it = p_c->data_xbegin();
536 m_it = p_c->data_xend(l, m_offset);
542 struct step_simd_invoker
545 static R apply(
const It& it)
548 return reg.load_unaligned(&(*it));
553 template <
bool is_const,
class T,
class S, layout_type L>
554 struct step_simd_invoker<xiterator<xscalar_stepper<is_const, T>, S, L>>
557 static R apply(
const xiterator<xscalar_stepper<is_const, T>, S, L>& it)
566 inline auto xstepper<C>::step_simd() -> simd_return_type<T>
568 using simd_type = simd_return_type<T>;
569 simd_type reg = detail::step_simd_invoker<subiterator_type>::template apply<simd_type>(m_it);
570 m_it += xt_simd::revert_simd_traits<simd_type>::size;
576 inline void xstepper<C>::store_simd(
const R& vec)
578 vec.store_unaligned(&(*m_it));
579 m_it += xt_simd::revert_simd_traits<R>::size;
584 void xstepper<C>::step_leading()
590 template <
class S,
class IT,
class ST>
591 void stepper_tools<layout_type::row_major>::increment_stepper(S& stepper, IT& index,
const ST& shape)
593 using size_type =
typename S::size_type;
594 const size_type size = index.size();
599 if (index[i] != shape[i] - 1)
616 if (size != size_type(0))
627 index[size - 1] = shape[size - 1];
634 template <
class S,
class IT,
class ST>
635 void stepper_tools<layout_type::row_major>::increment_stepper(
639 typename S::size_type n
642 using size_type =
typename S::size_type;
643 const size_type size = index.size();
644 const size_type leading_i = size - 1;
646 while (i != 0 && n != 0)
649 size_type inc = (i == leading_i) ? n : 1;
650 if (xtl::cmp_less(index[i] + inc, shape[i]))
653 stepper.step(i, inc);
655 if (i != leading_i || index.size() == 1)
664 size_type off = shape[i] - index[i] - 1;
665 stepper.step(i, off);
675 if (i == 0 && n != 0)
677 if (size != size_type(0))
688 index[leading_i] = shape[leading_i];
695 template <
class S,
class IT,
class ST>
696 void stepper_tools<layout_type::row_major>::decrement_stepper(S& stepper, IT& index,
const ST& shape)
698 using size_type =
typename S::size_type;
699 size_type i = index.size();
706 stepper.step_back(i);
711 index[i] = shape[i] - 1;
714 stepper.reset_back(i);
725 template <
class S,
class IT,
class ST>
726 void stepper_tools<layout_type::row_major>::decrement_stepper(
730 typename S::size_type n
733 using size_type =
typename S::size_type;
734 size_type i = index.size();
735 size_type leading_i = index.size() - 1;
736 while (i != 0 && n != 0)
739 size_type inc = (i == leading_i) ? n : 1;
740 if (xtl::cmp_greater_equal(index[i], inc))
743 stepper.step_back(i, inc);
745 if (i != leading_i || index.size() == 1)
754 size_type off = index[i];
755 stepper.step_back(i, off);
758 index[i] = shape[i] - 1;
761 stepper.reset_back(i);
765 if (i == 0 && n != 0)
772 template <
class S,
class IT,
class ST>
773 void stepper_tools<layout_type::column_major>::increment_stepper(S& stepper, IT& index,
const ST& shape)
775 using size_type =
typename S::size_type;
776 const size_type size = index.size();
780 if (index[i] != shape[i] - 1)
798 if (size != size_type(0))
816 template <
class S,
class IT,
class ST>
817 void stepper_tools<layout_type::column_major>::increment_stepper(
821 typename S::size_type n
824 using size_type =
typename S::size_type;
825 const size_type size = index.size();
826 const size_type leading_i = 0;
828 while (i != size && n != 0)
830 size_type inc = (i == leading_i) ? n : 1;
831 if (index[i] + inc < shape[i])
834 stepper.step(i, inc);
836 if (i != leading_i || size == 1)
846 size_type off = shape[i] - index[i] - 1;
847 stepper.step(i, off);
858 if (i == size && n != 0)
860 if (size != size_type(0))
871 index[leading_i] = shape[leading_i];
878 template <
class S,
class IT,
class ST>
879 void stepper_tools<layout_type::column_major>::decrement_stepper(S& stepper, IT& index,
const ST& shape)
881 using size_type =
typename S::size_type;
882 size_type size = index.size();
889 stepper.step_back(i);
894 index[i] = shape[i] - 1;
897 stepper.reset_back(i);
909 template <
class S,
class IT,
class ST>
910 void stepper_tools<layout_type::column_major>::decrement_stepper(
914 typename S::size_type n
917 using size_type =
typename S::size_type;
918 size_type size = index.size();
920 size_type leading_i = 0;
921 while (i != size && n != 0)
923 size_type inc = (i == leading_i) ? n : 1;
927 stepper.step_back(i, inc);
929 if (i != leading_i || index.size() == 1)
939 size_type off = index[i];
940 stepper.step_back(i, off);
943 index[i] = shape[i] - 1;
946 stepper.reset_back(i);
951 if (i == size && n != 0)
961 template <
class C,
bool is_const>
962 inline xindexed_stepper<C, is_const>::xindexed_stepper(xexpression_type* e, size_type offset,
bool end) noexcept
964 , m_index(xtl::make_sequence<index_type>(e->shape().size(), size_type(0)))
970 to_end(XTENSOR_DEFAULT_TRAVERSAL);
974 template <
class C,
bool is_const>
975 inline auto xindexed_stepper<C, is_const>::operator*() const -> reference
977 return p_e->element(m_index.cbegin(), m_index.cend());
980 template <
class C,
bool is_const>
981 inline void xindexed_stepper<C, is_const>::step(size_type dim, size_type n)
985 m_index[dim - m_offset] +=
static_cast<typename index_type::value_type
>(n);
989 template <
class C,
bool is_const>
990 inline void xindexed_stepper<C, is_const>::step_back(size_type dim, size_type n)
994 m_index[dim - m_offset] -=
static_cast<typename index_type::value_type
>(n);
998 template <
class C,
bool is_const>
999 inline void xindexed_stepper<C, is_const>::reset(size_type dim)
1001 if (dim >= m_offset)
1003 m_index[dim - m_offset] = 0;
1007 template <
class C,
bool is_const>
1008 inline void xindexed_stepper<C, is_const>::reset_back(size_type dim)
1010 if (dim >= m_offset)
1012 m_index[dim - m_offset] = p_e->shape()[dim - m_offset] - 1;
1016 template <
class C,
bool is_const>
1017 inline void xindexed_stepper<C, is_const>::to_begin()
1019 std::fill(m_index.begin(), m_index.end(), size_type(0));
1022 template <
class C,
bool is_const>
1023 inline void xindexed_stepper<C, is_const>::to_end(
layout_type l)
1025 const auto& shape = p_e->shape();
1037 m_index[l_dim] = shape[l_dim];
1047 inline shape_storage<S>::shape_storage(param_type shape)
1053 inline const S& shape_storage<S>::shape()
const
1059 inline shape_storage<S*>::shape_storage(param_type shape)
1065 inline const S& shape_storage<S*>::shape()
const
1083 template <
class St,
class S, layout_type L>
1084 inline xiterator<St, S, L>::xiterator(St st, shape_param_type shape,
bool end_index)
1085 : private_base(shape)
1088 end_index ? xtl::forward_sequence<index_type, const shape_type&>(this->shape())
1089 : xtl::make_sequence<index_type>(this->shape().size(), size_type(0))
1096 if (m_index.size() != size_type(0))
1110 m_linear_index = difference_type(std::accumulate(
1111 this->shape().cbegin(),
1112 this->shape().cend(),
1114 std::multiplies<size_type>()
1119 template <
class St,
class S, layout_type L>
1120 inline auto xiterator<St, S, L>::operator++() -> self_type&
1122 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape());
1127 template <
class St,
class S, layout_type L>
1128 inline auto xiterator<St, S, L>::operator--() -> self_type&
1130 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape());
1135 template <
class St,
class S, layout_type L>
1136 inline auto xiterator<St, S, L>::operator+=(difference_type n) -> self_type&
1140 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(n));
1144 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(-n));
1146 m_linear_index += n;
1150 template <
class St,
class S, layout_type L>
1151 inline auto xiterator<St, S, L>::operator-=(difference_type n) -> self_type&
1155 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(n));
1159 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(-n));
1161 m_linear_index -= n;
1165 template <
class St,
class S, layout_type L>
1166 inline auto xiterator<St, S, L>::operator-(
const self_type& rhs)
const -> difference_type
1168 return m_linear_index - rhs.m_linear_index;
1171 template <
class St,
class S, layout_type L>
1172 inline auto xiterator<St, S, L>::operator*() const -> reference
1177 template <
class St,
class S, layout_type L>
1178 inline auto xiterator<St, S, L>::operator->() const -> pointer
1183 template <
class St,
class S, layout_type L>
1184 inline bool xiterator<St, S, L>::equal(
const xiterator& rhs)
const
1186 XTENSOR_ASSERT(this->shape() == rhs.shape());
1187 return m_linear_index == rhs.m_linear_index;
1190 template <
class St,
class S, layout_type L>
1191 inline bool xiterator<St, S, L>::less_than(
const xiterator& rhs)
const
1193 XTENSOR_ASSERT(this->shape() == rhs.shape());
1194 return m_linear_index < rhs.m_linear_index;
1197 template <
class St,
class S, layout_type L>
1200 return lhs.equal(rhs);
1203 template <
class St,
class S, layout_type L>
1206 return lhs.less_than(rhs);
1213 template <
class It,
class BIt>
1214 xbounded_iterator<It, BIt>::xbounded_iterator(It it, BIt bound_it)
1216 , m_bound_it(bound_it)
1220 template <
class It,
class BIt>
1221 inline auto xbounded_iterator<It, BIt>::operator++() -> self_type&
1228 template <
class It,
class BIt>
1229 inline auto xbounded_iterator<It, BIt>::operator--() -> self_type&
1236 template <
class It,
class BIt>
1237 inline auto xbounded_iterator<It, BIt>::operator+=(difference_type n) -> self_type&
1244 template <
class It,
class BIt>
1245 inline auto xbounded_iterator<It, BIt>::operator-=(difference_type n) -> self_type&
1252 template <
class It,
class BIt>
1253 inline auto xbounded_iterator<It, BIt>::operator-(
const self_type& rhs)
const -> difference_type
1255 return m_it - rhs.m_it;
1258 template <
class It,
class BIt>
1259 inline auto xbounded_iterator<It, BIt>::operator*() const -> value_type
1261 using type =
decltype(*m_bound_it);
1262 return (
static_cast<type
>(*m_it) < *m_bound_it) ? *m_it :
static_cast<value_type
>((*m_bound_it) - 1);
1265 template <
class It,
class BIt>
1266 inline bool xbounded_iterator<It, BIt>::equal(
const self_type& rhs)
const
1268 return m_it == rhs.m_it && m_bound_it == rhs.m_bound_it;
1271 template <
class It,
class BIt>
1272 inline bool xbounded_iterator<It, BIt>::less_than(
const self_type& rhs)
const
1274 return m_it < rhs.m_it;
1277 template <
class It,
class BIt>
1280 return lhs.equal(rhs);
1283 template <
class It,
class BIt>
1286 return lhs.less_than(rhs);
standard mathematical functions for xexpressions