xtensor
 
Loading...
Searching...
No Matches
xaxis_iterator.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_AXIS_ITERATOR_HPP
11#define XTENSOR_AXIS_ITERATOR_HPP
12
13#include "../views/xstrided_view.hpp"
14
15namespace xt
16{
17
18 /******************
19 * xaxis_iterator *
20 ******************/
21
32 template <class CT>
34 {
35 public:
36
37 using self_type = xaxis_iterator<CT>;
38
39 using xexpression_type = std::decay_t<CT>;
40 using size_type = typename xexpression_type::size_type;
41 using difference_type = typename xexpression_type::difference_type;
42 using shape_type = typename xexpression_type::shape_type;
43 using value_type = xstrided_view<CT, shape_type>;
44 using reference = std::remove_reference_t<xtl::apply_cv_t<CT, value_type>>;
45 using pointer = xtl::xclosure_pointer<std::remove_reference_t<xtl::apply_cv_t<CT, value_type>>>;
46
47 using iterator_category = std::forward_iterator_tag;
48
49 template <class CTA>
50 xaxis_iterator(CTA&& e, size_type axis);
51 template <class CTA>
52 xaxis_iterator(CTA&& e, size_type axis, size_type index, size_type offset);
53
54 self_type& operator++();
55 self_type operator++(int);
56
57 reference operator*() const;
58 pointer operator->() const;
59
60 bool equal(const self_type& rhs) const;
61
62 private:
63
64 using storing_type = xtl::ptr_closure_type_t<CT>;
65 mutable storing_type p_expression;
66 size_type m_index;
67 size_type m_add_offset;
68 value_type m_sv;
69
70 template <class T, class CTA>
71 T get_storage_init(CTA&& e) const;
72 };
73
74 template <class CT>
75 bool operator==(const xaxis_iterator<CT>& lhs, const xaxis_iterator<CT>& rhs);
76
77 template <class CT>
78 bool operator!=(const xaxis_iterator<CT>& lhs, const xaxis_iterator<CT>& rhs);
79
80 template <class E>
81 auto axis_begin(E&& e);
82
83 template <class E>
84 auto axis_begin(E&& e, typename std::decay_t<E>::size_type axis);
85
86 template <class E>
87 auto axis_end(E&& e);
88
89 template <class E>
90 auto axis_end(E&& e, typename std::decay_t<E>::size_type axis);
91
92 /*********************************
93 * xaxis_iterator implementation *
94 *********************************/
95
96 namespace detail
97 {
98 template <class CT>
99 auto derive_xstrided_view(
100 CT&& e,
101 typename std::decay_t<CT>::size_type axis,
102 typename std::decay_t<CT>::size_type offset
103 )
104 {
105 using xexpression_type = std::decay_t<CT>;
106 using shape_type = typename xexpression_type::shape_type;
107 using strides_type = typename xexpression_type::strides_type;
108
109 const auto& e_shape = e.shape();
110 shape_type shape(e_shape.size() - 1);
111 auto nxt = std::copy(e_shape.cbegin(), e_shape.cbegin() + axis, shape.begin());
112 std::copy(e_shape.cbegin() + axis + 1, e_shape.end(), nxt);
113
114 const auto& e_strides = e.strides();
115 strides_type strides(e_strides.size() - 1);
116 auto nxt_strides = std::copy(e_strides.cbegin(), e_strides.cbegin() + axis, strides.begin());
117 std::copy(e_strides.cbegin() + axis + 1, e_strides.end(), nxt_strides);
118
119 return strided_view(std::forward<CT>(e), std::move(shape), std::move(strides), offset, e.layout());
120 }
121 }
122
123 template <class CT>
124 template <class T, class CTA>
125 inline T xaxis_iterator<CT>::get_storage_init(CTA&& e) const
126 {
127 if constexpr (xtl::pointer_concept<T>)
128 {
129 return &e;
130 }
131 else
132 {
133 return e;
134 }
135 }
136
141
147 template <class CT>
148 template <class CTA>
149 inline xaxis_iterator<CT>::xaxis_iterator(CTA&& e, size_type axis)
150 : xaxis_iterator(std::forward<CTA>(e), axis, 0, e.data_offset())
151 {
152 }
153
162 template <class CT>
163 template <class CTA>
164 inline xaxis_iterator<CT>::xaxis_iterator(CTA&& e, size_type axis, size_type index, size_type offset)
165 : p_expression(get_storage_init<storing_type>(std::forward<CTA>(e)))
166 , m_index(index)
167 , m_add_offset(static_cast<size_type>(e.strides()[axis]))
168 , m_sv(detail::derive_xstrided_view<CTA>(std::forward<CTA>(e), axis, offset))
169 {
170 }
171
173
178
181 template <class CT>
182 inline auto xaxis_iterator<CT>::operator++() -> self_type&
183 {
184 m_sv.set_offset(m_sv.data_offset() + m_add_offset);
185 ++m_index;
186 return *this;
187 }
188
193 template <class CT>
194 inline auto xaxis_iterator<CT>::operator++(int) -> self_type
195 {
196 self_type tmp(*this);
197 ++(*this);
198 return tmp;
199 }
200
202
207
212 template <class CT>
213 inline auto xaxis_iterator<CT>::operator*() const -> reference
214 {
215 return m_sv;
216 }
217
223 template <class CT>
224 inline auto xaxis_iterator<CT>::operator->() const -> pointer
225 {
226 return xtl::closure_pointer(operator*());
227 }
228
230
231 /*
232 * @name Comparisons
233 */
235
241 template <class CT>
242 inline bool xaxis_iterator<CT>::equal(const self_type& rhs) const
243 {
244 return p_expression == rhs.p_expression && m_index == rhs.m_index
245 && m_sv.data_offset() == rhs.m_sv.data_offset();
246 }
247
253 template <class CT>
254 inline bool operator==(const xaxis_iterator<CT>& lhs, const xaxis_iterator<CT>& rhs)
255 {
256 return lhs.equal(rhs);
257 }
258
263 template <class CT>
264 inline bool operator!=(const xaxis_iterator<CT>& lhs, const xaxis_iterator<CT>& rhs)
265 {
266 return !(lhs == rhs);
267 }
268
270
275
281 template <class E>
282 inline auto axis_begin(E&& e)
283 {
284 using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
285 return return_type(std::forward<E>(e), 0);
286 }
287
295 template <class E>
296 inline auto axis_begin(E&& e, typename std::decay_t<E>::size_type axis)
297 {
298 using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
299 return return_type(std::forward<E>(e), axis);
300 }
301
309 template <class E>
310 inline auto axis_end(E&& e)
311 {
312 using size_type = typename std::decay_t<E>::size_type;
313 using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
314 return return_type(
315 std::forward<E>(e),
316 0,
317 e.shape()[0],
318 static_cast<size_type>(e.strides()[0]) * e.shape()[0]
319 );
320 }
321
330 template <class E>
331 inline auto axis_end(E&& e, typename std::decay_t<E>::size_type axis)
332 {
333 using size_type = typename std::decay_t<E>::size_type;
334 using return_type = xaxis_iterator<xtl::closure_type_t<E>>;
335 return return_type(
336 std::forward<E>(e),
337 axis,
338 e.shape()[axis],
339 static_cast<size_type>(e.strides()[axis]) * e.shape()[axis]
340 );
341 }
342
344}
345
346#endif
Class for iteration over (N-1)-dimensional slices, where N is the dimension of the underlying express...
reference operator*() const
Returns the strided view at the current iteration position.
pointer operator->() const
Returns a pointer to the strided view at the current iteration position.
xaxis_iterator(CTA &&e, size_type axis)
Constructs an xaxis_iterator.
self_type & operator++()
Increments the iterator to the next position and returns it.
bool equal(const self_type &rhs) const
Checks equality of the xaxis_slice_iterator and rhs.
View of an xexpression using strides.
size_type data_offset() const noexcept
Returns the offset to the first element in the view.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto axis_begin(E &&e)
Returns an iterator to the first element of the expression for axis 0.
auto axis_end(E &&e)
Returns an iterator to the element following the last element of the expression for axis 0.
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.