10#ifndef XTENSOR_BROADCAST_HPP
11#define XTENSOR_BROADCAST_HPP
21#include <xtl/xsequence.hpp>
23#include "xaccessible.hpp"
24#include "xexpression.hpp"
25#include "xiterable.hpp"
27#include "xstrides.hpp"
28#include "xtensor_config.hpp"
38 template <
class E,
class S>
41 template <
class E,
class I, std::
size_t L>
50 template <
class Tag,
class CT,
class X>
53 template <
class CT,
class X>
59 template <
class CT,
class X>
64 template <
class CT,
class X>
72 template <
class CT,
class X>
75 template <
class CT,
class X>
78 using xexpression_type = std::decay_t<CT>;
79 using inner_shape_type = promote_shape_t<typename xexpression_type::shape_type, X>;
80 using const_stepper =
typename xexpression_type::const_stepper;
81 using stepper = const_stepper;
84 template <
class CT,
class X>
87 using xexpression_type = std::decay_t<CT>;
88 using reference =
typename xexpression_type::const_reference;
89 using const_reference =
typename xexpression_type::const_reference;
90 using size_type =
typename xexpression_type::size_type;
97 template <
class CT,
class X>
100 return linear_begin(c.expression());
103 template <
class CT,
class X>
104 XTENSOR_CONSTEXPR_RETURN
auto linear_end(xbroadcast<CT, X>& c)
noexcept
106 return linear_end(c.expression());
109 template <
class CT,
class X>
112 return linear_begin(c.expression());
115 template <
class CT,
class X>
118 return linear_end(c.expression());
128 std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
130 static bool check_overlap(
const E& expr,
const memory_range& dst_range)
132 if (expr.size() == 0)
138 using ChildE = std::decay_t<
decltype(expr.expression())>;
139 return overlapping_memory_checker_traits<ChildE>::check_overlap(expr.expression(), dst_range);
157 template <
class CT,
class X>
158 class xbroadcast :
public xsharable_expression<xbroadcast<CT, X>>,
160 public xconst_accessible<xbroadcast<CT, X>>,
161 public extension::xbroadcast_base_t<CT, X>
166 using xexpression_type = std::decay_t<CT>;
167 using accessible_base = xconst_accessible<self_type>;
168 using extension_base = extension::xbroadcast_base_t<CT, X>;
169 using expression_tag =
typename extension_base::expression_tag;
172 using value_type =
typename xexpression_type::value_type;
173 using reference =
typename inner_types::reference;
174 using const_reference =
typename inner_types::const_reference;
175 using pointer =
typename xexpression_type::const_pointer;
176 using const_pointer =
typename xexpression_type::const_pointer;
177 using size_type =
typename inner_types::size_type;
178 using difference_type =
typename xexpression_type::difference_type;
181 using inner_shape_type =
typename iterable_base::inner_shape_type;
182 using shape_type = inner_shape_type;
184 using stepper =
typename iterable_base::stepper;
185 using const_stepper =
typename iterable_base::const_stepper;
187 using bool_load_type =
typename xexpression_type::bool_load_type;
190 static constexpr bool contiguous_layout =
false;
192 template <
class CTA,
class S>
199 const inner_shape_type&
shape() const noexcept;
201 bool is_contiguous() const noexcept;
202 using accessible_base::
shape;
204 template <class... Args>
205 const_reference operator()(Args... args) const;
207 template <class... Args>
208 const_reference unchecked(Args... args) const;
211 const_reference element(It first, It last) const;
222 const_stepper stepper_begin(const S&
shape) const noexcept;
226 template <class E, class XCT = CT, class = std::enable_if_t<
xt::is_xscalar<XCT>::value>>
233 rebind_t<E> build_broadcast(E&& e) const;
238 inner_shape_type m_shape;
255 template <class E, class S>
258 using shape_type = filter_fixed_shape_t<std::decay_t<S>>;
260 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type,
decltype(s)>(s));
263 template <
class E,
class I, std::
size_t L>
264 inline auto broadcast(E&& e,
const I (&s)[L])
267 using shape_type =
typename broadcast_type::shape_type;
268 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type,
decltype(s)>(s));
286 template <
class CT,
class X>
287 template <
class CTA,
class S>
289 : m_e(std::forward<CTA>(e))
291 if (s.size() < m_e.dimension())
293 XTENSOR_THROW(xt::broadcast_error,
"Broadcast shape has fewer elements than original expression.");
295 xt::resize_container(m_shape, s.size());
296 std::copy(s.begin(), s.end(), m_shape.begin());
297 xt::broadcast_shape(m_e.shape(), m_shape);
307 template <
class CT,
class X>
310 : m_e(std::forward<CTA>(e))
311 , m_shape(std::move(s))
313 xt::broadcast_shape(m_e.shape(), m_shape);
325 template <
class CT,
class X>
334 template <
class CT,
class X>
340 template <
class CT,
class X>
341 inline bool xbroadcast<CT, X>::is_contiguous() const noexcept
358 template <
class CT,
class X>
359 template <
class... Args>
360 inline auto xbroadcast<CT, X>::operator()(Args... args)
const -> const_reference
384 template <
class CT,
class X>
385 template <
class... Args>
386 inline auto xbroadcast<CT, X>::unchecked(Args... args)
const -> const_reference
388 return this->operator()(args...);
398 template <
class CT,
class X>
400 inline auto xbroadcast<CT, X>::element(It, It last)
const -> const_reference
402 return m_e.element(last - this->
dimension(), last);
408 template <
class CT,
class X>
426 template <
class CT,
class X>
430 return xt::broadcast_shape(m_shape,
shape);
438 template <
class CT,
class X>
442 return this->
dimension() == m_e.dimension()
443 && std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin())
444 && m_e.has_linear_assign(
strides);
449 template <
class CT,
class X>
451 inline auto xbroadcast<CT, X>::stepper_begin(
const S& shape)
const noexcept -> const_stepper
454 return m_e.stepper_begin(shape);
457 template <
class CT,
class X>
459 inline auto xbroadcast<CT, X>::stepper_end(
const S& shape,
layout_type l)
const noexcept -> const_stepper
462 return m_e.stepper_end(shape, l);
465 template <
class CT,
class X>
466 template <
class E,
class XCT,
class>
469 auto& ed = e.derived_cast();
471 std::fill(ed.begin(), ed.end(), m_e());
474 template <
class CT,
class X>
476 inline auto xbroadcast<CT, X>::build_broadcast(E&& e)
const -> rebind_t<E>
478 return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape));
Broadcasted xexpression to a specified shape.
const xexpression_type & expression() const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
layout_type layout() const noexcept
bool broadcast_shape(S &shape, bool reuse_cache=false) const
xbroadcast(CTA &&e, shape_type &&s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
xbroadcast(CTA &&e, const S &s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
bool has_linear_assign(const S &strides) const noexcept
size_type size() const noexcept
size_type dimension() const noexcept
Returns the number of dimensions of the expression.
Base class for multidimensional iterable constant expressions.
Base class for xexpressions.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.