10#ifndef XTENSOR_ITERATOR_HPP
11#define XTENSOR_ITERATOR_HPP
20#include <xtl/xcompare.hpp>
21#include <xtl/xiterator_base.hpp>
22#include <xtl/xmeta_utils.hpp>
23#include <xtl/xsequence.hpp>
25#include "../core/xlayout.hpp"
26#include "../core/xshape.hpp"
27#include "../utils/xexception.hpp"
28#include "../utils/xutils.hpp"
40 template <
bool is_const,
class CT>
46 struct get_stepper_iterator_impl
48 using type =
typename C::container_iterator;
52 struct get_stepper_iterator_impl<const C>
54 using type =
typename C::const_container_iterator;
58 struct get_stepper_iterator_impl<xscalar<CT>>
60 using type =
typename xscalar<CT>::dummy_iterator;
64 struct get_stepper_iterator_impl<const xscalar<CT>>
66 using type =
typename xscalar<CT>::const_dummy_iterator;
71 using get_stepper_iterator =
typename detail::get_stepper_iterator_impl<C>::type;
80 struct index_type_impl
82 using type = dynamic_shape<typename ST::value_type>;
85 template <
class V, std::
size_t L>
86 struct index_type_impl<std::array<V, L>>
88 using type = std::array<V, L>;
91 template <std::size_t... I>
92 struct index_type_impl<fixed_shape<I...>>
94 using type = std::array<std::size_t,
sizeof...(I)>;
99 using xindex_type_t =
typename detail::index_type_impl<C>::type;
110 using storage_type = C;
111 using subiterator_type = get_stepper_iterator<C>;
112 using subiterator_traits = std::iterator_traits<subiterator_type>;
113 using value_type =
typename subiterator_traits::value_type;
114 using reference =
typename subiterator_traits::reference;
115 using pointer =
typename subiterator_traits::pointer;
116 using difference_type =
typename subiterator_traits::difference_type;
117 using size_type =
typename storage_type::size_type;
118 using shape_type =
typename storage_type::shape_type;
119 using simd_value_type = xt_simd::simd_type<value_type>;
121 template <
class requested_type>
122 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
124 xstepper() =
default;
125 xstepper(storage_type* c, subiterator_type it, size_type offset)
noexcept;
127 reference operator*()
const;
129 void step(size_type dim, size_type n = 1);
130 void step_back(size_type dim, size_type n = 1);
131 void reset(size_type dim);
132 void reset_back(size_type dim);
138 simd_return_type<T> step_simd();
143 void store_simd(
const R& vec);
148 subiterator_type m_it;
152 template <layout_type L>
159 template <
class S,
class IT,
class ST>
160 static void increment_stepper(S& stepper, IT& index,
const ST& shape);
162 template <
class S,
class IT,
class ST>
163 static void decrement_stepper(S& stepper, IT& index,
const ST& shape);
165 template <
class S,
class IT,
class ST>
166 static void increment_stepper(S& stepper, IT& index,
const ST& shape,
typename S::size_type n);
168 template <
class S,
class IT,
class ST>
169 static void decrement_stepper(S& stepper, IT& index,
const ST& shape,
typename S::size_type n);
176 template <
class E,
bool is_const>
177 class xindexed_stepper
181 using self_type = xindexed_stepper<E, is_const>;
182 using xexpression_type = std::conditional_t<is_const, const E, E>;
184 using value_type =
typename xexpression_type::value_type;
185 using reference = std::
186 conditional_t<is_const, typename xexpression_type::const_reference, typename xexpression_type::reference>;
187 using pointer = std::
188 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
189 using size_type =
typename xexpression_type::size_type;
190 using difference_type =
typename xexpression_type::difference_type;
192 using shape_type =
typename xexpression_type::shape_type;
193 using index_type = xindex_type_t<shape_type>;
195 xindexed_stepper() =
default;
196 xindexed_stepper(xexpression_type* e, size_type offset,
bool end =
false)
noexcept;
198 reference operator*()
const;
200 void step(size_type dim, size_type n = 1);
201 void step_back(size_type dim, size_type n = 1);
202 void reset(size_type dim);
203 void reset_back(size_type dim);
210 xexpression_type* p_e;
218 static const bool value =
false;
221 template <
class T,
bool B>
224 static const bool value =
true;
227 template <
class T,
class R = T>
232 template <
class T,
class R = T>
235 template <
class T,
class R = T>
240 template <
class T,
class R = T>
254 using shape_type = S;
255 using param_type =
const S&;
257 shape_storage() =
default;
258 shape_storage(param_type shape);
259 const S& shape()
const;
267 class shape_storage<S*>
271 using shape_type = S;
272 using param_type =
const S*;
274 shape_storage(param_type shape = 0);
275 const S& shape()
const;
282 template <layout_type L>
283 struct LAYOUT_FORBIDEN_FOR_XITERATOR;
286 template <
class St,
class S, layout_type L>
287 class xiterator :
public xtl::xrandom_access_iterator_base<
289 typename St::value_type,
290 typename St::difference_type,
291 typename St::pointer,
292 typename St::reference>,
293 private detail::shape_storage<S>
297 using self_type = xiterator<St, S, L>;
299 using stepper_type = St;
300 using value_type =
typename stepper_type::value_type;
301 using reference =
typename stepper_type::reference;
302 using pointer =
typename stepper_type::pointer;
303 using difference_type =
typename stepper_type::difference_type;
304 using size_type =
typename stepper_type::size_type;
305 using iterator_category = std::random_access_iterator_tag;
307 using private_base = detail::shape_storage<S>;
308 using shape_type =
typename private_base::shape_type;
309 using shape_param_type =
typename private_base::param_type;
310 using index_type = xindex_type_t<shape_type>;
312 xiterator() =
default;
315 xiterator(St st, shape_param_type shape,
bool end_index);
317 self_type& operator++();
318 self_type& operator--();
320 self_type& operator+=(difference_type n);
321 self_type& operator-=(difference_type n);
323 difference_type operator-(
const self_type& rhs)
const;
325 reference operator*()
const;
326 pointer operator->()
const;
328 bool equal(
const xiterator& rhs)
const;
329 bool less_than(
const xiterator& rhs)
const;
335 difference_type m_linear_index;
337 using checking_type =
typename detail::LAYOUT_FORBIDEN_FOR_XITERATOR<L>::type;
340 template <
class St,
class S, layout_type L>
343 template <
class St,
class S, layout_type L>
346 template <
class St,
class S, layout_type L>
355 template <
class It,
class BIt>
356 class xbounded_iterator :
public xtl::xrandom_access_iterator_base<
357 xbounded_iterator<It, BIt>,
358 typename std::iterator_traits<It>::value_type,
359 typename std::iterator_traits<It>::difference_type,
360 typename std::iterator_traits<It>::pointer,
361 typename std::iterator_traits<It>::reference>
365 using self_type = xbounded_iterator<It, BIt>;
367 using subiterator_type = It;
368 using bound_iterator_type = BIt;
369 using value_type =
typename std::iterator_traits<It>::value_type;
370 using reference =
typename std::iterator_traits<It>::reference;
371 using pointer =
typename std::iterator_traits<It>::pointer;
372 using difference_type =
typename std::iterator_traits<It>::difference_type;
373 using iterator_category = std::random_access_iterator_tag;
375 xbounded_iterator() =
default;
376 xbounded_iterator(It it, BIt bound_it);
378 self_type& operator++();
379 self_type& operator--();
381 self_type& operator+=(difference_type n);
382 self_type& operator-=(difference_type n);
384 difference_type operator-(
const self_type& rhs)
const;
386 value_type operator*()
const;
388 bool equal(
const self_type& rhs)
const;
389 bool less_than(
const self_type& rhs)
const;
393 subiterator_type m_it;
394 bound_iterator_type m_bound_it;
397 template <
class It,
class BIt>
400 template <
class It,
class BIt>
409 template <
class C,
class =
void_t<>>
410 struct has_linear_iterator : std::false_type
415 struct has_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
421 XTENSOR_CONSTEXPR_RETURN
auto linear_begin(C& c)
noexcept
423 if constexpr (detail::has_linear_iterator<C>::value)
425 return c.linear_begin();
434 XTENSOR_CONSTEXPR_RETURN
auto linear_end(C& c)
noexcept
436 if constexpr (detail::has_linear_iterator<C>::value)
438 return c.linear_end();
447 XTENSOR_CONSTEXPR_RETURN
auto linear_begin(
const C& c)
noexcept
449 if constexpr (detail::has_linear_iterator<C>::value)
451 return c.linear_cbegin();
460 XTENSOR_CONSTEXPR_RETURN
auto linear_end(
const C& c)
noexcept
462 if constexpr (detail::has_linear_iterator<C>::value)
464 return c.linear_cend();
477 inline xstepper<C>::xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept
485 inline auto xstepper<C>::operator*() const -> reference
491 inline void xstepper<C>::step(size_type dim, size_type n)
495 using strides_value_type =
typename std::decay_t<
decltype(p_c->strides())>::value_type;
496 m_it += difference_type(
static_cast<strides_value_type
>(n) * p_c->strides()[dim - m_offset]);
501 inline void xstepper<C>::step_back(size_type dim, size_type n)
505 using strides_value_type =
typename std::decay_t<
decltype(p_c->strides())>::value_type;
506 m_it -= difference_type(
static_cast<strides_value_type
>(n) * p_c->strides()[dim - m_offset]);
511 inline void xstepper<C>::reset(size_type dim)
515 m_it -= difference_type(p_c->backstrides()[dim - m_offset]);
520 inline void xstepper<C>::reset_back(size_type dim)
524 m_it += difference_type(p_c->backstrides()[dim - m_offset]);
529 inline void xstepper<C>::to_begin()
531 m_it = p_c->data_xbegin();
537 m_it = p_c->data_xend(l, m_offset);
543 struct step_simd_invoker
546 static R apply(
const It& it)
549 return reg.load_unaligned(&(*it));
554 template <
bool is_const,
class T,
class S, layout_type L>
555 struct step_simd_invoker<xiterator<xscalar_stepper<is_const, T>, S, L>>
558 static R apply(
const xiterator<xscalar_stepper<is_const, T>, S, L>& it)
567 inline auto xstepper<C>::step_simd() -> simd_return_type<T>
569 using simd_type = simd_return_type<T>;
570 simd_type reg = detail::step_simd_invoker<subiterator_type>::template apply<simd_type>(m_it);
571 m_it += xt_simd::revert_simd_traits<simd_type>::size;
577 inline void xstepper<C>::store_simd(
const R& vec)
579 vec.store_unaligned(&(*m_it));
580 m_it += xt_simd::revert_simd_traits<R>::size;
585 void xstepper<C>::step_leading()
591 template <
class S,
class IT,
class ST>
592 void stepper_tools<layout_type::row_major>::increment_stepper(S& stepper, IT& index,
const ST& shape)
594 using size_type =
typename S::size_type;
595 const size_type size = index.size();
600 if (index[i] != shape[i] - 1)
617 if (size != size_type(0))
628 index[size - 1] = shape[size - 1];
635 template <
class S,
class IT,
class ST>
636 void stepper_tools<layout_type::row_major>::increment_stepper(
640 typename S::size_type n
643 using size_type =
typename S::size_type;
644 const size_type size = index.size();
645 const size_type leading_i = size - 1;
647 while (i != 0 && n != 0)
650 size_type inc = (i == leading_i) ? n : 1;
651 if (xtl::cmp_less(index[i] + inc, shape[i]))
654 stepper.step(i, inc);
656 if (i != leading_i || index.size() == 1)
665 size_type off = shape[i] - index[i] - 1;
666 stepper.step(i, off);
676 if (i == 0 && n != 0)
678 if (size != size_type(0))
689 index[leading_i] = shape[leading_i];
696 template <
class S,
class IT,
class ST>
697 void stepper_tools<layout_type::row_major>::decrement_stepper(S& stepper, IT& index,
const ST& shape)
699 using size_type =
typename S::size_type;
700 size_type i = index.size();
707 stepper.step_back(i);
712 index[i] = shape[i] - 1;
715 stepper.reset_back(i);
726 template <
class S,
class IT,
class ST>
727 void stepper_tools<layout_type::row_major>::decrement_stepper(
731 typename S::size_type n
734 using size_type =
typename S::size_type;
735 size_type i = index.size();
736 size_type leading_i = index.size() - 1;
737 while (i != 0 && n != 0)
740 size_type inc = (i == leading_i) ? n : 1;
741 if (xtl::cmp_greater_equal(index[i], inc))
744 stepper.step_back(i, inc);
746 if (i != leading_i || index.size() == 1)
755 size_type off = index[i];
756 stepper.step_back(i, off);
759 index[i] = shape[i] - 1;
762 stepper.reset_back(i);
766 if (i == 0 && n != 0)
773 template <
class S,
class IT,
class ST>
774 void stepper_tools<layout_type::column_major>::increment_stepper(S& stepper, IT& index,
const ST& shape)
776 using size_type =
typename S::size_type;
777 const size_type size = index.size();
781 if (index[i] != shape[i] - 1)
799 if (size != size_type(0))
817 template <
class S,
class IT,
class ST>
818 void stepper_tools<layout_type::column_major>::increment_stepper(
822 typename S::size_type n
825 using size_type =
typename S::size_type;
826 const size_type size = index.size();
827 const size_type leading_i = 0;
829 while (i != size && n != 0)
831 size_type inc = (i == leading_i) ? n : 1;
832 if (index[i] + inc < shape[i])
835 stepper.step(i, inc);
837 if (i != leading_i || size == 1)
847 size_type off = shape[i] - index[i] - 1;
848 stepper.step(i, off);
859 if (i == size && n != 0)
861 if (size != size_type(0))
872 index[leading_i] = shape[leading_i];
879 template <
class S,
class IT,
class ST>
880 void stepper_tools<layout_type::column_major>::decrement_stepper(S& stepper, IT& index,
const ST& shape)
882 using size_type =
typename S::size_type;
883 size_type size = index.size();
890 stepper.step_back(i);
895 index[i] = shape[i] - 1;
898 stepper.reset_back(i);
910 template <
class S,
class IT,
class ST>
911 void stepper_tools<layout_type::column_major>::decrement_stepper(
915 typename S::size_type n
918 using size_type =
typename S::size_type;
919 size_type size = index.size();
921 size_type leading_i = 0;
922 while (i != size && n != 0)
924 size_type inc = (i == leading_i) ? n : 1;
928 stepper.step_back(i, inc);
930 if (i != leading_i || index.size() == 1)
940 size_type off = index[i];
941 stepper.step_back(i, off);
944 index[i] = shape[i] - 1;
947 stepper.reset_back(i);
952 if (i == size && n != 0)
962 template <
class C,
bool is_const>
963 inline xindexed_stepper<C, is_const>::xindexed_stepper(xexpression_type* e, size_type offset,
bool end) noexcept
965 , m_index(xtl::make_sequence<index_type>(e->shape().size(), size_type(0)))
971 to_end(XTENSOR_DEFAULT_TRAVERSAL);
975 template <
class C,
bool is_const>
976 inline auto xindexed_stepper<C, is_const>::operator*() const -> reference
978 return p_e->element(m_index.cbegin(), m_index.cend());
981 template <
class C,
bool is_const>
982 inline void xindexed_stepper<C, is_const>::step(size_type dim, size_type n)
986 m_index[dim - m_offset] +=
static_cast<typename index_type::value_type
>(n);
990 template <
class C,
bool is_const>
991 inline void xindexed_stepper<C, is_const>::step_back(size_type dim, size_type n)
995 m_index[dim - m_offset] -=
static_cast<typename index_type::value_type
>(n);
999 template <
class C,
bool is_const>
1000 inline void xindexed_stepper<C, is_const>::reset(size_type dim)
1002 if (dim >= m_offset)
1004 m_index[dim - m_offset] = 0;
1008 template <
class C,
bool is_const>
1009 inline void xindexed_stepper<C, is_const>::reset_back(size_type dim)
1011 if (dim >= m_offset)
1013 m_index[dim - m_offset] = p_e->shape()[dim - m_offset] - 1;
1017 template <
class C,
bool is_const>
1018 inline void xindexed_stepper<C, is_const>::to_begin()
1020 std::fill(m_index.begin(), m_index.end(), size_type(0));
1023 template <
class C,
bool is_const>
1024 inline void xindexed_stepper<C, is_const>::to_end(
layout_type l)
1026 const auto& shape = p_e->shape();
1038 m_index[l_dim] = shape[l_dim];
1048 inline shape_storage<S>::shape_storage(param_type shape)
1054 inline const S& shape_storage<S>::shape()
const
1060 inline shape_storage<S*>::shape_storage(param_type shape)
1066 inline const S& shape_storage<S*>::shape()
const
1084 template <
class St,
class S, layout_type L>
1085 inline xiterator<St, S, L>::xiterator(St st, shape_param_type shape,
bool end_index)
1086 : private_base(shape)
1089 end_index ? xtl::forward_sequence<index_type, const shape_type&>(this->shape())
1090 : xtl::make_sequence<index_type>(this->shape().size(), size_type(0))
1097 if (m_index.size() != size_type(0))
1111 m_linear_index = difference_type(std::accumulate(
1112 this->shape().cbegin(),
1113 this->shape().cend(),
1115 std::multiplies<size_type>()
1120 template <
class St,
class S, layout_type L>
1121 inline auto xiterator<St, S, L>::operator++() -> self_type&
1123 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape());
1128 template <
class St,
class S, layout_type L>
1129 inline auto xiterator<St, S, L>::operator--() -> self_type&
1131 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape());
1136 template <
class St,
class S, layout_type L>
1137 inline auto xiterator<St, S, L>::operator+=(difference_type n) -> self_type&
1141 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(n));
1145 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(-n));
1147 m_linear_index += n;
1151 template <
class St,
class S, layout_type L>
1152 inline auto xiterator<St, S, L>::operator-=(difference_type n) -> self_type&
1156 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(n));
1160 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(-n));
1162 m_linear_index -= n;
1166 template <
class St,
class S, layout_type L>
1167 inline auto xiterator<St, S, L>::operator-(
const self_type& rhs)
const -> difference_type
1169 return m_linear_index - rhs.m_linear_index;
1172 template <
class St,
class S, layout_type L>
1173 inline auto xiterator<St, S, L>::operator*() const -> reference
1178 template <
class St,
class S, layout_type L>
1179 inline auto xiterator<St, S, L>::operator->() const -> pointer
1184 template <
class St,
class S, layout_type L>
1185 inline bool xiterator<St, S, L>::equal(
const xiterator& rhs)
const
1187 XTENSOR_ASSERT(this->shape() == rhs.shape());
1188 return m_linear_index == rhs.m_linear_index;
1191 template <
class St,
class S, layout_type L>
1192 inline bool xiterator<St, S, L>::less_than(
const xiterator& rhs)
const
1194 XTENSOR_ASSERT(this->shape() == rhs.shape());
1195 return m_linear_index < rhs.m_linear_index;
1198 template <
class St,
class S, layout_type L>
1201 return lhs.equal(rhs);
1204 template <
class St,
class S, layout_type L>
1207 return lhs.less_than(rhs);
1214 template <
class It,
class BIt>
1215 xbounded_iterator<It, BIt>::xbounded_iterator(It it, BIt bound_it)
1217 , m_bound_it(bound_it)
1221 template <
class It,
class BIt>
1222 inline auto xbounded_iterator<It, BIt>::operator++() -> self_type&
1229 template <
class It,
class BIt>
1230 inline auto xbounded_iterator<It, BIt>::operator--() -> self_type&
1237 template <
class It,
class BIt>
1238 inline auto xbounded_iterator<It, BIt>::operator+=(difference_type n) -> self_type&
1245 template <
class It,
class BIt>
1246 inline auto xbounded_iterator<It, BIt>::operator-=(difference_type n) -> self_type&
1253 template <
class It,
class BIt>
1254 inline auto xbounded_iterator<It, BIt>::operator-(
const self_type& rhs)
const -> difference_type
1256 return m_it - rhs.m_it;
1259 template <
class It,
class BIt>
1260 inline auto xbounded_iterator<It, BIt>::operator*() const -> value_type
1262 using type =
decltype(*m_bound_it);
1263 return (
static_cast<type
>(*m_it) < *m_bound_it) ? *m_it :
static_cast<value_type
>((*m_bound_it) - 1);
1266 template <
class It,
class BIt>
1267 inline bool xbounded_iterator<It, BIt>::equal(
const self_type& rhs)
const
1269 return m_it == rhs.m_it && m_bound_it == rhs.m_bound_it;
1272 template <
class It,
class BIt>
1273 inline bool xbounded_iterator<It, BIt>::less_than(
const self_type& rhs)
const
1275 return m_it < rhs.m_it;
1278 template <
class It,
class BIt>
1281 return lhs.equal(rhs);
1284 template <
class It,
class BIt>
1287 return lhs.less_than(rhs);
standard mathematical functions for xexpressions