xtensor
 
Loading...
Searching...
No Matches
xstrided_view_base.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_STRIDED_VIEW_BASE_HPP
11#define XTENSOR_STRIDED_VIEW_BASE_HPP
12
13#include <type_traits>
14#include <variant>
15
16#include <xtl/xsequence.hpp>
17
18#include "../core/xaccessible.hpp"
19#include "../core/xstrides.hpp"
20#include "../core/xtensor_config.hpp"
21#include "../core/xtensor_forward.hpp"
22#include "../utils/xutils.hpp"
23#include "../views/xslice.hpp"
24
25namespace xt
26{
27 namespace detail
28 {
29 template <class CT, layout_type L>
30 class flat_expression_adaptor
31 {
32 public:
33
34 using xexpression_type = std::decay_t<CT>;
35 using shape_type = typename xexpression_type::shape_type;
36 using inner_strides_type = get_strides_t<shape_type>;
37 using index_type = inner_strides_type;
38 using size_type = typename xexpression_type::size_type;
39 using value_type = typename xexpression_type::value_type;
40 using const_reference = typename xexpression_type::const_reference;
41 using reference = std::conditional_t<
42 std::is_const<std::remove_reference_t<CT>>::value,
43 typename xexpression_type::const_reference,
44 typename xexpression_type::reference>;
45
46 using iterator = decltype(std::declval<std::remove_reference_t<CT>>().template begin<L>());
47 using const_iterator = decltype(std::declval<std::decay_t<CT>>().template cbegin<L>());
48 using reverse_iterator = decltype(std::declval<std::remove_reference_t<CT>>().template rbegin<L>());
49 using const_reverse_iterator = decltype(std::declval<std::decay_t<CT>>().template crbegin<L>());
50
51 explicit flat_expression_adaptor(CT* e);
52
53 template <class FST>
54 flat_expression_adaptor(CT* e, FST&& strides);
55
56 void update_pointer(CT* ptr) const;
57
58 size_type size() const;
59 reference operator[](size_type idx);
60 const_reference operator[](size_type idx) const;
61
62 iterator begin();
63 iterator end();
64 const_iterator begin() const;
65 const_iterator end() const;
66 const_iterator cbegin() const;
67 const_iterator cend() const;
68
69 private:
70
71 static index_type& get_index();
72
73 mutable CT* m_e;
74 inner_strides_type m_strides;
75 size_type m_size;
76 };
77
78 template <class T>
79 struct is_flat_expression_adaptor : std::false_type
80 {
81 };
82
83 template <class CT, layout_type L>
84 struct is_flat_expression_adaptor<flat_expression_adaptor<CT, L>> : std::true_type
85 {
86 };
87
88 template <class E, class ST>
89 struct provides_data_interface
90 : std::conjunction<has_data_interface<std::decay_t<E>>, std::negation<is_flat_expression_adaptor<ST>>>
91 {
92 };
93 }
94
95 template <class D>
96 class xstrided_view_base : public xaccessible<D>
97 {
98 public:
99
100 using base_type = xaccessible<D>;
101 using inner_types = xcontainer_inner_types<D>;
102 using xexpression_type = typename inner_types::xexpression_type;
103 using undecay_expression = typename inner_types::undecay_expression;
104 static constexpr bool is_const = std::is_const<std::remove_reference_t<undecay_expression>>::value;
105
106 using value_type = typename xexpression_type::value_type;
107 using reference = typename inner_types::reference;
108 using const_reference = typename inner_types::const_reference;
109 using pointer = std::
110 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
111 using const_pointer = typename xexpression_type::const_pointer;
112 using size_type = typename inner_types::size_type;
113 using difference_type = typename xexpression_type::difference_type;
114
115 using storage_getter = typename inner_types::storage_getter;
116 using inner_storage_type = typename inner_types::inner_storage_type;
117 using storage_type = std::remove_reference_t<inner_storage_type>;
118
119 using shape_type = typename inner_types::shape_type;
120 using strides_type = get_strides_t<shape_type>;
121 using backstrides_type = strides_type;
122
123 using inner_shape_type = shape_type;
124 using inner_strides_type = strides_type;
125 using inner_backstrides_type = backstrides_type;
126
127 using undecay_shape = typename inner_types::undecay_shape;
128
129 using simd_value_type = xt_simd::simd_type<value_type>;
130 using bool_load_type = typename xexpression_type::bool_load_type;
131
132 static constexpr layout_type static_layout = inner_types::layout;
133 static constexpr bool contiguous_layout = static_layout != layout_type::dynamic
134 && xexpression_type::contiguous_layout;
135
136 static constexpr bool
137 provides_data_interface = detail::provides_data_interface<xexpression_type, storage_type>::value;
138
139 template <class CTA, class SA>
140 xstrided_view_base(CTA&& e, SA&& shape, strides_type&& strides, size_type offset, layout_type layout) noexcept;
141
143
145
146 const inner_shape_type& shape() const noexcept;
147 const inner_strides_type& strides() const noexcept;
148 const inner_backstrides_type& backstrides() const noexcept;
149 layout_type layout() const noexcept;
150 bool is_contiguous() const noexcept;
151 using base_type::shape;
152
153 reference operator()();
154 const_reference operator()() const;
155
156 template <class... Args>
157 reference operator()(Args... args);
158
159 template <class... Args>
160 const_reference operator()(Args... args) const;
161
162 template <class... Args>
163 reference unchecked(Args... args);
164
165 template <class... Args>
166 const_reference unchecked(Args... args) const;
167
168 template <class It>
169 reference element(It first, It last);
171 template <class It>
172 const_reference element(It first, It last) const;
173
174 storage_type& storage() noexcept;
175 const storage_type& storage() const noexcept;
176
177 pointer data() noexcept
178 requires(provides_data_interface);
179 const_pointer data() const noexcept
180 requires(provides_data_interface);
181
182 size_type data_offset() const noexcept;
183
184 xexpression_type& expression() noexcept;
185 const xexpression_type& expression() const noexcept;
186
187 template <class O>
188 bool broadcast_shape(O& shape, bool reuse_cache = false) const;
189
190 template <class O>
191 bool has_linear_assign(const O& strides) const noexcept;
192
193 protected:
194
195 using offset_type = typename strides_type::value_type;
196
197 template <class... Args>
198 offset_type compute_index(Args... args) const;
200 template <class... Args>
201 offset_type compute_unchecked_index(Args... args) const;
203 template <class It>
204 offset_type compute_element_index(It first, It last) const;
206 void set_offset(size_type offset);
207
208 private:
209
210 undecay_expression m_e;
211 inner_storage_type m_storage;
212 inner_shape_type m_shape;
213 inner_strides_type m_strides;
214 inner_backstrides_type m_backstrides;
215 size_type m_offset;
216 layout_type m_layout;
217 };
219 /***************************
220 * flat_expression_adaptor *
221 ***************************/
222
223 namespace detail
224 {
225 template <class CT>
226 struct inner_storage_getter
227 {
228 using type = decltype(std::declval<CT>().storage());
229 using reference = std::add_lvalue_reference_t<CT>;
230
231 template <class E>
232 using rebind_t = inner_storage_getter<E>;
233
234 static decltype(auto) get_flat_storage(reference e)
235 {
236 return e.storage();
237 }
238
239 static auto get_offset(reference e)
240 {
241 return e.data_offset();
242 }
243
244 static decltype(auto) get_strides(reference e)
245 {
246 return e.strides();
247 }
248 };
249
250 template <class CT, layout_type L>
251 struct flat_adaptor_getter
252 {
253 using type = flat_expression_adaptor<std::remove_reference_t<CT>, L>;
254 using reference = std::add_lvalue_reference_t<CT>;
255
256 template <class E>
257 using rebind_t = flat_adaptor_getter<E, L>;
258
259 static type get_flat_storage(reference e)
260 {
261 // moved to addressof because ampersand on xview returns a closure pointer
262 return type(std::addressof(e));
263 }
264
265 static auto get_offset(reference)
266 {
267 return typename std::decay_t<CT>::size_type(0);
268 }
269
270 static auto get_strides(reference e)
271 {
272 dynamic_shape<std::ptrdiff_t> strides;
273 strides.resize(e.shape().size());
274 compute_strides(e.shape(), L, strides);
275 return strides;
276 }
277 };
278
279 template <class CT, layout_type L>
280 using flat_storage_getter = std::conditional_t<
282 inner_storage_getter<CT>,
283 flat_adaptor_getter<CT, L>>;
284
285 template <layout_type L, class E>
286 inline auto get_offset(E& e)
287 {
288 return flat_storage_getter<E, L>::get_offset(e);
289 }
290
291 template <layout_type L, class E>
292 inline decltype(auto) get_strides(E& e)
293 {
294 return flat_storage_getter<E, L>::get_strides(e);
295 }
296 }
297
298 /*************************************
299 * xstrided_view_base implementation *
300 *************************************/
301
306
315 template <class D>
316 template <class CTA, class SA>
318 CTA&& e,
319 SA&& shape,
320 strides_type&& strides,
321 size_type offset,
323 ) noexcept
324 : m_e(std::forward<CTA>(e))
325 ,
326 // m_storage(detail::get_flat_storage<undecay_expression>(m_e)),
327 m_storage(storage_getter::get_flat_storage(m_e))
328 , m_shape(std::forward<SA>(shape))
329 , m_strides(std::move(strides))
330 , m_offset(offset)
331 , m_layout(layout)
332 {
333 m_backstrides = xtl::make_sequence<backstrides_type>(m_shape.size(), 0);
334 adapt_strides(m_shape, m_strides, m_backstrides);
335 }
336
337 namespace detail
338 {
339 template <class T, class S>
340 auto& copy_move_storage(T& expr, const S& /*storage*/)
341 {
342 return expr.storage();
343 }
344
345 template <class T, class E, layout_type L>
346 auto copy_move_storage(T& expr, const detail::flat_expression_adaptor<E, L>& storage)
347 {
348 detail::flat_expression_adaptor<E, L> new_storage = storage; // copy storage
349 new_storage.update_pointer(std::addressof(expr));
350 return new_storage;
351 }
352 }
353
354 template <class D>
356 : base_type(std::move(rhs))
357 , m_e(std::forward<undecay_expression>(rhs.m_e))
358 , m_storage(detail::copy_move_storage(m_e, rhs.m_storage))
359 , m_shape(std::move(rhs.m_shape))
360 , m_strides(std::move(rhs.m_strides))
361 , m_backstrides(std::move(rhs.m_backstrides))
362 , m_offset(std::move(rhs.m_offset))
363 , m_layout(std::move(rhs.m_layout))
364 {
365 }
366
367 template <class D>
369 : base_type(rhs)
370 , m_e(rhs.m_e)
371 , m_storage(detail::copy_move_storage(m_e, rhs.m_storage))
372 , m_shape(rhs.m_shape)
373 , m_strides(rhs.m_strides)
374 , m_backstrides(rhs.m_backstrides)
375 , m_offset(rhs.m_offset)
376 , m_layout(rhs.m_layout)
377 {
378 }
379
381
386
389 template <class D>
390 inline auto xstrided_view_base<D>::shape() const noexcept -> const inner_shape_type&
391 {
392 return m_shape;
393 }
394
398 template <class D>
399 inline auto xstrided_view_base<D>::strides() const noexcept -> const inner_strides_type&
400 {
401 return m_strides;
402 }
403
407 template <class D>
408 inline auto xstrided_view_base<D>::backstrides() const noexcept -> const inner_backstrides_type&
409 {
410 return m_backstrides;
411 }
412
416 template <class D>
417 inline auto xstrided_view_base<D>::layout() const noexcept -> layout_type
418 {
419 return m_layout;
420 }
421
422 template <class D>
423 inline bool xstrided_view_base<D>::is_contiguous() const noexcept
424 {
425 return m_layout != layout_type::dynamic && m_e.is_contiguous();
426 }
427
429
434 template <class D>
435 inline auto xstrided_view_base<D>::operator()() -> reference
436 {
437 return m_storage[static_cast<size_type>(m_offset)];
438 }
439
440 template <class D>
441 inline auto xstrided_view_base<D>::operator()() const -> const_reference
442 {
443 return m_storage[static_cast<size_type>(m_offset)];
444 }
445
452 template <class D>
453 template <class... Args>
454 inline auto xstrided_view_base<D>::operator()(Args... args) -> reference
455 {
456 XTENSOR_TRY(check_index(shape(), args...));
457 XTENSOR_CHECK_DIMENSION(shape(), args...);
458 offset_type index = compute_index(args...);
459 return m_storage[static_cast<size_type>(index)];
460 }
461
468 template <class D>
469 template <class... Args>
470 inline auto xstrided_view_base<D>::operator()(Args... args) const -> const_reference
471 {
472 XTENSOR_TRY(check_index(shape(), args...));
473 XTENSOR_CHECK_DIMENSION(shape(), args...);
474 offset_type index = compute_index(args...);
475 return m_storage[static_cast<size_type>(index)];
476 }
477
497 template <class D>
498 template <class... Args>
499 inline auto xstrided_view_base<D>::unchecked(Args... args) -> reference
500 {
501 offset_type index = compute_unchecked_index(args...);
502 return m_storage[static_cast<size_type>(index)];
503 }
504
524 template <class D>
525 template <class... Args>
526 inline auto xstrided_view_base<D>::unchecked(Args... args) const -> const_reference
527 {
528 offset_type index = compute_unchecked_index(args...);
529 return m_storage[static_cast<size_type>(index)];
530 }
531
539 template <class D>
540 template <class It>
541 inline auto xstrided_view_base<D>::element(It first, It last) -> reference
542 {
543 XTENSOR_TRY(check_element_index(shape(), first, last));
544 return m_storage[static_cast<size_type>(compute_element_index(first, last))];
545 }
546
554 template <class D>
555 template <class It>
556 inline auto xstrided_view_base<D>::element(It first, It last) const -> const_reference
557 {
558 XTENSOR_TRY(check_element_index(shape(), first, last));
559 return m_storage[static_cast<size_type>(compute_element_index(first, last))];
560 }
561
565 template <class D>
566 inline auto xstrided_view_base<D>::storage() noexcept -> storage_type&
567 {
568 return m_storage;
569 }
570
574 template <class D>
575 inline auto xstrided_view_base<D>::storage() const noexcept -> const storage_type&
576 {
577 return m_storage;
578 }
579
584 template <class D>
585 inline auto xstrided_view_base<D>::data() noexcept -> pointer
586 requires(provides_data_interface)
587 {
588 return m_e.data();
589 }
590
595 template <class D>
596 inline auto xstrided_view_base<D>::data() const noexcept -> const_pointer
597 requires(provides_data_interface)
598 {
599 return m_e.data();
600 }
601
605 template <class D>
606 inline auto xstrided_view_base<D>::data_offset() const noexcept -> size_type
607 {
608 return m_offset;
609 }
610
614 template <class D>
615 inline auto xstrided_view_base<D>::expression() noexcept -> xexpression_type&
616 {
617 return m_e;
618 }
619
623 template <class D>
624 inline auto xstrided_view_base<D>::expression() const noexcept -> const xexpression_type&
625 {
626 return m_e;
627 }
628
630
635
641 template <class D>
642 template <class O>
644 {
645 return xt::broadcast_shape(m_shape, shape);
646 }
647
653 template <class D>
654 template <class O>
655 inline bool xstrided_view_base<D>::has_linear_assign(const O& str) const noexcept
656 {
657 return has_data_interface<xexpression_type>::value && str.size() == strides().size()
658 && std::equal(str.cbegin(), str.cend(), strides().begin());
659 }
660
662
663 template <class D>
664 template <class... Args>
665 inline auto xstrided_view_base<D>::compute_index(Args... args) const -> offset_type
666 {
667 return static_cast<offset_type>(m_offset)
668 + xt::data_offset<offset_type>(strides(), static_cast<offset_type>(args)...);
669 }
670
671 template <class D>
672 template <class... Args>
673 inline auto xstrided_view_base<D>::compute_unchecked_index(Args... args) const -> offset_type
674 {
675 return static_cast<offset_type>(m_offset)
676 + xt::unchecked_data_offset<offset_type>(strides(), static_cast<offset_type>(args)...);
677 }
678
679 template <class D>
680 template <class It>
681 inline auto xstrided_view_base<D>::compute_element_index(It first, It last) const -> offset_type
682 {
683 return static_cast<offset_type>(m_offset) + xt::element_offset<offset_type>(strides(), first, last);
684 }
685
686 template <class D>
687 void xstrided_view_base<D>::set_offset(size_type offset)
688 {
689 m_offset = offset;
690 }
691
692 /******************************************
693 * flat_expression_adaptor implementation *
694 ******************************************/
695
696 namespace detail
697 {
698 template <class CT, layout_type L>
699 inline flat_expression_adaptor<CT, L>::flat_expression_adaptor(CT* e)
700 : m_e(e)
701 {
702 resize_container(get_index(), m_e->dimension());
703 resize_container(m_strides, m_e->dimension());
704 m_size = compute_size(m_e->shape());
705 compute_strides(m_e->shape(), L, m_strides);
706 }
707
708 template <class CT, layout_type L>
709 template <class FST>
710 inline flat_expression_adaptor<CT, L>::flat_expression_adaptor(CT* e, FST&& strides)
711 : m_e(e)
712 , m_strides(xtl::forward_sequence<inner_strides_type, FST>(strides))
713 {
714 resize_container(get_index(), m_e->dimension());
715 m_size = m_e->size();
716 }
717
718 template <class CT, layout_type L>
719 inline void flat_expression_adaptor<CT, L>::update_pointer(CT* ptr) const
720 {
721 m_e = ptr;
722 }
723
724 template <class CT, layout_type L>
725 inline auto flat_expression_adaptor<CT, L>::size() const -> size_type
726 {
727 return m_size;
728 }
729
730 template <class CT, layout_type L>
731 inline auto flat_expression_adaptor<CT, L>::operator[](size_type idx) -> reference
732 {
733 auto i = static_cast<typename index_type::value_type>(idx);
734 get_index() = detail::unravel_noexcept(i, m_strides, L);
735 return m_e->element(get_index().cbegin(), get_index().cend());
736 }
737
738 template <class CT, layout_type L>
739 inline auto flat_expression_adaptor<CT, L>::operator[](size_type idx) const -> const_reference
740 {
741 auto i = static_cast<typename index_type::value_type>(idx);
742 get_index() = detail::unravel_noexcept(i, m_strides, L);
743 return m_e->element(get_index().cbegin(), get_index().cend());
744 }
745
746 template <class CT, layout_type L>
747 inline auto flat_expression_adaptor<CT, L>::begin() -> iterator
748 {
749 return m_e->template begin<L>();
750 }
751
752 template <class CT, layout_type L>
753 inline auto flat_expression_adaptor<CT, L>::end() -> iterator
754 {
755 return m_e->template end<L>();
756 }
757
758 template <class CT, layout_type L>
759 inline auto flat_expression_adaptor<CT, L>::begin() const -> const_iterator
760 {
761 return m_e->template cbegin<L>();
762 }
763
764 template <class CT, layout_type L>
765 inline auto flat_expression_adaptor<CT, L>::end() const -> const_iterator
766 {
767 return m_e->template cend<L>();
768 }
769
770 template <class CT, layout_type L>
771 inline auto flat_expression_adaptor<CT, L>::cbegin() const -> const_iterator
772 {
773 return m_e->template cbegin<L>();
774 }
775
776 template <class CT, layout_type L>
777 inline auto flat_expression_adaptor<CT, L>::cend() const -> const_iterator
778 {
779 return m_e->template cend<L>();
780 }
781
782 template <class CT, layout_type L>
783 inline auto flat_expression_adaptor<CT, L>::get_index() -> index_type&
784 {
785 thread_local static index_type index;
786 return index;
787 }
788 }
789
790 /**********************************
791 * Builder helpers implementation *
792 **********************************/
793
794 namespace detail
795 {
796 template <class S>
797 struct slice_getter_impl
798 {
799 const S& m_shape;
800 mutable std::size_t idx;
801 using array_type = std::array<std::ptrdiff_t, 3>;
802
803 explicit slice_getter_impl(const S& shape)
804 : m_shape(shape)
805 , idx(0)
806 {
807 }
808
809 template <class T>
810 array_type operator()(const T& /*t*/) const
811 {
812 return array_type{{0, 0, 0}};
813 }
814
815 template <class A, class B, class C>
816 array_type operator()(const xrange_adaptor<A, B, C>& range) const
817 {
818 auto sl = range.get(static_cast<std::size_t>(m_shape[idx]));
819 return array_type({sl(0), sl.size(), sl.step_size()});
820 }
821
822 template <class T>
823 array_type operator()(const xrange<T>& range) const
824 {
825 return array_type({range(T(0)), range.size(), T(1)});
826 }
827
828 template <class T>
829 array_type operator()(const xstepped_range<T>& range) const
830 {
831 return array_type({range(T(0)), range.size(), range.step_size(T(0))});
832 }
833 };
834
835 template <class adj_strides_policy>
836 struct strided_view_args : adj_strides_policy
837 {
838 using base_type = adj_strides_policy;
839
840 template <class S, class ST, class V>
841 void
842 fill_args(const S& shape, ST&& old_strides, std::size_t base_offset, layout_type layout, const V& slices)
843 {
844 // Compute dimension
845 std::size_t dimension = shape.size(), n_newaxis = 0, n_add_all = 0;
846 std::ptrdiff_t dimension_check = static_cast<std::ptrdiff_t>(shape.size());
847
848 bool has_ellipsis = false;
849 for (const auto& el : slices)
850 {
851 if (std::get_if<xt::xnewaxis_tag>(&el) != nullptr)
852 {
853 ++dimension;
854 ++n_newaxis;
855 }
856 else if (std::get_if<std::ptrdiff_t>(&el) != nullptr)
857 {
858 --dimension;
859 --dimension_check;
860 }
861 else if (std::get_if<xt::xellipsis_tag>(&el) != nullptr)
862 {
863 if (has_ellipsis == true)
864 {
865 XTENSOR_THROW(std::runtime_error, "Ellipsis can only appear once.");
866 }
867 has_ellipsis = true;
868 }
869 else
870 {
871 --dimension_check;
872 }
873 }
874
875 if (dimension_check < 0)
876 {
877 XTENSOR_THROW(std::runtime_error, "Too many slices for view.");
878 }
879
880 if (has_ellipsis)
881 {
882 // replace ellipsis with N * xt::all
883 // remove -1 because of the ellipsis slize itself
884 n_add_all = shape.size() - (slices.size() - 1 - n_newaxis);
885 }
886
887 // Compute strided view
888 new_offset = base_offset;
889 new_shape.resize(dimension);
890 new_strides.resize(dimension);
891 base_type::resize(dimension);
892
893 auto old_shape = shape;
894 using old_strides_value_type = typename std::decay_t<ST>::value_type;
895
896 std::ptrdiff_t axis_skip = 0;
897 std::size_t idx = 0, i = 0, i_ax = 0;
898
899 auto slice_getter = detail::slice_getter_impl<S>(shape);
900
901 for (; i < slices.size(); ++i)
902 {
903 i_ax = static_cast<std::size_t>(static_cast<std::ptrdiff_t>(i) - axis_skip);
904 auto ptr = std::get_if<std::ptrdiff_t>(&slices[i]);
905 if (ptr != nullptr)
906 {
907 auto slice0 = static_cast<old_strides_value_type>(*ptr);
908 new_offset += static_cast<std::size_t>(slice0 * old_strides[i_ax]);
909 }
910 else if (std::get_if<xt::xnewaxis_tag>(&slices[i]) != nullptr)
911 {
912 new_shape[idx] = 1;
913 base_type::set_fake_slice(idx);
914 ++axis_skip, ++idx;
915 }
916 else if (std::get_if<xt::xellipsis_tag>(&slices[i]) != nullptr)
917 {
918 for (std::size_t j = 0; j < n_add_all; ++j)
919 {
920 new_shape[idx] = old_shape[i_ax];
921 new_strides[idx] = old_strides[i_ax];
922 base_type::set_fake_slice(idx);
923 ++idx, ++i_ax;
924 }
925 axis_skip = axis_skip - static_cast<std::ptrdiff_t>(n_add_all) + 1;
926 }
927 else if (std::get_if<xt::xall_tag>(&slices[i]) != nullptr)
928 {
929 new_shape[idx] = old_shape[i_ax];
930 new_strides[idx] = old_strides[i_ax];
931 base_type::set_fake_slice(idx);
932 ++idx;
933 }
934 else if (base_type::fill_args(slices, i, idx, old_shape[i_ax], old_strides[i_ax], new_shape, new_strides))
935 {
936 ++idx;
937 }
938 else
939 {
940 slice_getter.idx = i_ax;
941 auto info = std::visit(slice_getter, slices[i]);
942 new_offset += static_cast<std::size_t>(info[0] * old_strides[i_ax]);
943 new_shape[idx] = static_cast<std::size_t>(info[1]);
944 new_strides[idx] = info[2] * old_strides[i_ax];
945 base_type::set_fake_slice(idx);
946 ++idx;
947 }
948 }
949
950 i_ax = static_cast<std::size_t>(static_cast<std::ptrdiff_t>(i) - axis_skip);
951 for (; i_ax < old_shape.size(); ++i_ax, ++idx)
952 {
953 new_shape[idx] = old_shape[i_ax];
954 new_strides[idx] = old_strides[i_ax];
955 base_type::set_fake_slice(idx);
956 }
957
958 new_layout = do_strides_match(new_shape, new_strides, layout, true) ? layout
959 : layout_type::dynamic;
960 }
961
962 using shape_type = dynamic_shape<std::size_t>;
963 shape_type new_shape;
964 using strides_type = dynamic_shape<std::ptrdiff_t>;
965 strides_type new_strides;
966 std::size_t new_offset;
967 layout_type new_layout;
968 };
969 }
970}
971
972#endif
layout_type layout() const noexcept
xstrided_view_base(CTA &&e, SA &&shape, strides_type &&strides, size_type offset, layout_type layout) noexcept
Constructs an xstrided_view_base.
bool has_linear_assign(const O &strides) const noexcept
const inner_strides_type & strides() const noexcept
bool broadcast_shape(O &shape, bool reuse_cache=false) const
const inner_backstrides_type & backstrides() const noexcept
const inner_shape_type & shape() const noexcept
size_type data_offset() const noexcept
storage_type & storage() noexcept
xexpression_type & expression() noexcept
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
Definition xstrides.hpp:566
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
Definition xslice.hpp:744
layout_type
Definition xlayout.hpp:24