10#ifndef XTENSOR_STRIDED_VIEW_BASE_HPP 
   11#define XTENSOR_STRIDED_VIEW_BASE_HPP 
   16#include <xtl/xsequence.hpp> 
   18#include "../core/xaccessible.hpp" 
   19#include "../core/xstrides.hpp" 
   20#include "../core/xtensor_config.hpp" 
   21#include "../core/xtensor_forward.hpp" 
   22#include "../utils/xutils.hpp" 
   23#include "../views/xslice.hpp" 
   29        template <
class CT, layout_type L>
 
   30        class flat_expression_adaptor
 
   34            using xexpression_type = std::decay_t<CT>;
 
   35            using shape_type = 
typename xexpression_type::shape_type;
 
   36            using inner_strides_type = get_strides_t<shape_type>;
 
   37            using index_type = inner_strides_type;
 
   38            using size_type = 
typename xexpression_type::size_type;
 
   39            using value_type = 
typename xexpression_type::value_type;
 
   40            using const_reference = 
typename xexpression_type::const_reference;
 
   41            using reference = std::conditional_t<
 
   42                std::is_const<std::remove_reference_t<CT>>::value,
 
   43                typename xexpression_type::const_reference,
 
   44                typename xexpression_type::reference>;
 
   46            using iterator = 
decltype(std::declval<std::remove_reference_t<CT>>().template begin<L>());
 
   47            using const_iterator = 
decltype(std::declval<std::decay_t<CT>>().template cbegin<L>());
 
   48            using reverse_iterator = 
decltype(std::declval<std::remove_reference_t<CT>>().template rbegin<L>());
 
   49            using const_reverse_iterator = 
decltype(std::declval<std::decay_t<CT>>().template crbegin<L>());
 
   51            explicit flat_expression_adaptor(CT* e);
 
   54            flat_expression_adaptor(CT* e, FST&& 
strides);
 
   56            void update_pointer(CT* ptr) 
const;
 
   58            size_type size() 
const;
 
   59            reference operator[](size_type idx);
 
   60            const_reference operator[](size_type idx) 
const;
 
   64            const_iterator begin() 
const;
 
   65            const_iterator end() 
const;
 
   66            const_iterator cbegin() 
const;
 
   67            const_iterator cend() 
const;
 
   71            static index_type& get_index();
 
   74            inner_strides_type m_strides;
 
   79        struct is_flat_expression_adaptor : std::false_type
 
   83        template <
class CT, layout_type L>
 
   84        struct is_flat_expression_adaptor<flat_expression_adaptor<CT, L>> : std::true_type
 
   88        template <
class E, 
class ST>
 
   89        struct provides_data_interface
 
   90            : std::conjunction<has_data_interface<std::decay_t<E>>, std::negation<is_flat_expression_adaptor<ST>>>
 
  100        using base_type = xaccessible<D>;
 
  102        using xexpression_type = 
typename inner_types::xexpression_type;
 
  103        using undecay_expression = 
typename inner_types::undecay_expression;
 
  104        static constexpr bool is_const = std::is_const<std::remove_reference_t<undecay_expression>>::value;
 
  106        using value_type = 
typename xexpression_type::value_type;
 
  107        using reference = 
typename inner_types::reference;
 
  108        using const_reference = 
typename inner_types::const_reference;
 
  109        using pointer = std::
 
  110            conditional_t<is_const, typename xexpression_type::const_pointer, typename xexpression_type::pointer>;
 
  111        using const_pointer = 
typename xexpression_type::const_pointer;
 
  112        using size_type = 
typename inner_types::size_type;
 
  113        using difference_type = 
typename xexpression_type::difference_type;
 
  115        using storage_getter = 
typename inner_types::storage_getter;
 
  116        using inner_storage_type = 
typename inner_types::inner_storage_type;
 
  117        using storage_type = std::remove_reference_t<inner_storage_type>;
 
  119        using shape_type = 
typename inner_types::shape_type;
 
  120        using strides_type = get_strides_t<shape_type>;
 
  121        using backstrides_type = strides_type;
 
  123        using inner_shape_type = shape_type;
 
  124        using inner_strides_type = strides_type;
 
  125        using inner_backstrides_type = backstrides_type;
 
  127        using undecay_shape = 
typename inner_types::undecay_shape;
 
  129        using simd_value_type = xt_simd::simd_type<value_type>;
 
  130        using bool_load_type = 
