xtensor
 
Loading...
Searching...
No Matches
xbroadcast.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_BROADCAST_HPP
11#define XTENSOR_BROADCAST_HPP
12
13#include <algorithm>
14#include <array>
15#include <cstddef>
16#include <type_traits>
17#include <utility>
18
19#include <xtl/xsequence.hpp>
20
21#include "../containers/xscalar.hpp"
22#include "../core/xaccessible.hpp"
23#include "../core/xexpression.hpp"
24#include "../core/xiterable.hpp"
25#include "../core/xstrides.hpp"
26#include "../core/xtensor_config.hpp"
27#include "../utils/xutils.hpp"
28
29namespace xt
30{
31
32 /*************
33 * broadcast *
34 *************/
35
36 template <class E, class S>
37 auto broadcast(E&& e, const S& s);
38
39 template <class E, class I, std::size_t L>
40 auto broadcast(E&& e, const I (&s)[L]);
41
42 /*************************
43 * xbroadcast extensions *
44 *************************/
45
46 namespace extension
47 {
48 template <class Tag, class CT, class X>
50
51 template <class CT, class X>
53 {
54 using type = xtensor_empty_base;
55 };
56
57 template <class CT, class X>
58 struct xbroadcast_base : xbroadcast_base_impl<xexpression_tag_t<CT>, CT, X>
59 {
60 };
61
62 template <class CT, class X>
63 using xbroadcast_base_t = typename xbroadcast_base<CT, X>::type;
64 }
65
66 /**************
67 * xbroadcast *
68 **************/
69
70 template <class CT, class X>
71 class xbroadcast;
72
73 template <class E>
75
76 template <class CT, class X>
78 {
79 using xexpression_type = std::decay_t<CT>;
80 using inner_shape_type = promote_shape_t<typename xexpression_type::shape_type, X>;
81 using const_stepper = typename xexpression_type::const_stepper;
82 using stepper = const_stepper;
83 };
84
85 template <class CT, class X>
87 {
88 using xexpression_type = std::decay_t<CT>;
89 using reference = typename xexpression_type::const_reference;
90 using const_reference = typename xexpression_type::const_reference;
91 using size_type = typename xexpression_type::size_type;
92 };
93
94 /*****************************
95 * linear_begin / linear_end *
96 *****************************/
97
98 template <class CT, class X>
99 constexpr auto linear_begin(xbroadcast<CT, X>& c) noexcept
100 {
101 return linear_begin(c.expression());
102 }
103
104 template <class CT, class X>
105 constexpr auto linear_end(xbroadcast<CT, X>& c) noexcept
106 {
107 return linear_end(c.expression());
108 }
109
110 template <class CT, class X>
111 constexpr auto linear_begin(const xbroadcast<CT, X>& c) noexcept
112 {
113 return linear_begin(c.expression());
114 }
115
116 template <class CT, class X>
117 constexpr auto linear_end(const xbroadcast<CT, X>& c) noexcept
118 {
119 return linear_end(c.expression());
120 }
121
122 /*************************************
123 * overlapping_memory_checker_traits *
124 *************************************/
125
126 template <xbroadcast_concept E>
129 {
130 static bool check_overlap(const E& expr, const memory_range& dst_range)
131 {
132 if (expr.size() == 0)
133 {
134 return false;
135 }
136 else
137 {
138 using ChildE = std::decay_t<decltype(expr.expression())>;
139 return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
140 }
141 }
142 };
143
157 template <class CT, class X>
158 class xbroadcast : public xsharable_expression<xbroadcast<CT, X>>,
159 public xconst_iterable<xbroadcast<CT, X>>,
160 public xconst_accessible<xbroadcast<CT, X>>,
161 public extension::xbroadcast_base_t<CT, X>
162 {
163 public:
164
165 using self_type = xbroadcast<CT, X>;
166 using xexpression_type = std::decay_t<CT>;
167 using accessible_base = xconst_accessible<self_type>;
168 using extension_base = extension::xbroadcast_base_t<CT, X>;
169 using expression_tag = typename extension_base::expression_tag;
170
171 using inner_types = xcontainer_inner_types<self_type>;
172 using value_type = typename xexpression_type::value_type;
173 using reference = typename inner_types::reference;
174 using const_reference = typename inner_types::const_reference;
175 using pointer = typename xexpression_type::const_pointer;
176 using const_pointer = typename xexpression_type::const_pointer;
177 using size_type = typename inner_types::size_type;
178 using difference_type = typename xexpression_type::difference_type;
179
180 using iterable_base = xconst_iterable<self_type>;
181 using inner_shape_type = typename iterable_base::inner_shape_type;
182 using shape_type = inner_shape_type;
183
184 using stepper = typename iterable_base::stepper;
185 using const_stepper = typename iterable_base::const_stepper;
186
187 using bool_load_type = typename xexpression_type::bool_load_type;
188
189 static constexpr layout_type static_layout = layout_type::dynamic;
190 static constexpr bool contiguous_layout = false;
191
192 template <class CTA, class S>
193 xbroadcast(CTA&& e, const S& s);
194
195 template <class CTA>
196 xbroadcast(CTA&& e, shape_type&& s);
197
199 const inner_shape_type& shape() const noexcept;
200 layout_type layout() const noexcept;
201 bool is_contiguous() const noexcept;
202 using accessible_base::shape;
203
204 template <class... Args>
205 const_reference operator()(Args... args) const;
206
207 template <class... Args>
208 const_reference unchecked(Args... args) const;
209
210 template <class It>
211 const_reference element(It first, It last) const;
212
213 const xexpression_type& expression() const noexcept;
214
215 template <class S>
216 bool broadcast_shape(S& shape, bool reuse_cache = false) const;
217
218 template <class S>
219 bool has_linear_assign(const S& strides) const noexcept;
220
221 template <class S>
222 const_stepper stepper_begin(const S& shape) const noexcept;
223 template <class S>
224 const_stepper stepper_end(const S& shape, layout_type l) const noexcept;
225
226 template <class E, xscalar_concept XCT = CT>
227 void assign_to(xexpression<E>& e) const;
228
229 template <class E>
230 using rebind_t = xbroadcast<E, X>;
231
232 template <class E>
233 rebind_t<E> build_broadcast(E&& e) const;
234
235 private:
236
237 CT m_e;
238 inner_shape_type m_shape;
239 };
240
241 /****************************
242 * broadcast implementation *
243 ****************************/
244
255 template <class E, class S>
256 inline auto broadcast(E&& e, const S& s)
257 {
258 using shape_type = filter_fixed_shape_t<std::decay_t<S>>;
259 using broadcast_type = xbroadcast<const_xclosure_t<E>, shape_type>;
260 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
261 }
262
263 template <class E, class I, std::size_t L>
264 inline auto broadcast(E&& e, const I (&s)[L])
265 {
266 using broadcast_type = xbroadcast<const_xclosure_t<E>, std::array<std::size_t, L>>;
267 using shape_type = typename broadcast_type::shape_type;
268 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(s)>(s));
269 }
270
271 /*****************************
272 * xbroadcast implementation *
273 *****************************/
274
279
286 template <class CT, class X>
287 template <class CTA, class S>
288 inline xbroadcast<CT, X>::xbroadcast(CTA&& e, const S& s)
289 : m_e(std::forward<CTA>(e))
290 {
291 if (s.size() < m_e.dimension())
292 {
293 XTENSOR_THROW(xt::broadcast_error, "Broadcast shape has fewer elements than original expression.");
294 }
295 xt::resize_container(m_shape, s.size());
296 std::copy(s.begin(), s.end(), m_shape.begin());
297 xt::broadcast_shape(m_e.shape(), m_shape);
298 }
299
307 template <class CT, class X>
308 template <class CTA>
309 inline xbroadcast<CT, X>::xbroadcast(CTA&& e, shape_type&& s)
310 : m_e(std::forward<CTA>(e))
311 , m_shape(std::move(s))
312 {
313 xt::broadcast_shape(m_e.shape(), m_shape);
314 }
315
317
322
325 template <class CT, class X>
326 inline auto xbroadcast<CT, X>::shape() const noexcept -> const inner_shape_type&
327 {
328 return m_shape;
329 }
330
334 template <class CT, class X>
336 {
337 return m_e.layout();
338 }
339
340 template <class CT, class X>
341 inline bool xbroadcast<CT, X>::is_contiguous() const noexcept
342 {
343 return false;
344 }
345
347
352
358 template <class CT, class X>
359 template <class... Args>
360 inline auto xbroadcast<CT, X>::operator()(Args... args) const -> const_reference
361 {
362 return m_e(args...);
363 }
364
384 template <class CT, class X>
385 template <class... Args>
386 inline auto xbroadcast<CT, X>::unchecked(Args... args) const -> const_reference
387 {
388 return this->operator()(args...);
389 }
390
398 template <class CT, class X>
399 template <class It>
400 inline auto xbroadcast<CT, X>::element(It, It last) const -> const_reference
401 {
402 return m_e.element(last - this->dimension(), last);
403 }
404
408 template <class CT, class X>
409 inline auto xbroadcast<CT, X>::expression() const noexcept -> const xexpression_type&
410 {
411 return m_e;
412 }
413
415
420
426 template <class CT, class X>
427 template <class S>
428 inline bool xbroadcast<CT, X>::broadcast_shape(S& shape, bool) const
429 {
430 return xt::broadcast_shape(m_shape, shape);
431 }
432
438 template <class CT, class X>
439 template <class S>
440 inline bool xbroadcast<CT, X>::has_linear_assign(const S& strides) const noexcept
441 {
442 return this->dimension() == m_e.dimension()
443 && std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin())
444 && m_e.has_linear_assign(strides);
445 }
446
448
449 template <class CT, class X>
450 template <class S>
451 inline auto xbroadcast<CT, X>::stepper_begin(const S& shape) const noexcept -> const_stepper
452 {
453 // Could check if (broadcastable(shape, m_shape)
454 return m_e.stepper_begin(shape);
455 }
456
457 template <class CT, class X>
458 template <class S>
459 inline auto xbroadcast<CT, X>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper
460 {
461 // Could check if (broadcastable(shape, m_shape)
462 return m_e.stepper_end(shape, l);
463 }
464
465 template <class CT, class X>
466 template <class E, xscalar_concept XCT>
467 inline void xbroadcast<CT, X>::assign_to(xexpression<E>& e) const
468 {
469 auto& ed = e.derived_cast();
470 ed.resize(m_shape);
471 std::fill(ed.begin(), ed.end(), m_e());
472 }
473
474 template <class CT, class X>
475 template <class E>
476 inline auto xbroadcast<CT, X>::build_broadcast(E&& e) const -> rebind_t<E>
477 {
478 return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape));
479 }
480}
481
482#endif
Broadcasted xexpression to a specified shape.
const xexpression_type & expression() const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
layout_type layout() const noexcept
bool broadcast_shape(S &shape, bool reuse_cache=false) const
xbroadcast(CTA &&e, shape_type &&s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
xbroadcast(CTA &&e, const S &s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
bool has_linear_assign(const S &strides) const noexcept
size_type size() const noexcept(noexcept(derived_cast().shape()))
size_type dimension() const noexcept
Returns the number of dimensions of the expression.
Base class for multidimensional iterable constant expressions.
Definition xiterable.hpp:37
Base class for xexpressions.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:250
standard mathematical functions for xexpressions
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.
layout_type
Definition xlayout.hpp:24