xtensor
 
Loading...
Searching...
No Matches
xiterator.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_ITERATOR_HPP
11#define XTENSOR_ITERATOR_HPP
12
13#include <algorithm>
14#include <array>
15#include <cstddef>
16#include <iterator>
17#include <numeric>
18#include <vector>
19
20#include <xtl/xcompare.hpp>
21#include <xtl/xiterator_base.hpp>
22#include <xtl/xmeta_utils.hpp>
23#include <xtl/xsequence.hpp>
24
25#include "../core/xlayout.hpp"
26#include "../core/xshape.hpp"
27#include "../utils/xexception.hpp"
28#include "../utils/xutils.hpp"
29
30namespace xt
31{
32
33 /***********************
34 * iterator meta utils *
35 ***********************/
36
37 template <class CT>
38 class xscalar;
39
40 template <bool is_const, class CT>
41 class xscalar_stepper;
42
43 namespace detail
44 {
45 template <class C>
46 struct get_stepper_iterator_impl
47 {
48 using type = typename C::container_iterator;
49 };
50
51 template <class C>
52 struct get_stepper_iterator_impl<const C>
53 {
54 using type = typename C::const_container_iterator;
55 };
56
57 template <class CT>
58 struct get_stepper_iterator_impl<xscalar<CT>>
59 {
60 using type = typename xscalar<CT>::dummy_iterator;
61 };
62
63 template <class CT>
64 struct get_stepper_iterator_impl<const xscalar<CT>>
65 {
66 using type = typename xscalar<CT>::const_dummy_iterator;
67 };
68 }
69
70 template <class C>
71 using get_stepper_iterator = typename detail::get_stepper_iterator_impl<C>::type;
72
73 /********************************
74 * xindex_type_t implementation *
75 ********************************/
76
77 namespace detail
78 {
79 template <class ST>
80 struct index_type_impl
81 {
82 using type = dynamic_shape<typename ST::value_type>;
83 };
84
85 template <class V, std::size_t L>
86 struct index_type_impl<std::array<V, L>>
87 {
88 using type = std::array<V, L>;
89 };
90
91 template <std::size_t... I>
92 struct index_type_impl<fixed_shape<I...>>
93 {
94 using type = std::array<std::size_t, sizeof...(I)>;
95 };
96 }
97
98 template <class C>
99 using xindex_type_t = typename detail::index_type_impl<C>::type;
100
101 /************
102 * xstepper *
103 ************/
104
105 template <class C>
106 class xstepper
107 {
108 public:
109
110 using storage_type = C;
111 using subiterator_type = get_stepper_iterator<C>;
112 using subiterator_traits = std::iterator_traits<subiterator_type>;
113 using value_type = typename subiterator_traits::value_type;
114 using reference = typename subiterator_traits::reference;
115 using pointer = typename subiterator_traits::pointer;
116 using difference_type = typename subiterator_traits::difference_type;
117 using size_type = typename storage_type::size_type;
118 using shape_type = typename storage_type::shape_type;
119 using simd_value_type = xt_simd::simd_type<value_type>;
120
121 template <class requested_type>
122 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
123
124 xstepper() = default;
125 xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept;
126
127 reference operator*() const;
128
129 void step(size_type dim, size_type n = 1);
130 void step_back(size_type dim, size_type n = 1);
131 void reset(size_type dim);
132 void reset_back(size_type dim);
133
134 void to_begin();
135 void to_end(layout_type l);
136
137 template <class T>
138 simd_return_type<T> step_simd();
139
140 void step_leading();
141
142 template <class R>
143 void store_simd(const R& vec);
144
145 private:
146
147 storage_type* p_c;
148 subiterator_type m_it;
149 size_type m_offset;
150 };
151
152 template <layout_type L>
154 {
155 // For performance reasons, increment_stepper and decrement_stepper are
156 // specialized for the case where n=1, which underlies operator++ and
157 // operator-- on xiterators.
158
159 template <class S, class IT, class ST>
160 static void increment_stepper(S& stepper, IT& index, const ST& shape);
161
162 template <class S, class IT, class ST>
163 static void decrement_stepper(S& stepper, IT& index, const ST& shape);
164
165 template <class S, class IT, class ST>
166 static void increment_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n);
167
168 template <class S, class IT, class ST>
169 static void decrement_stepper(S& stepper, IT& index, const ST& shape, typename S::size_type n);
170 };
171
172 /********************
173 * xindexed_stepper *
174 ********************/
175
176 template <class E, bool is_const>
177 class xindexed_stepper
178 {
179 public:
180
181 using self_type = xindexed_stepper<E, is_const>;
182 using xexpression_type = std::conditional_t<is_const, const E, E>;
183
184 using value_type = typename xexpression_type::value_type;
185 using reference = std::
186 conditional_t<is_const, typename xexpression_type::const_reference, typename xexpression_type::reference>;
187 using pointer = std::
188 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
189 using size_type = typename xexpression_type::size_type;
190 using difference_type = typename xexpression_type::difference_type;
191
192 using shape_type = typename xexpression_type::shape_type;
193 using index_type = xindex_type_t<shape_type>;
194
195 xindexed_stepper() = default;
196 xindexed_stepper(xexpression_type* e, size_type offset, bool end = false) noexcept;
197
198 reference operator*() const;
199
200 void step(size_type dim, size_type n = 1);
201 void step_back(size_type dim, size_type n = 1);
202 void reset(size_type dim);
203 void reset_back(size_type dim);
204
205 void to_begin();
206 void to_end(layout_type l);
207
208 private:
209
210 xexpression_type* p_e;
211 index_type m_index;
212 size_type m_offset;
213 };
214
215 template <class T>
217 {
218 static const bool value = false;
219 };
220
221 template <class T, bool B>
223 {
224 static const bool value = true;
225 };
226
227 template <class T, class R = T>
228 struct enable_indexed_stepper : std::enable_if<is_indexed_stepper<T>::value, R>
229 {
230 };
231
232 template <class T, class R = T>
233 using enable_indexed_stepper_t = typename enable_indexed_stepper<T, R>::type;
234
235 template <class T, class R = T>
236 struct disable_indexed_stepper : std::enable_if<!is_indexed_stepper<T>::value, R>
237 {
238 };
239
240 template <class T, class R = T>
241 using disable_indexed_stepper_t = typename disable_indexed_stepper<T, R>::type;
242
243 /*************
244 * xiterator *
245 *************/
246
247 namespace detail
248 {
249 template <class S>
250 class shape_storage
251 {
252 public:
253
254 using shape_type = S;
255 using param_type = const S&;
256
257 shape_storage() = default;
258 shape_storage(param_type shape);
259 const S& shape() const;
260
261 private:
262
263 S m_shape;
264 };
265
266 template <class S>
267 class shape_storage<S*>
268 {
269 public:
270
271 using shape_type = S;
272 using param_type = const S*;
273
274 shape_storage(param_type shape = 0);
275 const S& shape() const;
276
277 private:
278
279 const S* p_shape;
280 };
281
282 template <layout_type L>
283 struct LAYOUT_FORBIDEN_FOR_XITERATOR;
284 }
285
286 template <class St, class S, layout_type L>
287 class xiterator : public xtl::xrandom_access_iterator_base<
288 xiterator<St, S, L>,
289 typename St::value_type,
290 typename St::difference_type,
291 typename St::pointer,
292 typename St::reference>,
293 private detail::shape_storage<S>
294 {
295 public:
296
297 using self_type = xiterator<St, S, L>;
298
299 using stepper_type = St;
300 using value_type = typename stepper_type::value_type;
301 using reference = typename stepper_type::reference;
302 using pointer = typename stepper_type::pointer;
303 using difference_type = typename stepper_type::difference_type;
304 using size_type = typename stepper_type::size_type;
305 using iterator_category = std::random_access_iterator_tag;
306
307 using private_base = detail::shape_storage<S>;
308 using shape_type = typename private_base::shape_type;
309 using shape_param_type = typename private_base::param_type;
310 using index_type = xindex_type_t<shape_type>;
311
312 xiterator() = default;
313
314 // end_index means either reverse_iterator && !end or !reverse_iterator && end
315 xiterator(St st, shape_param_type shape, bool end_index);
316
317 self_type& operator++();
318 self_type& operator--();
319
320 self_type& operator+=(difference_type n);
321 self_type& operator-=(difference_type n);
322
323 difference_type operator-(const self_type& rhs) const;
324
325 reference operator*() const;
326 pointer operator->() const;
327
328 bool equal(const xiterator& rhs) const;
329 bool less_than(const xiterator& rhs) const;
330
331 private:
332
333 stepper_type m_st;
334 index_type m_index;
335 difference_type m_linear_index;
336
337 using checking_type = typename detail::LAYOUT_FORBIDEN_FOR_XITERATOR<L>::type;
338 };
339
340 template <class St, class S, layout_type L>
341 bool operator==(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs);
342
343 template <class St, class S, layout_type L>
344 bool operator<(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs);
345
346 template <class St, class S, layout_type L>
347 struct is_contiguous_container<xiterator<St, S, L>> : std::false_type
348 {
349 };
350
351 /*********************
352 * xbounded_iterator *
353 *********************/
354
355 template <class It, class BIt>
356 class xbounded_iterator : public xtl::xrandom_access_iterator_base<
357 xbounded_iterator<It, BIt>,
358 typename std::iterator_traits<It>::value_type,
359 typename std::iterator_traits<It>::difference_type,
360 typename std::iterator_traits<It>::pointer,
361 typename std::iterator_traits<It>::reference>
362 {
363 public:
364
365 using self_type = xbounded_iterator<It, BIt>;
366
367 using subiterator_type = It;
368 using bound_iterator_type = BIt;
369 using value_type = typename std::iterator_traits<It>::value_type;
370 using reference = typename std::iterator_traits<It>::reference;
371 using pointer = typename std::iterator_traits<It>::pointer;
372 using difference_type = typename std::iterator_traits<It>::difference_type;
373 using iterator_category = std::random_access_iterator_tag;
374
375 xbounded_iterator() = default;
376 xbounded_iterator(It it, BIt bound_it);
377
378 self_type& operator++();
379 self_type& operator--();
380
381 self_type& operator+=(difference_type n);
382 self_type& operator-=(difference_type n);
383
384 difference_type operator-(const self_type& rhs) const;
385
386 value_type operator*() const;
387
388 bool equal(const self_type& rhs) const;
389 bool less_than(const self_type& rhs) const;
390
391 private:
392
393 subiterator_type m_it;
394 bound_iterator_type m_bound_it;
395 };
396
397 template <class It, class BIt>
398 bool operator==(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs);
399
400 template <class It, class BIt>
401 bool operator<(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs);
402
403 /*****************************
404 * linear_begin / linear_end *
405 *****************************/
406
407 namespace detail
408 {
409 template <class C, class = void_t<>>
410 struct has_linear_iterator : std::false_type
411 {
412 };
413
414 template <class C>
415 struct has_linear_iterator<C, void_t<decltype(std::declval<C>().linear_cbegin())>> : std::true_type
416 {
417 };
418 }
419
420 template <class C>
421 XTENSOR_CONSTEXPR_RETURN auto linear_begin(C& c) noexcept
422 {
423 if constexpr (detail::has_linear_iterator<C>::value)
424 {
425 return c.linear_begin();
426 }
427 else
428 {
429 return c.begin();
430 }
431 }
432
433 template <class C>
434 XTENSOR_CONSTEXPR_RETURN auto linear_end(C& c) noexcept
435 {
436 if constexpr (detail::has_linear_iterator<C>::value)
437 {
438 return c.linear_end();
439 }
440 else
441 {
442 return c.end();
443 }
444 }
445
446 template <class C>
447 XTENSOR_CONSTEXPR_RETURN auto linear_begin(const C& c) noexcept
448 {
449 if constexpr (detail::has_linear_iterator<C>::value)
450 {
451 return c.linear_cbegin();
452 }
453 else
454 {
455 return c.cbegin();
456 }
457 }
458
459 template <class C>
460 XTENSOR_CONSTEXPR_RETURN auto linear_end(const C& c) noexcept
461 {
462 if constexpr (detail::has_linear_iterator<C>::value)
463 {
464 return c.linear_cend();
465 }
466 else
467 {
468 return c.cend();
469 }
470 }
471
472 /***************************
473 * xstepper implementation *
474 ***************************/
475
476 template <class C>
477 inline xstepper<C>::xstepper(storage_type* c, subiterator_type it, size_type offset) noexcept
478 : p_c(c)
479 , m_it(it)
480 , m_offset(offset)
481 {
482 }
483
484 template <class C>
485 inline auto xstepper<C>::operator*() const -> reference
486 {
487 return *m_it;
488 }
489
490 template <class C>
491 inline void xstepper<C>::step(size_type dim, size_type n)
492 {
493 if (dim >= m_offset)
494 {
495 using strides_value_type = typename std::decay_t<decltype(p_c->strides())>::value_type;
496 m_it += difference_type(static_cast<strides_value_type>(n) * p_c->strides()[dim - m_offset]);
497 }
498 }
499
500 template <class C>
501 inline void xstepper<C>::step_back(size_type dim, size_type n)
502 {
503 if (dim >= m_offset)
504 {
505 using strides_value_type = typename std::decay_t<decltype(p_c->strides())>::value_type;
506 m_it -= difference_type(static_cast<strides_value_type>(n) * p_c->strides()[dim - m_offset]);
507 }
508 }
509
510 template <class C>
511 inline void xstepper<C>::reset(size_type dim)
512 {
513 if (dim >= m_offset)
514 {
515 m_it -= difference_type(p_c->backstrides()[dim - m_offset]);
516 }
517 }
518
519 template <class C>
520 inline void xstepper<C>::reset_back(size_type dim)
521 {
522 if (dim >= m_offset)
523 {
524 m_it += difference_type(p_c->backstrides()[dim - m_offset]);
525 }
526 }
527
528 template <class C>
529 inline void xstepper<C>::to_begin()
530 {
531 m_it = p_c->data_xbegin();
532 }
533
534 template <class C>
535 inline void xstepper<C>::to_end(layout_type l)
536 {
537 m_it = p_c->data_xend(l, m_offset);
538 }
539
540 namespace detail
541 {
542 template <class It>
543 struct step_simd_invoker
544 {
545 template <class R>
546 static R apply(const It& it)
547 {
548 R reg;
549 return reg.load_unaligned(&(*it));
550 // return reg;
551 }
552 };
553
554 template <bool is_const, class T, class S, layout_type L>
555 struct step_simd_invoker<xiterator<xscalar_stepper<is_const, T>, S, L>>
556 {
557 template <class R>
558 static R apply(const xiterator<xscalar_stepper<is_const, T>, S, L>& it)
559 {
560 return R(*it);
561 }
562 };
563 }
564
565 template <class C>
566 template <class T>
567 inline auto xstepper<C>::step_simd() -> simd_return_type<T>
568 {
569 using simd_type = simd_return_type<T>;
570 simd_type reg = detail::step_simd_invoker<subiterator_type>::template apply<simd_type>(m_it);
571 m_it += xt_simd::revert_simd_traits<simd_type>::size;
572 return reg;
573 }
574
575 template <class C>
576 template <class R>
577 inline void xstepper<C>::store_simd(const R& vec)
578 {
579 vec.store_unaligned(&(*m_it));
580 m_it += xt_simd::revert_simd_traits<R>::size;
581 ;
582 }
583
584 template <class C>
585 void xstepper<C>::step_leading()
586 {
587 ++m_it;
588 }
589
590 template <>
591 template <class S, class IT, class ST>
592 void stepper_tools<layout_type::row_major>::increment_stepper(S& stepper, IT& index, const ST& shape)
593 {
594 using size_type = typename S::size_type;
595 const size_type size = index.size();
596 size_type i = size;
597 while (i != 0)
598 {
599 --i;
600 if (index[i] != shape[i] - 1)
601 {
602 ++index[i];
603 stepper.step(i);
604 return;
605 }
606 else
607 {
608 index[i] = 0;
609 if (i != 0)
610 {
611 stepper.reset(i);
612 }
613 }
614 }
615 if (i == 0)
616 {
617 if (size != size_type(0))
618 {
619 std::transform(
620 shape.cbegin(),
621 shape.cend() - 1,
622 index.begin(),
623 [](const auto& v)
624 {
625 return v - 1;
626 }
627 );
628 index[size - 1] = shape[size - 1];
629 }
630 stepper.to_end(layout_type::row_major);
631 }
632 }
633
634 template <>
635 template <class S, class IT, class ST>
636 void stepper_tools<layout_type::row_major>::increment_stepper(
637 S& stepper,
638 IT& index,
639 const ST& shape,
640 typename S::size_type n
641 )
642 {
643 using size_type = typename S::size_type;
644 const size_type size = index.size();
645 const size_type leading_i = size - 1;
646 size_type i = size;
647 while (i != 0 && n != 0)
648 {
649 --i;
650 size_type inc = (i == leading_i) ? n : 1;
651 if (xtl::cmp_less(index[i] + inc, shape[i]))
652 {
653 index[i] += inc;
654 stepper.step(i, inc);
655 n -= inc;
656 if (i != leading_i || index.size() == 1)
657 {
658 i = index.size();
659 }
660 }
661 else
662 {
663 if (i == leading_i)
664 {
665 size_type off = shape[i] - index[i] - 1;
666 stepper.step(i, off);
667 n -= off;
668 }
669 index[i] = 0;
670 if (i != 0)
671 {
672 stepper.reset(i);
673 }
674 }
675 }
676 if (i == 0 && n != 0)
677 {
678 if (size != size_type(0))
679 {
680 std::transform(
681 shape.cbegin(),
682 shape.cend() - 1,
683 index.begin(),
684 [](const auto& v)
685 {
686 return v - 1;
687 }
688 );
689 index[leading_i] = shape[leading_i];
690 }
691 stepper.to_end(layout_type::row_major);
692 }
693 }
694
695 template <>
696 template <class S, class IT, class ST>
697 void stepper_tools<layout_type::row_major>::decrement_stepper(S& stepper, IT& index, const ST& shape)
698 {
699 using size_type = typename S::size_type;
700 size_type i = index.size();
701 while (i != 0)
702 {
703 --i;
704 if (index[i] != 0)
705 {
706 --index[i];
707 stepper.step_back(i);
708 return;
709 }
710 else
711 {
712 index[i] = shape[i] - 1;
713 if (i != 0)
714 {
715 stepper.reset_back(i);
716 }
717 }
718 }
719 if (i == 0)
720 {
721 stepper.to_begin();
722 }
723 }
724
725 template <>
726 template <class S, class IT, class ST>
727 void stepper_tools<layout_type::row_major>::decrement_stepper(
728 S& stepper,
729 IT& index,
730 const ST& shape,
731 typename S::size_type n
732 )
733 {
734 using size_type = typename S::size_type;
735 size_type i = index.size();
736 size_type leading_i = index.size() - 1;
737 while (i != 0 && n != 0)
738 {
739 --i;
740 size_type inc = (i == leading_i) ? n : 1;
741 if (xtl::cmp_greater_equal(index[i], inc))
742 {
743 index[i] -= inc;
744 stepper.step_back(i, inc);
745 n -= inc;
746 if (i != leading_i || index.size() == 1)
747 {
748 i = index.size();
749 }
750 }
751 else
752 {
753 if (i == leading_i)
754 {
755 size_type off = index[i];
756 stepper.step_back(i, off);
757 n -= off;
758 }
759 index[i] = shape[i] - 1;
760 if (i != 0)
761 {
762 stepper.reset_back(i);
763 }
764 }
765 }
766 if (i == 0 && n != 0)
767 {
768 stepper.to_begin();
769 }
770 }
771
772 template <>
773 template <class S, class IT, class ST>
774 void stepper_tools<layout_type::column_major>::increment_stepper(S& stepper, IT& index, const ST& shape)
775 {
776 using size_type = typename S::size_type;
777 const size_type size = index.size();
778 size_type i = 0;
779 while (i != size)
780 {
781 if (index[i] != shape[i] - 1)
782 {
783 ++index[i];
784 stepper.step(i);
785 return;
786 }
787 else
788 {
789 index[i] = 0;
790 if (i != size - 1)
791 {
792 stepper.reset(i);
793 }
794 }
795 ++i;
796 }
797 if (i == size)
798 {
799 if (size != size_type(0))
800 {
801 std::transform(
802 shape.cbegin() + 1,
803 shape.cend(),
804 index.begin() + 1,
805 [](const auto& v)
806 {
807 return v - 1;
808 }
809 );
810 index[0] = shape[0];
811 }
812 stepper.to_end(layout_type::column_major);
813 }
814 }
815
816 template <>
817 template <class S, class IT, class ST>
818 void stepper_tools<layout_type::column_major>::increment_stepper(
819 S& stepper,
820 IT& index,
821 const ST& shape,
822 typename S::size_type n
823 )
824 {
825 using size_type = typename S::size_type;
826 const size_type size = index.size();
827 const size_type leading_i = 0;
828 size_type i = 0;
829 while (i != size && n != 0)
830 {
831 size_type inc = (i == leading_i) ? n : 1;
832 if (index[i] + inc < shape[i])
833 {
834 index[i] += inc;
835 stepper.step(i, inc);
836 n -= inc;
837 if (i != leading_i || size == 1)
838 {
839 i = 0;
840 continue;
841 }
842 }
843 else
844 {
845 if (i == leading_i)
846 {
847 size_type off = shape[i] - index[i] - 1;
848 stepper.step(i, off);
849 n -= off;
850 }
851 index[i] = 0;
852 if (i != size - 1)
853 {
854 stepper.reset(i);
855 }
856 }
857 ++i;
858 }
859 if (i == size && n != 0)
860 {
861 if (size != size_type(0))
862 {
863 std::transform(
864 shape.cbegin() + 1,
865 shape.cend(),
866 index.begin() + 1,
867 [](const auto& v)
868 {
869 return v - 1;
870 }
871 );
872 index[leading_i] = shape[leading_i];
873 }
874 stepper.to_end(layout_type::column_major);
875 }
876 }
877
878 template <>
879 template <class S, class IT, class ST>
880 void stepper_tools<layout_type::column_major>::decrement_stepper(S& stepper, IT& index, const ST& shape)
881 {
882 using size_type = typename S::size_type;
883 size_type size = index.size();
884 size_type i = 0;
885 while (i != size)
886 {
887 if (index[i] != 0)
888 {
889 --index[i];
890 stepper.step_back(i);
891 return;
892 }
893 else
894 {
895 index[i] = shape[i] - 1;
896 if (i != size - 1)
897 {
898 stepper.reset_back(i);
899 }
900 }
901 ++i;
902 }
903 if (i == size)
904 {
905 stepper.to_begin();
906 }
907 }
908
909 template <>
910 template <class S, class IT, class ST>
911 void stepper_tools<layout_type::column_major>::decrement_stepper(
912 S& stepper,
913 IT& index,
914 const ST& shape,
915 typename S::size_type n
916 )
917 {
918 using size_type = typename S::size_type;
919 size_type size = index.size();
920 size_type i = 0;
921 size_type leading_i = 0;
922 while (i != size && n != 0)
923 {
924 size_type inc = (i == leading_i) ? n : 1;
925 if (index[i] >= inc)
926 {
927 index[i] -= inc;
928 stepper.step_back(i, inc);
929 n -= inc;
930 if (i != leading_i || index.size() == 1)
931 {
932 i = 0;
933 continue;
934 }
935 }
936 else
937 {
938 if (i == leading_i)
939 {
940 size_type off = index[i];
941 stepper.step_back(i, off);
942 n -= off;
943 }
944 index[i] = shape[i] - 1;
945 if (i != size - 1)
946 {
947 stepper.reset_back(i);
948 }
949 }
950 ++i;
951 }
952 if (i == size && n != 0)
953 {
954 stepper.to_begin();
955 }
956 }
957
958 /***********************************
959 * xindexed_stepper implementation *
960 ***********************************/
961
962 template <class C, bool is_const>
963 inline xindexed_stepper<C, is_const>::xindexed_stepper(xexpression_type* e, size_type offset, bool end) noexcept
964 : p_e(e)
965 , m_index(xtl::make_sequence<index_type>(e->shape().size(), size_type(0)))
966 , m_offset(offset)
967 {
968 if (end)
969 {
970 // Note: the layout here doesn't matter (unused) but using default traversal looks more "correct".
971 to_end(XTENSOR_DEFAULT_TRAVERSAL);
972 }
973 }
974
975 template <class C, bool is_const>
976 inline auto xindexed_stepper<C, is_const>::operator*() const -> reference
977 {
978 return p_e->element(m_index.cbegin(), m_index.cend());
979 }
980
981 template <class C, bool is_const>
982 inline void xindexed_stepper<C, is_const>::step(size_type dim, size_type n)
983 {
984 if (dim >= m_offset)
985 {
986 m_index[dim - m_offset] += static_cast<typename index_type::value_type>(n);
987 }
988 }
989
990 template <class C, bool is_const>
991 inline void xindexed_stepper<C, is_const>::step_back(size_type dim, size_type n)
992 {
993 if (dim >= m_offset)
994 {
995 m_index[dim - m_offset] -= static_cast<typename index_type::value_type>(n);
996 }
997 }
998
999 template <class C, bool is_const>
1000 inline void xindexed_stepper<C, is_const>::reset(size_type dim)
1001 {
1002 if (dim >= m_offset)
1003 {
1004 m_index[dim - m_offset] = 0;
1005 }
1006 }
1007
1008 template <class C, bool is_const>
1009 inline void xindexed_stepper<C, is_const>::reset_back(size_type dim)
1010 {
1011 if (dim >= m_offset)
1012 {
1013 m_index[dim - m_offset] = p_e->shape()[dim - m_offset] - 1;
1014 }
1015 }
1016
1017 template <class C, bool is_const>
1018 inline void xindexed_stepper<C, is_const>::to_begin()
1019 {
1020 std::fill(m_index.begin(), m_index.end(), size_type(0));
1021 }
1022
1023 template <class C, bool is_const>
1024 inline void xindexed_stepper<C, is_const>::to_end(layout_type l)
1025 {
1026 const auto& shape = p_e->shape();
1027 std::transform(
1028 shape.cbegin(),
1029 shape.cend(),
1030 m_index.begin(),
1031 [](const auto& v)
1032 {
1033 return v - 1;
1034 }
1035 );
1036
1037 size_type l_dim = (l == layout_type::row_major) ? shape.size() - 1 : 0;
1038 m_index[l_dim] = shape[l_dim];
1039 }
1040
1041 /****************************
1042 * xiterator implementation *
1043 ****************************/
1044
1045 namespace detail
1046 {
1047 template <class S>
1048 inline shape_storage<S>::shape_storage(param_type shape)
1049 : m_shape(shape)
1050 {
1051 }
1052
1053 template <class S>
1054 inline const S& shape_storage<S>::shape() const
1055 {
1056 return m_shape;
1057 }
1058
1059 template <class S>
1060 inline shape_storage<S*>::shape_storage(param_type shape)
1061 : p_shape(shape)
1062 {
1063 }
1064
1065 template <class S>
1066 inline const S& shape_storage<S*>::shape() const
1067 {
1068 return *p_shape;
1069 }
1070
1071 template <>
1072 struct LAYOUT_FORBIDEN_FOR_XITERATOR<layout_type::row_major>
1073 {
1074 using type = int;
1075 };
1076
1077 template <>
1078 struct LAYOUT_FORBIDEN_FOR_XITERATOR<layout_type::column_major>
1079 {
1080 using type = int;
1081 };
1082 }
1083
1084 template <class St, class S, layout_type L>
1085 inline xiterator<St, S, L>::xiterator(St st, shape_param_type shape, bool end_index)
1086 : private_base(shape)
1087 , m_st(st)
1088 , m_index(
1089 end_index ? xtl::forward_sequence<index_type, const shape_type&>(this->shape())
1090 : xtl::make_sequence<index_type>(this->shape().size(), size_type(0))
1091 )
1092 , m_linear_index(0)
1093 {
1094 // end_index means either reverse_iterator && !end or !reverse_iterator && end
1095 if (end_index)
1096 {
1097 if (m_index.size() != size_type(0))
1098 {
1099 auto iter_begin = (L == layout_type::row_major) ? m_index.begin() : m_index.begin() + 1;
1100 auto iter_end = (L == layout_type::row_major) ? m_index.end() - 1 : m_index.end();
1101 std::transform(
1102 iter_begin,
1103 iter_end,
1104 iter_begin,
1105 [](const auto& v)
1106 {
1107 return v - 1;
1108 }
1109 );
1110 }
1111 m_linear_index = difference_type(std::accumulate(
1112 this->shape().cbegin(),
1113 this->shape().cend(),
1114 size_type(1),
1115 std::multiplies<size_type>()
1116 ));
1117 }
1118 }
1119
1120 template <class St, class S, layout_type L>
1121 inline auto xiterator<St, S, L>::operator++() -> self_type&
1122 {
1123 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape());
1124 ++m_linear_index;
1125 return *this;
1126 }
1127
1128 template <class St, class S, layout_type L>
1129 inline auto xiterator<St, S, L>::operator--() -> self_type&
1130 {
1131 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape());
1132 --m_linear_index;
1133 return *this;
1134 }
1135
1136 template <class St, class S, layout_type L>
1137 inline auto xiterator<St, S, L>::operator+=(difference_type n) -> self_type&
1138 {
1139 if (n >= 0)
1140 {
1141 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(), static_cast<size_type>(n));
1142 }
1143 else
1144 {
1145 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(), static_cast<size_type>(-n));
1146 }
1147 m_linear_index += n;
1148 return *this;
1149 }
1150
1151 template <class St, class S, layout_type L>
1152 inline auto xiterator<St, S, L>::operator-=(difference_type n) -> self_type&
1153 {
1154 if (n >= 0)
1155 {
1156 stepper_tools<L>::decrement_stepper(m_st, m_index, this->shape(), static_cast<size_type>(n));
1157 }
1158 else
1159 {
1160 stepper_tools<L>::increment_stepper(m_st, m_index, this->shape(), static_cast<size_type>(-n));
1161 }
1162 m_linear_index -= n;
1163 return *this;
1164 }
1165
1166 template <class St, class S, layout_type L>
1167 inline auto xiterator<St, S, L>::operator-(const self_type& rhs) const -> difference_type
1168 {
1169 return m_linear_index - rhs.m_linear_index;
1170 }
1171
1172 template <class St, class S, layout_type L>
1173 inline auto xiterator<St, S, L>::operator*() const -> reference
1174 {
1175 return *m_st;
1176 }
1177
1178 template <class St, class S, layout_type L>
1179 inline auto xiterator<St, S, L>::operator->() const -> pointer
1180 {
1181 return &(*m_st);
1182 }
1183
1184 template <class St, class S, layout_type L>
1185 inline bool xiterator<St, S, L>::equal(const xiterator& rhs) const
1186 {
1187 XTENSOR_ASSERT(this->shape() == rhs.shape());
1188 return m_linear_index == rhs.m_linear_index;
1189 }
1190
1191 template <class St, class S, layout_type L>
1192 inline bool xiterator<St, S, L>::less_than(const xiterator& rhs) const
1193 {
1194 XTENSOR_ASSERT(this->shape() == rhs.shape());
1195 return m_linear_index < rhs.m_linear_index;
1196 }
1197
1198 template <class St, class S, layout_type L>
1199 inline bool operator==(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs)
1200 {
1201 return lhs.equal(rhs);
1202 }
1203
1204 template <class St, class S, layout_type L>
1205 bool operator<(const xiterator<St, S, L>& lhs, const xiterator<St, S, L>& rhs)
1206 {
1207 return lhs.less_than(rhs);
1208 }
1209
1210 /************************************
1211 * xbounded_iterator implementation *
1212 ************************************/
1213
1214 template <class It, class BIt>
1215 xbounded_iterator<It, BIt>::xbounded_iterator(It it, BIt bound_it)
1216 : m_it(it)
1217 , m_bound_it(bound_it)
1218 {
1219 }
1220
1221 template <class It, class BIt>
1222 inline auto xbounded_iterator<It, BIt>::operator++() -> self_type&
1223 {
1224 ++m_it;
1225 ++m_bound_it;
1226 return *this;
1227 }
1228
1229 template <class It, class BIt>
1230 inline auto xbounded_iterator<It, BIt>::operator--() -> self_type&
1231 {
1232 --m_it;
1233 --m_bound_it;
1234 return *this;
1235 }
1236
1237 template <class It, class BIt>
1238 inline auto xbounded_iterator<It, BIt>::operator+=(difference_type n) -> self_type&
1239 {
1240 m_it += n;
1241 m_bound_it += n;
1242 return *this;
1243 }
1244
1245 template <class It, class BIt>
1246 inline auto xbounded_iterator<It, BIt>::operator-=(difference_type n) -> self_type&
1247 {
1248 m_it -= n;
1249 m_bound_it -= n;
1250 return *this;
1251 }
1252
1253 template <class It, class BIt>
1254 inline auto xbounded_iterator<It, BIt>::operator-(const self_type& rhs) const -> difference_type
1255 {
1256 return m_it - rhs.m_it;
1257 }
1258
1259 template <class It, class BIt>
1260 inline auto xbounded_iterator<It, BIt>::operator*() const -> value_type
1261 {
1262 using type = decltype(*m_bound_it);
1263 return (static_cast<type>(*m_it) < *m_bound_it) ? *m_it : static_cast<value_type>((*m_bound_it) - 1);
1264 }
1265
1266 template <class It, class BIt>
1267 inline bool xbounded_iterator<It, BIt>::equal(const self_type& rhs) const
1268 {
1269 return m_it == rhs.m_it && m_bound_it == rhs.m_bound_it;
1270 }
1271
1272 template <class It, class BIt>
1273 inline bool xbounded_iterator<It, BIt>::less_than(const self_type& rhs) const
1274 {
1275 return m_it < rhs.m_it;
1276 }
1277
1278 template <class It, class BIt>
1279 inline bool operator==(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs)
1280 {
1281 return lhs.equal(rhs);
1282 }
1283
1284 template <class It, class BIt>
1285 inline bool operator<(const xbounded_iterator<It, BIt>& lhs, const xbounded_iterator<It, BIt>& rhs)
1286 {
1287 return lhs.less_than(rhs);
1288 }
1289}
1290
1291#endif
standard mathematical functions for xexpressions
layout_type
Definition xlayout.hpp:24