typename xexpression_type::bool_load_type;
 
  132        static constexpr layout_type static_layout = inner_types::layout;
 
  134                                                  && xexpression_type::contiguous_layout;
 
  136        static constexpr bool 
  137            provides_data_interface = detail::provides_data_interface<xexpression_type, storage_type>::value;
 
  139        template <
class CTA, 
class SA>
 
  146        const inner_shape_type& 
shape() const noexcept;
 
  147        const inner_strides_type& 
strides() const noexcept;
 
  150        bool is_contiguous() const noexcept;
 
  151        using base_type::
shape;
 
  153        reference operator()();
 
  154        const_reference operator()() const;
 
  156        template <class... Args>
 
  157        reference operator()(Args... args);
 
  159        template <class... Args>
 
  160        const_reference operator()(Args... args) const;
 
  162        template <class... Args>
 
  163        reference unchecked(Args... args);
 
  165        template <class... Args>
 
  166        const_reference unchecked(Args... args) const;
 
  169        reference element(It first, It last);
 
  172        const_reference element(It first, It last) const;
 
  178            requires(provides_data_interface);
 
  179        const_pointer 
data() const noexcept
 
  180            requires(provides_data_interface);
 
  195        using offset_type = typename strides_type::value_type;
 
  197        template <class... Args>
 
  198        offset_type compute_index(Args... args) const;
 
  200        template <class... Args>
 
  201        offset_type compute_unchecked_index(Args... args) const;
 
  204        offset_type compute_element_index(It first, It last) const;
 
  206        void set_offset(size_type offset);
 
  210        undecay_expression m_e;
 
  211        inner_storage_type m_storage;
 
  212        inner_shape_type m_shape;
 
  213        inner_strides_type m_strides;
 
  214        inner_backstrides_type m_backstrides;
 
 
  226        struct inner_storage_getter
 
  228            using type = 
decltype(std::declval<CT>().storage());
 
  229            using reference = std::add_lvalue_reference_t<CT>;
 
  232            using rebind_t = inner_storage_getter<E>;
 
  234            static decltype(
auto) get_flat_storage(reference e)
 
  239            static auto get_offset(reference e)
 
  241                return e.data_offset();
 
  244            static decltype(
auto) get_strides(reference e)
 
  250        template <
class CT, layout_type L>
 
  251        struct flat_adaptor_getter
 
  253            using type = flat_expression_adaptor<std::remove_reference_t<CT>, L>;
 
  254            using reference = std::add_lvalue_reference_t<CT>;
 
  257            using rebind_t = flat_adaptor_getter<E, L>;
 
  259            static type get_flat_storage(reference e)
 
  262                return type(std::addressof(e));
 
  265            static auto get_offset(reference)
 
  267                return typename std::decay_t<CT>::size_type(0);
 
  270            static auto get_strides(reference e)
 
  272                dynamic_shape<std::ptrdiff_t> strides;
 
  273                strides.resize(e.shape().size());
 
  279        template <
class CT, layout_type L>
 
  280        using flat_storage_getter = std::conditional_t<
 
  282            inner_storage_getter<CT>,
 
  283            flat_adaptor_getter<CT, L>>;
 
  285        template <layout_type L, 
class E>
 
  286        inline auto get_offset(E& e)
 
  288            return flat_storage_getter<E, L>::get_offset(e);
 
  291        template <layout_type L, 
class E>
 
  292        inline decltype(
auto) get_strides(E& e)
 
  294            return flat_storage_getter<E, L>::get_strides(e);
 
  316    template <
class CTA, 
class SA>
 
  324        : m_e(std::forward<CTA>(e))
 
  327        m_storage(storage_getter::get_flat_storage(m_e))
 
  328        , m_shape(std::forward<SA>(
shape))
 
  329        , m_strides(std::move(
strides))
 
  333        m_backstrides = xtl::make_sequence<backstrides_type>(m_shape.size(), 0);
 
  334        adapt_strides(m_shape, m_strides, m_backstrides);
 
 
  339        template <
class T, 
class S>
 
  340        auto& copy_move_storage(T& expr, 
const S& )
 
  342            return expr.storage();
 
  345        template <
