xtensor
Loading...
Searching...
No Matches
xchunked_assign.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_CHUNKED_ASSIGN_HPP
11#define XTENSOR_CHUNKED_ASSIGN_HPP
12
13#include "xnoalias.hpp"
14#include "xstrided_view.hpp"
15
16namespace xt
17{
18
19 /*******************
20 * xchunk_assigner *
21 *******************/
22
23 template <class T, class chunk_storage>
25 {
26 public:
27
28 using temporary_type = T;
29
30 template <class E, class DST>
31 void build_and_assign_temporary(const xexpression<E>& e, DST& dst);
32 };
33
34 /*********************************
35 * xchunked_semantic declaration *
36 *********************************/
37
38 template <class D>
40 {
41 public:
42
44 using derived_type = D;
45 using temporary_type = typename base_type::temporary_type;
46
47 template <class E>
48 derived_type& assign_xexpression(const xexpression<E>& e);
49
50 template <class E>
51 derived_type& computed_assign(const xexpression<E>& e);
52
53 template <class E, class F>
54 derived_type& scalar_computed_assign(const E& e, F&& f);
55
56 protected:
57
58 xchunked_semantic() = default;
59 ~xchunked_semantic() = default;
60
61 xchunked_semantic(const xchunked_semantic&) = default;
62 xchunked_semantic& operator=(const xchunked_semantic&) = default;
63
65 xchunked_semantic& operator=(xchunked_semantic&&) = default;
66
67 template <class E>
68 derived_type& operator=(const xexpression<E>& e);
69
70 private:
71
72 template <class CS>
73 xchunked_assigner<temporary_type, CS> get_assigner(const CS&) const;
74 };
75
76 /*******************
77 * xchunk_iterator *
78 *******************/
79
80 template <class CS>
81 class xchunked_array;
82
83 template <class E>
84 class xchunked_view;
85
86 namespace detail
87 {
88 template <class T>
89 struct is_xchunked_array : std::false_type
90 {
91 };
92
93 template <class CS>
94 struct is_xchunked_array<xchunked_array<CS>> : std::true_type
95 {
96 };
97
98 template <class T>
99 struct is_xchunked_view : std::false_type
100 {
101 };
102
103 template <class E>
104 struct is_xchunked_view<xchunked_view<E>> : std::true_type
105 {
106 };
107
108 struct invalid_chunk_iterator
109 {
110 };
111
112 template <class A>
113 struct xchunk_iterator_array
114 {
115 using reference = decltype(*(std::declval<A>().chunks().begin()));
116
117 inline decltype(auto) get_chunk(A& arr, typename A::size_type i, const xstrided_slice_vector&) const
118 {
119 using difference_type = typename A::difference_type;
120 return *(arr.chunks().begin() + static_cast<difference_type>(i));
121 }
122 };
123
124 template <class V>
125 struct xchunk_iterator_view
126 {
127 using reference = decltype(xt::strided_view(
128 std::declval<V>().expression(),
129 std::declval<xstrided_slice_vector>()
130 ));
131
132 inline auto get_chunk(V& view, typename V::size_type, const xstrided_slice_vector& sv) const
133 {
134 return xt::strided_view(view.expression(), sv);
135 }
136 };
137
138 template <class T>
139 struct xchunk_iterator_base
140 : std::conditional_t<
141 is_xchunked_array<std::decay_t<T>>::value,
142 xchunk_iterator_array<T>,
143 std::conditional_t<is_xchunked_view<std::decay_t<T>>::value, xchunk_iterator_view<T>, invalid_chunk_iterator>>
144 {
145 };
146 }
147
148 template <class E>
149 class xchunk_iterator : private detail::xchunk_iterator_base<E>
150 {
151 public:
152
153 using base_type = detail::xchunk_iterator_base<E>;
155 using size_type = typename E::size_type;
156 using shape_type = typename E::shape_type;
157 using slice_vector = xstrided_slice_vector;
158
159 using reference = typename base_type::reference;
160 using value_type = std::remove_reference_t<reference>;
161 using pointer = value_type*;
162 using difference_type = typename E::difference_type;
163 using iterator_category = std::forward_iterator_tag;
164
165
166 xchunk_iterator() = default;
167 xchunk_iterator(E& chunked_expression, shape_type&& chunk_index, size_type chunk_linear_index);
168
169 self_type& operator++();
170 self_type operator++(int);
171 decltype(auto) operator*() const;
172
173 bool operator==(const self_type& rhs) const;
174 bool operator!=(const self_type& rhs) const;
175
176 const shape_type& chunk_index() const;
177
178 const slice_vector& get_slice_vector() const;
179 slice_vector get_chunk_slice_vector() const;
180
181 private:
182
183 void fill_slice_vector(size_type index);
184
185 E* p_chunked_expression;
186 shape_type m_chunk_index;
187 size_type m_chunk_linear_index;
188 xstrided_slice_vector m_slice_vector;
189 };
190
191 /************************************
192 * xchunked_semantic implementation *
193 ************************************/
194
195 template <class T, class CS>
196 template <class E, class DST>
198 {
199 temporary_type tmp(e, CS(), dst.chunk_shape());
200 dst = std::move(tmp);
201 }
202
203 template <class D>
204 template <class E>
205 inline auto xchunked_semantic<D>::assign_xexpression(const xexpression<E>& e) -> derived_type&
206 {
207 auto& d = this->derived_cast();
208 const auto& chunk_shape = d.chunk_shape();
209 size_t i = 0;
210 auto it_end = d.chunk_end();
211 for (auto it = d.chunk_begin(); it != it_end; ++it, ++i)
212 {
213 auto rhs = strided_view(e.derived_cast(), it.get_slice_vector());
214 if (rhs.shape() != chunk_shape)
215 {
216 noalias(strided_view(*it, it.get_chunk_slice_vector())) = rhs;
217 }
218 else
219 {
220 noalias(*it) = rhs;
221 }
222 }
223
224 return this->derived_cast();
225 }
226
227 template <class D>
228 template <class E>
229 inline auto xchunked_semantic<D>::computed_assign(const xexpression<E>& e) -> derived_type&
230 {
231 D& d = this->derived_cast();
232 if (e.derived_cast().dimension() > d.dimension() || e.derived_cast().shape() > d.shape())
233 {
234 return operator=(e);
235 }
236 else
237 {
238 return assign_xexpression(e);
239 }
240 }
241
242 template <class D>
243 template <class E, class F>
244 inline auto xchunked_semantic<D>::scalar_computed_assign(const E& e, F&& f) -> derived_type&
245 {
246 for (auto& c : this->derived_cast().chunks())
247 {
248 c.scalar_computed_assign(e, f);
249 }
250 return this->derived_cast();
251 }
252
253 template <class D>
254 template <class E>
255 inline auto xchunked_semantic<D>::operator=(const xexpression<E>& e) -> derived_type&
256 {
257 D& d = this->derived_cast();
258 get_assigner(d.chunks()).build_and_assign_temporary(e, d);
259 return d;
260 }
261
262 template <class D>
263 template <class CS>
264 inline auto xchunked_semantic<D>::get_assigner(const CS&) const -> xchunked_assigner<temporary_type, CS>
265 {
266 return xchunked_assigner<temporary_type, CS>();
267 }
268
269 /**********************************
270 * xchunk_iterator implementation *
271 **********************************/
272
273 template <class E>
274 inline xchunk_iterator<E>::xchunk_iterator(E& expression, shape_type&& chunk_index, size_type chunk_linear_index)
275 : p_chunked_expression(&expression)
276 , m_chunk_index(std::move(chunk_index))
277 , m_chunk_linear_index(chunk_linear_index)
278 , m_slice_vector(m_chunk_index.size())
279 {
280 for (size_type i = 0; i < m_chunk_index.size(); ++i)
281 {
282 fill_slice_vector(i);
283 }
284 }
285
286 template <class E>
287 inline xchunk_iterator<E>& xchunk_iterator<E>::operator++()
288 {
289 if (m_chunk_linear_index + 1u != p_chunked_expression->grid_size())
290 {
291 size_type i = p_chunked_expression->dimension();
292 while (i != 0)
293 {
294 --i;
295 if (m_chunk_index[i] + 1u == p_chunked_expression->grid_shape()[i])
296 {
297 m_chunk_index[i] = 0;
298 fill_slice_vector(i);
299 }
300 else
301 {
302 m_chunk_index[i] += 1;
303 fill_slice_vector(i);
304 break;
305 }
306 }
307 }
308 m_chunk_linear_index++;
309 return *this;
310 }
311
312 template <class E>
313 inline xchunk_iterator<E> xchunk_iterator<E>::operator++(int)
314 {
315 xchunk_iterator<E> it = *this;
316 ++(*this);
317 return it;
318 }
319
320 template <class E>
321 inline decltype(auto) xchunk_iterator<E>::operator*() const
322 {
323 return base_type::get_chunk(*p_chunked_expression, m_chunk_linear_index, m_slice_vector);
324 }
325
326 template <class E>
327 inline bool xchunk_iterator<E>::operator==(const xchunk_iterator& other) const
328 {
329 return m_chunk_linear_index == other.m_chunk_linear_index;
330 }
331
332 template <class E>
333 inline bool xchunk_iterator<E>::operator!=(const xchunk_iterator& other) const
334 {
335 return !(*this == other);
336 }
337
338 template <class E>
339 inline auto xchunk_iterator<E>::get_slice_vector() const -> const slice_vector&
340 {
341 return m_slice_vector;
342 }
343
344 template <class E>
345 auto xchunk_iterator<E>::chunk_index() const -> const shape_type&
346 {
347 return m_chunk_index;
348 }
349
350 template <class E>
351 inline auto xchunk_iterator<E>::get_chunk_slice_vector() const -> slice_vector
352 {
353 slice_vector slices(m_chunk_index.size());
354 for (size_type i = 0; i < m_chunk_index.size(); ++i)
355 {
356 size_type chunk_shape = p_chunked_expression->chunk_shape()[i];
357 size_type end = std::min(
358 chunk_shape,
359 p_chunked_expression->shape()[i] - m_chunk_index[i] * chunk_shape
360 );
361 slices[i] = range(0u, end);
362 }
363 return slices;
364 }
365
366 template <class E>
367 inline void xchunk_iterator<E>::fill_slice_vector(size_type i)
368 {
369 size_type range_start = m_chunk_index[i] * p_chunked_expression->chunk_shape()[i];
370 size_type range_end = std::min(
371 (m_chunk_index[i] + 1) * p_chunked_expression->chunk_shape()[i],
372 p_chunked_expression->shape()[i]
373 );
374 m_slice_vector[i] = range(range_start, range_end);
375 }
376}
377
378#endif
Base interface for assignable xexpressions.
Definition xsemantic.hpp:58
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
Definition xslice.hpp:818
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.
Definition xview.hpp:1834