xtensor
Loading...
Searching...
No Matches
xiterator.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_ITERATOR_HPP
11#define XTENSOR_ITERATOR_HPP
12
13#include <algorithm>
14#include <array>
15#include <cstddef>
16#include <iterator>
17#include <numeric>
18#include <vector>
19
20#include <xtl/xcompare.hpp>
21#include <xtl/xiterator_base.hpp>
22#include <xtl/xmeta_utils.hpp>
23#include <xtl/xsequence.hpp>
24
25#include "xexception.hpp"
26#include "xlayout.hpp"
27#include "xshape.hpp"
28#include "xutils.hpp"
29
30namespace xt
31{
32
33 /***********************
34 * iterator meta utils *
35 ***********************/
36
37 template <class CT>
38 class xscalar;
39
40 template <bool is_const, class CT>
41 class xscalar_stepper;
42
43 namespace detail
44 {
45 template <class C>
46 struct get_stepper_iterator_impl
47 {
48 using type = typename C::container_iterator;
49 };
50
51 template <class C>
52 struct get_stepper_iterator_impl<const C>
53 {
54 using type = typename C::const_container_iterator;
55 };
56
57 template <class CT>
58 struct get_stepper_iterator_impl<xscalar<CT>>
59 {
60 using type = typename xscalar<CT>::dummy_iterator;
61 };
62
63 template <class CT>
64 struct get_stepper_iterator_impl<const xscalar<CT>>
65 {
66 using type = typename xscalar<CT>::const_dummy_iterator;
67 };
68 }
69
70 template <class C>
71 using get_stepper_iterator = typename detail::get_stepper_iterator_impl<C>::type;
72
73 /********************************
74 * xindex_type_t implementation *
75 ********************************/
76
77 namespace detail
78 {
79 template <class ST>
80 struct index_type_impl
81 {
82 using type = dynamic_shape<typename ST::value_type>;
83 };
84
85 template <class V, std::size_t L>
86 struct index_type_impl<std::array<V, L>>
87 {
88 using type = std::array<V, L>;
89 };
90
91 template <std::size_t... I>
92 struct index_type_impl<fixed_shape<I...>>
93 {
94 using type = std::array<std::size_t, sizeof...(I)>;
95 };
96 }
97
98 template <class C>
99 using xindex_type_t = typename detail::index_type_impl<C>::type;
100
101 /************
102 * xstepper *
103 ************/
104
105 template <class C>
107 {
108 public:
109
110 using storage_type = C;
111 using subiterator_type = get_stepper_iterator<C>;
112 using subiterator_traits = std::iterator_traits<subiterator_type>;
113 using value_type = typename subiterator_traits::value_type;
114 using reference = typename subiterator_traits::reference;
115 using pointer = typename subiterator_traits::pointer;
116 using difference_type = typename subiterator_traits::difference_type;
117 using size_type = typename storage_type::size_type;
118 using shape_type = typename storage_type::shape_type;
119 using simd_value_type = xt_simd::simd_type<value_type>;
120
121 template <class requested_type>
122 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
123
124 xstepper() = default;
125 xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept;
126
127 reference operator*() const;
128
129 void step(size_type dim, size_type n = 1);
130 void step_back(size_type dim, size_type n = 1);
131 void reset(size_type dim);
132 void reset_back(size_type dim);
133
134 void to_begin();
135 void to_end(layout_type l);
136
137 template <class T>
138 simd_return_type<T> step_simd();
139
140 void step_leading();
141
142 template <class R>
143 void store_simd(const R& vec);
144
145 private:
146
147 storage_type* p_c;
148 subiterator_type m_it;
149 size_type m_offset;
150 };
151
152 template <layout_type L>
154 {
155 // For performance reasons, increment_stepper and decrement_stepper are
156 // specialized for the case where n=1, which underlies operator++ and
157 // operator-- on xiterators.
158
159 template <class S, class IT, class ST>
160 static void increment_stepper(S& stepper, IT& index, const ST& shape);
161
162 template <class S, class IT, class ST>
163 static void decrement_stepper(S& stepper, IT& index, const ST& shape);
164
165 template <class S, class IT, class ST>
166 static void increment_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n);
167
168 template <class S, class IT, class ST>
169 static void decrement_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n);
170 };
171
172 /********************
173 * xindexed_stepper *
174 ********************/
175
176 template <class E, bool is_const>
178 {
179 public:
180
182 using xexpression_type = std::conditional_t<is_const, const E, E>;
183
184 using value_type = typename xexpression_type::value_type;
185 using reference = std::
186 conditional_t<is_const, typename xexpression_type::const_reference, typename xexpression_type::reference>;
187 using pointer = std::
188 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
189 using size_type = typename xexpression_type::size_type;
190 using difference_type = typename xexpression_type::difference_type;
191
192 using shape_type = typename xexpression_type::shape_type;
194
195 xindexed_stepper() = default;
196 xindexed_stepper(xexpression_type* e, size_type offset, bool end = false) noexcept;
197
198 reference operator*() const;
199
200 void step(size_type dim, size_type n = 1);
201 void step_back(size_type dim, size_type n = 1);
202 void reset(size_type dim);
203 void reset_back(size_type dim);
204
205 void to_begin();
206 void to_end(layout_type l);
207
208 private:
209
210 xexpression_type* p_e;
211 index_type m_index;
212 size_type m_offset;
213 };
214
215 template <class T>
217 {
218 static const bool value = false;
219 };
220
221 template <class T, bool B>
223 {
224 static const bool value = true;
225 };
226
227 template <class T, class R = T>
228 struct enable_indexed_stepper : std::enable_if<is_indexed_stepper<T>::value, R>
229 {
230 };
231
232 template <class T, class R = T>
233 using enable_indexed_stepper_t = typename enable_indexed_stepper<T, R>::type;
234
235 template <class T, class R = T>
236 struct disable_indexed_stepper : std::enable_if<!is_indexed_stepper<T>::value, R>
237 {
238 };
239
240 template <class T, class R = T>
241 using disable_indexed_stepper_t = typename disable_indexed_stepper<T, R>::type;
242
243 /*************
244 * xiterator *
245 *************/
246
247 namespace detail
248 {
249 template <class S>
250 class shape_storage
251 {
252 public:
253
254 using shape_type = S;
255 using param_type = const S&;
256
257 shape_storage() = default;
258 shape_storage(param_type shape);
259 const S& shape() const;
260
261 private:
262
263 S m_shape;
264 };
265
266 template <class S>
267 class shape_storage<S*>
268 {
269 public:
270
271 using shape_type = S;
272 using param_type = const S*;
273
274 shape_storage(param_type shape = 0);
275 const S& shape() const;
276
277 private:
278
279 const S* p_shape;
280 };
281
282 template <layout_type L>
283 struct LAYOUT_FORBIDEN_FOR_XITERATOR;
284 }
285
286 template <class St, class S, layout_type L>
287 class xiterator : public xtl::xrandom_access_iterator_base<
288 xiterator<St, S, L>,
289 typename St::value_type,
290 typename St::difference_type,
291 typename St::pointer,
292 typename St::reference>,
293 private detail::shape_storage<S>
294 {
295 public:
296
298
299 using stepper_type = St;
300 using value_type = typename stepper_type::value_type;
301 using reference = typename stepper_type::reference;
302 using pointer = typename stepper_type::pointer;
303 using difference_type = typename stepper_type::difference_type;
304 using size_type = typename stepper_type::size_type;
305 using iterator_category = std::random_access_iterator_tag;
306
307 using private_base = detail::shape_storage<S>;
308 using shape_type = typename private_base::shape_type;
309 using shape_param_type = typename private_base::param_type;
311
312 xiterator() = default;
313
314 // end_index means either reverse_iterator && !end or !reverse_iterator && end
315 xiterator(St st, shape_param_type shape, bool end_index);
316
317 self_type& operator++();
318 self_type& operator--();
319
320 self_type& operator+=(difference_type n);
321 self_type& operator-=(difference_type n);
322
323 difference_type operator-(const self_type& rhs) const;
324
325 reference operator*() const;
326 pointer operator->() const;
327
328 bool equal(const xiterator& rhs) const;
329 bool less_than(const xiterator& rhs) const;
330
331 private:
332
333 stepper_type m_st;
334 index_type m_index;
335 difference_type m_linear_index;
336
337 using checking_type = typename detail::LAYOUT_FORBIDEN_FOR_XITERATOR<L>::type;
338 };
339
340 template <class St, class S, layout_type L>
342
343 template <class St, class S, layout_type L>
345
346 template <class St, class S, layout_type L>
347 struct is_contiguous_container<xiterator<St, S, L>> : std::false_type
348 {
349 };
350
351 /*********************
352 * xbounded_iterator *
353 *********************/
354
355 template <class It, class BIt>
356 class xbounded_iterator : public xtl::xrandom_access_iterator_base<
357 xbounded_iterator<It, BIt>,
358 typename std::iterator_traits<It>::value_type,
359 typename std::iterator_traits<It>::difference_type,
360 typename std::iterator_traits<It>::pointer,
361 typename std::iterator_traits<It>::reference>
362 {
363 public:
364
366
367 using subiterator_type = It;
368 using bound_iterator_type = BIt;
369 using value_type = typename std::iterator_traits<It>::value_type;
370 using reference = typename std::iterator_traits<It>::reference;
371 using pointer = typename std::iterator_traits<It>::pointer;
372 using difference_type = typename std::iterator_traits<It>::difference_type;
373 using iterator_category = std::random_access_iterator_tag;
374
375 xbounded_iterator() = default;
377
378 self_type& operator++();
379 self_type& operator--();
380
381 self_type& operator+=(difference_type n);
382 self_type& operator-=(difference_type n);
383
384 difference_type operator-(const self_type& rhs) const;
385
386 value_type operator*() const;
387
388 bool equal(const self_type& rhs) const;
389 bool less_than(const self_type& rhs) const;
390
391 private:
392
393 subiterator_type m_it;
394 bound_iterator_type m_bound_it;
395 };
396
397 template <class It, class BIt>
399
400 template <class It, class BIt>
402
403 /*****************************
404 * linear_begin / linear_end *
405 *****************************/
406
407 namespace detail
408 {
409 template <class C, class = void_t<>>
410 struct has_linear_iterator : std::false_type
411 {
412 };
413
414 template <class C>
415 struct has_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
416 {
417 };
418 }
419
420 template <class C>
421 XTENSOR_CONSTEXPR_RETURN auto linear_begin(C& c) noexcept
422 {
423 return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
424 [&](auto self)
425 {
426 return self(c).linear_begin();
427 },
428 /*else*/
429 [&](auto self)
430 {
431 return self(c).begin();
432 }
433 );
434 }
435
436 template <class C>
437 XTENSOR_CONSTEXPR_RETURN auto linear_end(C& c) noexcept
438 {
439 return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
440 [&](auto self)
441 {
442 return self(c).linear_end();
443 },
444 /*else*/
445 [&](auto self)
446 {
447 return self(c).end();
448 }
449 );
450 }
451
452 template <class C>
453 XTENSOR_CONSTEXPR_RETURN auto linear_begin(const C& c) noexcept
454 {
455 return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
456 [&](auto self)
457 {
458 return self(c).linear_cbegin();
459 },
460 /*else*/
461 [&](auto self)
462 {
463 return self(c).cbegin();
464 }
465 );
466 }
467
468 template <class C>
469 XTENSOR_CONSTEXPR_RETURN auto linear_end(const C& c) noexcept
470 {
471 return xtl::mpl::static_if<detail::has_linear_iterator<C>::value>(
472 [&](auto self)
473 {
474 return self(c).linear_cend();
475 },
476 /*else*/
477 [&](auto self)
478 {
479 return self(c).cend();
480 }
481 );
482 }
483
484 /***************************
485 * xstepper implementation *
486 ***************************/
487
488 template <class C>
489 inline xstepper<C>::xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept
490 : p_c(c)
491 , m_it(it)
492 , m_offset(offset)
493 {
494 }
495
496 template <class C>
497 inline auto xstepper<C>::operator*() const -> reference
498 {
499 return *m_it;
500 }
501
502 template <class C>
503 inline void xstepper<C>::step(size_type dim, size_type n)
504 {
505 if (dim >= m_offset)
506 {
507 using strides_value_type = typename std::decay_t<decltype(p_c->strides())>::value_type;
508 m_it += difference_type(static_cast<strides_value_type>(n) * p_c->strides()[dim - m_offset]);
509 }
510 }
511
512 template <class C>
513 inline void xstepper<C>::step_back(size_type dim, size_type n)
514 {
515 if (dim >= m_offset)
516 {
517 using strides_value_type = typename std::decay_t<decltype(p_c->strides())>::value_type;
518 m_it -= difference_type(static_cast<strides_value_type>(n) * p_c->strides()[dim - m_offset]);
519 }
520 }
521
522 template <class C>
523 inline void xstepper<C>::reset(size_type dim)
524 {
525 if (dim >= m_offset)
526 {
527 m_it -= difference_type(p_c->backstrides()[dim - m_offset]);
528 }
529 }
530
531 template <class C>
532 inline void xstepper<C>::reset_back(size_type dim)
533 {
534 if (dim >= m_offset)
535 {
536 m_it += difference_type(p_c->backstrides()[dim - m_offset]);
537 }
538 }
539
540 template <class C>
541 inline void xstepper<C>::to_begin()
542 {
543 m_it = p_c->data_xbegin();
544 }
545
546 template <class C>
547 inline void xstepper<C>::to_end(layout_type l)
548 {
549 m_it = p_c->data_xend(l, m_offset);
550 }
551
552 namespace detail
553 {
554 template <class It>
555 struct step_simd_invoker
556 {
557 template <class R>
558 static R apply(const It& it)
559 {
560 R reg;
561 return reg.load_unaligned(&(*it));
562 // return reg;
563 }
564 };
565
566 template <bool is_const, class T, class S, layout_type L>
567 struct step_simd_invoker<xiterator<xscalar_stepper<is_const, T>, S, L>>
568 {
569 template <class R>
570 static R apply(const xiterator<xscalar_stepper<is_const, T>, S, L>& it)
571 {
572 return R(*it);
573 }
574 };
575 }
576
577 template <class C>
578 template <class T>
579 inline auto xstepper<C>::step_simd() -> simd_return_type<T>
580 {
581 using simd_type = simd_return_type<T>;
582 simd_type reg = detail::step_simd_invoker<subiterator_type>::template apply<simd_type>(m_it);
584 return reg;
585 }
586
587 template <class C>
588 template <class R>
589 inline void xstepper<C>::store_simd(const R& vec)
590 {
591 vec.store_unaligned(&(*m_it));
593 ;
594 }
595
596 template <class C>
597 void xstepper<C>::step_leading()
598 {
599 ++m_it;
600 }
601
602 template <>
603 template <class S, class IT, class ST>
604 void stepper_tools<layout_type::row_major>::increment_stepper(S& stepper, IT& index, const ST& shape)
605 {
606 using size_type = typename S::size_type;
607 const size_type size = index.size();
608 size_type i = size;
609 while (i != 0)
610 {
611 --i;
612 if (index[i] != shape[i] - 1)
613 {
614 ++index[i];
615 stepper.step(i);
616 return;
617 }
618 else
619 {
620 index[i] = 0;
621 if (i != 0)
622 {
623 stepper.reset(i);
624 }
625 }
626 }
627 if (i == 0)
628 {
629 if (size != size_type(0))
630 {
631 std::transform(
632 shape.cbegin(),
633 shape.cend() - 1,
634 index.begin(),
635 [](const auto& v)
636 {
637 return v - 1;
638 }
639 );
640 index[size - 1] = shape[size - 1];
641 }
642 stepper.to_end(layout_type::row_major);
643 }
644 }
645
646 template <>
647 template <class S, class IT, class ST>
648 void stepper_tools<layout_type::row_major>::increment_stepper(
649 S& stepper,
650 IT& index,
651 const ST& shape,
652 typename S::size_type n
653 )
654 {
655 using size_type = typename S::size_type;
656 const size_type size = index.size();
657 const size_type leading_i = size - 1;
658 size_type i = size;
659 while (i != 0 && n != 0)
660 {
661 --i;
662 size_type inc = (i == leading_i) ? n : 1;
663 if (xtl::cmp_less(index[i] + inc, shape[i]))
664 {
665 index[i] += inc;
666 stepper.step(i, inc);
667 n -= inc;
668 if (i != leading_i || index.size() == 1)
669 {
670 i = index.size();
671 }
672 }
673 else
674 {
675 if (i == leading_i)
676 {
677 size_type off = shape[i] - index[i] - 1;
678 stepper.step(i, off);
679 n -= off;
680 }
681 index[i] = 0;
682 if (i != 0)
683 {
684 stepper.reset(i);
685 }
686 }
687 }
688 if (i == 0 && n != 0)
689 {
690 if (size != size_type(0))
691 {
692 std::transform(
693 shape.cbegin(),
694 shape.cend() - 1,
695 index.begin(),
696 [](const auto& v)
697 {
698 return v - 1;
699 }
700 );
701 index[leading_i] = shape[leading_i];
702 }
703 stepper.to_end(layout_type::row_major);
704 }
705 }
706
707 template <>
708 template <class S, class IT, class ST>
709 void stepper_tools<layout_type::row_major>::decrement_stepper(S& stepper, IT& index, const ST& shape)
710 {
711 using size_type = typename S::size_type;
712 size_type i = index.size();
713 while (i != 0)
714 {
715 --i;
716 if (index[i] != 0)
717 {
718 --index[i];
719 stepper.step_back(i);
720 return;
721 }
722 else
723 {
724 index[i] = shape[i] - 1;
725 if (i != 0)
726 {
727 stepper.reset_back(i);
728 }
729 }
730 }
731 if (i == 0)
732 {
733 stepper.to_begin();
734 }
735 }
736
737 template <>
738 template <class S, class IT, class ST>
739 void stepper_tools<layout_type::row_major>::decrement_stepper(
740 S& stepper,
741 IT& index,
742 const ST& shape,
743 typename S::size_type n
744 )
745 {
746 using size_type = typename S::size_type;
747 size_type i = index.size();
748 size_type leading_i = index.size() - 1;
749 while (i != 0 && n != 0)
750 {
751 --i;
752 size_type inc = (i == leading_i) ? n : 1;
753 if (xtl::cmp_greater_equal(index[i], inc))
754 {
755 index[i] -= inc;
756 stepper.step_back(i, inc);
757 n -= inc;
758 if (i != leading_i || index.size() == 1)
759 {
760 i = index.size();
761 }
762 }
763 else
764 {
765 if (i == leading_i)
766 {
767 size_type off = index[i];
768 stepper.step_back(i, off);
769 n -= off;
770 }
771 index[i] = shape[i] - 1;
772 if (i != 0)
773 {
774 stepper.reset_back(i);
775 }
776 }
777 }
778 if (i == 0 && n != 0)
779 {
780 stepper.to_begin();
781 }
782 }
783
784 template <>
785 template <class S, class IT, class ST>
786 void stepper_tools<layout_type::column_major>::increment_stepper(S& stepper, IT& index, const ST& shape)
787 {
788 using size_type = typename S::size_type;
789 const size_type size = index.size();
790 size_type i = 0;
791 while (i != size)
792 {
793 if (index[i] != shape[i] - 1)
794 {
795 ++index[i];
796 stepper.step(i);
797 return;
798 }
799 else
800 {
801 index[i] = 0;
802 if (i != size - 1)
803 {
804 stepper.reset(i);
805 }
806 }
807 ++i;
808 }
809 if (i == size)
810 {
811 if (size != size_type(0))
812 {
813 std::transform(
814 shape.cbegin() + 1,
815 shape.cend(),
816 index.begin() + 1,
817 [](const auto& v)
818 {
819 return v - 1;
820 }
821 );
822 index[0] = shape[0];
823 }
824 stepper.to_end(layout_type::column_major);
825 }
826 }
827
828 template <>
829 template <class S, class IT, class ST>
830 void stepper_tools<layout_type::column_major>::increment_stepper(
831 S& stepper,
832 IT& index,
833 const ST& shape,
834 typename S::size_type n
835 )
836 {
837 using size_type = typename S::size_type;
838 const size_type size = index.size();
839 const size_type leading_i = 0;
840 size_type i = 0;
841 while (i != size && n != 0)
842 {
843 size_type inc = (i == leading_i) ? n : 1;
844 if (index[i] + inc < shape[i])
845 {
846 index[i] += inc;
847 stepper.step(i, inc);
848 n -= inc;
849 if (i != leading_i || size == 1)
850 {
851 i = 0;
852 continue;
853 }
854 }
855 else
856 {
857 if (i == leading_i)
858 {
859 size_type off = shape[i] - index[i] - 1;
860 stepper.step(i, off);
861 n -= off;
862 }
863 index[i] = 0;
864 if (i != size - 1)
865 {
866 stepper.reset(i);
867 }
868 }
869 ++i;
870 }
871 if (i == size && n != 0)
872 {
873 if (size != size_type(0))
874 {
875 std::transform(
876 shape.cbegin() + 1,
877 shape.cend(),
878 index.begin() + 1,
879 [](const auto& v)
880 {
881 return v - 1;
882 }
883 );
884 index[leading_i] = shape[leading_i];
885 }
886 stepper.to_end(layout_type::column_major);
887 }
888 }
889
890 template <>
891 template <class S, class IT, class ST>
892 void stepper_tools<layout_type::column_major>::decrement_stepper(S& stepper, IT& index, const ST& shape)
893 {
894 using size_type = typename S::size_type;
895 size_type size = index.size();
896 size_type i = 0;
897 while (i != size)
898 {
899 if (index[i] != 0)
900 {
901 --index[i];
902 stepper.step_back(i);
903 return;
904 }
905 else
906 {
907 index[i] = shape[i] - 1;
908 if (i != size - 1)
909 {
910 stepper.reset_back(i);
911 }
912 }
913 ++i;
914 }
915 if (i == size)
916 {
917 stepper.to_begin();
918 }
919 }
920
921 template <>
922 template <class S, class IT, class ST>
923 void stepper_tools<layout_type::column_major>::decrement_stepper(
924 S& stepper,
925 IT& index,
926 const ST& shape,
927 typename S::size_type n
928 )
929 {
930 using size_type = typename S::size_type;
931 size_type size = index.size();
932 size_type i = 0;
933 size_type leading_i = 0;
934 while (i != size && n != 0)
935 {
936 size_type inc = (i == leading_i) ? n : 1;
937 if (index[i] >= inc)
938 {
939 index[i] -= inc;
940 stepper.step_back(i, inc);
941 n -= inc;
942 if (i != leading_i || index.size() == 1)
943 {
944 i = 0;
945 continue;
946 }
947 }
948 else
949 {
950 if (i == leading_i)
951 {
952 size_type off = index[i];
953 stepper.step_back(i, off);
954 n -= off;
955 }
956 index[i] = shape[i] - 1;
957 if (i != size - 1)
958 {
959 stepper.reset_back(i);
960 }
961 }
962 ++i;
963 }
964 if (i == size && n != 0)
965 {
966 stepper.to_begin();
967 }
968 }
969
970 /***********************************
971 * xindexed_stepper implementation *
972 ***********************************/
973
974 template <class C, bool is_const>
975 inline xindexed_stepper<C, is_const>::xindexed_stepper(xexpression_type* e, size_type offset, bool end) noexcept
976 : p_e(e)
977 , m_index(xtl::make_sequence<index_type>(e->shape().size(), size_type(0)))
978 , m_offset(offset)
979 {
980 if (end)
981 {
982 // Note: the layout here doesn't matter (unused) but using default traversal looks more "correct".
983 to_end(XTENSOR_DEFAULT_TRAVERSAL);
984 }
985 }
986
987 template <class C, bool is_const>
988 inline auto xindexed_stepper<C, is_const>::operator*() const -> reference
989 {
990 return p_e->element(m_index.cbegin(), m_index.cend());
991 }
992
993 template <class C, bool is_const>
994 inline void xindexed_stepper<C, is_const>::step(size_type dim, size_type n)
995 {
996 if (dim >= m_offset)
997 {
998 m_index[dim - m_offset] += static_cast<typename index_type::value_type>(n);
999 }
1000 }
1001
1002 template <class C, bool is_const>
1003 inline void xindexed_stepper<C, is_const>::step_back(size_type dim, size_type n)
1004 {
1005 if (dim >= m_offset)
1006 {
1007 m_index[dim - m_offset] -= static_cast<typename index_type::value_type>(n);
1008 }
1009 }
1010
1011 template <class C, bool is_const>
1012 inline void xindexed_stepper<C, is_const>::reset(size_type dim)
1013 {
1014 if (dim >= m_offset)
1015 {
1016 m_index[dim - m_offset] = 0;
1017 }
1018 }
1019
1020 template <class C, bool is_const>
1021 inline void xindexed_stepper<C, is_const>::reset_back(size_type dim)
1022 {
1023 if (dim >= m_offset)
1024 {
1025 m_index[dim - m_offset] = p_e->shape()[dim - m_offset] - 1;
1026 }
1027 }
1028
1029 template <class C, bool is_const>
1030 inline void xindexed_stepper<C, is_const>::to_begin()
1031 {
1032 std::fill(m_index.begin(), m_index.end(), size_type(0));
1033 }
1034
1035 template <class C, bool is_const>
1036 inline void xindexed_stepper<C, is_const>::to_end(layout_type l)
1037 {
1038 const auto& shape = p_e->shape();
1039 std::transform(
1040 shape.cbegin(),
1041 shape.cend(),
1042 m_index.begin(),
1043 [](const auto& v)
1044 {
1045 return v - 1;
1046 }
1047 );
1048
1049 size_type l_dim = (l == layout_type::row_major) ? shape.size() - 1 : 0;
1050 m_index[l_dim] = shape[l_dim];
1051 }
1052
1053 /****************************
1054 * xiterator implementation *
1055 ****************************/
1056
1057 namespace detail
1058 {
1059 template <class S>
1060 inline shape_storage<S>::shape_storage(param_type shape)
1061 : m_shape(shape)
1062 {
1063 }
1064
1065 template <class S>
1066 inline const S& shape_storage<S>::shape() const
1067 {
1068 return m_shape;
1069 }
1070
1071 template <class S>
1072 inline shape_storage<S*>::shape_storage(param_type shape)
1073 : p_shape(shape)
1074 {
1075 }
1076
1077 template <class S>
1078 inline const S& shape_storage<S*>::shape() const
1079 {
1080 return *p_shape;
1081 }
1082
1083 template <>
1084 struct LAYOUT_FORBIDEN_FOR_XITERATOR<layout_type::row_major>
1085 {
1086 using type = int;
1087 };
1088
1089 template <>
1090 struct LAYOUT_FORBIDEN_FOR_XITERATOR<layout_type::column_major>
1091 {
1092 using type = int;
1093 };
1094 }
1095
1096 template <class St, class S, layout_type L>
1097 inline xiterator<St, S, L>::xiterator(St st, shape_param_type shape, bool end_index)
1098 : private_base(shape)
1099 , m_st(st)
1100 , m_index(
1101 end_index ? xtl::forward_sequence<index_type, const shape_type&>(this->shape())
1102 : xtl::make_sequence<index_type>(this->shape().size(), size_type(0))
1103 )
1104 , m_linear_index(0)
1105 {
1106 // end_index means either reverse_iterator && !end or !reverse_iterator && end
1107 if (end_index)
1108 {
1109 if (m_index.size() != size_type(0))
1110 {
1111 auto iter_begin = (L == layout_type::row_major) ? m_index.begin() : m_index.begin() + 1;
1112 auto iter_end = (L == layout_type::row_major) ? m_index.end() - 1 : m_index.end();
1113 std::transform(
1114 iter_begin,
1115 iter_end,
1116 iter_begin,
1117 [](const auto& v)
1118 {
1119 return v - 1;
1120 }
1121 );
1122 }
1123 m_linear_index = difference_type(std::accumulate(
1124 this->shape().cbegin(),
1125 this->shape().cend(),
1126 size_type(1),
1127 std::multiplies<size_type>()
1128 ));
1129 }
1130 }
1131
1132 template <class St, class S, layout_type L>
1133 inline auto xiterator<St, S, L>::operator++() -> self_type&
1134 {
1135 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape());
1136 ++m_linear_index;
1137 return *this;
1138 }
1139
1140 template <class St, class S, layout_type L>
1141 inline auto xiterator<St, S, L>::operator--() -> self_type&
1142 {
1143 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape());
1144 --m_linear_index;
1145 return *this;
1146 }
1147
1148 template <class St, class S, layout_type L>
1149 inline auto xiterator<St, S, L>::operator+=(difference_type n) -> self_type&
1150 {
1151 if (n >= 0)
1152 {
1153 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(), static_cast<size_type>(n));
1154 }
1155 else
1156 {
1157 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(), static_cast<size_type>(-n));
1158 }
1159 m_linear_index += n;
1160 return *this;
1161 }
1162
1163 template <class St, class S, layout_type L>
1164 inline auto xiterator<St, S, L>::operator-=(difference_type n) -> self_type&
1165 {
1166 if (n >= 0)
1167 {
1168 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(), static_cast<size_type>(n));
1169 }
1170 else
1171 {
1172 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(), static_cast<size_type>(-n));
1173 }
1174 m_linear_index -= n;
1175 return *this;
1176 }
1177
1178 template <class St, class S, layout_type L>
1179 inline auto xiterator<St, S, L>::operator-(const self_type& rhs) const -> difference_type
1180 {
1181 return m_linear_index - rhs.m_linear_index;
1182 }
1183
1184 template <class St, class S, layout_type L>
1185 inline auto xiterator<St, S, L>::operator*() const -> reference
1186 {
1187 return *m_st;
1188 }
1189
1190 template <class St, class S, layout_type L>
1191 inline auto xiterator<St, S, L>::operator->() const -> pointer
1192 {
1193 return &(*m_st);
1194 }
1195
1196 template <class St, class S, layout_type L>
1197 inline bool xiterator<St, S, L>::equal(const xiterator& rhs) const
1198 {
1199 XTENSOR_ASSERT(this->shape() == rhs.shape());
1200 return m_linear_index == rhs.m_linear_index;
1201 }
1202
1203 template <class St, class S, layout_type L>
1204 inline bool xiterator<St, S, L>::less_than(const xiterator& rhs) const
1205 {
1206 XTENSOR_ASSERT(this->shape() == rhs.shape());
1207 return m_linear_index < rhs.m_linear_index;
1208 }
1209
1210 template <class St, class S, layout_type L>
1211 inline bool operator==(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs)
1212 {
1213 return lhs.equal(rhs);
1214 }
1215
1216 template <class St, class S, layout_type L>
1217 bool operator<(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs)
1218 {
1219 return lhs.less_than(rhs);
1220 }
1221
1222 /************************************
1223 * xbounded_iterator implementation *
1224 ************************************/
1225
1226 template <class It, class BIt>
1227 xbounded_iterator<It, BIt>::xbounded_iterator(It it, BIt bound_it)
1228 : m_it(it)
1229 , m_bound_it(bound_it)
1230 {
1231 }
1232
1233 template <class It, class BIt>
1234 inline auto xbounded_iterator<It, BIt>::operator++() -> self_type&
1235 {
1236 ++m_it;
1237 ++m_bound_it;
1238 return *this;
1239 }
1240
1241 template <class It, class BIt>
1242 inline auto xbounded_iterator<It, BIt>::operator--() -> self_type&
1243 {
1244 --m_it;
1245 --m_bound_it;
1246 return *this;
1247 }
1248
1249 template <class It, class BIt>
1250 inline auto xbounded_iterator<It, BIt>::operator+=(difference_type n) -> self_type&
1251 {
1252 m_it += n;
1253 m_bound_it += n;
1254 return *this;
1255 }
1256
1257 template <class It, class BIt>
1258 inline auto xbounded_iterator<It, BIt>::operator-=(difference_type n) -> self_type&
1259 {
1260 m_it -= n;
1261 m_bound_it -= n;
1262 return *this;
1263 }
1264
1265 template <class It, class BIt>
1266 inline auto xbounded_iterator<It, BIt>::operator-(const self_type& rhs) const -> difference_type
1267 {
1268 return m_it - rhs.m_it;
1269 }
1270
1271 template <class It, class BIt>
1272 inline auto xbounded_iterator<It, BIt>::operator*() const -> value_type
1273 {
1274 using type = decltype(*m_bound_it);
1275 return (static_cast<type>(*m_it) < *m_bound_it) ? *m_it : static_cast<value_type>((*m_bound_it) - 1);
1276 }
1277
1278 template <class It, class BIt>
1279 inline bool xbounded_iterator<It, BIt>::equal(const self_type& rhs) const
1280 {
1281 return m_it == rhs.m_it && m_bound_it == rhs.m_bound_it;
1282 }
1283
1284 template <class It, class BIt>
1285 inline bool xbounded_iterator<It, BIt>::less_than(const self_type& rhs) const
1286 {
1287 return m_it < rhs.m_it;
1288 }
1289
1290 template <class It, class BIt>
1291 inline bool operator==(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs)
1292 {
1293 return lhs.equal(rhs);
1294 }
1295
1296 template <class It, class BIt>
1297 inline bool operator<(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs)
1298 {
1299 return lhs.less_than(rhs);
1300 }
1301}
1302
1303#endif
standard mathematical functions for xexpressions
bool operator==(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks equality of the iterators.
layout_type
Definition xlayout.hpp:24