xtensor
 
Loading...
Searching...
No Matches
xaxis_slice_iterator.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_AXIS_SLICE_ITERATOR_HPP
11#define XTENSOR_AXIS_SLICE_ITERATOR_HPP
12
13#include "../views/xstrided_view.hpp"
14
15namespace xt
16{
17
27 template <class CT>
29 {
30 public:
31
32 using self_type = xaxis_slice_iterator<CT>;
33
34 using xexpression_type = std::decay_t<CT>;
35 using size_type = typename xexpression_type::size_type;
36 using difference_type = typename xexpression_type::difference_type;
37 using shape_type = typename xexpression_type::shape_type;
38 using strides_type = typename xexpression_type::strides_type;
39 using value_type = xstrided_view<CT, shape_type>;
40 using reference = std::remove_reference_t<xtl::apply_cv_t<CT, value_type>>;
41 using pointer = xtl::xclosure_pointer<std::remove_reference_t<xtl::apply_cv_t<CT, value_type>>>;
42
43 using iterator_category = std::forward_iterator_tag;
44
45 template <class CTA>
46 xaxis_slice_iterator(CTA&& e, size_type axis);
47 template <class CTA>
48 xaxis_slice_iterator(CTA&& e, size_type axis, size_type index, size_type offset);
49
50 self_type& operator++();
51 self_type operator++(int);
52
53 reference operator*() const;
54 pointer operator->() const;
55
56 bool equal(const self_type& rhs) const;
57
58 private:
59
60 using storing_type = xtl::ptr_closure_type_t<CT>;
61 mutable storing_type p_expression;
62 size_type m_index;
63 size_type m_offset;
64 size_type m_axis_stride;
65 size_type m_lower_shape;
66 size_type m_upper_shape;
67 size_type m_iter_size;
68 bool m_is_target_axis;
69 value_type m_sv;
70
71 template <class T, class CTA>
72 T get_storage_init(CTA&& e) const;
73 };
74
75 template <class CT>
76 bool operator==(const xaxis_slice_iterator<CT>& lhs, const xaxis_slice_iterator<CT>& rhs);
77
78 template <class CT>
79 bool operator!=(const xaxis_slice_iterator<CT>& lhs, const xaxis_slice_iterator<CT>& rhs);
80
81 template <class E>
82 auto xaxis_slice_begin(E&& e);
83
84 template <class E>
85 auto xaxis_slice_begin(E&& e, typename std::decay_t<E>::size_type axis);
86
87 template <class E>
88 auto xaxis_slice_end(E&& e);
89
90 template <class E>
91 auto xaxis_slice_end(E&& e, typename std::decay_t<E>::size_type axis);
92
93 /***************************************
94 * xaxis_slice_iterator implementation *
95 ***************************************/
96
97 template <class CT>
98 template <class T, class CTA>
99 T xaxis_slice_iterator<CT>::get_storage_init(CTA&& e) const
100 {
101 if constexpr (xtl::pointer_concept<T>)
102 {
103 return &e;
104 }
105 else
106 {
107 return e;
108 }
109 }
110
115
121 template <class CT>
122 template <class CTA>
124 : xaxis_slice_iterator(std::forward<CTA>(e), axis, 0, e.data_offset())
125 {
126 }
127
136 template <class CT>
137 template <class CTA>
138 inline xaxis_slice_iterator<CT>::xaxis_slice_iterator(CTA&& e, size_type axis, size_type index, size_type offset)
139 : p_expression(get_storage_init<storing_type>(std::forward<CTA>(e)))
140 , m_index(index)
141 , m_offset(offset)
142 , m_axis_stride(static_cast<size_type>(e.strides()[axis]) * (e.shape()[axis] - 1u))
143 , m_lower_shape(0)
144 , m_upper_shape(0)
145 , m_iter_size(0)
146 , m_is_target_axis(false)
147 , m_sv(strided_view(
148 std::forward<CT>(e),
149 std::forward<shape_type>({e.shape()[axis]}),
150 std::forward<strides_type>({e.strides()[axis]}),
151 offset,
152 e.layout()
153 ))
154 {
155 if (e.layout() == layout_type::row_major)
156 {
157 m_is_target_axis = axis == e.dimension() - 1;
158 m_lower_shape = std::accumulate(
159 e.shape().begin() + axis + 1,
160 e.shape().end(),
161 size_t(1),
162 std::multiplies<>()
163 );
164 m_iter_size = std::accumulate(e.shape().begin() + 1, e.shape().end(), size_t(1), std::multiplies<>());
165 }
166 else
167 {
168 m_is_target_axis = axis == 0;
169 m_lower_shape = std::accumulate(
170 e.shape().begin(),
171 e.shape().begin() + axis,
172 size_t(1),
173 std::multiplies<>()
174 );
175 m_iter_size = std::accumulate(e.shape().begin(), e.shape().end() - 1, size_t(1), std::multiplies<>());
176 }
177 m_upper_shape = m_lower_shape + m_axis_stride;
178 }
179
181
186
189 template <class CT>
190 inline auto xaxis_slice_iterator<CT>::operator++() -> self_type&
191 {
192 ++m_index;
193 ++m_offset;
194 auto index_compare = (m_offset % m_iter_size);
195 if (m_is_target_axis || (m_upper_shape >= index_compare && index_compare >= m_lower_shape))
196 {
197 m_offset += m_axis_stride;
198 }
199 m_sv.set_offset(m_offset);
200 return *this;
201 }
202
207 template <class CT>
208 inline auto xaxis_slice_iterator<CT>::operator++(int) -> self_type
209 {
210 self_type tmp(*this);
211 ++(*this);
212 return tmp;
213 }
214
216
221
226 template <class CT>
227 inline auto xaxis_slice_iterator<CT>::operator*() const -> reference
228 {
229 return m_sv;
230 }
231
237 template <class CT>
238 inline auto xaxis_slice_iterator<CT>::operator->() const -> pointer
239 {
240 return xtl::closure_pointer(operator*());
241 }
242
244
245 /*
246 * @name Comparisons
247 */
249
254 template <class CT>
255 inline bool xaxis_slice_iterator<CT>::equal(const self_type& rhs) const
256 {
257 return p_expression == rhs.p_expression && m_index == rhs.m_index;
258 }
259
265 template <class CT>
266 inline bool operator==(const xaxis_slice_iterator<CT>& lhs, const xaxis_slice_iterator<CT>& rhs)
267 {
268 return lhs.equal(rhs);
269 }
270
275 template <class CT>
276 inline bool operator!=(const xaxis_slice_iterator<CT>& lhs, const xaxis_slice_iterator<CT>& rhs)
277 {
278 return !(lhs == rhs);
279 }
280
282
287
293 template <class E>
294 inline auto axis_slice_begin(E&& e)
295 {
297 return return_type(std::forward<E>(e), 0);
298 }
299
307 template <class E>
308 inline auto axis_slice_begin(E&& e, typename std::decay_t<E>::size_type axis)
309 {
311 return return_type(std::forward<E>(e), axis, 0, e.data_offset());
312 }
313
321 template <class E>
322 inline auto axis_slice_end(E&& e)
323 {
325 return return_type(
326 std::forward<E>(e),
327 0,
328 std::accumulate(e.shape().begin() + 1, e.shape().end(), size_t(1), std::multiplies<>()),
329 e.size()
330 );
331 }
332
341 template <class E>
342 inline auto axis_slice_end(E&& e, typename std::decay_t<E>::size_type axis)
343 {
345 auto index_sum = std::accumulate(
346 e.shape().begin(),
347 e.shape().begin() + axis,
348 size_t(1),
349 std::multiplies<>()
350 );
351 return return_type(
352 std::forward<E>(e),
353 axis,
354 std::accumulate(e.shape().begin() + axis + 1, e.shape().end(), index_sum, std::multiplies<>()),
355 e.size() + axis
356 );
357 }
358
360}
361
362#endif
Class for iteration over one-dimensional slices.
self_type & operator++()
Increments the iterator to the next position and returns it.
reference operator*() const
Returns the strided view at the current iteration position.
pointer operator->() const
Returns a pointer to the strided view at the current iteration position.
bool equal(const self_type &rhs) const
Checks equality of the xaxis_slice_iterator and rhs.
xaxis_slice_iterator(CTA &&e, size_type axis)
Constructs an xaxis_slice_iterator.
View of an xexpression using strides.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto axis_slice_begin(E &&e)
Returns an iterator to the first element of the expression for axis 0.
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto axis_slice_end(E &&e)
Returns an iterator to the element following the last element of the expression for axis 0.