10#ifndef XTENSOR_BROADCAST_HPP
11#define XTENSOR_BROADCAST_HPP
21#include <xtl/xsequence.hpp>
23#include "xaccessible.hpp"
24#include "xexpression.hpp"
25#include "xiterable.hpp"
27#include "xstrides.hpp"
28#include "xtensor_config.hpp"
38 template <
class E,
class S>
41 template <
class E,
class I, std::
size_t L>
50 template <
class Tag,
class CT,
class X>
53 template <
class CT,
class X>
59 template <
class CT,
class X>
64 template <
class CT,
class X>
72 template <
class CT,
class X>
75 template <
class CT,
class X>
78 using xexpression_type = std::decay_t<CT>;
80 using const_stepper =
typename xexpression_type::const_stepper;
81 using stepper = const_stepper;
84 template <
class CT,
class X>
87 using xexpression_type = std::decay_t<CT>;
88 using reference =
typename xexpression_type::const_reference;
89 using const_reference =
typename xexpression_type::const_reference;
90 using size_type =
typename xexpression_type::size_type;
97 template <
class CT,
class X>
100 return linear_begin(
c.expression());
103 template <
class CT,
class X>
104 XTENSOR_CONSTEXPR_RETURN
auto linear_end(xbroadcast<CT, X>& c)
noexcept
106 return linear_end(c.expression());
109 template <
class CT,
class X>
110 XTENSOR_CONSTEXPR_RETURN
auto linear_begin(
const xbroadcast<CT, X>& c)
noexcept
112 return linear_begin(c.expression());
115 template <
class CT,
class X>
116 XTENSOR_CONSTEXPR_RETURN
auto linear_end(
const xbroadcast<CT, X>& c)
noexcept
118 return linear_end(c.expression());
128 std::
enable_if_t<!has_memory_address<E>::value && is_specialization_of<xbroadcast, E>::value>>
132 if (
expr.size() == 0)
138 using ChildE = std::decay_t<
decltype(
expr.expression())>;
157 template <
class CT,
class X>
161 public extension::xbroadcast_base_t<CT, X>
166 using xexpression_type = std::decay_t<CT>;
168 using extension_base = extension::xbroadcast_base_t<CT, X>;
169 using expression_tag =
typename extension_base::expression_tag;
172 using value_type =
typename xexpression_type::value_type;
173 using reference =
typename inner_types::reference;
174 using const_reference =
typename inner_types::const_reference;
175 using pointer =
typename xexpression_type::const_pointer;
176 using const_pointer =
typename xexpression_type::const_pointer;
177 using size_type =
typename inner_types::size_type;
178 using difference_type =
typename xexpression_type::difference_type;
181 using inner_shape_type =
typename iterable_base::inner_shape_type;
182 using shape_type = inner_shape_type;
184 using stepper =
typename iterable_base::stepper;
185 using const_stepper =
typename iterable_base::const_stepper;
187 using bool_load_type =
typename xexpression_type::bool_load_type;
190 static constexpr bool contiguous_layout =
false;
192 template <
class CTA,
class S>
199 const inner_shape_type&
shape()
const noexcept;
201 bool is_contiguous()
const noexcept;
204 template <
class...
Args>
205 const_reference operator()(
Args...
args)
const;
207 template <
class...
Args>
208 const_reference unchecked(
Args...
args)
const;
213 const xexpression_type&
expression()
const noexcept;
222 const_stepper stepper_begin(
const S&
shape)
const noexcept;
238 inner_shape_type m_shape;
255 template <
class E,
class S>
260 return broadcast_type(std::forward<E>(
e), xtl::forward_sequence<shape_type,
decltype(
s)>(
s));
263 template <
class E,
class I, std::
size_t L>
264 inline auto broadcast(E&& e,
const I (&s)[L])
266 using broadcast_type = xbroadcast<const_xclosure_t<E>, std::array<std::size_t, L>>;
267 using shape_type =
typename broadcast_type::shape_type;
268 return broadcast_type(std::forward<E>(e), xtl::forward_sequence<shape_type,
decltype(s)>(s));
286 template <
class CT,
class X>
287 template <
class CTA,
class S>
291 if (
s.size() < m_e.dimension())
293 XTENSOR_THROW(
xt::broadcast_error,
"Broadcast shape has fewer elements than original expression.");
295 xt::resize_container(m_shape,
s.size());
296 std::copy(
s.begin(),
s.end(), m_shape.begin());
297 xt::broadcast_shape(m_e.shape(), m_shape);
307 template <
class CT,
class X>
311 , m_shape(std::
move(
s))
313 xt::broadcast_shape(m_e.shape(), m_shape);
325 template <
class CT,
class X>
334 template <
class CT,
class X>
340 template <
class CT,
class X>
358 template <
class CT,
class X>
359 template <
class... Args>
384 template <
class CT,
class X>
385 template <
class...
Args>
388 return this->operator()(
args...);
398 template <
class CT,
class X>
402 return m_e.element(
last - this->dimension(),
last);
408 template <
class CT,
class X>
426 template <
class CT,
class X>
430 return xt::broadcast_shape(m_shape, shape);
438 template <
class CT,
class X>
442 return this->dimension() == m_e.dimension()
443 && std::equal(m_shape.cbegin(), m_shape.cend(), m_e.shape().cbegin())
444 && m_e.has_linear_assign(
strides);
449 template <
class CT,
class X>
454 return m_e.stepper_begin(shape);
457 template <
class CT,
class X>
459 inline auto xbroadcast<CT, X>::stepper_end(
const S& shape,
layout_type l)
const noexcept -> const_stepper
462 return m_e.stepper_end(shape, l);
465 template <
class CT,
class X>
466 template <
class E,
class XCT,
class>
467 inline void xbroadcast<CT, X>::assign_to(xexpression<E>& e)
const
469 auto& ed = e.derived_cast();
471 std::fill(ed.begin(), ed.end(), m_e());
474 template <
class CT,
class X>
476 inline auto xbroadcast<CT, X>::build_broadcast(E&& e)
const -> rebind_t<E>
478 return rebind_t<E>(std::forward<E>(e), inner_shape_type(m_shape));
Broadcasted xexpression to a specified shape.
const xexpression_type & expression() const noexcept
Returns a constant reference to the underlying expression of the broadcast expression.
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
layout_type layout() const noexcept
Returns the layout_type of the expression.
bool broadcast_shape(S &shape, bool reuse_cache=false) const
Broadcast the shape of the function to the specified parameter.
xbroadcast(CTA &&e, const S &s)
Constructs an xbroadcast expression broadcasting the specified xexpression to the given shape.
bool has_linear_assign(const S &strides) const noexcept
Checks whether the xbroadcast can be linearly assigned to an expression with the specified strides.
Base class for implementation of common expression constant access methods.
size_type size() const noexcept
Returns the size of the expression.
size_type shape(size_type index) const
Returns the i-th dimension of the expression.
Base class for multidimensional iterable constant expressions.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.