class T, 
class E, layout_type L>
 
  346        auto copy_move_storage(T& expr, 
const detail::flat_expression_adaptor<E, L>& storage)
 
  348            detail::flat_expression_adaptor<E, L> new_storage = storage;  
 
  349            new_storage.update_pointer(std::addressof(expr));
 
  356        : base_type(std::move(rhs))
 
  357        , m_e(std::forward<undecay_expression>(rhs.m_e))
 
  358        , m_storage(detail::copy_move_storage(m_e, rhs.m_storage))
 
  359        , m_shape(std::move(rhs.m_shape))
 
  360        , m_strides(std::move(rhs.m_strides))
 
  361        , m_backstrides(std::move(rhs.m_backstrides))
 
  362        , m_offset(std::move(rhs.m_offset))
 
  363        , m_layout(std::move(rhs.m_layout))
 
  371        , m_storage(detail::copy_move_storage(m_e, rhs.m_storage))
 
  372        , m_shape(rhs.m_shape)
 
  373        , m_strides(rhs.m_strides)
 
  374        , m_backstrides(rhs.m_backstrides)
 
  375        , m_offset(rhs.m_offset)
 
  376        , m_layout(rhs.m_layout)
 
  410        return m_backstrides;
 
 
  423    inline bool xstrided_view_base<D>::is_contiguous() const noexcept
 
  435    inline auto xstrided_view_base<D>::operator()() -> reference
 
  437        return m_storage[
static_cast<size_type
>(m_offset)];
 
  441    inline auto xstrided_view_base<D>::operator()() const -> const_reference
 
  443        return m_storage[
static_cast<size_type
>(m_offset)];
 
  453    template <
class... Args>
 
  454    inline auto xstrided_view_base<D>::operator()(Args... args) -> reference
 
  456        XTENSOR_TRY(check_index(
shape(), args...));
 
  457        XTENSOR_CHECK_DIMENSION(
shape(), args...);
 
  458        offset_type index = compute_index(args...);
 
  459        return m_storage[
static_cast<size_type
>(index)];
 
 
  469    template <
class... Args>
 
  470    inline auto xstrided_view_base<D>::operator()(Args... args) 
const -> const_reference
 
  472        XTENSOR_TRY(check_index(
shape(), args...));
 
  473        XTENSOR_CHECK_DIMENSION(
shape(), args...);
 
  474        offset_type index = compute_index(args...);
 
  475        return m_storage[
static_cast<size_type
>(index)];
 
 
  498    template <
class... Args>
 
  499    inline auto xstrided_view_base<D>::unchecked(Args... args) -> reference
 
  501        offset_type index = compute_unchecked_index(args...);
 
  502        return m_storage[
static_cast<size_type
>(index)];
 
 
  525    template <
class... Args>
 
  526    inline auto xstrided_view_base<D>::unchecked(Args... args) 
const -> const_reference
 
  528        offset_type index = compute_unchecked_index(args...);
 
  529        return m_storage[
static_cast<size_type
>(index)];
 
 
  541    inline auto xstrided_view_base<D>::element(It first, It last) -> reference
 
  543        XTENSOR_TRY(check_element_index(
shape(), first, last));
 
  544        return m_storage[
static_cast<size_type
>(compute_element_index(first, last))];
 
 
  556    inline auto xstrided_view_base<D>::element(It first, It last) 
const -> const_reference
 
  558        XTENSOR_TRY(check_element_index(
shape(), first, last));
 
  559        return m_storage[
static_cast<size_type
>(compute_element_index(first, last))];
 
 
  586        requires(provides_data_interface)
 
 
  597        requires(provides_data_interface)
 
 
  645        return xt::broadcast_shape(m_shape, 
shape);
 
 
  658               && std::equal(str.cbegin(), str.cend(), 
strides().begin());
 
 
  664    template <
class... Args>
 
  665    inline auto xstrided_view_base<D>::compute_index(Args... args) 
const -> offset_type
 
  667        return static_cast<offset_type
>(m_offset)
 
  668               + xt::data_offset<offset_type>(
strides(), 
static_cast<offset_type
>(args)...);
 
  672    template <
class... Args>
 
  673    inline auto xstrided_view_base<D>::compute_unchecked_index(Args... args) 
const -> offset_type
 
  675        return static_cast<offset_type
>(m_offset)
 
  676               + xt::unchecked_data_offset<offset_type>(
strides(), 
static_cast<offset_type
>(args)...);
 
  681    inline auto xstrided_view_base<D>::compute_element_index(It first, It last) 
const -> offset_type
 
  683        return static_cast<offset_type
>(m_offset) + xt::element_offset<offset_type>(
strides(), first, last);
 
  687    void xstrided_view_base<D>::set_offset(size_type offset)
 
  698        template <
class CT, layout_type L>
 
  699        inline flat_expression_adaptor<CT, L>::flat_expression_adaptor(CT* e)
 
  702            resize_container(get_index(), m_e->dimension());
 
  703            resize_container(m_strides, m_e->dimension());
 
  704            m_size = compute_size(m_e->shape());
 
  708        template <
