xtensor
Loading...
Searching...
No Matches
xaxis_iterator.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_AXIS_ITERATOR_HPP
11#define XTENSOR_AXIS_ITERATOR_HPP
12
13#include "xstrided_view.hpp"
14
15namespace xt
16{
17
18 /******************
19 * xaxis_iterator *
20 ******************/
21
32 template <class CT>
34 {
35 public:
36
38
39 using xexpression_type = std::decay_t<CT>;
40 using size_type = typename xexpression_type::size_type;
41 using difference_type = typename xexpression_type::difference_type;
42 using shape_type = typename xexpression_type::shape_type;
44 using reference = std::remove_reference_t<apply_cv_t<CT, value_type>>;
45 using pointer = xtl::xclosure_pointer<std::remove_reference_t<apply_cv_t<CT, value_type>>>;
46
47 using iterator_category = std::forward_iterator_tag;
48
49 template <class CTA>
50 xaxis_iterator(CTA&& e, size_type axis);
51 template <class CTA>
52 xaxis_iterator(CTA&& e, size_type axis, size_type index, size_type offset);
53
56
57 reference operator*() const;
58 pointer operator->() const;
59
60 bool equal(const self_type& rhs) const;
61
62 private:
63
64 using storing_type = xtl::ptr_closure_type_t<CT>;
65 mutable storing_type p_expression;
66 size_type m_index;
67 size_type m_add_offset;
68 value_type m_sv;
69
70 template <class T, class CTA>
71 std::enable_if_t<std::is_pointer<T>::value, T> get_storage_init(CTA&& e) const;
72
73 template <class T, class CTA>
74 std::enable_if_t<!std::is_pointer<T>::value, T> get_storage_init(CTA&& e) const;
75 };
76
77 template <class CT>
79
80 template <class CT>
82
83 template <class E>
84 auto axis_begin(E&& e);
85
86 template <class E>
87 auto axis_begin(E&& e, typename std::decay_t<E>::size_type axis);
88
89 template <class E>
90 auto axis_end(E&& e);
91
92 template <class E>
93 auto axis_end(E&& e, typename std::decay_t<E>::size_type axis);
94
95 /*********************************
96 * xaxis_iterator implementation *
97 *********************************/
98
99 namespace detail
100 {
101 template <class CT>
102 auto derive_xstrided_view(
103 CT&& e,
104 typename std::decay_t<CT>::size_type axis,
105 typename std::decay_t<CT>::size_type offset
106 )
107 {
108 using xexpression_type = std::decay_t<CT>;
109 using shape_type = typename xexpression_type::shape_type;
110 using strides_type = typename xexpression_type::strides_type;
111
112 const auto& e_shape = e.shape();
113 shape_type shape(e_shape.size() - 1);
114 auto nxt = std::copy(e_shape.cbegin(), e_shape.cbegin() + axis, shape.begin());
115 std::copy(e_shape.cbegin() + axis + 1, e_shape.end(), nxt);
116
117 const auto& e_strides = e.strides();
118 strides_type strides(e_strides.size() - 1);
119 auto nxt_strides = std::copy(e_strides.cbegin(), e_strides.cbegin() + axis, strides.begin());
120 std::copy(e_strides.cbegin() + axis + 1, e_strides.end(), nxt_strides);
121
122 return strided_view(std::forward<CT>(e), std::move(shape), std::move(strides), offset, e.layout());
123 }
124 }
125
126 template <class CT>
127 template <class T, class CTA>
128 inline std::enable_if_t<std::is_pointer<T>::value, T> xaxis_iterator<CT>::get_storage_init(CTA&& e) const
129 {
130 return &e;
131 }
132
133 template <class CT>
134 template <class T, class CTA>
135 inline std::enable_if_t<!std::is_pointer<T>::value, T> xaxis_iterator<CT>::get_storage_init(CTA&& e) const
136 {
137 return e;
138 }
139
150 template <class CT>
151 template <class CTA>
152 inline xaxis_iterator<CT>::xaxis_iterator(CTA&& e, size_type axis)
153 : xaxis_iterator(std::forward<CTA>(e), axis, 0, e.data_offset())
154 {
155 }
156
165 template <class CT>
166 template <class CTA>
167 inline xaxis_iterator<CT>::xaxis_iterator(CTA&& e, size_type axis, size_type index, size_type offset)
168 : p_expression(get_storage_init<storing_type>(std::forward<CTA>(e)))
169 , m_index(index)
170 , m_add_offset(static_cast<size_type>(e.strides()[axis]))
171 , m_sv(detail::derive_xstrided_view<CTA>(std::forward<CTA>(e), axis, offset))
172 {
173 }
174
176
184 template <class CT>
186 {
187 m_sv.set_offset(m_sv.data_offset() + m_add_offset);
188 ++m_index;
189 return *this;
190 }
191
196 template <class CT>
198 {
199 self_type tmp(*this);
200 ++(*this);
201 return tmp;
202 }
203
205
215 template <class CT>
216 inline auto xaxis_iterator<CT>::operator*() const -> reference
217 {
218 return m_sv;
219 }
220
226 template <class CT>
227 inline auto xaxis_iterator<CT>::operator->() const -> pointer
228 {
229 return xtl::closure_pointer(operator*());
230 }
231
233
234 /*
235 * @name Comparisons
236 */
238
244 template <class CT>
245 inline bool xaxis_iterator<CT>::equal(const self_type& rhs) const
246 {
247 return p_expression == rhs.p_expression && m_index == rhs.m_index
248 && m_sv.data_offset() == rhs.m_sv.data_offset();
249 }
250
256 template <class CT>
258 {
259 return lhs.equal(rhs);
260 }
261
266 template <class CT>
268 {
269 return !(lhs == rhs);
270 }
271
273
284 template <class E>
285 inline auto axis_begin(E&& e)
286 {
287 using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
288 return return_type(std::forward<E>(e), 0);
289 }
290
298 template <class E>
299 inline auto axis_begin(E&& e, typename std::decay_t<E>::size_type axis)
300 {
301 using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
302 return return_type(std::forward<E>(e), axis);
303 }
304
312 template <class E>
313 inline auto axis_end(E&& e)
314 {
315 using size_type = typename std::decay_t<E>::size_type;
316 using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
317 return return_type(
318 std::forward<E>(e),
319 0,
320 e.shape()[0],
321 static_cast<size_type>(e.strides()[0]) * e.shape()[0]
322 );
323 }
324
333 template <class E>
334 inline auto axis_end(E&& e, typename std::decay_t<E>::size_type axis)
335 {
336 using size_type = typename std::decay_t<E>::size_type;
337 using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
338 return return_type(
339 std::forward<E>(e),
340 axis,
341 e.shape()[axis],
342 static_cast<size_type>(e.strides()[axis]) * e.shape()[axis]
343 );
344 }
345
347}
348
349#endif
Class for iteration over (N-1)-dimensional slices, where N is the dimension of the underlying express...
reference operator*() const
Returns the strided view at the current iteration position.
pointer operator->() const
Returns a pointer to the strided view at the current iteration position.
xaxis_iterator(CTA &&e, size_type axis)
Constructs an xaxis_iterator.
self_type & operator++()
Increments the iterator to the next position and returns it.
bool equal(const self_type &rhs) const
Checks equality of the xaxis_slice_iterator and rhs.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto axis_begin(E &&e)
Returns an iterator to the first element of the expression for axis 0.
bool operator==(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks equality of the iterators.
auto axis_end(E &&e)
Returns an iterator to the element following the last element of the expression for axis 0.
bool operator!=(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks inequality of the iterators.
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.