xtensor
 
Loading...
Searching...
No Matches
xiterator.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_ITERATOR_HPP
11#define XTENSOR_ITERATOR_HPP
12
13#include <algorithm>
14#include <array>
15#include <cstddef>
16#include <iterator>
17#include <numeric>
18
19#include <xtl/xcompare.hpp>
20#include <xtl/xiterator_base.hpp>
21#include <xtl/xmeta_utils.hpp>
22#include <xtl/xsequence.hpp>
23
24#include "../core/xlayout.hpp"
25#include "../core/xshape.hpp"
26#include "../utils/xexception.hpp"
27#include "../utils/xutils.hpp"
28
29namespace xt
30{
31
32 /***********************
33 * iterator meta utils *
34 ***********************/
35
36 template <class CT>
37 class xscalar;
38
39 template <bool is_const, class CT>
40 class xscalar_stepper;
41
42 namespace detail
43 {
44 template <class C>
45 struct get_stepper_iterator_impl
46 {
47 using type = typename C::container_iterator;
48 };
49
50 template <class C>
51 struct get_stepper_iterator_impl<const C>
52 {
53 using type = typename C::const_container_iterator;
54 };
55
56 template <class CT>
57 struct get_stepper_iterator_impl<xscalar<CT>>
58 {
59 using type = typename xscalar<CT>::dummy_iterator;
60 };
61
62 template <class CT>
63 struct get_stepper_iterator_impl<const xscalar<CT>>
64 {
65 using type = typename xscalar<CT>::const_dummy_iterator;
66 };
67 }
68
69 template <class C>
70 using get_stepper_iterator = typename detail::get_stepper_iterator_impl<C>::type;
71
72 /********************************
73 * xindex_type_t implementation *
74 ********************************/
75
76 namespace detail
77 {
78 template <class ST>
79 struct index_type_impl
80 {
81 using type = dynamic_shape<typename ST::value_type>;
82 };
83
84 template <class V, std::size_t L>
85 struct index_type_impl<std::array<V, L>>
86 {
87 using type = std::array<V, L>;
88 };
89
90 template <std::size_t... I>
91 struct index_type_impl<fixed_shape<I...>>
92 {
93 using type = std::array<std::size_t, sizeof...(I)>;
94 };
95 }
96
97 template <class C>
98 using xindex_type_t = typename detail::index_type_impl<C>::type;
99
100 /************
101 * xstepper *
102 ************/
103
104 template <class C>
105 class xstepper
106 {
107 public:
108
109 using storage_type = C;
110 using subiterator_type = get_stepper_iterator<C>;
111 using subiterator_traits = std::iterator_traits<subiterator_type>;
112 using value_type = typename subiterator_traits::value_type;
113 using reference = typename subiterator_traits::reference;
114 using pointer = typename subiterator_traits::pointer;
115 using difference_type = typename subiterator_traits::difference_type;
116 using size_type = typename storage_type::size_type;
117 using shape_type = typename storage_type::shape_type;
118 using simd_value_type = xt_simd::simd_type<value_type>;
119
120 template <class requested_type>
121 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
122
123 xstepper() = default;
124 xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept;
125
126 reference operator*() const;
127
128 void step(size_type dim, size_type n = 1);
129 void step_back(size_type dim, size_type n = 1);
130 void reset(size_type dim);
131 void reset_back(size_type dim);
132
133 void to_begin();
134 void to_end(layout_type l);
135
136 template <class T>
137 simd_return_type<T> step_simd();
138
139 void step_leading();
140
141 template <class R>
142 void store_simd(const R& vec);
143
144 private:
145
146 storage_type* p_c;
147 subiterator_type m_it;
148 size_type m_offset;
149 };
150
151 template <layout_type L>
153 {
154 // For performance reasons, increment_stepper and decrement_stepper are
155 // specialized for the case where n=1, which underlies operator++ and
156 // operator-- on xiterators.
157
158 template <class S, class IT, class ST>
159 static void increment_stepper(S& stepper, IT& index, const ST& shape);
160
161 template <class S, class IT, class ST>
162 static void decrement_stepper(S& stepper, IT& index, const ST& shape);
163
164 template <class S, class IT, class ST>
165 static void increment_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n);
166
167 template <class S, class IT, class ST>
168 static void decrement_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n);
169 };
170
171 /********************
172 * xindexed_stepper *
173 ********************/
174
175 template <class E, bool is_const>
176 class xindexed_stepper
177 {
178 public:
179
180 using self_type = xindexed_stepper<E, is_const>;
181 using xexpression_type = std::conditional_t<is_const, const E, E>;
182
183 using value_type = typename xexpression_type::value_type;
184 using reference = std::
185 conditional_t<is_const, typename xexpression_type::const_reference, typename xexpression_type::reference>;
186 using pointer = std::
187 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
188 using size_type = typename xexpression_type::size_type;
189 using difference_type = typename xexpression_type::difference_type;
190
191 using shape_type = typename xexpression_type::shape_type;
192 using index_type = xindex_type_t<shape_type>;
193
194 xindexed_stepper() = default;
195 xindexed_stepper(xexpression_type* e, size_type offset, bool end = false) noexcept;
196
197 reference operator*() const;
198
199 void step(size_type dim, size_type n = 1);
200 void step_back(size_type dim, size_type n = 1);
201 void reset(size_type dim);
202 void reset_back(size_type dim);
203
204 void to_begin();
205 void to_end(layout_type l);
206
207 private:
208
209 xexpression_type* p_e;
210 index_type m_index;
211 size_type m_offset;
212 };
213
214 template <class T>
216 {
217 static const bool value = false;
218 };
219
220 template <class T, bool B>
222 {
223 static const bool value = true;
224 };
225
226 template <class T, class R = T>
227 struct enable_indexed_stepper : std::enable_if<is_indexed_stepper<T>::value, R>
228 {
229 };
230
231 template <class T, class R = T>
232 using enable_indexed_stepper_t = typename enable_indexed_stepper<T, R>::type;
233
234 template <class T, class R = T>
235 struct disable_indexed_stepper : std::enable_if<!is_indexed_stepper<T>::value, R>
236 {
237 };
238
239 template <class T, class R = T>
240 using disable_indexed_stepper_t = typename disable_indexed_stepper<T, R>::type;
241
242 /*************
243 * xiterator *
244 *************/
245
246 namespace detail
247 {
248 template <class S>
249 class shape_storage
250 {
251 public:
252
253 using shape_type = S;
254 using param_type = const S&;
255
256 shape_storage() = default;
257 shape_storage(param_type shape);
258 const S& shape() const;
259
260 private:
261
262 S m_shape;
263 };
264
265 template <class S>
266 class shape_storage<S*>
267 {
268 public:
269
270 using shape_type = S;
271 using param_type = const S*;
272
273 shape_storage(param_type shape = 0);
274 const S& shape() const;
275
276 private:
277
278 const S* p_shape;
279 };
280
281 template <layout_type L>
282 struct LAYOUT_FORBIDEN_FOR_XITERATOR;
283 }
284
285 template <class St, class S, layout_type L>
286 class xiterator : public xtl::xrandom_access_iterator_base<
287 xiterator<St, S, L>,
288 typename St::value_type,
289 typename St::difference_type,
290 typename St::pointer,
291 typename St::reference>,
292 private detail::shape_storage<S>
293 {
294 public:
295
296 using self_type = xiterator<St, S, L>;
297
298 using stepper_type = St;
299 using value_type = typename stepper_type::value_type;
300 using reference = typename stepper_type::reference;
301 using pointer = typename stepper_type::pointer;
302 using difference_type = typename stepper_type::difference_type;
303 using size_type = typename stepper_type::size_type;
304 using iterator_category = std::random_access_iterator_tag;
305
306 using private_base = detail::shape_storage<S>;
307 using shape_type = typename private_base::shape_type;
308 using shape_param_type = typename private_base::param_type;
309 using index_type = xindex_type_t<shape_type>;
310
311 xiterator() = default;
312
313 // end_index means either reverse_iterator && !end or !reverse_iterator && end
314 xiterator(St st, shape_param_type shape, bool end_index);
315
316 self_type& operator++();
317 self_type& operator--();
318
319 self_type& operator+=(difference_type n);
320 self_type& operator-=(difference_type n);
321
322 difference_type operator-(const self_type& rhs) const;
323
324 reference operator*() const;
325 pointer operator->() const;
326
327 bool equal(const xiterator& rhs) const;
328 bool less_than(const xiterator& rhs) const;
329
330 private:
331
332 stepper_type m_st;
333 index_type m_index;
334 difference_type m_linear_index;
335
336 using checking_type = typename detail::LAYOUT_FORBIDEN_FOR_XITERATOR<L>::type;
337 };
338
339 template <class St, class S, layout_type L>
340 bool operator==(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs);
341
342 template <class St, class S, layout_type L>
343 bool operator<(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs);
344
345 template <class St, class S, layout_type L>
346 struct is_contiguous_container<xiterator<St, S, L>> : std::false_type
347 {
348 };
349
350 /*********************
351 * xbounded_iterator *
352 *********************/
353
354 template <class It, class BIt>
355 class xbounded_iterator : public xtl::xrandom_access_iterator_base<
356 xbounded_iterator<It, BIt>,
357 typename std::iterator_traits<It>::value_type,
358 typename std::iterator_traits<It>::difference_type,
359 typename std::iterator_traits<It>::pointer,
360 typename std::iterator_traits<It>::reference>
361 {
362 public:
363
364 using self_type = xbounded_iterator<It, BIt>;
365
366 using subiterator_type = It;
367 using bound_iterator_type = BIt;
368 using value_type = typename std::iterator_traits<It>::value_type;
369 using reference = typename std::iterator_traits<It>::reference;
370 using pointer = typename std::iterator_traits<It>::pointer;
371 using difference_type = typename std::iterator_traits<It>::difference_type;
372 using iterator_category = std::random_access_iterator_tag;
373
374 xbounded_iterator() = default;
375 xbounded_iterator(It it, BIt bound_it);
376
377 self_type& operator++();
378 self_type& operator--();
379
380 self_type& operator+=(difference_type n);
381 self_type& operator-=(difference_type n);
382
383 difference_type operator-(const self_type& rhs) const;
384
385 value_type operator*() const;
386
387 bool equal(const self_type& rhs) const;
388 bool less_than(const self_type& rhs) const;
389
390 private:
391
392 subiterator_type m_it;
393 bound_iterator_type m_bound_it;
394 };
395
396 template <class It, class BIt>
397 bool operator==(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs);
398
399 template <class It, class BIt>
400 bool operator<(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs);
401
402 /*****************************
403 * linear_begin / linear_end *
404 *****************************/
405
406 namespace detail
407 {
408 template <class C, class = void_t<>>
409 struct has_linear_iterator : std::false_type
410 {
411 };
412
413 template <class C>
414 struct has_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
415 {
416 };
417 }
418
419 template <class C>
420 constexpr auto linear_begin(C& c) noexcept
421 {
422 if constexpr (detail::has_linear_iterator<C>::value)
423 {
424 return c.linear_begin();
425 }
426 else
427 {
428 return c.begin();
429 }
430 }
431
432 template <class C>
433 constexpr auto linear_end(C& c) noexcept
434 {
435 if constexpr (detail::has_linear_iterator<C>::value)
436 {
437 return c.linear_end();
438 }
439 else
440 {
441 return c.end();
442 }
443 }
444
445 template <class C>
446 constexpr auto linear_begin(const C& c) noexcept
447 {
448 if constexpr (detail::has_linear_iterator<C>::value)
449 {
450 return c.linear_cbegin();
451 }
452 else
453 {
454 return c.cbegin();
455 }
456 }
457
458 template <class C>
459 constexpr auto linear_end(const C& c) noexcept
460 {
461 if constexpr (detail::has_linear_iterator<C>::value)
462 {
463 return c.linear_cend();
464 }
465 else
466 {
467 return c.cend();
468 }
469 }
470
471 /***************************
472 * xstepper implementation *
473 ***************************/
474
475 template <class C>
476 inline xstepper<C>::xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept
477 : p_c(c)
478 , m_it(it)
479 , m_offset(offset)
480 {
481 }
482
483 template <class C>
484 inline auto xstepper<C>::operator*() const -> reference
485 {
486 return *m_it;
487 }
488
489 template <class C>
490 inline void xstepper<C>::step(size_type dim, size_type n)
491 {
492 if (dim >= m_offset)
493 {
494 using strides_value_type = typename std::decay_t<decltype(p_c->strides())>::value_type;
495 m_it += difference_type(static_cast<strides_value_type>(n) * p_c->strides()[dim - m_offset]);
496 }
497 }
498
499 template <class C>
500 inline void xstepper<C>::step_back(size_type dim, size_type n)
501 {
502 if (dim >= m_offset)
503 {
504 using strides_value_type = typename std::decay_t<decltype(p_c->strides())>::value_type;
505 m_it -= difference_type(static_cast<strides_value_type>(n) * p_c->strides()[dim - m_offset]);
506 }
507 }
508
509 template <class C>
510 inline void xstepper<C>::reset(size_type dim)
511 {
512 if (dim >= m_offset)
513 {
514 m_it -= difference_type(p_c->backstrides()[dim - m_offset]);
515 }
516 }
517
518 template <class C>
519 inline void xstepper<C>::reset_back(size_type dim)
520 {
521 if (dim >= m_offset)
522 {
523 m_it += difference_type(p_c->backstrides()[dim - m_offset]);
524 }
525 }
526
527 template <class C>
528 inline void xstepper<C>::to_begin()
529 {
530 m_it = p_c->data_xbegin();
531 }
532
533 template <class C>
534 inline void xstepper<C>::to_end(layout_type l)
535 {
536 m_it = p_c->data_xend(l, m_offset);
537 }
538
539 namespace detail
540 {
541 template <class It>
542 struct step_simd_invoker
543 {
544 template <class R>
545 static R apply(const It& it)
546 {
547 R reg;
548 return reg.load_unaligned(&(*it));
549 // return reg;
550 }
551 };
552
553 template <bool is_const, class T, class S, layout_type L>
554 struct step_simd_invoker<xiterator<xscalar_stepper<is_const, T>, S, L>>
555 {
556 template <class R>
557 static R apply(const xiterator<xscalar_stepper<is_const, T>, S, L>& it)
558 {
559 return R(*it);
560 }
561 };
562 }
563
564 template <class C>
565 template <class T>
566 inline auto xstepper<C>::step_simd() -> simd_return_type<T>
567 {
568 using simd_type = simd_return_type<T>;
569 simd_type reg = detail::step_simd_invoker<subiterator_type>::template apply<simd_type>(m_it);
570 m_it += xt_simd::revert_simd_traits<simd_type>::size;
571 return reg;
572 }
573
574 template <class C>
575 template <class R>
576 inline void xstepper<C>::store_simd(const R& vec)
577 {
578 vec.store_unaligned(&(*m_it));
579 m_it += xt_simd::revert_simd_traits<R>::size;
580 ;
581 }
582
583 template <class C>
584 void xstepper<C>::step_leading()
585 {
586 ++m_it;
587 }
588
589 template <>
590 template <class S, class IT, class ST>
591 void stepper_tools<layout_type::row_major>::increment_stepper(S& stepper, IT& index, const ST& shape)
592 {
593 using size_type = typename S::size_type;
594 const size_type size = index.size();
595 size_type i = size;
596 while (i != 0)
597 {
598 --i;
599 if (index[i] != shape[i] - 1)
600 {
601 ++index[i];
602 stepper.step(i);
603 return;
604 }
605 else
606 {
607 index[i] = 0;
608 if (i != 0)
609 {
610 stepper.reset(i);
611 }
612 }
613 }
614 if (i == 0)
615 {
616 if (size != size_type(0))
617 {
618 std::transform(
619 shape.cbegin(),
620 shape.cend() - 1,
621 index.begin(),
622 [](const auto& v)
623 {
624 return v - 1;
625 }
626 );
627 index[size - 1] = shape[size - 1];
628 }
629 stepper.to_end(layout_type::row_major);
630 }
631 }
632
633 template <>
634 template <class S, class IT, class ST>
635 void stepper_tools<layout_type::row_major>::increment_stepper(
636 S& stepper,
637 IT& index,
638 const ST& shape,
639 typename S::size_type n
640 )
641 {
642 using size_type = typename S::size_type;
643 const size_type size = index.size();
644 const size_type leading_i = size - 1;
645 size_type i = size;
646 while (i != 0 && n != 0)
647 {
648 --i;
649 size_type inc = (i == leading_i) ? n : 1;
650 if (xtl::cmp_less(index[i] + inc, shape[i]))
651 {
652 index[i] += inc;
653 stepper.step(i, inc);
654 n -= inc;
655 if (i != leading_i || index.size() == 1)
656 {
657 i = index.size();
658 }
659 }
660 else
661 {
662 if (i == leading_i)
663 {
664 size_type off = shape[i] - index[i] - 1;
665 stepper.step(i, off);
666 n -= off;
667 }
668 index[i] = 0;
669 if (i != 0)
670 {
671 stepper.reset(i);
672 }
673 }
674 }
675 if (i == 0 && n != 0)
676 {
677 if (size != size_type(0))
678 {
679 std::transform(
680 shape.cbegin(),
681 shape.cend() - 1,
682 index.begin(),
683 [](const auto& v)
684 {
685 return v - 1;
686 }
687 );
688 index[leading_i] = shape[leading_i];
689 }
690 stepper.to_end(layout_type::row_major);
691 }
692 }
693
694 template <>
695 template <class S, class IT, class ST>
696 void stepper_tools<layout_type::row_major>::decrement_stepper(S& stepper, IT& index, const ST& shape)
697 {
698 using size_type = typename S::size_type;
699 size_type i = index.size();
700 while (i != 0)
701 {
702 --i;
703 if (index[i] != 0)
704 {
705 --index[i];
706 stepper.step_back(i);
707 return;
708 }
709 else
710 {
711 index[i] = shape[i] - 1;
712 if (i != 0)
713 {
714 stepper.reset_back(i);
715 }
716 }
717 }
718 if (i == 0)
719 {
720 stepper.to_begin();
721 }
722 }
723
724 template <>
725 template <class S, class IT, class ST>
726 void stepper_tools<layout_type::row_major>::decrement_stepper(
727 S& stepper,
728 IT& index,
729 const ST& shape,
730 typename S::size_type n
731 )
732 {
733 using size_type = typename S::size_type;
734 size_type i = index.size();
735 size_type leading_i = index.size() - 1;
736 while (i != 0 && n != 0)
737 {
738 --i;
739 size_type inc = (i == leading_i) ? n : 1;
740 if (xtl::cmp_greater_equal(index[i], inc))
741 {
742 index[i] -= inc;
743 stepper.step_back(i, inc);
744 n -= inc;
745 if (i != leading_i || index.size() == 1)
746 {
747 i = index.size();
748 }
749 }
750 else
751 {
752 if (i == leading_i)
753 {
754 size_type off = index[i];
755 stepper.step_back(i, off);
756 n -= off;
757 }
758 index[i] = shape[i] - 1;
759 if (i != 0)
760 {
761 stepper.reset_back(i);
762 }
763 }
764 }
765 if (i == 0 && n != 0)
766 {
767 stepper.to_begin();
768 }
769 }
770
771 template <>
772 template <class S, class IT, class ST>
773 void stepper_tools<layout_type::column_major>::increment_stepper(S& stepper, IT& index, const ST& shape)
774 {
775 using size_type = typename S::size_type;
776 const size_type size = index.size();
777 size_type i = 0;
778 while (i != size)
779 {
780 if (index[i] != shape[i] - 1)
781 {
782 ++index[i];
783 stepper.step(i);
784 return;
785 }
786 else
787 {
788 index[i] = 0;
789 if (i != size - 1)
790 {
791 stepper.reset(i);
792 }
793 }
794 ++i;
795 }
796 if (i == size)
797 {
798 if (size != size_type(0))
799 {
800 std::transform(
801 shape.cbegin() + 1,
802 shape.cend(),
803 index.begin() + 1,
804 [](const auto& v)
805 {
806 return v - 1;
807 }
808 );
809 index[0] = shape[0];
810 }
811 stepper.to_end(layout_type::column_major);
812 }
813 }
814
815 template <>
816 template <class S, class IT, class ST>
817 void stepper_tools<layout_type::column_major>::increment_stepper(
818 S& stepper,
819 IT& index,
820 const ST& shape,
821 typename S::size_type n
822 )
823 {
824 using size_type = typename S::size_type;
825 const size_type size = index.size();
826 const size_type leading_i = 0;
827 size_type i = 0;
828 while (i != size && n != 0)
829 {
830 size_type inc = (i == leading_i) ? n : 1;
831 if (index[i] + inc < shape[i])
832 {
833 index[i] += inc;
834 stepper.step(i, inc);
835 n -= inc;
836 if (i != leading_i || size == 1)
837 {
838 i = 0;
839 continue;
840 }
841 }
842 else
843 {
844 if (i == leading_i)
845 {
846 size_type off = shape[i] - index[i] - 1;
847 stepper.step(i, off);
848 n -= off;
849 }
850 index[i] = 0;
851 if (i != size - 1)
852 {
853 stepper.reset(i);
854 }
855 }
856 ++i;
857 }
858 if (i == size && n != 0)
859 {
860 if (size != size_type(0))
861 {
862 std::transform(
863 shape.cbegin() + 1,
864 shape.cend(),
865 index.begin() + 1,
866 [](const auto& v)
867 {
868 return v - 1;
869 }
870 );
871 index[leading_i] = shape[leading_i];
872 }
873 stepper.to_end(layout_type::column_major);
874 }
875 }
876
877 template <>
878 template <class S, class IT, class ST>
879 void stepper_tools<layout_type::column_major>::decrement_stepper(S& stepper, IT& index, const ST& shape)
880 {
881 using size_type = typename S::size_type;
882 size_type size = index.size();
883 size_type i = 0;
884 while (i != size)
885 {
886 if (index[i] != 0)
887 {
888 --index[i];
889 stepper.step_back(i);
890 return;
891 }
892 else
893 {
894 index[i] = shape[i] - 1;
895 if (i != size - 1)
896 {
897 stepper.reset_back(i);
898 }
899 }
900 ++i;
901 }
902 if (i == size)
903 {
904 stepper.to_begin();
905 }
906 }
907
908 template <>
909 template <class S, class IT, class ST>
910 void stepper_tools<layout_type::column_major>::decrement_stepper(
911 S& stepper,
912 IT& index,
913 const ST& shape,
914 typename S::size_type n
915 )
916 {
917 using size_type = typename S::size_type;
918 size_type size = index.size();
919 size_type i = 0;
920 size_type leading_i = 0;
921 while (i != size && n != 0)
922 {
923 size_type inc = (i == leading_i) ? n : 1;
924 if (index[i] >= inc)
925 {
926 index[i] -= inc;
927 stepper.step_back(i, inc);
928 n -= inc;
929 if (i != leading_i || index.size() == 1)
930 {
931 i = 0;
932 continue;
933 }
934 }
935 else
936 {
937 if (i == leading_i)
938 {
939 size_type off = index[i];
940 stepper.step_back(i, off);
941 n -= off;
942 }
943 index[i] = shape[i] - 1;
944 if (i != size - 1)
945 {
946 stepper.reset_back(i);
947 }
948 }
949 ++i;
950 }
951 if (i == size && n != 0)
952 {
953 stepper.to_begin();
954 }
955 }
956
957 /***********************************
958 * xindexed_stepper implementation *
959 ***********************************/
960
961 template <class C, bool is_const>
962 inline xindexed_stepper<C, is_const>::xindexed_stepper(xexpression_type* e, size_type offset, bool end) noexcept
963 : p_e(e)
964 , m_index(xtl::make_sequence<index_type>(e->shape().size(), size_type(0)))
965 , m_offset(offset)
966 {
967 if (end)
968 {
969 // Note: the layout here doesn't matter (unused) but using default traversal looks more "correct".
970 to_end(XTENSOR_DEFAULT_TRAVERSAL);
971 }
972 }
973
974 template <class C, bool is_const>
975 inline auto xindexed_stepper<C, is_const>::operator*() const -> reference
976 {
977 return p_e->element(m_index.cbegin(), m_index.cend());
978 }
979
980 template <class C, bool is_const>
981 inline void xindexed_stepper<C, is_const>::step(size_type dim, size_type n)
982 {
983 if (dim >= m_offset)
984 {
985 m_index[dim - m_offset] += static_cast<typename index_type::value_type>(n);
986 }
987 }
988
989 template <class C, bool is_const>
990 inline void xindexed_stepper<C, is_const>::step_back(size_type dim, size_type n)
991 {
992 if (dim >= m_offset)
993 {
994 m_index[dim - m_offset] -= static_cast<typename index_type::value_type>(n);
995 }
996 }
997
998 template <class C, bool is_const>
999 inline void xindexed_stepper<C, is_const>::reset(size_type dim)
1000 {
1001 if (dim >= m_offset)
1002 {
1003 m_index[dim - m_offset] = 0;
1004 }
1005 }
1006
1007 template <class C, bool is_const>
1008 inline void xindexed_stepper<C, is_const>::reset_back(size_type dim)
1009 {
1010 if (dim >= m_offset)
1011 {
1012 m_index[dim - m_offset] = p_e->shape()[dim - m_offset] - 1;
1013 }
1014 }
1015
1016 template <class C, bool is_const>
1017 inline void xindexed_stepper<C, is_const>::to_begin()
1018 {
1019 std::fill(m_index.begin(), m_index.end(), size_type(0));
1020 }
1021
1022 template <class C, bool is_const>
1023 inline void xindexed_stepper<C, is_const>::to_end(layout_type l)
1024 {
1025 const auto& shape = p_e->shape();
1026 std::transform(
1027 shape.cbegin(),
1028 shape.cend(),
1029 m_index.begin(),
1030 [](const auto& v)
1031 {
1032 return v - 1;
1033 }
1034 );
1035
1036 size_type l_dim = (l == layout_type::row_major) ? shape.size() - 1 : 0;
1037 m_index[l_dim] = shape[l_dim];
1038 }
1039
1040 /****************************
1041 * xiterator implementation *
1042 ****************************/
1043
1044 namespace detail
1045 {
1046 template <class S>
1047 inline shape_storage<S>::shape_storage(param_type shape)
1048 : m_shape(shape)
1049 {
1050 }
1051
1052 template <class S>
1053 inline const S& shape_storage<S>::shape() const
1054 {
1055 return m_shape;
1056 }
1057
1058 template <class S>
1059 inline shape_storage<S*>::shape_storage(param_type shape)
1060 : p_shape(shape)
1061 {
1062 }
1063
1064 template <class S>
1065 inline const S& shape_storage<S*>::shape() const
1066 {
1067 return *p_shape;
1068 }
1069
1070 template <>
1071 struct LAYOUT_FORBIDEN_FOR_XITERATOR<layout_type::row_major>
1072 {
1073 using type = int;
1074 };
1075
1076 template <>
1077 struct LAYOUT_FORBIDEN_FOR_XITERATOR<layout_type::column_major>
1078 {
1079 using type = int;
1080 };
1081 }
1082
1083 template <class St, class S, layout_type L>
1084 inline xiterator<St, S, L>::xiterator(St st, shape_param_type shape, bool end_index)
1085 : private_base(shape)
1086 , m_st(st)
1087 , m_index(
1088 end_index ? xtl::forward_sequence<index_type, const shape_type&>(this->shape())
1089 : xtl::make_sequence<index_type>(this->shape().size(), size_type(0))
1090 )
1091 , m_linear_index(0)
1092 {
1093 // end_index means either reverse_iterator && !end or !reverse_iterator && end
1094 if (end_index)
1095 {
1096 if (m_index.size() != size_type(0))
1097 {
1098 auto iter_begin = (L == layout_type::row_major) ? m_index.begin() : m_index.begin() + 1;
1099 auto iter_end = (L == layout_type::row_major) ? m_index.end() - 1 : m_index.end();
1100 std::transform(
1101 iter_begin,
1102 iter_end,
1103 iter_begin,
1104 [](const auto& v)
1105 {
1106 return v - 1;
1107 }
1108 );
1109 }
1110 m_linear_index = difference_type(std::accumulate(
1111 this->shape().cbegin(),
1112 this->shape().cend(),
1113 size_type(1),
1114 std::multiplies<size_type>()
1115 ));
1116 }
1117 }
1118
1119 template <class St, class S, layout_type L>
1120 inline auto xiterator<St, S, L>::operator++() -> self_type&
1121 {
1122 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape());
1123 ++m_linear_index;
1124 return *this;
1125 }
1126
1127 template <class St, class S, layout_type L>
1128 inline auto xiterator<St, S, L>::operator--() -> self_type&
1129 {
1130 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape());
1131 --m_linear_index;
1132 return *this;
1133 }
1134
1135 template <class St, class S, layout_type L>
1136 inline auto xiterator<St, S, L>::operator+=(difference_type n) -> self_type&
1137 {
1138 if (n >= 0)
1139 {
1140 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(), static_cast<size_type>(n));
1141 }
1142 else
1143 {
1144 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(), static_cast<size_type>(-n));
1145 }
1146 m_linear_index += n;
1147 return *this;
1148 }
1149
1150 template <class St, class S, layout_type L>
1151 inline auto xiterator<St, S, L>::operator-=(difference_type n) -> self_type&
1152 {
1153 if (n >= 0)
1154 {
1155 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(), static_cast<size_type>(n));
1156 }
1157 else
1158 {
1159 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(), static_cast<size_type>(-n));
1160 }
1161 m_linear_index -= n;
1162 return *this;
1163 }
1164
1165 template <class St, class S, layout_type L>
1166 inline auto xiterator<St, S, L>::operator-(const self_type& rhs) const -> difference_type
1167 {
1168 return m_linear_index - rhs.m_linear_index;
1169 }
1170
1171 template <class St, class S, layout_type L>
1172 inline auto xiterator<St, S, L>::operator*() const -> reference
1173 {
1174 return *m_st;
1175 }
1176
1177 template <class St, class S, layout_type L>
1178 inline auto xiterator<St, S, L>::operator->() const -> pointer
1179 {
1180 return &(*m_st);
1181 }
1182
1183 template <class St, class S, layout_type L>
1184 inline bool xiterator<St, S, L>::equal(const xiterator& rhs) const
1185 {
1186 XTENSOR_ASSERT(this->shape() == rhs.shape());
1187 return m_linear_index == rhs.m_linear_index;
1188 }
1189
1190 template <class St, class S, layout_type L>
1191 inline bool xiterator<St, S, L>::less_than(const xiterator& rhs) const
1192 {
1193 XTENSOR_ASSERT(this->shape() == rhs.shape());
1194 return m_linear_index < rhs.m_linear_index;
1195 }
1196
1197 template <class St, class S, layout_type L>
1198 inline bool operator==(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs)
1199 {
1200 return lhs.equal(rhs);
1201 }
1202
1203 template <class St, class S, layout_type L>
1204 bool operator<(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs)
1205 {
1206 return lhs.less_than(rhs);
1207 }
1208
1209 /************************************
1210 * xbounded_iterator implementation *
1211 ************************************/
1212
1213 template <class It, class BIt>
1214 xbounded_iterator<It, BIt>::xbounded_iterator(It it, BIt bound_it)
1215 : m_it(it)
1216 , m_bound_it(bound_it)
1217 {
1218 }
1219
1220 template <class It, class BIt>
1221 inline auto xbounded_iterator<It, BIt>::operator++() -> self_type&
1222 {
1223 ++m_it;
1224 ++m_bound_it;
1225 return *this;
1226 }
1227
1228 template <class It, class BIt>
1229 inline auto xbounded_iterator<It, BIt>::operator--() -> self_type&
1230 {
1231 --m_it;
1232 --m_bound_it;
1233 return *this;
1234 }
1235
1236 template <class It, class BIt>
1237 inline auto xbounded_iterator<It, BIt>::operator+=(difference_type n) -> self_type&
1238 {
1239 m_it += n;
1240 m_bound_it += n;
1241 return *this;
1242 }
1243
1244 template <class It, class BIt>
1245 inline auto xbounded_iterator<It, BIt>::operator-=(difference_type n) -> self_type&
1246 {
1247 m_it -= n;
1248 m_bound_it -= n;
1249 return *this;
1250 }
1251
1252 template <class It, class BIt>
1253 inline auto xbounded_iterator<It, BIt>::operator-(const self_type& rhs) const -> difference_type
1254 {
1255 return m_it - rhs.m_it;
1256 }
1257
1258 template <class It, class BIt>
1259 inline auto xbounded_iterator<It, BIt>::operator*() const -> value_type
1260 {
1261 using type = decltype(*m_bound_it);
1262 return (static_cast<type>(*m_it) < *m_bound_it) ? *m_it : static_cast<value_type>((*m_bound_it) - 1);
1263 }
1264
1265 template <class It, class BIt>
1266 inline bool xbounded_iterator<It, BIt>::equal(const self_type& rhs) const
1267 {
1268 return m_it == rhs.m_it && m_bound_it == rhs.m_bound_it;
1269 }
1270
1271 template <class It, class BIt>
1272 inline bool xbounded_iterator<It, BIt>::less_than(const self_type& rhs) const
1273 {
1274 return m_it < rhs.m_it;
1275 }
1276
1277 template <class It, class BIt>
1278 inline bool operator==(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs)
1279 {
1280 return lhs.equal(rhs);
1281 }
1282
1283 template <class It, class BIt>
1284 inline bool operator<(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs)
1285 {
1286 return lhs.less_than(rhs);
1287 }
1288}
1289
1290#endif
standard mathematical functions for xexpressions
layout_type
Definition xlayout.hpp:24