class CT, layout_type L>
 
  710        inline flat_expression_adaptor<CT, L>::flat_expression_adaptor(CT* e, FST&& 
strides)
 
  712            , m_strides(xtl::forward_sequence<inner_strides_type, FST>(
strides))
 
  714            resize_container(get_index(), m_e->dimension());
 
  715            m_size = m_e->size();
 
  718        template <
class CT, layout_type L>
 
  719        inline void flat_expression_adaptor<CT, L>::update_pointer(CT* ptr)
 const 
  724        template <
class CT, layout_type L>
 
  725        inline auto flat_expression_adaptor<CT, L>::size() const -> size_type
 
  730        template <
class CT, layout_type L>
 
  731        inline auto flat_expression_adaptor<CT, L>::operator[](size_type idx) -> reference
 
  733            auto i = 
static_cast<typename index_type::value_type
>(idx);
 
  734            get_index() = detail::unravel_noexcept(i, m_strides, L);
 
  735            return m_e->element(get_index().cbegin(), get_index().cend());
 
  738        template <
class CT, layout_type L>
 
  739        inline auto flat_expression_adaptor<CT, L>::operator[](size_type idx) 
const -> const_reference
 
  741            auto i = 
static_cast<typename index_type::value_type
>(idx);
 
  742            get_index() = detail::unravel_noexcept(i, m_strides, L);
 
  743            return m_e->element(get_index().cbegin(), get_index().cend());
 
  746        template <
class CT, layout_type L>
 
  747        inline auto flat_expression_adaptor<CT, L>::begin() -> iterator
 
  749            return m_e->template begin<L>();
 
  752        template <
class CT, layout_type L>
 
  753        inline auto flat_expression_adaptor<CT, L>::end() -> iterator
 
  755            return m_e->template end<L>();
 
  758        template <
class CT, layout_type L>
 
  759        inline auto flat_expression_adaptor<CT, L>::begin() const -> const_iterator
 
  761            return m_e->template cbegin<L>();
 
  764        template <
class CT, layout_type L>
 
  765        inline auto flat_expression_adaptor<CT, L>::end() const -> const_iterator
 
  767            return m_e->template cend<L>();
 
  770        template <
class CT, layout_type L>
 
  771        inline auto flat_expression_adaptor<CT, L>::cbegin() const -> const_iterator
 
  773            return m_e->template cbegin<L>();
 
  776        template <
class CT, layout_type L>
 
  777        inline auto flat_expression_adaptor<CT, L>::cend() const -> const_iterator
 
  779            return m_e->template cend<L>();
 
  782        template <
class CT, layout_type L>
 
  783        inline auto flat_expression_adaptor<CT, L>::get_index() -> index_type&
 
  785            thread_local static index_type index;
 
  797        struct slice_getter_impl
 
  800            mutable std::size_t idx;
 
  801            using array_type = std::array<std::ptrdiff_t, 3>;
 
  803            explicit slice_getter_impl(
const S& shape)
 
  810            array_type operator()(
const T& )
 const 
  812                return array_type{{0, 0, 0}};
 
  815            template <
class A, 
class B, 
class C>
 
  816            array_type operator()(
const xrange_adaptor<A, B, C>& range)
 const 
  818                auto sl = 
range.get(
static_cast<std::size_t
>(m_shape[idx]));
 
  819                return array_type({sl(0), sl.size(), sl.step_size()});
 
  823            array_type operator()(
const xrange<T>& range)
 const 
  825                return array_type({
range(T(0)), 
range.size(), T(1)});
 
  829            array_type operator()(
const xstepped_range<T>& range)
 const 
  835        template <
class adj_str
ides_policy>
 
  836        struct strided_view_args : adj_strides_policy
 
  838            using base_type = adj_strides_policy;
 
  840            template <
class S, 
class ST, 
class V>
 
  842            fill_args(
const S& shape, ST&& old_strides, std::size_t base_offset, layout_type layout, 
const V& slices)
 
  845                std::size_t dimension = shape.size(), n_newaxis = 0, n_add_all = 0;
 
  846                std::ptrdiff_t dimension_check = 
static_cast<std::ptrdiff_t
>(shape.size());
 
  848                bool has_ellipsis = 
