10#ifndef XTENSOR_ITERATOR_HPP
11#define XTENSOR_ITERATOR_HPP
20#include <xtl/xcompare.hpp>
21#include <xtl/xiterator_base.hpp>
22#include <xtl/xmeta_utils.hpp>
23#include <xtl/xsequence.hpp>
25#include "xexception.hpp"
40 template <
bool is_const,
class CT>
41 class xscalar_stepper;
46 struct get_stepper_iterator_impl
48 using type =
typename C::container_iterator;
52 struct get_stepper_iterator_impl<const C>
54 using type =
typename C::const_container_iterator;
58 struct get_stepper_iterator_impl<xscalar<CT>>
60 using type =
typename xscalar<CT>::dummy_iterator;
64 struct get_stepper_iterator_impl<const xscalar<CT>>
66 using type =
typename xscalar<CT>::const_dummy_iterator;
71 using get_stepper_iterator =
typename detail::get_stepper_iterator_impl<C>::type;
80 struct index_type_impl
82 using type = dynamic_shape<typename ST::value_type>;
85 template <
class V, std::
size_t L>
86 struct index_type_impl<std::array<V, L>>
88 using type = std::array<V, L>;
91 template <std::size_t... I>
92 struct index_type_impl<fixed_shape<I...>>
94 using type = std::array<std::size_t,
sizeof...(I)>;
99 using xindex_type_t =
typename detail::index_type_impl<C>::type;
110 using storage_type =
C;
112 using subiterator_traits = std::iterator_traits<subiterator_type>;
113 using value_type =
typename subiterator_traits::value_type;
114 using reference =
typename subiterator_traits::reference;
115 using pointer =
typename subiterator_traits::pointer;
116 using difference_type =
typename subiterator_traits::difference_type;
117 using size_type =
typename storage_type::size_type;
118 using shape_type =
typename storage_type::shape_type;
119 using simd_value_type = xt_simd::simd_type<value_type>;
121 template <
class requested_type>
122 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
127 reference operator*()
const;
129 void step(size_type
dim, size_type
n = 1);
130 void step_back(size_type
dim, size_type
n = 1);
131 void reset(size_type
dim);
132 void reset_back(size_type
dim);
143 void store_simd(
const R&
vec);
148 subiterator_type m_it;
152 template <layout_type L>
159 template <
class S,
class IT,
class ST>
160 static void increment_stepper(
S& stepper,
IT& index,
const ST& shape);
162 template <
class S,
class IT,
class ST>
163 static void decrement_stepper(
S& stepper,
IT& index,
const ST& shape);
165 template <
class S,
class IT,
class ST>
166 static void increment_stepper(
S& stepper,
IT& index,
const ST& shape,
typename S::size_type
n);
168 template <
class S,
class IT,
class ST>
169 static void decrement_stepper(
S& stepper,
IT& index,
const ST& shape,
typename S::size_type
n);
176 template <
class E,
bool is_const>
182 using xexpression_type = std::conditional_t<is_const, const E, E>;
184 using value_type =
typename xexpression_type::value_type;
185 using reference = std::
186 conditional_t<is_const, typename xexpression_type::const_reference, typename xexpression_type::reference>;
187 using pointer = std::
188 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
189 using size_type =
typename xexpression_type::size_type;
190 using difference_type =
typename xexpression_type::difference_type;
192 using shape_type =
typename xexpression_type::shape_type;
198 reference operator*()
const;
200 void step(size_type
dim, size_type
n = 1);
201 void step_back(size_type
dim, size_type
n = 1);
202 void reset(size_type
dim);
203 void reset_back(size_type
dim);
210 xexpression_type* p_e;
218 static const bool value =
false;
221 template <
class T,
bool B>
224 static const bool value =
true;
227 template <
class T,
class R = T>
232 template <
class T,
class R = T>
235 template <
class T,
class R = T>
240 template <
class T,
class R = T>
254 using shape_type =
S;
255 using param_type =
const S&;
257 shape_storage() =
default;
258 shape_storage(param_type shape);
259 const S& shape()
const;
267 class shape_storage<S*>
271 using shape_type = S;
272 using param_type =
const S*;
274 shape_storage(param_type shape = 0);
275 const S& shape()
const;
282 template <layout_type L>
283 struct LAYOUT_FORBIDEN_FOR_XITERATOR;
286 template <
class St,
class S, layout_type L>
287 class xiterator :
public xtl::xrandom_access_iterator_base<
289 typename St::value_type,
290 typename St::difference_type,
291 typename St::pointer,
292 typename St::reference>,
293 private detail::shape_storage<S>
300 using value_type =
typename stepper_type::value_type;
301 using reference =
typename stepper_type::reference;
302 using pointer =
typename stepper_type::pointer;
303 using difference_type =
typename stepper_type::difference_type;
304 using size_type =
typename stepper_type::size_type;
305 using iterator_category = std::random_access_iterator_tag;
307 using private_base = detail::shape_storage<S>;
309 using shape_param_type =
typename private_base::param_type;
325 reference operator*()
const;
326 pointer operator->()
const;
335 difference_type m_linear_index;
337 using checking_type =
typename detail::LAYOUT_FORBIDEN_FOR_XITERATOR<L>::type;
340 template <
class St,
class S, layout_type L>
343 template <
class St,
class S, layout_type L>
346 template <
class St,
class S, layout_type L>
355 template <
class It,
class BIt>
357 xbounded_iterator<It, BIt>,
358 typename std::iterator_traits<It>::value_type,
359 typename std::iterator_traits<It>::difference_type,
360 typename std::iterator_traits<It>::pointer,
361 typename std::iterator_traits<It>::reference>
369 using value_type =
typename std::iterator_traits<It>::value_type;
370 using reference =
typename std::iterator_traits<It>::reference;
371 using pointer =
typename std::iterator_traits<It>::pointer;
372 using difference_type =
typename std::iterator_traits<It>::difference_type;
373 using iterator_category = std::random_access_iterator_tag;
386 value_type operator*()
const;
397 template <
class It,
class BIt>
400 template <
class It,
class BIt>
409 template <
class C,
class =
void_t<>>
410 struct has_linear_iterator : std::false_type
415 struct has_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
421 XTENSOR_CONSTEXPR_RETURN
auto linear_begin(C& c)
noexcept
423 return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
426 return self(c).linear_begin();
431 return self(c).begin();
437 XTENSOR_CONSTEXPR_RETURN
auto linear_end(C& c)
noexcept
439 return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
442 return self(c).linear_end();
447 return self(c).end();
453 XTENSOR_CONSTEXPR_RETURN
auto linear_begin(
const C& c)
noexcept
455 return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
458 return self(c).linear_cbegin();
463 return self(c).cbegin();
469 XTENSOR_CONSTEXPR_RETURN
auto linear_end(
const C& c)
noexcept
471 return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
474 return self(c).linear_cend();
479 return self(c).cend();
489 inline xstepper<C>::xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept
497 inline auto xstepper<C>::operator*() const -> reference
503 inline void xstepper<C>::step(size_type dim, size_type n)
507 using strides_value_type =
typename std::decay_t<
decltype(p_c->strides())>::value_type;
508 m_it += difference_type(
static_cast<strides_value_type
>(n) * p_c->strides()[dim - m_offset]);
513 inline void xstepper<C>::step_back(size_type dim, size_type n)
517 using strides_value_type =
typename std::decay_t<
decltype(p_c->strides())>::value_type;
518 m_it -= difference_type(
static_cast<strides_value_type
>(n) * p_c->strides()[dim - m_offset]);
523 inline void xstepper<C>::reset(size_type dim)
527 m_it -= difference_type(p_c->backstrides()[dim - m_offset]);
532 inline void xstepper<C>::reset_back(size_type dim)
536 m_it += difference_type(p_c->backstrides()[dim - m_offset]);
541 inline void xstepper<C>::to_begin()
543 m_it = p_c->data_xbegin();
549 m_it = p_c->data_xend(l, m_offset);
555 struct step_simd_invoker
558 static R apply(
const It& it)
561 return reg.load_unaligned(&(*it));
566 template <
bool is_const,
class T,
class S, layout_type L>
567 struct step_simd_invoker<xiterator<xscalar_stepper<is_const, T>, S, L>>
570 static R apply(
const xiterator<xscalar_stepper<is_const, T>, S, L>& it)
579 inline auto xstepper<C>::step_simd() -> simd_return_type<T>
581 using simd_type = simd_return_type<T>;
582 simd_type reg = detail::step_simd_invoker<subiterator_type>::template apply<simd_type>(m_it);
589 inline void xstepper<C>::store_simd(
const R& vec)
591 vec.store_unaligned(&(*m_it));
597 void xstepper<C>::step_leading()
603 template <
class S,
class IT,
class ST>
604 void stepper_tools<layout_type::row_major>::increment_stepper(S& stepper, IT& index,
const ST& shape)
606 using size_type =
typename S::size_type;
607 const size_type size = index.size();
612 if (index[i] != shape[i] - 1)
629 if (size != size_type(0))
640 index[size - 1] = shape[size - 1];
647 template <
class S,
class IT,
class ST>
648 void stepper_tools<layout_type::row_major>::increment_stepper(
652 typename S::size_type n
655 using size_type =
typename S::size_type;
656 const size_type size = index.size();
657 const size_type leading_i = size - 1;
659 while (i != 0 && n != 0)
662 size_type inc = (i == leading_i) ? n : 1;
663 if (xtl::cmp_less(index[i] + inc, shape[i]))
666 stepper.step(i, inc);
668 if (i != leading_i || index.size() == 1)
677 size_type off = shape[i] - index[i] - 1;
678 stepper.step(i, off);
688 if (i == 0 && n != 0)
690 if (size != size_type(0))
701 index[leading_i] = shape[leading_i];
708 template <
class S,
class IT,
class ST>
709 void stepper_tools<layout_type::row_major>::decrement_stepper(S& stepper, IT& index,
const ST& shape)
711 using size_type =
typename S::size_type;
712 size_type i = index.size();
719 stepper.step_back(i);
724 index[i] = shape[i] - 1;
727 stepper.reset_back(i);
738 template <
class S,
class IT,
class ST>
739 void stepper_tools<layout_type::row_major>::decrement_stepper(
743 typename S::size_type n
746 using size_type =
typename S::size_type;
747 size_type i = index.size();
748 size_type leading_i = index.size() - 1;
749 while (i != 0 && n != 0)
752 size_type inc = (i == leading_i) ? n : 1;
753 if (xtl::cmp_greater_equal(index[i], inc))
756 stepper.step_back(i, inc);
758 if (i != leading_i || index.size() == 1)
767 size_type off = index[i];
768 stepper.step_back(i, off);
771 index[i] = shape[i] - 1;
774 stepper.reset_back(i);
778 if (i == 0 && n != 0)
785 template <
class S,
class IT,
class ST>
786 void stepper_tools<layout_type::column_major>::increment_stepper(S& stepper, IT& index,
const ST& shape)
788 using size_type =
typename S::size_type;
789 const size_type size = index.size();
793 if (index[i] != shape[i] - 1)
811 if (size != size_type(0))
829 template <
class S,
class IT,
class ST>
830 void stepper_tools<layout_type::column_major>::increment_stepper(
834 typename S::size_type n
837 using size_type =
typename S::size_type;
838 const size_type size = index.size();
839 const size_type leading_i = 0;
841 while (i != size && n != 0)
843 size_type inc = (i == leading_i) ? n : 1;
844 if (index[i] + inc < shape[i])
847 stepper.step(i, inc);
849 if (i != leading_i || size == 1)
859 size_type off = shape[i] - index[i] - 1;
860 stepper.step(i, off);
871 if (i == size && n != 0)
873 if (size != size_type(0))
884 index[leading_i] = shape[leading_i];
891 template <
class S,
class IT,
class ST>
892 void stepper_tools<layout_type::column_major>::decrement_stepper(S& stepper, IT& index,
const ST& shape)
894 using size_type =
typename S::size_type;
895 size_type size = index.size();
902 stepper.step_back(i);
907 index[i] = shape[i] - 1;
910 stepper.reset_back(i);
922 template <
class S,
class IT,
class ST>
923 void stepper_tools<layout_type::column_major>::decrement_stepper(
927 typename S::size_type n
930 using size_type =
typename S::size_type;
931 size_type size = index.size();
933 size_type leading_i = 0;
934 while (i != size && n != 0)
936 size_type inc = (i == leading_i) ? n : 1;
940 stepper.step_back(i, inc);
942 if (i != leading_i || index.size() == 1)
952 size_type off = index[i];
953 stepper.step_back(i, off);
956 index[i] = shape[i] - 1;
959 stepper.reset_back(i);
964 if (i == size && n != 0)
974 template <
class C,
bool is_const>
975 inline xindexed_stepper<C, is_const>::xindexed_stepper(xexpression_type* e, size_type offset,
bool end) noexcept
977 , m_index(xtl::make_sequence<index_type>(e->shape().size(), size_type(0)))
983 to_end(XTENSOR_DEFAULT_TRAVERSAL);
987 template <
class C,
bool is_const>
988 inline auto xindexed_stepper<C, is_const>::operator*() const -> reference
990 return p_e->element(m_index.cbegin(), m_index.cend());
993 template <
class C,
bool is_const>
994 inline void xindexed_stepper<C, is_const>::step(size_type dim, size_type n)
998 m_index[dim - m_offset] +=
static_cast<typename index_type::value_type
>(n);
1002 template <
class C,
bool is_const>
1003 inline void xindexed_stepper<C, is_const>::step_back(size_type dim, size_type n)
1005 if (dim >= m_offset)
1007 m_index[dim - m_offset] -=
static_cast<typename index_type::value_type
>(n);
1011 template <
class C,
bool is_const>
1012 inline void xindexed_stepper<C, is_const>::reset(size_type dim)
1014 if (dim >= m_offset)
1016 m_index[dim - m_offset] = 0;
1020 template <
class C,
bool is_const>
1021 inline void xindexed_stepper<C, is_const>::reset_back(size_type dim)
1023 if (dim >= m_offset)
1025 m_index[dim - m_offset] = p_e->shape()[dim - m_offset] - 1;
1029 template <
class C,
bool is_const>
1030 inline void xindexed_stepper<C, is_const>::to_begin()
1032 std::fill(m_index.begin(), m_index.end(), size_type(0));
1035 template <
class C,
bool is_const>
1036 inline void xindexed_stepper<C, is_const>::to_end(
layout_type l)
1038 const auto& shape = p_e->shape();
1050 m_index[l_dim] = shape[l_dim];
1060 inline shape_storage<S>::shape_storage(param_type shape)
1066 inline const S& shape_storage<S>::shape()
const
1072 inline shape_storage<S*>::shape_storage(param_type shape)
1078 inline const S& shape_storage<S*>::shape()
const
1096 template <
class St,
class S, layout_type L>
1097 inline xiterator<St, S, L>::xiterator(St st, shape_param_type shape,
bool end_index)
1098 : private_base(shape)
1101 end_index ? xtl::forward_sequence<index_type, const shape_type&>(this->shape())
1102 : xtl::make_sequence<index_type>(this->shape().size(), size_type(0))
1109 if (m_index.size() != size_type(0))
1123 m_linear_index = difference_type(std::accumulate(
1124 this->shape().cbegin(),
1125 this->shape().cend(),
1127 std::multiplies<size_type>()
1132 template <
class St,
class S, layout_type L>
1133 inline auto xiterator<St, S, L>::operator++() -> self_type&
1135 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape());
1140 template <
class St,
class S, layout_type L>
1141 inline auto xiterator<St, S, L>::operator--() -> self_type&
1143 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape());
1148 template <
class St,
class S, layout_type L>
1149 inline auto xiterator<St, S, L>::operator+=(difference_type n) -> self_type&
1153 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(n));
1157 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(-n));
1159 m_linear_index += n;
1163 template <
class St,
class S, layout_type L>
1164 inline auto xiterator<St, S, L>::operator-=(difference_type n) -> self_type&
1168 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(n));
1172 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(),
static_cast<size_type
>(-n));
1174 m_linear_index -= n;
1178 template <
class St,
class S, layout_type L>
1179 inline auto xiterator<St, S, L>::operator-(
const self_type& rhs)
const -> difference_type
1181 return m_linear_index - rhs.m_linear_index;
1184 template <
class St,
class S, layout_type L>
1185 inline auto xiterator<St, S, L>::operator*() const -> reference
1190 template <
class St,
class S, layout_type L>
1191 inline auto xiterator<St, S, L>::operator->() const -> pointer
1196 template <
class St,
class S, layout_type L>
1197 inline bool xiterator<St, S, L>::equal(
const xiterator& rhs)
const
1199 XTENSOR_ASSERT(this->shape() == rhs.shape());
1200 return m_linear_index == rhs.m_linear_index;
1203 template <
class St,
class S, layout_type L>
1204 inline bool xiterator<St, S, L>::less_than(
const xiterator& rhs)
const
1206 XTENSOR_ASSERT(this->shape() == rhs.shape());
1207 return m_linear_index < rhs.m_linear_index;
1210 template <
class St,
class S, layout_type L>
1211 inline bool operator==(
const xiterator<St, S, L>& lhs,
const xiterator<St, S, L>& rhs)
1213 return lhs.equal(rhs);
1216 template <
class St,
class S, layout_type L>
1217 bool operator<(
const xiterator<St, S, L>& lhs,
const xiterator<St, S, L>& rhs)
1219 return lhs.less_than(rhs);
1226 template <
class It,
class BIt>
1227 xbounded_iterator<It, BIt>::xbounded_iterator(It it, BIt bound_it)
1229 , m_bound_it(bound_it)
1233 template <
class It,
class BIt>
1234 inline auto xbounded_iterator<It, BIt>::operator++() -> self_type&
1241 template <
class It,
class BIt>
1242 inline auto xbounded_iterator<It, BIt>::operator--() -> self_type&
1249 template <
class It,
class BIt>
1250 inline auto xbounded_iterator<It, BIt>::operator+=(difference_type n) -> self_type&
1257 template <
class It,
class BIt>
1258 inline auto xbounded_iterator<It, BIt>::operator-=(difference_type n) -> self_type&
1265 template <
class It,
class BIt>
1266 inline auto xbounded_iterator<It, BIt>::operator-(
const self_type& rhs)
const -> difference_type
1268 return m_it - rhs.m_it;
1271 template <
class It,
class BIt>
1272 inline auto xbounded_iterator<It, BIt>::operator*() const -> value_type
1274 using type =
decltype(*m_bound_it);
1275 return (
static_cast<type
>(*m_it) < *m_bound_it) ? *m_it :
static_cast<value_type
>((*m_bound_it) - 1);
1278 template <
class It,
class BIt>
1279 inline bool xbounded_iterator<It, BIt>::equal(
const self_type& rhs)
const
1281 return m_it == rhs.m_it && m_bound_it == rhs.m_bound_it;
1284 template <
class It,
class BIt>
1285 inline bool xbounded_iterator<It, BIt>::less_than(
const self_type& rhs)
const
1287 return m_it < rhs.m_it;
1290 template <
class It,
class BIt>
1291 inline bool operator==(
const xbounded_iterator<It, BIt>& lhs,
const xbounded_iterator<It, BIt>& rhs)
1293 return lhs.equal(rhs);
1296 template <
class It,
class BIt>
1297 inline bool operator<(
const xbounded_iterator<It, BIt>& lhs,
const xbounded_iterator<It, BIt>& rhs)
1299 return lhs.less_than(rhs);
standard mathematical functions for xexpressions
bool operator==(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks equality of the iterators.