xtensor
Loading...
Searching...
No Matches
xbroadcast.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_BROADCAST_HPP
11#define XTENSOR_BROADCAST_HPP
12
13#include <algorithm>
14#include <array>
15#include <cstddef>
16#include <iterator>
17#include <numeric>
18#include <type_traits>
19#include <utility>
20
21#include <xtl/xsequence.hpp>
22
23#include "xaccessible.hpp"
24#include "xexpression.hpp"
25#include "xiterable.hpp"
26#include "xscalar.hpp"
27#include "xstrides.hpp"
28#include "xtensor_config.hpp"
29#include "xutils.hpp"
30
31namespace xt
32{
33
34 /*************
35 * broadcast *
36 *************/
37
38 template <class E, class S>
39 auto broadcast(E&& e, const S& s);
40
41 template <class E, class I, std::size_t L>
42 auto broadcast(E&& e, const I (&s)[L]);
43
44 /*************************
45 * xbroadcast extensions *
46 *************************/
47
48 namespace extension
49 {
50 template <class Tag, class CT, class X>
52
53 template <class CT, class X>
58
59 template <class CT, class X>
60 struct xbroadcast_base : xbroadcast_base_impl<xexpression_tag_t<CT>, CT, X>
61 {
62 };
63
64 template <class CT, class X>
65 using xbroadcast_base_t = typename xbroadcast_base<CT, X>::type;
66 }
67
68 /**************
69 * xbroadcast *
70 **************/
71
72 template <class CT, class X>
73 class xbroadcast;
74
75 template <class CT, class X>
77 {
78 using xexpression_type = std::decay_t<CT>;
80 using const_stepper = typename xexpression_type::const_stepper;
81 using stepper = const_stepper;
82 };
83
84 template <class CT, class X>
86 {
87 using xexpression_type = std::decay_t<CT>;
88 using reference = typename xexpression_type::const_reference;
89 using const_reference = typename xexpression_type::const_reference;
90 using size_type = typename xexpression_type::size_type;
91 };
92
93 /*****************************
94 * linear_begin / linear_end *
95 *****************************/
96
97 template <class CT, class X>
98 XTENSOR_CONSTEXPR_RETURN auto linear_begin(xbroadcast<CT, X>& c) noexcept
99 {
100 return linear_begin(c.expression());
101 }
102
103 template <class CT, class X>
104 XTENSOR_CONSTEXPR_RETURN auto linear_end(xbroadcast<CT, X>& c) noexcept
105 {
106 return linear_end(c.expression());
107 }
108
109 template <class CT, class X>
110 XTENSOR_CONSTEXPR_RETURN auto linear_begin(const xbroadcast<CT, X>& c) noexcept
111 {
112 return linear_begin(c.expression());
113 }
114
115 template <class CT, class X>
116 XTENSOR_CONSTEXPR_RETURN auto linear_end(const xbroadcast<CT, X>& c) noexcept
117 {
118 return linear_end(c.expression());
119 }
120
121 /*************************************
122 * overlapping_memory_checker_traits *
123 *************************************/
124
125 template <class E>
127 E,
128 std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
129 {
130 static bool check_overlap(const E& expr, const memory_range& dst_range)
131 {
132 if (expr.size() == 0)
133 {
134 return false;
135 }
136 else
137 {
138 using ChildE = std::decay_t<decltype(expr.expression())>;
140 }
141 }
142 };
143
157 template <class CT, class X>
158 class xbroadcast : public xsharable_expression<xbroadcast<CT, X>>,
159 public xconst_iterable<xbroadcast<CT, X>>,
160 public xconst_accessible<xbroadcast<CT, X>>,
161 public extension::xbroadcast_base_t<CT, X>
162 {
163 public:
164
166 using xexpression_type = std::decay_t<CT>;
168 using extension_base = extension::xbroadcast_base_t<CT, X>;
169 using expression_tag = typename extension_base::expression_tag;
170
172 using value_type = typename xexpression_type::value_type;
173 using reference = typename inner_types::reference;
174 using const_reference = typename inner_types::const_reference;
175 using pointer = typename xexpression_type::const_pointer;
176 using const_pointer = typename xexpression_type::const_pointer;
177 using size_type = typename inner_types::size_type;
178 using difference_type = typename xexpression_type::difference_type;
179
181 using inner_shape_type = typename iterable_base::inner_shape_type;
182 using shape_type = inner_shape_type;
183
184 using stepper = typename iterable_base::stepper;
185 using const_stepper = typename iterable_base::const_stepper;
186
187 using bool_load_type = typename xexpression_type::bool_load_type;
188
189 static constexpr layout_type static_layout = layout_type::dynamic;
190 static constexpr bool contiguous_layout = false;
191
192 template <class CTA, class S>
193 xbroadcast(CTA&& e, const S& s);
194
195 template <class CTA>
196 xbroadcast(CTA&& e, shape_type&& s);
197
199 const inner_shape_type& shape() const noexcept;
200 layout_type layout() const noexcept;
201 bool is_contiguous() const noexcept;
203
204 template <class... Args>
205 const_reference operator()(Args... args) const;
206
207 template <class... Args>
208 const_reference unchecked(Args... args) const;
209
210 template <class It>
211 const_reference element(It first, It last) const;
212
213 const xexpression_type& expression() const noexcept;
214
215 template <class S>
216 bool broadcast_shape(S& shape, bool reuse_cache = false) const;
217
218 template <class S>
219 bool has_linear_assign(const S& strides) const noexcept;
220
221 template <class S>
222 const_stepper stepper_begin(const S& shape) const noexcept;
223 template <class S>
224 const_stepper stepper_end(const S& shape, layout_type l) const noexcept;
225
227 void assign_to(xexpression<E>& e) const;
228
229 template <class E>
231
232 template <class E>
233 rebind_t<E> build_broadcast(E&& e) const;
234
235 private:
236
237 CT m_e;
238 inner_shape_type m_shape;
239 };
240
241 /****************************
242 * broadcast implementation *
243 ****************************/
244
255 template <class E, class S>
256 inline auto broadcast(E&& e, const S& s)
257 {
258 using shape_type = filter_fixed_shape_t<std::decay_t<S>>;
260 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
261 }
262
263 template <class E, class I, std::size_t L>
264 inline auto broadcast(E&& e, const I (&s)[L])
265 {
266 using broadcast_type = xbroadcast<const_xclosure_t<E>, std::array<std::size_t, L>>;
267 using shape_type = typename broadcast_type::shape_type;
268 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
269 }
270
271 /*****************************
272 * xbroadcast implementation *
273 *****************************/
274
286 template <class CT, class X>
287 template <class CTA, class S>
289 : m_e(std::forward<CTA>(e))
290 {
291 if (s.size() < m_e.dimension())
292 {
293 XTENSOR_THROW(xt::broadcast_error, "Broadcast shape has fewer elements than original expression.");
294 }
295 xt::resize_container(m_shape, s.size());
296 std::copy(s.begin(), s.end(), m_shape.begin());
297 xt::broadcast_shape(m_e.shape(), m_shape);
298 }
299
307 template <class CT, class X>
308 template <class CTA>
309 inline xbroadcast<CT, X>::xbroadcast(CTA&& e, shape_type&& s)
310 : m_e(std::forward<CTA>(e))
311 , m_shape(std::move(s))
312 {
313 xt::broadcast_shape(m_e.shape(), m_shape);
314 }
315
317
325 template <class CT, class X>
326 inline auto xbroadcast<CT, X>::shape() const noexcept -> const inner_shape_type&
327 {
328 return m_shape;
329 }
330
334 template <class CT, class X>
336 {
337 return m_e.layout();
338 }
339
340 template <class CT, class X>
342 {
343 return false;
344 }
345
347
358 template <class CT, class X>
359 template <class... Args>
360 inline auto xbroadcast<CT, X>::operator()(Args... args) const -> const_reference
361 {
362 return m_e(args...);
363 }
364
384 template <class CT, class X>
385 template <class... Args>
386 inline auto xbroadcast<CT, X>::unchecked(Args... args) const -> const_reference
387 {
388 return this->operator()(args...);
389 }
390
398 template <class CT, class X>
399 template <class It>
400 inline auto xbroadcast<CT, X>::element(It, It last) const -> const_reference
401 {
402 return m_e.element(last - this->dimension(), last);
403 }
404
408 template <class CT, class X>
409 inline auto xbroadcast<CT, X>::expression() const noexcept -> const xexpression_type&
410 {
411 return m_e;
412 }
413
415
426 template <class CT, class X>
427 template <class S>
428 inline bool xbroadcast<CT, X>::broadcast_shape(S& shape, bool) const
429 {
430 return xt::broadcast_shape(m_shape, shape);
431 }
432
438 template <class CT, class X>
439 template <class S>
440 inline bool xbroadcast<CT, X>::has_linear_assign(const S& strides) const noexcept
441 {
442 return this->dimension() == m_e.dimension()
443 && std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin())
444 && m_e.has_linear_assign(strides);
445 }
446
448
449 template <class CT, class X>
450 template <class S>
451 inline auto xbroadcast<CT, X>::stepper_begin(const S& shape) const noexcept -> const_stepper
452 {
453 // Could check if (broadcastable(shape, m_shape)
454 return m_e.stepper_begin(shape);
455 }
456
457 template <class CT, class X>
458 template <class S>
459 inline auto xbroadcast<CT, X>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper
460 {
461 // Could check if (broadcastable(shape, m_shape)
462 return m_e.stepper_end(shape, l);
463 }
464
465 template <class CT, class X>
466 template <class E, class XCT, class>
467 inline void xbroadcast<CT, X>::assign_to(xexpression<E>& e) const
468 {
469 auto& ed = e.derived_cast();
470 ed.resize(m_shape);
471 std::fill(ed.begin(), ed.end(), m_e());
472 }
473
474 template <class CT, class X>
475 template <class E>
476 inline auto xbroadcast<CT, X>::build_broadcast(E&& e) const -> rebind_t<E>
477 {
478 return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape));
479 }
480}
481
482#endif
Broadcasted xexpression to a specified shape.
const xexpression_type & expression() const noexcept
Returns a constant reference to the underlying expression of the broadcast expression.
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
layout_type layout() const noexcept
Returns the layout_type of the expression.
bool broadcast_shape(S &shape, bool reuse_cache=false) const
Broadcast the shape of the function to the specified parameter.
xbroadcast(CTA &&e, const S &s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
bool has_linear_assign(const S &strides) const noexcept
Checks whether the xbroadcast can be linearly assigned to an expression with the specified strides.
Base class for implementation of common expression constant access methods.
size_type size() const noexcept
Returns the size of the expression.
size_type shape(size_type index) const
Returns the i-th dimension of the expression.
Base class for multidimensional iterable constant expressions.
Definition xiterable.hpp:37
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.
layout_type
Definition xlayout.hpp:24