10#ifndef XTENSOR_STRIDED_VIEW_BASE_HPP
11#define XTENSOR_STRIDED_VIEW_BASE_HPP
16#include <xtl/xsequence.hpp>
18#include "../core/xaccessible.hpp"
19#include "../core/xstrides.hpp"
20#include "../core/xtensor_config.hpp"
21#include "../core/xtensor_forward.hpp"
22#include "../utils/xutils.hpp"
23#include "../views/xslice.hpp"
29 template <
class CT, layout_type L>
30 class flat_expression_adaptor
34 using xexpression_type = std::decay_t<CT>;
35 using shape_type =
typename xexpression_type::shape_type;
36 using inner_strides_type = get_strides_t<shape_type>;
37 using index_type = inner_strides_type;
38 using size_type =
typename xexpression_type::size_type;
39 using value_type =
typename xexpression_type::value_type;
40 using const_reference =
typename xexpression_type::const_reference;
41 using reference = std::conditional_t<
42 std::is_const<std::remove_reference_t<CT>>::value,
43 typename xexpression_type::const_reference,
44 typename xexpression_type::reference>;
46 using iterator =
decltype(std::declval<std::remove_reference_t<CT>>().template begin<L>());
47 using const_iterator =
decltype(std::declval<std::decay_t<CT>>().template cbegin<L>());
48 using reverse_iterator =
decltype(std::declval<std::remove_reference_t<CT>>().template rbegin<L>());
49 using const_reverse_iterator =
decltype(std::declval<std::decay_t<CT>>().template crbegin<L>());
51 explicit flat_expression_adaptor(CT* e);
54 flat_expression_adaptor(CT* e, FST&&
strides);
56 void update_pointer(CT* ptr)
const;
58 size_type size()
const;
59 reference operator[](size_type idx);
60 const_reference operator[](size_type idx)
const;
64 const_iterator begin()
const;
65 const_iterator end()
const;
66 const_iterator cbegin()
const;
67 const_iterator cend()
const;
71 static index_type& get_index();
74 inner_strides_type m_strides;
79 struct is_flat_expression_adaptor : std::false_type
83 template <
class CT, layout_type L>
84 struct is_flat_expression_adaptor<flat_expression_adaptor<CT, L>> : std::true_type
88 template <
class E,
class ST>
89 struct provides_data_interface
90 : std::conjunction<has_data_interface<std::decay_t<E>>, std::negation<is_flat_expression_adaptor<ST>>>
100 using base_type = xaccessible<D>;
102 using xexpression_type =
typename inner_types::xexpression_type;
103 using undecay_expression =
typename inner_types::undecay_expression;
104 static constexpr bool is_const = std::is_const<std::remove_reference_t<undecay_expression>>::value;
106 using value_type =
typename xexpression_type::value_type;
107 using reference =
typename inner_types::reference;
108 using const_reference =
typename inner_types::const_reference;
109 using pointer = std::
110 conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
111 using const_pointer =
typename xexpression_type::const_pointer;
112 using size_type =
typename inner_types::size_type;
113 using difference_type =
typename xexpression_type::difference_type;
115 using storage_getter =
typename inner_types::storage_getter;
116 using inner_storage_type =
typename inner_types::inner_storage_type;
117 using storage_type = std::remove_reference_t<inner_storage_type>;
119 using shape_type =
typename inner_types::shape_type;
120 using strides_type = get_strides_t<shape_type>;
121 using backstrides_type = strides_type;
123 using inner_shape_type = shape_type;
124 using inner_strides_type = strides_type;
125 using inner_backstrides_type = backstrides_type;
127 using undecay_shape =
typename inner_types::undecay_shape;
129 using simd_value_type = xt_simd::simd_type<value_type>;
130 using bool_load_type =
typename xexpression_type::bool_load_type;
132 static constexpr layout_type static_layout = inner_types::layout;
134 && xexpression_type::contiguous_layout;
136 static constexpr bool
137 provides_data_interface = detail::provides_data_interface<xexpression_type, storage_type>::value;
139 template <
class CTA,
class SA>
146 const inner_shape_type&
shape() const noexcept;
147 const inner_strides_type&
strides() const noexcept;
150 bool is_contiguous() const noexcept;
151 using base_type::
shape;
153 reference operator()();
154 const_reference operator()() const;
156 template <class... Args>
157 reference operator()(Args... args);
159 template <class... Args>
160 const_reference operator()(Args... args) const;
162 template <class... Args>
163 reference unchecked(Args... args);
165 template <class... Args>
166 const_reference unchecked(Args... args) const;
169 reference element(It first, It last);
172 const_reference element(It first, It last) const;
178 requires(provides_data_interface);
179 const_pointer
data() const noexcept
180 requires(provides_data_interface);
195 using offset_type = typename strides_type::value_type;
197 template <class... Args>
198 offset_type compute_index(Args... args) const;
200 template <class... Args>
201 offset_type compute_unchecked_index(Args... args) const;
204 offset_type compute_element_index(It first, It last) const;
206 void set_offset(size_type offset);
210 undecay_expression m_e;
211 inner_storage_type m_storage;
212 inner_shape_type m_shape;
213 inner_strides_type m_strides;
214 inner_backstrides_type m_backstrides;
226 struct inner_storage_getter
228 using type =
decltype(std::declval<CT>().storage());
229 using reference = std::add_lvalue_reference_t<CT>;
232 using rebind_t = inner_storage_getter<E>;
234 static decltype(
auto) get_flat_storage(reference e)
239 static auto get_offset(reference e)
241 return e.data_offset();
244 static decltype(
auto) get_strides(reference e)
250 template <
class CT, layout_type L>
251 struct flat_adaptor_getter
253 using type = flat_expression_adaptor<std::remove_reference_t<CT>, L>;
254 using reference = std::add_lvalue_reference_t<CT>;
257 using rebind_t = flat_adaptor_getter<E, L>;
259 static type get_flat_storage(reference e)
262 return type(std::addressof(e));
265 static auto get_offset(reference)
267 return typename std::decay_t<CT>::size_type(0);
270 static auto get_strides(reference e)
272 dynamic_shape<std::ptrdiff_t> strides;
273 strides.resize(e.shape().size());
279 template <
class CT, layout_type L>
280 using flat_storage_getter = std::conditional_t<
282 inner_storage_getter<CT>,
283 flat_adaptor_getter<CT, L>>;
285 template <layout_type L,
class E>
286 inline auto get_offset(E& e)
288 return flat_storage_getter<E, L>::get_offset(e);
291 template <layout_type L,
class E>
292 inline decltype(
auto) get_strides(E& e)
294 return flat_storage_getter<E, L>::get_strides(e);
316 template <
class CTA,
class SA>
324 : m_e(std::forward<CTA>(e))
327 m_storage(storage_getter::get_flat_storage(m_e))
328 , m_shape(std::forward<SA>(
shape))
329 , m_strides(std::move(
strides))
333 m_backstrides = xtl::make_sequence<backstrides_type>(m_shape.size(), 0);
334 adapt_strides(m_shape, m_strides, m_backstrides);
339 template <
class T,
class S>
340 auto& copy_move_storage(T& expr,
const S& )
342 return expr.storage();
345 template <
class T,
class E, layout_type L>
346 auto copy_move_storage(T& expr,
const detail::flat_expression_adaptor<E, L>& storage)
348 detail::flat_expression_adaptor<E, L> new_storage = storage;
349 new_storage.update_pointer(std::addressof(expr));
356 : base_type(std::move(rhs))
357 , m_e(std::forward<undecay_expression>(rhs.m_e))
358 , m_storage(detail::copy_move_storage(m_e, rhs.m_storage))
359 , m_shape(std::move(rhs.m_shape))
360 , m_strides(std::move(rhs.m_strides))
361 , m_backstrides(std::move(rhs.m_backstrides))
362 , m_offset(std::move(rhs.m_offset))
363 , m_layout(std::move(rhs.m_layout))
371 , m_storage(detail::copy_move_storage(m_e, rhs.m_storage))
372 , m_shape(rhs.m_shape)
373 , m_strides(rhs.m_strides)
374 , m_backstrides(rhs.m_backstrides)
375 , m_offset(rhs.m_offset)
376 , m_layout(rhs.m_layout)
410 return m_backstrides;
423 inline bool xstrided_view_base<D>::is_contiguous() const noexcept
435 inline auto xstrided_view_base<D>::operator()() -> reference
437 return m_storage[
static_cast<size_type
>(m_offset)];
441 inline auto xstrided_view_base<D>::operator()() const -> const_reference
443 return m_storage[
static_cast<size_type
>(m_offset)];
453 template <
class... Args>
454 inline auto xstrided_view_base<D>::operator()(Args... args) -> reference
456 XTENSOR_TRY(check_index(
shape(), args...));
457 XTENSOR_CHECK_DIMENSION(
shape(), args...);
458 offset_type index = compute_index(args...);
459 return m_storage[
static_cast<size_type
>(index)];
469 template <
class... Args>
470 inline auto xstrided_view_base<D>::operator()(Args... args)
const -> const_reference
472 XTENSOR_TRY(check_index(
shape(), args...));
473 XTENSOR_CHECK_DIMENSION(
shape(), args...);
474 offset_type index = compute_index(args...);
475 return m_storage[
static_cast<size_type
>(index)];
498 template <
class... Args>
499 inline auto xstrided_view_base<D>::unchecked(Args... args) -> reference
501 offset_type index = compute_unchecked_index(args...);
502 return m_storage[
static_cast<size_type
>(index)];
525 template <
class... Args>
526 inline auto xstrided_view_base<D>::unchecked(Args... args)
const -> const_reference
528 offset_type index = compute_unchecked_index(args...);
529 return m_storage[
static_cast<size_type
>(index)];
541 inline auto xstrided_view_base<D>::element(It first, It last) -> reference
543 XTENSOR_TRY(check_element_index(
shape(), first, last));
544 return m_storage[
static_cast<size_type
>(compute_element_index(first, last))];
556 inline auto xstrided_view_base<D>::element(It first, It last)
const -> const_reference
558 XTENSOR_TRY(check_element_index(
shape(), first, last));
559 return m_storage[
static_cast<size_type
>(compute_element_index(first, last))];
586 requires(provides_data_interface)
597 requires(provides_data_interface)
645 return xt::broadcast_shape(m_shape,
shape);
658 && std::equal(str.cbegin(), str.cend(),
strides().begin());
664 template <
class... Args>
665 inline auto xstrided_view_base<D>::compute_index(Args... args)
const -> offset_type
667 return static_cast<offset_type
>(m_offset)
668 + xt::data_offset<offset_type>(
strides(),
static_cast<offset_type
>(args)...);
672 template <
class... Args>
673 inline auto xstrided_view_base<D>::compute_unchecked_index(Args... args)
const -> offset_type
675 return static_cast<offset_type
>(m_offset)
676 + xt::unchecked_data_offset<offset_type>(
strides(),
static_cast<offset_type
>(args)...);
681 inline auto xstrided_view_base<D>::compute_element_index(It first, It last)
const -> offset_type
683 return static_cast<offset_type
>(m_offset) + xt::element_offset<offset_type>(
strides(), first, last);
687 void xstrided_view_base<D>::set_offset(size_type offset)
698 template <
class CT, layout_type L>
699 inline flat_expression_adaptor<CT, L>::flat_expression_adaptor(CT* e)
702 resize_container(get_index(), m_e->dimension());
703 resize_container(m_strides, m_e->dimension());
704 m_size = compute_size(m_e->shape());
708 template <
class CT, layout_type L>
710 inline flat_expression_adaptor<CT, L>::flat_expression_adaptor(CT* e, FST&&
strides)
712 , m_strides(xtl::forward_sequence<inner_strides_type, FST>(
strides))
714 resize_container(get_index(), m_e->dimension());
715 m_size = m_e->size();
718 template <
class CT, layout_type L>
719 inline void flat_expression_adaptor<CT, L>::update_pointer(CT* ptr)
const
724 template <
class CT, layout_type L>
725 inline auto flat_expression_adaptor<CT, L>::size() const -> size_type
730 template <
class CT, layout_type L>
731 inline auto flat_expression_adaptor<CT, L>::operator[](size_type idx) -> reference
733 auto i =
static_cast<typename index_type::value_type
>(idx);
734 get_index() = detail::unravel_noexcept(i, m_strides, L);
735 return m_e->element(get_index().cbegin(), get_index().cend());
738 template <
class CT, layout_type L>
739 inline auto flat_expression_adaptor<CT, L>::operator[](size_type idx)
const -> const_reference
741 auto i =
static_cast<typename index_type::value_type
>(idx);
742 get_index() = detail::unravel_noexcept(i, m_strides, L);
743 return m_e->element(get_index().cbegin(), get_index().cend());
746 template <
class CT, layout_type L>
747 inline auto flat_expression_adaptor<CT, L>::begin() -> iterator
749 return m_e->template begin<L>();
752 template <
class CT, layout_type L>
753 inline auto flat_expression_adaptor<CT, L>::end() -> iterator
755 return m_e->template end<L>();
758 template <
class CT, layout_type L>
759 inline auto flat_expression_adaptor<CT, L>::begin() const -> const_iterator
761 return m_e->template cbegin<L>();
764 template <
class CT, layout_type L>
765 inline auto flat_expression_adaptor<CT, L>::end() const -> const_iterator
767 return m_e->template cend<L>();
770 template <
class CT, layout_type L>
771 inline auto flat_expression_adaptor<CT, L>::cbegin() const -> const_iterator
773 return m_e->template cbegin<L>();
776 template <
class CT, layout_type L>
777 inline auto flat_expression_adaptor<CT, L>::cend() const -> const_iterator
779 return m_e->template cend<L>();
782 template <
class CT, layout_type L>
783 inline auto flat_expression_adaptor<CT, L>::get_index() -> index_type&
785 thread_local static index_type index;
797 struct slice_getter_impl
800 mutable std::size_t idx;
801 using array_type = std::array<std::ptrdiff_t, 3>;
803 explicit slice_getter_impl(
const S& shape)
810 array_type operator()(
const T& )
const
812 return array_type{{0, 0, 0}};
815 template <
class A,
class B,
class C>
816 array_type operator()(
const xrange_adaptor<A, B, C>& range)
const
818 auto sl =
range.get(
static_cast<std::size_t
>(m_shape[idx]));
819 return array_type({sl(0), sl.size(), sl.step_size()});
823 array_type operator()(
const xrange<T>& range)
const
825 return array_type({
range(T(0)),
range.size(), T(1)});
829 array_type operator()(
const xstepped_range<T>& range)
const
835 template <
class adj_str
ides_policy>
836 struct strided_view_args : adj_strides_policy
838 using base_type = adj_strides_policy;
840 template <
class S,
class ST,
class V>
842 fill_args(
const S& shape, ST&& old_strides, std::size_t base_offset, layout_type layout,
const V& slices)
845 std::size_t dimension = shape.size(), n_newaxis = 0, n_add_all = 0;
846 std::ptrdiff_t dimension_check =
static_cast<std::ptrdiff_t
>(shape.size());
848 bool has_ellipsis =
false;
849 for (
const auto& el : slices)
851 if (std::get_if<xt::xnewaxis_tag>(&el) !=
nullptr)
856 else if (std::get_if<std::ptrdiff_t>(&el) !=
nullptr)
861 else if (std::get_if<xt::xellipsis_tag>(&el) !=
nullptr)
863 if (has_ellipsis ==
true)
865 XTENSOR_THROW(std::runtime_error,
"Ellipsis can only appear once.");
875 if (dimension_check < 0)
877 XTENSOR_THROW(std::runtime_error,
"Too many slices for view.");
884 n_add_all = shape.size() - (slices.size() - 1 - n_newaxis);
888 new_offset = base_offset;
889 new_shape.resize(dimension);
890 new_strides.resize(dimension);
891 base_type::resize(dimension);
893 auto old_shape = shape;
894 using old_strides_value_type =
typename std::decay_t<ST>::value_type;
896 std::ptrdiff_t axis_skip = 0;
897 std::size_t idx = 0, i = 0, i_ax = 0;
899 auto slice_getter = detail::slice_getter_impl<S>(shape);
901 for (; i < slices.size(); ++i)
903 i_ax =
static_cast<std::size_t
>(
static_cast<std::ptrdiff_t
>(i) - axis_skip);
904 auto ptr = std::get_if<std::ptrdiff_t>(&slices[i]);
907 auto slice0 =
static_cast<old_strides_value_type
>(*ptr);
908 new_offset +=
static_cast<std::size_t
>(slice0 * old_strides[i_ax]);
910 else if (std::get_if<xt::xnewaxis_tag>(&slices[i]) !=
nullptr)
913 base_type::set_fake_slice(idx);
916 else if (std::get_if<xt::xellipsis_tag>(&slices[i]) !=
nullptr)
918 for (std::size_t j = 0; j < n_add_all; ++j)
920 new_shape[idx] = old_shape[i_ax];
921 new_strides[idx] = old_strides[i_ax];
922 base_type::set_fake_slice(idx);
925 axis_skip = axis_skip -
static_cast<std::ptrdiff_t
>(n_add_all) + 1;
927 else if (std::get_if<xt::xall_tag>(&slices[i]) !=
nullptr)
929 new_shape[idx] = old_shape[i_ax];
930 new_strides[idx] = old_strides[i_ax];
931 base_type::set_fake_slice(idx);
934 else if (base_type::fill_args(slices, i, idx, old_shape[i_ax], old_strides[i_ax], new_shape, new_strides))
940 slice_getter.idx = i_ax;
941 auto info = std::visit(slice_getter, slices[i]);
942 new_offset +=
static_cast<std::size_t
>(info[0] * old_strides[i_ax]);
943 new_shape[idx] =
static_cast<std::size_t
>(info[1]);
944 new_strides[idx] = info[2] * old_strides[i_ax];
945 base_type::set_fake_slice(idx);
950 i_ax =
static_cast<std::size_t
>(
static_cast<std::ptrdiff_t
>(i) - axis_skip);
951 for (; i_ax < old_shape.size(); ++i_ax, ++idx)
953 new_shape[idx] = old_shape[i_ax];
954 new_strides[idx] = old_strides[i_ax];
955 base_type::set_fake_slice(idx);
958 new_layout = do_strides_match(new_shape, new_strides, layout,
true) ? layout
959 : layout_type::dynamic;
962 using shape_type = dynamic_shape<std::size_t>;
963 shape_type new_shape;
964 using strides_type = dynamic_shape<std::ptrdiff_t>;
965 strides_type new_strides;
966 std::size_t new_offset;
layout_type layout() const noexcept
xstrided_view_base(CTA &&e, SA &&shape, strides_type &&strides, size_type offset, layout_type layout) noexcept
Constructs an xstrided_view_base.
bool has_linear_assign(const O &strides) const noexcept
const inner_strides_type & strides() const noexcept
bool broadcast_shape(O &shape, bool reuse_cache=false) const
const inner_backstrides_type & backstrides() const noexcept
const inner_shape_type & shape() const noexcept
size_type data_offset() const noexcept
storage_type & storage() noexcept
xexpression_type & expression() noexcept
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).