xtensor
 
Loading...
Searching...
No Matches
xbroadcast.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_BROADCAST_HPP
11#define XTENSOR_BROADCAST_HPP
12
13#include <algorithm>
14#include <array>
15#include <cstddef>
16#include <iterator>
17#include <numeric>
18#include <type_traits>
19#include <utility>
20
21#include <xtl/xsequence.hpp>
22
23#include "../containers/xscalar.hpp"
24#include "../core/xaccessible.hpp"
25#include "../core/xexpression.hpp"
26#include "../core/xiterable.hpp"
27#include "../core/xstrides.hpp"
28#include "../core/xtensor_config.hpp"
29#include "../utils/xutils.hpp"
30
31namespace xt
32{
33
34 /*************
35 * broadcast *
36 *************/
37
38 template <class E, class S>
39 auto broadcast(E&& e, const S& s);
40
41 template <class E, class I, std::size_t L>
42 auto broadcast(E&& e, const I (&s)[L]);
43
44 /*************************
45 * xbroadcast extensions *
46 *************************/
47
48 namespace extension
49 {
50 template <class Tag, class CT, class X>
52
53 template <class CT, class X>
55 {
56 using type = xtensor_empty_base;
57 };
58
59 template <class CT, class X>
60 struct xbroadcast_base : xbroadcast_base_impl<xexpression_tag_t<CT>, CT, X>
61 {
62 };
63
64 template <class CT, class X>
65 using xbroadcast_base_t = typename xbroadcast_base<CT, X>::type;
66 }
67
68 /**************
69 * xbroadcast *
70 **************/
71
72 template <class CT, class X>
73 class xbroadcast;
74
75 template <class E>
77
78 template <class CT, class X>
80 {
81 using xexpression_type = std::decay_t<CT>;
82 using inner_shape_type = promote_shape_t<typename xexpression_type::shape_type, X>;
83 using const_stepper = typename xexpression_type::const_stepper;
84 using stepper = const_stepper;
85 };
86
87 template <class CT, class X>
89 {
90 using xexpression_type = std::decay_t<CT>;
91 using reference = typename xexpression_type::const_reference;
92 using const_reference = typename xexpression_type::const_reference;
93 using size_type = typename xexpression_type::size_type;
94 };
95
96 /*****************************
97 * linear_begin / linear_end *
98 *****************************/
99
100 template <class CT, class X>
101 XTENSOR_CONSTEXPR_RETURN auto linear_begin(xbroadcast<CT, X>& c) noexcept
102 {
103 return linear_begin(c.expression());
104 }
105
106 template <class CT, class X>
107 XTENSOR_CONSTEXPR_RETURN auto linear_end(xbroadcast<CT, X>& c) noexcept
108 {
109 return linear_end(c.expression());
110 }
111
112 template <class CT, class X>
113 XTENSOR_CONSTEXPR_RETURN auto linear_begin(const xbroadcast<CT, X>& c) noexcept
114 {
115 return linear_begin(c.expression());
116 }
117
118 template <class CT, class X>
119 XTENSOR_CONSTEXPR_RETURN auto linear_end(const xbroadcast<CT, X>& c) noexcept
120 {
121 return linear_end(c.expression());
122 }
123
124 /*************************************
125 * overlapping_memory_checker_traits *
126 *************************************/
127
128 template <xbroadcast_concept E>
131 {
132 static bool check_overlap(const E& expr, const memory_range& dst_range)
133 {
134 if (expr.size() == 0)
135 {
136 return false;
137 }
138 else
139 {
140 using ChildE = std::decay_t<decltype(expr.expression())>;
141 return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
142 }
143 }
144 };
145
159 template <class CT, class X>
160 class xbroadcast : public xsharable_expression<xbroadcast<CT, X>>,
161 public xconst_iterable<xbroadcast<CT, X>>,
162 public xconst_accessible<xbroadcast<CT, X>>,
163 public extension::xbroadcast_base_t<CT, X>
164 {
165 public:
166
167 using self_type = xbroadcast<CT, X>;
168 using xexpression_type = std::decay_t<CT>;
169 using accessible_base = xconst_accessible<self_type>;
170 using extension_base = extension::xbroadcast_base_t<CT, X>;
171 using expression_tag = typename extension_base::expression_tag;
172
173 using inner_types = xcontainer_inner_types<self_type>;
174 using value_type = typename xexpression_type::value_type;
175 using reference = typename inner_types::reference;
176 using const_reference = typename inner_types::const_reference;
177 using pointer = typename xexpression_type::const_pointer;
178 using const_pointer = typename xexpression_type::const_pointer;
179 using size_type = typename inner_types::size_type;
180 using difference_type = typename xexpression_type::difference_type;
181
182 using iterable_base = xconst_iterable<self_type>;
183 using inner_shape_type = typename iterable_base::inner_shape_type;
184 using shape_type = inner_shape_type;
185
186 using stepper = typename iterable_base::stepper;
187 using const_stepper = typename iterable_base::const_stepper;
188
189 using bool_load_type = typename xexpression_type::bool_load_type;
190
191 static constexpr layout_type static_layout = layout_type::dynamic;
192 static constexpr bool contiguous_layout = false;
193
194 template <class CTA, class S>
195 xbroadcast(CTA&& e, const S& s);
196
197 template <class CTA>
198 xbroadcast(CTA&& e, shape_type&& s);
199
201 const inner_shape_type& shape() const noexcept;
202 layout_type layout() const noexcept;
203 bool is_contiguous() const noexcept;
204 using accessible_base::shape;
205
206 template <class... Args>
207 const_reference operator()(Args... args) const;
208
209 template <class... Args>
210 const_reference unchecked(Args... args) const;
211
212 template <class It>
213 const_reference element(It first, It last) const;
214
215 const xexpression_type& expression() const noexcept;
216
217 template <class S>
218 bool broadcast_shape(S& shape, bool reuse_cache = false) const;
219
220 template <class S>
221 bool has_linear_assign(const S& strides) const noexcept;
222
223 template <class S>
224 const_stepper stepper_begin(const S& shape) const noexcept;
225 template <class S>
226 const_stepper stepper_end(const S& shape, layout_type l) const noexcept;
227
228 template <class E, xscalar_concept XCT = CT>
229 void assign_to(xexpression<E>& e) const;
230
231 template <class E>
232 using rebind_t = xbroadcast<E, X>;
233
234 template <class E>
235 rebind_t<E> build_broadcast(E&& e) const;
236
237 private:
238
239 CT m_e;
240 inner_shape_type m_shape;
241 };
242
243 /****************************
244 * broadcast implementation *
245 ****************************/
246
257 template <class E, class S>
258 inline auto broadcast(E&& e, const S& s)
259 {
260 using shape_type = filter_fixed_shape_t<std::decay_t<S>>;
261 using broadcast_type = xbroadcast<const_xclosure_t<E>, shape_type>;
262 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
263 }
264
265 template <class E, class I, std::size_t L>
266 inline auto broadcast(E&& e, const I (&s)[L])
267 {
268 using broadcast_type = xbroadcast<const_xclosure_t<E>, std::array<std::size_t, L>>;
269 using shape_type = typename broadcast_type::shape_type;
270 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
271 }
272
273 /*****************************
274 * xbroadcast implementation *
275 *****************************/
276
281
288 template <class CT, class X>
289 template <class CTA, class S>
290 inline xbroadcast<CT, X>::xbroadcast(CTA&& e, const S& s)
291 : m_e(std::forward<CTA>(e))
292 {
293 if (s.size() < m_e.dimension())
294 {
295 XTENSOR_THROW(xt::broadcast_error, "Broadcast shape has fewer elements than original expression.");
296 }
297 xt::resize_container(m_shape, s.size());
298 std::copy(s.begin(), s.end(), m_shape.begin());
299 xt::broadcast_shape(m_e.shape(), m_shape);
300 }
301
309 template <class CT, class X>
310 template <class CTA>
311 inline xbroadcast<CT, X>::xbroadcast(CTA&& e, shape_type&& s)
312 : m_e(std::forward<CTA>(e))
313 , m_shape(std::move(s))
314 {
315 xt::broadcast_shape(m_e.shape(), m_shape);
316 }
317
319
324
327 template <class CT, class X>
328 inline auto xbroadcast<CT, X>::shape() const noexcept -> const inner_shape_type&
329 {
330 return m_shape;
331 }
332
336 template <class CT, class X>
338 {
339 return m_e.layout();
340 }
341
342 template <class CT, class X>
343 inline bool xbroadcast<CT, X>::is_contiguous() const noexcept
344 {
345 return false;
346 }
347
349
354
360 template <class CT, class X>
361 template <class... Args>
362 inline auto xbroadcast<CT, X>::operator()(Args... args) const -> const_reference
363 {
364 return m_e(args...);
365 }
366
386 template <class CT, class X>
387 template <class... Args>
388 inline auto xbroadcast<CT, X>::unchecked(Args... args) const -> const_reference
389 {
390 return this->operator()(args...);
391 }
392
400 template <class CT, class X>
401 template <class It>
402 inline auto xbroadcast<CT, X>::element(It, It last) const -> const_reference
403 {
404 return m_e.element(last - this->dimension(), last);
405 }
406
410 template <class CT, class X>
411 inline auto xbroadcast<CT, X>::expression() const noexcept -> const xexpression_type&
412 {
413 return m_e;
414 }
415
417
422
428 template <class CT, class X>
429 template <class S>
430 inline bool xbroadcast<CT, X>::broadcast_shape(S& shape, bool) const
431 {
432 return xt::broadcast_shape(m_shape, shape);
433 }
434
440 template <class CT, class X>
441 template <class S>
442 inline bool xbroadcast<CT, X>::has_linear_assign(const S& strides) const noexcept
443 {
444 return this->dimension() == m_e.dimension()
445 && std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin())
446 && m_e.has_linear_assign(strides);
447 }
448
450
451 template <class CT, class X>
452 template <class S>
453 inline auto xbroadcast<CT, X>::stepper_begin(const S& shape) const noexcept -> const_stepper
454 {
455 // Could check if (broadcastable(shape, m_shape)
456 return m_e.stepper_begin(shape);
457 }
458
459 template <class CT, class X>
460 template <class S>
461 inline auto xbroadcast<CT, X>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper
462 {
463 // Could check if (broadcastable(shape, m_shape)
464 return m_e.stepper_end(shape, l);
465 }
466
467 template <class CT, class X>
468 template <class E, xscalar_concept XCT>
469 inline void xbroadcast<CT, X>::assign_to(xexpression<E>& e) const
470 {
471 auto& ed = e.derived_cast();
472 ed.resize(m_shape);
473 std::fill(ed.begin(), ed.end(), m_e());
474 }
475
476 template <class CT, class X>
477 template <class E>
478 inline auto xbroadcast<CT, X>::build_broadcast(E&& e) const -> rebind_t<E>
479 {
480 return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape));
481 }
482}
483
484#endif
Broadcasted xexpression to a specified shape.
const xexpression_type & expression() const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
layout_type layout() const noexcept
bool broadcast_shape(S &shape, bool reuse_cache=false) const
xbroadcast(CTA &&e, shape_type &&s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
xbroadcast(CTA &&e, const S &s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
bool has_linear_assign(const S &strides) const noexcept
size_type size() const noexcept
size_type dimension() const noexcept
Returns the number of dimensions of the expression.
Base class for multidimensional iterable constant expressions.
Definition xiterable.hpp:37
Base class for xexpressions.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.
layout_type
Definition xlayout.hpp:24