xtensor
Loading...
Searching...
No Matches
xaxis_slice_iterator.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_AXIS_SLICE_ITERATOR_HPP
11#define XTENSOR_AXIS_SLICE_ITERATOR_HPP
12
13#include "xstrided_view.hpp"
14
15namespace xt
16{
17
27 template <class CT>
29 {
30 public:
31
33
34 using xexpression_type = std::decay_t<CT>;
35 using size_type = typename xexpression_type::size_type;
36 using difference_type = typename xexpression_type::difference_type;
37 using shape_type = typename xexpression_type::shape_type;
38 using strides_type = typename xexpression_type::strides_type;
40 using reference = std::remove_reference_t<apply_cv_t<CT, value_type>>;
41 using pointer = xtl::xclosure_pointer<std::remove_reference_t<apply_cv_t<CT, value_type>>>;
42
43 using iterator_category = std::forward_iterator_tag;
44
45 template <class CTA>
46 xaxis_slice_iterator(CTA&& e, size_type axis);
47 template <class CTA>
48 xaxis_slice_iterator(CTA&& e, size_type axis, size_type index, size_type offset);
49
52
53 reference operator*() const;
54 pointer operator->() const;
55
56 bool equal(const self_type& rhs) const;
57
58 private:
59
60 using storing_type = xtl::ptr_closure_type_t<CT>;
61 mutable storing_type p_expression;
62 size_type m_index;
63 size_type m_offset;
64 size_type m_axis_stride;
65 size_type m_lower_shape;
66 size_type m_upper_shape;
67 size_type m_iter_size;
68 bool m_is_target_axis;
69 value_type m_sv;
70
71 template <class T, class CTA>
72 std::enable_if_t<std::is_pointer<T>::value, T> get_storage_init(CTA&& e) const;
73
74 template <class T, class CTA>
75 std::enable_if_t<!std::is_pointer<T>::value, T> get_storage_init(CTA&& e) const;
76 };
77
78 template <class CT>
80
81 template <class CT>
83
84 template <class E>
85 auto xaxis_slice_begin(E&& e);
86
87 template <class E>
88 auto xaxis_slice_begin(E&& e, typename std::decay_t<E>::size_type axis);
89
90 template <class E>
91 auto xaxis_slice_end(E&& e);
92
93 template <class E>
94 auto xaxis_slice_end(E&& e, typename std::decay_t<E>::size_type axis);
95
96 /***************************************
97 * xaxis_slice_iterator implementation *
98 ***************************************/
99
100 template <class CT>
101 template <class T, class CTA>
102 inline std::enable_if_t<std::is_pointer<T>::value, T>
104 {
105 return &e;
106 }
107
108 template <class CT>
109 template <class T, class CTA>
110 inline std::enable_if_t<!std::is_pointer<T>::value, T>
111 xaxis_slice_iterator<CT>::get_storage_init(CTA&& e) const
112 {
113 return e;
114 }
115
126 template <class CT>
127 template <class CTA>
129 : xaxis_slice_iterator(std::forward<CTA>(e), axis, 0, e.data_offset())
130 {
131 }
132
141 template <class CT>
142 template <class CTA>
143 inline xaxis_slice_iterator<CT>::xaxis_slice_iterator(CTA&& e, size_type axis, size_type index, size_type offset)
144 : p_expression(get_storage_init<storing_type>(std::forward<CTA>(e)))
145 , m_index(index)
146 , m_offset(offset)
147 , m_axis_stride(static_cast<size_type>(e.strides()[axis]) * (e.shape()[axis] - 1u))
148 , m_lower_shape(0)
149 , m_upper_shape(0)
150 , m_iter_size(0)
151 , m_is_target_axis(false)
152 , m_sv(strided_view(
153 std::forward<CT>(e),
154 std::forward<shape_type>({e.shape()[axis]}),
155 std::forward<strides_type>({e.strides()[axis]}),
156 offset,
157 e.layout()
158 ))
159 {
160 if (e.layout() == layout_type::row_major)
161 {
162 m_is_target_axis = axis == e.dimension() - 1;
163 m_lower_shape = std::accumulate(
164 e.shape().begin() + axis + 1,
165 e.shape().end(),
166 size_t(1),
167 std::multiplies<>()
168 );
169 m_iter_size = std::accumulate(e.shape().begin() + 1, e.shape().end(), size_t(1), std::multiplies<>());
170 }
171 else
172 {
173 m_is_target_axis = axis == 0;
174 m_lower_shape = std::accumulate(
175 e.shape().begin(),
176 e.shape().begin() + axis,
177 size_t(1),
178 std::multiplies<>()
179 );
180 m_iter_size = std::accumulate(e.shape().begin(), e.shape().end() - 1, size_t(1), std::multiplies<>());
181 }
182 m_upper_shape = m_lower_shape + m_axis_stride;
183 }
184
186
194 template <class CT>
196 {
197 ++m_index;
198 ++m_offset;
199 auto index_compare = (m_offset % m_iter_size);
200 if (m_is_target_axis || (m_upper_shape >= index_compare && index_compare >= m_lower_shape))
201 {
202 m_offset += m_axis_stride;
203 }
204 m_sv.set_offset(m_offset);
205 return *this;
206 }
207
212 template <class CT>
214 {
215 self_type tmp(*this);
216 ++(*this);
217 return tmp;
218 }
219
221
231 template <class CT>
233 {
234 return m_sv;
235 }
236
242 template <class CT>
244 {
245 return xtl::closure_pointer(operator*());
246 }
247
249
250 /*
251 * @name Comparisons
252 */
254
259 template <class CT>
261 {
262 return p_expression == rhs.p_expression && m_index == rhs.m_index;
263 }
264
270 template <class CT>
272 {
273 return lhs.equal(rhs);
274 }
275
280 template <class CT>
282 {
283 return !(lhs == rhs);
284 }
285
287
298 template <class E>
299 inline auto axis_slice_begin(E&& e)
300 {
302 return return_type(std::forward<E>(e), 0);
303 }
304
312 template <class E>
313 inline auto axis_slice_begin(E&& e, typename std::decay_t<E>::size_type axis)
314 {
316 return return_type(std::forward<E>(e), axis, 0, e.data_offset());
317 }
318
326 template <class E>
327 inline auto axis_slice_end(E&& e)
328 {
330 return return_type(
331 std::forward<E>(e),
332 0,
333 std::accumulate(e.shape().begin() + 1, e.shape().end(), size_t(1), std::multiplies<>()),
334 e.size()
335 );
336 }
337
346 template <class E>
347 inline auto axis_slice_end(E&& e, typename std::decay_t<E>::size_type axis)
348 {
350 auto index_sum = std::accumulate(
351 e.shape().begin(),
352 e.shape().begin() + axis,
353 size_t(1),
354 std::multiplies<>()
355 );
356 return return_type(
357 std::forward<E>(e),
358 axis,
359 std::accumulate(e.shape().begin() + axis + 1, e.shape().end(), index_sum, std::multiplies<>()),
360 e.size() + axis
361 );
362 }
363
365}
366
367#endif
Class for iteration over one-dimensional slices.
self_type & operator++()
Increments the iterator to the next position and returns it.
reference operator*() const
Returns the strided view at the current iteration position.
pointer operator->() const
Returns a pointer to the strided view at the current iteration position.
bool equal(const self_type &rhs) const
Checks equality of the xaxis_slice_iterator and rhs.
xaxis_slice_iterator(CTA &&e, size_type axis)
Constructs an xaxis_slice_iterator.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
bool operator==(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks equality of the iterators.
bool operator!=(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks inequality of the iterators.
auto axis_slice_begin(E &&e)
Returns an iterator to the first element of the expression for axis 0.
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto axis_slice_end(E &&e)
Returns an iterator to the element following the last element of the expression for axis 0.