10#ifndef XTENSOR_BROADCAST_HPP
11#define XTENSOR_BROADCAST_HPP
21#include <xtl/xsequence.hpp>
23#include "../containers/xscalar.hpp"
24#include "../core/xaccessible.hpp"
25#include "../core/xexpression.hpp"
26#include "../core/xiterable.hpp"
27#include "../core/xstrides.hpp"
28#include "../core/xtensor_config.hpp"
29#include "../utils/xutils.hpp"
38 template <
class E,
class S>
41 template <
class E,
class I, std::
size_t L>
50 template <
class Tag,
class CT,
class X>
53 template <
class CT,
class X>
59 template <
class CT,
class X>
64 template <
class CT,
class X>
72 template <
class CT,
class X>
78 template <
class CT,
class X>
81 using xexpression_type = std::decay_t<CT>;
82 using inner_shape_type = promote_shape_t<typename xexpression_type::shape_type, X>;
83 using const_stepper =
typename xexpression_type::const_stepper;
84 using stepper = const_stepper;
87 template <
class CT,
class X>
90 using xexpression_type = std::decay_t<CT>;
91 using reference =
typename xexpression_type::const_reference;
92 using const_reference =
typename xexpression_type::const_reference;
93 using size_type =
typename xexpression_type::size_type;
100 template <
class CT,
class X>
103 return linear_begin(c.expression());
106 template <
class CT,
class X>
107 XTENSOR_CONSTEXPR_RETURN
auto linear_end(xbroadcast<CT, X>& c)
noexcept
109 return linear_end(c.expression());
112 template <
class CT,
class X>
115 return linear_begin(c.expression());
118 template <
class CT,
class X>
121 return linear_end(c.expression());
128 template <xbroadcast_concept E>
132 static bool check_overlap(
const E& expr,
const memory_range& dst_range)
134 if (expr.size() == 0)
140 using ChildE = std::decay_t<
decltype(expr.expression())>;
141 return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
159 template <
class CT,
class X>
160 class xbroadcast :
public xsharable_expression<xbroadcast<CT, X>>,
162 public xconst_accessible<xbroadcast<CT, X>>,
163 public extension::xbroadcast_base_t<CT, X>
168 using xexpression_type = std::decay_t<CT>;
169 using accessible_base = xconst_accessible<self_type>;
170 using extension_base = extension::xbroadcast_base_t<CT, X>;
171 using expression_tag =
typename extension_base::expression_tag;
174 using value_type =
typename xexpression_type::value_type;
175 using reference =
typename inner_types::reference;
176 using const_reference =
typename inner_types::const_reference;
177 using pointer =
typename xexpression_type::const_pointer;
178 using const_pointer =
typename xexpression_type::const_pointer;
179 using size_type =
typename inner_types::size_type;
180 using difference_type =
typename xexpression_type::difference_type;
183 using inner_shape_type =
typename iterable_base::inner_shape_type;
184 using shape_type = inner_shape_type;
186 using stepper =
typename iterable_base::stepper;
187 using const_stepper =
typename iterable_base::const_stepper;
189 using bool_load_type =
typename xexpression_type::bool_load_type;
192 static constexpr bool contiguous_layout =
false;
194 template <
class CTA,
class S>
201 const inner_shape_type&
shape() const noexcept;
203 bool is_contiguous() const noexcept;
204 using accessible_base::
shape;
206 template <class... Args>
207 const_reference operator()(Args... args) const;
209 template <class... Args>
210 const_reference unchecked(Args... args) const;
213 const_reference element(It first, It last) const;
224 const_stepper stepper_begin(const S&
shape) const noexcept;
235 rebind_t<E> build_broadcast(E&& e) const;
240 inner_shape_type m_shape;
257 template <class E, class S>
260 using shape_type = filter_fixed_shape_t<std::decay_t<S>>;
262 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type,
decltype(s)>(s));
265 template <
class E,
class I, std::
size_t L>
266 inline auto broadcast(E&& e,
const I (&s)[L])
269 using shape_type =
typename broadcast_type::shape_type;
270 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type,
decltype(s)>(s));
288 template <
class CT,
class X>
289 template <
class CTA,
class S>
291 : m_e(std::forward<CTA>(e))
293 if (s.size() < m_e.dimension())
295 XTENSOR_THROW(xt::broadcast_error,
"Broadcast shape has fewer elements than original expression.");
297 xt::resize_container(m_shape, s.size());
298 std::copy(s.begin(), s.end(), m_shape.begin());
299 xt::broadcast_shape(m_e.shape(), m_shape);
309 template <
class CT,
class X>
312 : m_e(std::forward<CTA>(e))
313 , m_shape(std::move(s))
315 xt::broadcast_shape(m_e.shape(), m_shape);
327 template <
class CT,
class X>
336 template <
class CT,
class X>
342 template <
class CT,
class X>
343 inline bool xbroadcast<CT, X>::is_contiguous() const noexcept
360 template <
class CT,
class X>
361 template <
class... Args>
362 inline auto xbroadcast<CT, X>::operator()(Args... args)
const -> const_reference
386 template <
class CT,
class X>
387 template <
class... Args>
388 inline auto xbroadcast<CT, X>::unchecked(Args... args)
const -> const_reference
390 return this->operator()(args...);
400 template <
class CT,
class X>
402 inline auto xbroadcast<CT, X>::element(It, It last)
const -> const_reference
404 return m_e.element(last - this->
dimension(), last);
410 template <
class CT,
class X>
428 template <
class CT,
class X>
432 return xt::broadcast_shape(m_shape,
shape);
440 template <
class CT,
class X>
444 return this->
dimension() == m_e.dimension()
445 && std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin())
446 && m_e.has_linear_assign(
strides);
451 template <
class CT,
class X>
453 inline auto xbroadcast<CT, X>::stepper_begin(
const S& shape)
const noexcept -> const_stepper
456 return m_e.stepper_begin(shape);
459 template <
class CT,
class X>
461 inline auto xbroadcast<CT, X>::stepper_end(
const S& shape,
layout_type l)
const noexcept -> const_stepper
464 return m_e.stepper_end(shape, l);
467 template <
class CT,
class X>
468 template <
class E, xscalar_concept XCT>
471 auto& ed = e.derived_cast();
473 std::fill(ed.begin(), ed.end(), m_e());
476 template <
class CT,
class X>
478 inline auto xbroadcast<CT, X>::build_broadcast(E&& e)
const -> rebind_t<E>
480 return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape));
Broadcasted xexpression to a specified shape.
const xexpression_type & expression() const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
layout_type layout() const noexcept
bool broadcast_shape(S &shape, bool reuse_cache=false) const
xbroadcast(CTA &&e, shape_type &&s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
xbroadcast(CTA &&e, const S &s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
bool has_linear_assign(const S &strides) const noexcept
size_type size() const noexcept
size_type dimension() const noexcept
Returns the number of dimensions of the expression.
Base class for multidimensional iterable constant expressions.
Base class for xexpressions.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.