false;
 
  849                for (
const auto& el : slices)
 
  851                    if (std::get_if<xt::xnewaxis_tag>(&el) != 
nullptr)
 
  856                    else if (std::get_if<std::ptrdiff_t>(&el) != 
nullptr)
 
  861                    else if (std::get_if<xt::xellipsis_tag>(&el) != 
nullptr)
 
  863                        if (has_ellipsis == 
true)
 
  865                            XTENSOR_THROW(std::runtime_error, 
"Ellipsis can only appear once.");
 
  875                if (dimension_check < 0)
 
  877                    XTENSOR_THROW(std::runtime_error, 
"Too many slices for view.");
 
  884                    n_add_all = shape.size() - (slices.size() - 1 - n_newaxis);
 
  888                new_offset = base_offset;
 
  889                new_shape.resize(dimension);
 
  890                new_strides.resize(dimension);
 
  891                base_type::resize(dimension);
 
  893                auto old_shape = shape;
 
  894                using old_strides_value_type = 
typename std::decay_t<ST>::value_type;
 
  896                std::ptrdiff_t axis_skip = 0;
 
  897                std::size_t idx = 0, i = 0, i_ax = 0;
 
  899                auto slice_getter = detail::slice_getter_impl<S>(shape);
 
  901                for (; i < slices.size(); ++i)
 
  903                    i_ax = 
static_cast<std::size_t
>(
static_cast<std::ptrdiff_t
>(i) - axis_skip);
 
  904                    auto ptr = std::get_if<std::ptrdiff_t>(&slices[i]);
 
  907                        auto slice0 = 
static_cast<old_strides_value_type
>(*ptr);
 
  908                        new_offset += 
static_cast<std::size_t
>(slice0 * old_strides[i_ax]);
 
  910                    else if (std::get_if<xt::xnewaxis_tag>(&slices[i]) != 
nullptr)
 
  913                        base_type::set_fake_slice(idx);
 
  916                    else if (std::get_if<xt::xellipsis_tag>(&slices[i]) != 
nullptr)
 
  918                        for (std::size_t j = 0; j < n_add_all; ++j)
 
  920                            new_shape[idx] = old_shape[i_ax];
 
  921                            new_strides[idx] = old_strides[i_ax];
 
  922                            base_type::set_fake_slice(idx);
 
  925                        axis_skip = axis_skip - 
static_cast<std::ptrdiff_t
>(n_add_all) + 1;
 
  927                    else if (std::get_if<xt::xall_tag>(&slices[i]) != 
nullptr)
 
  929                        new_shape[idx] = old_shape[i_ax];
 
  930                        new_strides[idx] = old_strides[i_ax];
 
  931                        base_type::set_fake_slice(idx);
 
  934                    else if (base_type::fill_args(slices, i, idx, old_shape[i_ax], old_strides[i_ax], new_shape, new_strides))
 
  940                        slice_getter.idx = i_ax;
 
  941                        auto info = std::visit(slice_getter, slices[i]);
 
  942                        new_offset += 
static_cast<std::size_t
>(info[0] * old_strides[i_ax]);
 
  943                        new_shape[idx] = 
static_cast<std::size_t
>(info[1]);
 
  944                        new_strides[idx] = info[2] * old_strides[i_ax];
 
  945                        base_type::set_fake_slice(idx);
 
  950                i_ax = 
static_cast<std::size_t
>(
static_cast<std::ptrdiff_t
>(i) - axis_skip);
 
  951                for (; i_ax < old_shape.size(); ++i_ax, ++idx)
 
  953                    new_shape[idx] = old_shape[i_ax];
 
  954                    new_strides[idx] = old_strides[i_ax];
 
  955                    base_type::set_fake_slice(idx);
 
  958                new_layout = do_strides_match(new_shape, new_strides, layout, 
true) ? layout
 
  959                                                                                    : layout_type::dynamic;
 
  962            using shape_type = dynamic_shape<std::size_t>;
 
  963            shape_type new_shape;
 
  964            using strides_type = dynamic_shape<std::ptrdiff_t>;
 
  965            strides_type new_strides;
 
  966            std::size_t new_offset;
 
layout_type layout() const noexcept
xstrided_view_base(CTA &&e, SA &&shape, strides_type &&strides, size_type offset, layout_type layout) noexcept
Constructs an xstrided_view_base.
bool has_linear_assign(const O &strides) const noexcept
const inner_strides_type & strides() const noexcept
bool broadcast_shape(O &shape, bool reuse_cache=false) const
const inner_backstrides_type & backstrides() const noexcept
const inner_shape_type & shape() const noexcept
size_type data_offset() const noexcept
storage_type & storage() noexcept
xexpression_type & expression() noexcept
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).