10#ifndef XTENSOR_AXIS_SLICE_ITERATOR_HPP
11#define XTENSOR_AXIS_SLICE_ITERATOR_HPP
13#include "xstrided_view.hpp"
34 using xexpression_type = std::decay_t<CT>;
35 using size_type =
typename xexpression_type::size_type;
36 using difference_type =
typename xexpression_type::difference_type;
37 using shape_type =
typename xexpression_type::shape_type;
38 using strides_type =
typename xexpression_type::strides_type;
40 using reference = std::remove_reference_t<apply_cv_t<CT, value_type>>;
41 using pointer = xtl::xclosure_pointer<std::remove_reference_t<apply_cv_t<CT, value_type>>>;
43 using iterator_category = std::forward_iterator_tag;
60 using storing_type = xtl::ptr_closure_type_t<CT>;
61 mutable storing_type p_expression;
64 size_type m_axis_stride;
65 size_type m_lower_shape;
66 size_type m_upper_shape;
67 size_type m_iter_size;
68 bool m_is_target_axis;
71 template <
class T,
class CTA>
72 std::enable_if_t<std::is_pointer<T>::value, T> get_storage_init(
CTA&&
e)
const;
74 template <
class T,
class CTA>
75 std::enable_if_t<!std::is_pointer<T>::value, T> get_storage_init(
CTA&&
e)
const;
85 auto xaxis_slice_begin(E&&
e);
88 auto xaxis_slice_begin(E&&
e,
typename std::decay_t<E>::size_type axis);
91 auto xaxis_slice_end(E&&
e);
94 auto xaxis_slice_end(E&&
e,
typename std::decay_t<E>::size_type axis);
101 template <
class T,
class CTA>
102 inline std::enable_if_t<std::is_pointer<T>::value, T>
109 template <
class T,
class CTA>
110 inline std::enable_if_t<!std::is_pointer<T>::value, T>
111 xaxis_slice_iterator<CT>::get_storage_init(CTA&& e)
const
144 : p_expression(get_storage_init<storing_type>(std::
forward<
CTA>(
e)))
151 , m_is_target_axis(
false)
154 std::
forward<shape_type>({
e.shape()[axis]}),
155 std::forward<strides_type>({
e.strides()[axis]}),
162 m_is_target_axis = axis == e.dimension() - 1;
163 m_lower_shape = std::accumulate(
164 e.shape().begin() + axis + 1,
169 m_iter_size = std::accumulate(e.shape().begin() + 1, e.shape().end(),
size_t(1), std::multiplies<>());
173 m_is_target_axis = axis == 0;
174 m_lower_shape = std::accumulate(
176 e.shape().begin() + axis,
180 m_iter_size = std::accumulate(e.shape().begin(), e.shape().end() - 1,
size_t(1), std::multiplies<>());
182 m_upper_shape = m_lower_shape + m_axis_stride;
202 m_offset += m_axis_stride;
204 m_sv.set_offset(m_offset);
245 return xtl::closure_pointer(
operator*());
262 return p_expression ==
rhs.p_expression && m_index ==
rhs.m_index;
302 return return_type(std::forward<E>(
e), 0);
316 return return_type(std::forward<E>(
e), axis, 0,
e.data_offset());
333 std::accumulate(
e.shape().begin() + 1,
e.shape().end(),
size_t(1), std::multiplies<>()),
352 e.shape().begin() + axis,
359 std::accumulate(
e.shape().begin() + axis + 1,
e.shape().end(),
index_sum, std::multiplies<>()),
Class for iteration over one-dimensional slices.
self_type & operator++()
Increments the iterator to the next position and returns it.
reference operator*() const
Returns the strided view at the current iteration position.
pointer operator->() const
Returns a pointer to the strided view at the current iteration position.
bool equal(const self_type &rhs) const
Checks equality of the xaxis_slice_iterator and rhs.
xaxis_slice_iterator(CTA &&e, size_type axis)
Constructs an xaxis_slice_iterator.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
bool operator==(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks equality of the iterators.
bool operator!=(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks inequality of the iterators.
auto axis_slice_begin(E &&e)
Returns an iterator to the first element of the expression for axis 0.
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto axis_slice_end(E &&e)
Returns an iterator to the element following the last element of the expression for axis 0.