10#ifndef XTENSOR_STRIDED_VIEW_BASE_HPP
11#define XTENSOR_STRIDED_VIEW_BASE_HPP
15#include <xtl/xsequence.hpp>
16#include <xtl/xvariant.hpp>
18#include "xaccessible.hpp"
20#include "xstrides.hpp"
21#include "xtensor_config.hpp"
22#include "xtensor_forward.hpp"
29 template <
class CT, layout_type L>
30 class flat_expression_adaptor
34 using xexpression_type = std::decay_t<CT>;
35 using shape_type =
typename xexpression_type::shape_type;
36 using inner_strides_type = get_strides_t<shape_type>;
37 using index_type = inner_strides_type;
38 using size_type =
typename xexpression_type::size_type;
39 using value_type =
typename xexpression_type::value_type;
40 using const_reference =
typename xexpression_type::const_reference;
41 using reference = std::conditional_t<
42 std::is_const<std::remove_reference_t<CT>>::value,
43 typename xexpression_type::const_reference,
44 typename xexpression_type::reference>;
46 using iterator =
decltype(std::declval<std::remove_reference_t<CT>>().template begin<L>());
47 using const_iterator =
decltype(std::declval<std::decay_t<CT>>().template cbegin<L>());
48 using reverse_iterator =
decltype(std::declval<std::remove_reference_t<CT>>().template rbegin<L>());
49 using const_reverse_iterator =
decltype(std::declval<std::decay_t<CT>>().template crbegin<L>());
51 explicit flat_expression_adaptor(CT* e);
54 flat_expression_adaptor(CT* e, FST&&
strides);
56 void update_pointer(CT* ptr)
const;
58 size_type size()
const;
59 reference operator[](size_type idx);
60 const_reference operator[](size_type idx)
const;
64 const_iterator begin()
const;
65 const_iterator end()
const;
66 const_iterator cbegin()
const;
67 const_iterator cend()
const;
71 static index_type& get_index();
74 inner_strides_type m_strides;
79 struct is_flat_expression_adaptor : std::false_type
83 template <
class CT, layout_type L>
84 struct is_flat_expression_adaptor<flat_expression_adaptor<CT, L>> : std::true_type
88 template <
class E,
class ST>
89 struct provides_data_interface
90 : xtl::conjunction<has_data_interface<std::decay_t<E>>, xtl::negation<is_flat_expression_adaptor<ST>>>
100 using base_type = xaccessible<D>;
102 using xexpression_type =
typename inner_types::xexpression_type;
103 using undecay_expression =
typename inner_types::undecay_expression;
104 static constexpr bool is_const = std::is_const<std::remove_reference_t<undecay_expression>>::value;
106 using value_type =
typename xexpression_type::value_type;
107 using reference =
typename inner_types::reference;
108 using const_reference =
typename inner_types::const_reference;
109 using pointer = std::
110 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
111 using const_pointer =
typename xexpression_type::const_pointer;
112 using size_type =
typename inner_types::size_type;
113 using difference_type =
typename xexpression_type::difference_type;
115 using storage_getter =
typename inner_types::storage_getter;
116 using inner_storage_type =
typename inner_types::inner_storage_type;
117 using storage_type = std::remove_reference_t<inner_storage_type>;
119 using shape_type =
typename inner_types::shape_type;
120 using strides_type = get_strides_t<shape_type>;
121 using backstrides_type = strides_type;
123 using inner_shape_type = shape_type;
124 using inner_strides_type = strides_type;
125 using inner_backstrides_type = backstrides_type;
127 using undecay_shape =
typename inner_types::undecay_shape;
129 using simd_value_type = xt_simd::simd_type<value_type>;
130 using bool_load_type =
typename xexpression_type::bool_load_type;
132 static constexpr layout_type static_layout = inner_types::layout;
134 && xexpression_type::contiguous_layout;
136 template <
class CTA,
class SA>
143 const inner_shape_type&
shape() const noexcept;
144 const inner_strides_type&
strides() const noexcept;
147 bool is_contiguous() const noexcept;
148 using base_type::
shape;
150 reference operator()();
151 const_reference operator()() const;
153 template <class... Args>
154 reference operator()(Args... args);
156 template <class... Args>
157 const_reference operator()(Args... args) const;
159 template <class... Args>
160 reference unchecked(Args... args);
162 template <class... Args>
163 const_reference unchecked(Args... args) const;
166 reference element(It first, It last);
169 const_reference element(It first, It last) const;
174 template <class E = xexpression_type, class ST = storage_type>
175 std::enable_if_t<detail::provides_data_interface<E, ST>::value, pointer> data() noexcept;
176 template <class E = xexpression_type, class ST = storage_type>
177 std::enable_if_t<detail::provides_data_interface<E, ST>::value, const_pointer> data() const noexcept;
191 using offset_type = typename strides_type::value_type;
193 template <class... Args>
194 offset_type compute_index(Args... args) const;
196 template <class... Args>
197 offset_type compute_unchecked_index(Args... args) const;
200 offset_type compute_element_index(It first, It last) const;
202 void set_offset(size_type offset);
206 undecay_expression m_e;
207 inner_storage_type m_storage;
208 inner_shape_type m_shape;
209 inner_strides_type m_strides;
210 inner_backstrides_type m_backstrides;
222 struct inner_storage_getter
224 using type =
decltype(std::declval<CT>().storage());
225 using reference = std::add_lvalue_reference_t<CT>;
228 using rebind_t = inner_storage_getter<E>;
230 static decltype(
auto) get_flat_storage(reference e)
235 static auto get_offset(reference e)
237 return e.data_offset();
240 static decltype(
auto) get_strides(reference e)
246 template <
class CT, layout_type L>
247 struct flat_adaptor_getter
249 using type = flat_expression_adaptor<std::remove_reference_t<CT>, L>;
250 using reference = std::add_lvalue_reference_t<CT>;
253 using rebind_t = flat_adaptor_getter<E, L>;
255 static type get_flat_storage(reference e)
258 return type(std::addressof(e));
261 static auto get_offset(reference)
263 return typename std::decay_t<CT>::size_type(0);
266 static auto get_strides(reference e)
268 dynamic_shape<std::ptrdiff_t> strides;
269 strides.resize(e.shape().size());
275 template <
class CT, layout_type L>
276 using flat_storage_getter = std::conditional_t<
278 inner_storage_getter<CT>,
279 flat_adaptor_getter<CT, L>>;
281 template <layout_type L,
class E>
282 inline auto get_offset(E& e)
284 return flat_storage_getter<E, L>::get_offset(e);
287 template <layout_type L,
class E>
288 inline decltype(
auto) get_strides(E& e)
290 return flat_storage_getter<E, L>::get_strides(e);
312 template <
class CTA,
class SA>
320 : m_e(std::forward<CTA>(e))
323 m_storage(storage_getter::get_flat_storage(m_e))
324 , m_shape(std::forward<SA>(
shape))
325 , m_strides(std::move(
strides))
329 m_backstrides = xtl::make_sequence<backstrides_type>(m_shape.size(), 0);
330 adapt_strides(m_shape, m_strides, m_backstrides);
335 template <
class T,
class S>
336 auto& copy_move_storage(T& expr,
const S& )
338 return expr.storage();
341 template <
class T,
class E, layout_type L>
342 auto copy_move_storage(T& expr,
const detail::flat_expression_adaptor<E, L>& storage)
344 detail::flat_expression_adaptor<E, L> new_storage = storage;
345 new_storage.update_pointer(std::addressof(expr));
352 : base_type(std::move(rhs))
353 , m_e(std::forward<undecay_expression>(rhs.m_e))
354 , m_storage(detail::copy_move_storage(m_e, rhs.m_storage))
355 , m_shape(std::move(rhs.m_shape))
356 , m_strides(std::move(rhs.m_strides))
357 , m_backstrides(std::move(rhs.m_backstrides))
358 , m_offset(std::move(rhs.m_offset))
359 , m_layout(std::move(rhs.m_layout))
367 , m_storage(detail::copy_move_storage(m_e, rhs.m_storage))
368 , m_shape(rhs.m_shape)
369 , m_strides(rhs.m_strides)
370 , m_backstrides(rhs.m_backstrides)
371 , m_offset(rhs.m_offset)
372 , m_layout(rhs.m_layout)
406 return m_backstrides;
419 inline bool xstrided_view_base<D>::is_contiguous() const noexcept
431 inline auto xstrided_view_base<D>::operator()() -> reference
433 return m_storage[
static_cast<size_type
>(m_offset)];
437 inline auto xstrided_view_base<D>::operator()() const -> const_reference
439 return m_storage[
static_cast<size_type
>(m_offset)];
449 template <
class... Args>
450 inline auto xstrided_view_base<D>::operator()(Args... args) -> reference
452 XTENSOR_TRY(check_index(
shape(), args...));
453 XTENSOR_CHECK_DIMENSION(
shape(), args...);
454 offset_type index = compute_index(args...);
455 return m_storage[
static_cast<size_type
>(index)];
465 template <
class... Args>
466 inline auto xstrided_view_base<D>::operator()(Args... args)
const -> const_reference
468 XTENSOR_TRY(check_index(
shape(), args...));
469 XTENSOR_CHECK_DIMENSION(
shape(), args...);
470 offset_type index = compute_index(args...);
471 return m_storage[
static_cast<size_type
>(index)];
494 template <
class... Args>
495 inline auto xstrided_view_base<D>::unchecked(Args... args) -> reference
497 offset_type index = compute_unchecked_index(args...);
498 return m_storage[
static_cast<size_type
>(index)];
521 template <
class... Args>
522 inline auto xstrided_view_base<D>::unchecked(Args... args)
const -> const_reference
524 offset_type index = compute_unchecked_index(args...);
525 return m_storage[
static_cast<size_type
>(index)];
537 inline auto xstrided_view_base<D>::element(It first, It last) -> reference
539 XTENSOR_TRY(check_element_index(
shape(), first, last));
540 return m_storage[
static_cast<size_type
>(compute_element_index(first, last))];
552 inline auto xstrided_view_base<D>::element(It first, It last)
const -> const_reference
554 XTENSOR_TRY(check_element_index(
shape(), first, last));
555 return m_storage[
static_cast<size_type
>(compute_element_index(first, last))];
581 template <
class E,
class ST>
582 inline auto xstrided_view_base<D>::data() noexcept
583 -> std::enable_if_t<detail::provides_data_interface<E, ST>::value, pointer>
593 template <
class E,
class ST>
594 inline auto xstrided_view_base<D>::data() const noexcept
595 -> std::enable_if_t<detail::provides_data_interface<E, ST>::value, const_pointer>
643 return xt::broadcast_shape(m_shape,
shape);
656 && std::equal(str.cbegin(), str.cend(),
strides().begin());
662 template <
class... Args>
663 inline auto xstrided_view_base<D>::compute_index(Args... args)
const -> offset_type
665 return static_cast<offset_type
>(m_offset)
666 + xt::data_offset<offset_type>(
strides(),
static_cast<offset_type
>(args)...);
670 template <
class... Args>
671 inline auto xstrided_view_base<D>::compute_unchecked_index(Args... args)
const -> offset_type
673 return static_cast<offset_type
>(m_offset)
674 + xt::unchecked_data_offset<offset_type>(
strides(),
static_cast<offset_type
>(args)...);
679 inline auto xstrided_view_base<D>::compute_element_index(It first, It last)
const -> offset_type
681 return static_cast<offset_type
>(m_offset) + xt::element_offset<offset_type>(
strides(), first, last);
685 void xstrided_view_base<D>::set_offset(size_type offset)
696 template <
class CT, layout_type L>
697 inline flat_expression_adaptor<CT, L>::flat_expression_adaptor(CT* e)
700 resize_container(get_index(), m_e->dimension());
701 resize_container(m_strides, m_e->dimension());
702 m_size = compute_size(m_e->shape());
706 template <
class CT, layout_type L>
708 inline flat_expression_adaptor<CT, L>::flat_expression_adaptor(CT* e, FST&&
strides)
710 , m_strides(xtl::forward_sequence<inner_strides_type, FST>(
strides))
712 resize_container(get_index(), m_e->dimension());
713 m_size = m_e->size();
716 template <
class CT, layout_type L>
717 inline void flat_expression_adaptor<CT, L>::update_pointer(CT* ptr)
const
722 template <
class CT, layout_type L>
723 inline auto flat_expression_adaptor<CT, L>::size() const -> size_type
728 template <
class CT, layout_type L>
729 inline auto flat_expression_adaptor<CT, L>::operator[](size_type idx) -> reference
731 auto i =
static_cast<typename index_type::value_type
>(idx);
732 get_index() = detail::unravel_noexcept(i, m_strides, L);
733 return m_e->element(get_index().cbegin(), get_index().cend());
736 template <
class CT, layout_type L>
737 inline auto flat_expression_adaptor<CT, L>::operator[](size_type idx)
const -> const_reference
739 auto i =
static_cast<typename index_type::value_type
>(idx);
740 get_index() = detail::unravel_noexcept(i, m_strides, L);
741 return m_e->element(get_index().cbegin(), get_index().cend());
744 template <
class CT, layout_type L>
745 inline auto flat_expression_adaptor<CT, L>::begin() -> iterator
747 return m_e->template begin<L>();
750 template <
class CT, layout_type L>
751 inline auto flat_expression_adaptor<CT, L>::end() -> iterator
753 return m_e->template end<L>();
756 template <
class CT, layout_type L>
757 inline auto flat_expression_adaptor<CT, L>::begin() const -> const_iterator
759 return m_e->template cbegin<L>();
762 template <
class CT, layout_type L>
763 inline auto flat_expression_adaptor<CT, L>::end() const -> const_iterator
765 return m_e->template cend<L>();
768 template <
class CT, layout_type L>
769 inline auto flat_expression_adaptor<CT, L>::cbegin() const -> const_iterator
771 return m_e->template cbegin<L>();
774 template <
class CT, layout_type L>
775 inline auto flat_expression_adaptor<CT, L>::cend() const -> const_iterator
777 return m_e->template cend<L>();
780 template <
class CT, layout_type L>
781 inline auto flat_expression_adaptor<CT, L>::get_index() -> index_type&
783 thread_local static index_type index;
795 struct slice_getter_impl
798 mutable std::size_t idx;
799 using array_type = std::array<std::ptrdiff_t, 3>;
801 explicit slice_getter_impl(
const S& shape)
808 array_type operator()(
const T& )
const
810 return array_type{{0, 0, 0}};
813 template <
class A,
class B,
class C>
814 array_type operator()(
const xrange_adaptor<A, B, C>& range)
const
816 auto sl =
range.get(
static_cast<std::size_t
>(m_shape[idx]));
817 return array_type({sl(0), sl.size(), sl.step_size()});
821 array_type operator()(
const xrange<T>& range)
const
823 return array_type({
range(T(0)),
range.size(), T(1)});
827 array_type operator()(
const xstepped_range<T>& range)
const
833 template <
class adj_str
ides_policy>
834 struct strided_view_args : adj_strides_policy
836 using base_type = adj_strides_policy;
838 template <
class S,
class ST,
class V>
840 fill_args(
const S& shape, ST&& old_strides, std::size_t base_offset, layout_type layout,
const V& slices)
843 std::size_t dimension = shape.size(), n_newaxis = 0, n_add_all = 0;
844 std::ptrdiff_t dimension_check =
static_cast<std::ptrdiff_t
>(shape.size());
846 bool has_ellipsis =
false;
847 for (
const auto& el : slices)
849 if (xtl::get_if<xt::xnewaxis_tag>(&el) !=
nullptr)
854 else if (xtl::get_if<std::ptrdiff_t>(&el) !=
nullptr)
859 else if (xtl::get_if<xt::xellipsis_tag>(&el) !=
nullptr)
861 if (has_ellipsis ==
true)
863 XTENSOR_THROW(std::runtime_error,
"Ellipsis can only appear once.");
873 if (dimension_check < 0)
875 XTENSOR_THROW(std::runtime_error,
"Too many slices for view.");
882 n_add_all = shape.size() - (slices.size() - 1 - n_newaxis);
886 new_offset = base_offset;
887 new_shape.resize(dimension);
888 new_strides.resize(dimension);
889 base_type::resize(dimension);
891 auto old_shape = shape;
892 using old_strides_value_type =
typename std::decay_t<ST>::value_type;
894 std::ptrdiff_t axis_skip = 0;
895 std::size_t idx = 0, i = 0, i_ax = 0;
897 auto slice_getter = detail::slice_getter_impl<S>(shape);
899 for (; i < slices.size(); ++i)
901 i_ax =
static_cast<std::size_t
>(
static_cast<std::ptrdiff_t
>(i) - axis_skip);
902 auto ptr = xtl::get_if<std::ptrdiff_t>(&slices[i]);
905 auto slice0 =
static_cast<old_strides_value_type
>(*ptr);
906 new_offset +=
static_cast<std::size_t
>(slice0 * old_strides[i_ax]);
908 else if (xtl::get_if<xt::xnewaxis_tag>(&slices[i]) !=
nullptr)
911 base_type::set_fake_slice(idx);
914 else if (xtl::get_if<xt::xellipsis_tag>(&slices[i]) !=
nullptr)
916 for (std::size_t j = 0; j < n_add_all; ++j)
918 new_shape[idx] = old_shape[i_ax];
919 new_strides[idx] = old_strides[i_ax];
920 base_type::set_fake_slice(idx);
923 axis_skip = axis_skip -
static_cast<std::ptrdiff_t
>(n_add_all) + 1;
925 else if (xtl::get_if<xt::xall_tag>(&slices[i]) !=
nullptr)
927 new_shape[idx] = old_shape[i_ax];
928 new_strides[idx] = old_strides[i_ax];
929 base_type::set_fake_slice(idx);
932 else if (base_type::fill_args(slices, i, idx, old_shape[i_ax], old_strides[i_ax], new_shape, new_strides))
938 slice_getter.idx = i_ax;
939 auto info = xtl::visit(slice_getter, slices[i]);
940 new_offset +=
static_cast<std::size_t
>(info[0] * old_strides[i_ax]);
941 new_shape[idx] =
static_cast<std::size_t
>(info[1]);
942 new_strides[idx] = info[2] * old_strides[i_ax];
943 base_type::set_fake_slice(idx);
948 i_ax =
static_cast<std::size_t
>(
static_cast<std::ptrdiff_t
>(i) - axis_skip);
949 for (; i_ax < old_shape.size(); ++i_ax, ++idx)
951 new_shape[idx] = old_shape[i_ax];
952 new_strides[idx] = old_strides[i_ax];
953 base_type::set_fake_slice(idx);
956 new_layout = do_strides_match(new_shape, new_strides, layout,
true) ? layout
957 : layout_type::dynamic;
960 using shape_type = dynamic_shape<std::size_t>;
961 shape_type new_shape;
962 using strides_type = dynamic_shape<std::ptrdiff_t>;
963 strides_type new_strides;
964 std::size_t new_offset;
layout_type layout() const noexcept
xstrided_view_base(CTA &&e, SA &&shape, strides_type &&strides, size_type offset, layout_type layout) noexcept
Constructs an xstrided_view_base.
bool has_linear_assign(const O &strides) const noexcept
const inner_strides_type & strides() const noexcept
bool broadcast_shape(O &shape, bool reuse_cache=false) const
const inner_backstrides_type & backstrides() const noexcept
const inner_shape_type & shape() const noexcept
size_type data_offset() const noexcept
storage_type & storage() noexcept
xexpression_type & expression() noexcept
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).