10#ifndef XTENSOR_XREPEAT
11#define XTENSOR_XREPEAT
16#include "xaccessible.hpp"
17#include "xexpression.hpp"
18#include "xiterable.hpp"
22 template <
class CT,
class R>
25 template <
class S,
class R>
26 class xrepeat_stepper;
34 template <
class Tag,
class CT,
class X>
37 template <
class CT,
class X>
43 template <
class CT,
class X>
48 template <
class CT,
class X>
56 template <
class CT,
class R>
59 using xexpression_type = std::decay_t<CT>;
60 using reference =
typename xexpression_type::const_reference;
61 using const_reference =
typename xexpression_type::const_reference;
62 using size_type =
typename xexpression_type::size_type;
63 using temporary_type =
typename xexpression_type::temporary_type;
65 static constexpr bool is_const = std::is_const<std::remove_reference_t<CT>>::value;
67 using extract_storage_type = xtl::mpl::eval_if_t<
69 detail::expr_storage_type<xexpression_type>,
71 using storage_type = std::conditional_t<is_const, const extract_storage_type, extract_storage_type>;
74 template <
class CT,
class R>
77 using xexpression_type = std::decay_t<CT>;
78 using repeats_type = std::decay_t<R>;
79 using inner_shape_type =
typename xexpression_type::inner_shape_type;
95 template <
class CT,
class R>
99 public extension::xrepeat_base_t<CT, R>
104 using xexpression_type = std::decay_t<CT>;
106 using extension_base = extension::xrepeat_base_t<CT, R>;
107 using expression_tag =
typename extension_base::expression_tag;
109 using value_type =
typename xexpression_type::value_type;
110 using shape_type =
typename xexpression_type::shape_type;
111 using repeats_type = xtl::const_closure_type_t<R>;
114 using reference =
typename container_type::reference;
115 using const_reference =
typename container_type::const_reference;
116 using size_type =
typename container_type::size_type;
117 using temporary_type =
typename container_type::temporary_type;
119 static constexpr layout_type static_layout = xexpression_type::static_layout;
120 static constexpr bool contiguous_layout =
false;
122 using bool_load_type =
typename xexpression_type::bool_load_type;
123 using pointer =
typename xexpression_type::pointer;
124 using const_pointer =
typename xexpression_type::const_pointer;
125 using difference_type =
typename xexpression_type::difference_type;
128 using stepper =
typename iterable_type::stepper;
129 using const_stepper =
typename iterable_type::const_stepper;
135 const shape_type&
shape()
const noexcept;
137 bool is_contiguous()
const noexcept;
140 template <
class...
Args>
141 const_reference operator()(
Args...
args)
const;
143 template <
class...
Args>
144 const_reference unchecked(
Args...
args)
const;
149 const xexpression_type&
expression()
const noexcept;
157 const_stepper stepper_begin()
const;
158 const_stepper stepper_begin(
const shape_type&
s)
const;
161 const_stepper stepper_end(
const shape_type&
s,
layout_type l)
const;
166 size_type m_repeating_axis;
167 repeats_type m_repeats;
170 const_reference access()
const;
172 template <
class Arg,
class...
Args>
173 const_reference access(
Arg arg,
Args...
args)
const;
175 template <std::size_t I,
class Arg,
class...
Args>
176 const_reference access_impl(stepper&&
s,
Arg arg,
Args...
args)
const;
178 template <std::
size_t I>
179 const_reference access_impl(stepper&&
s)
const;
186 template <
class S,
class R>
191 using repeats_type =
R;
192 using storage_type =
typename S::storage_type;
193 using subiterator_type =
typename S::subiterator_type;
194 using subiterator_traits =
typename S::subiterator_traits;
195 using value_type =
typename subiterator_traits::value_type;
196 using reference =
typename subiterator_traits::reference;
197 using pointer =
typename subiterator_traits::pointer;
198 using difference_type =
typename subiterator_traits::difference_type;
199 using size_type =
typename storage_type::size_type;
200 using shape_type =
typename storage_type::shape_type;
201 using simd_value_type = xt_simd::simd_type<value_type>;
203 template <
class requested_type>
204 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
208 reference operator*()
const;
210 void step(size_type
dim, size_type
n = 1);
211 void step_back(size_type
dim, size_type
n = 1);
212 void reset(size_type
dim);
213 void reset_back(size_type
dim);
224 void store_simd(
const V&
vec);
229 const shape_type& m_shape;
231 std::ptrdiff_t m_repeating_steps;
232 std::vector<size_type> m_positions;
233 size_type m_subposition;
235 size_type m_repeating_axis;
236 const repeats_type& m_repeats;
238 void make_step(size_type
dim, size_type
n);
239 void make_step_back(size_type
dim, size_type
n);
241 std::vector<size_type> get_next_positions(size_type
dim, size_type
steps_to_go)
const;
242 std::vector<size_type> get_next_positions_back(size_type
dim, size_type
steps_to_go)
const;
257 template <
class CT,
class R>
261 , m_repeating_axis(axis)
278 template <
class CT,
class R>
287 template <
class CT,
class R>
293 template <
class CT,
class R>
311 template <
class CT,
class R>
312 template <
class... Args>
315 return access(
args...);
337 template <
class CT,
class R>
338 template <
class...
Args>
341 return this->operator()(
args...);
351 template <
class CT,
class R>
355 auto s = stepper_begin(m_e.shape());
356 std::size_t dimension = 0;
360 s.step(dimension, *
iter);
370 template <
class CT,
class R>
388 template <
class CT,
class R>
392 return xt::broadcast_shape(m_shape, shape);
400 template <
class CT,
class R>
409 template <
class CT,
class R>
415 template <
class CT,
class R>
416 template <
class Arg,
class... Args>
417 inline auto xrepeat<CT, R>::access(Arg arg, Args... args)
const -> const_reference
419 constexpr size_t number_of_arguments = 1 +
sizeof...(Args);
420 if (number_of_arguments > this->dimension())
422 return access(args...);
424 return access_impl<0>(stepper_begin(m_e.shape()), arg, args...);
427 template <
class CT,
class R>
428 inline auto xrepeat<CT, R>::stepper_begin() const -> const_stepper
430 return stepper_begin(m_e.shape());
433 template <
class CT,
class R>
434 inline auto xrepeat<CT, R>::stepper_begin(
const shape_type& s)
const -> const_stepper
436 return const_stepper(m_e.stepper_begin(s), m_shape, m_repeats, m_repeating_axis);
439 template <
class CT,
class R>
440 inline auto xrepeat<CT, R>::stepper_end(
layout_type l)
const -> const_stepper
442 return stepper_end(m_e.shape(), l);
445 template <
class CT,
class R>
446 inline auto xrepeat<CT, R>::stepper_end(
const shape_type& s,
layout_type l)
const -> const_stepper
448 auto st = const_stepper(m_e.stepper_begin(s), m_shape, m_repeats, m_repeating_axis);
453 template <
class CT,
class R>
454 template <std::size_t I,
class Arg,
class... Args>
455 inline auto xrepeat<CT, R>::access_impl(stepper&& s, Arg arg, Args... args)
const -> const_reference
457 s.step(I,
static_cast<size_type
>(arg));
458 return access_impl<I + 1>(std::forward<stepper>(s), args...);
461 template <
class CT,
class R>
462 template <std::
size_t I>
463 inline auto xrepeat<CT, R>::access_impl(stepper&& s)
const -> const_reference
472 template <
class S,
class R>
473 xrepeat_stepper<S, R>::xrepeat_stepper(S&& s,
const shape_type& shape,
const repeats_type& repeats, size_type axis)
474 : m_substepper(std::forward<S>(s))
476 , m_repeating_steps(0)
477 , m_positions(shape.size())
479 , m_repeating_axis(axis)
484 template <
class S,
class R>
485 inline auto xrepeat_stepper<S, R>::operator*() const -> reference
487 return m_substepper.operator*();
490 template <
class S,
class R>
491 inline void xrepeat_stepper<S, R>::step(size_type dim, size_type steps_to_go)
493 if (m_positions[dim] + steps_to_go >= m_shape[dim])
495 const auto next_positions = get_next_positions(dim, steps_to_go);
496 if (next_positions[dim] > m_positions[dim])
498 make_step(dim, next_positions[dim] - m_positions[dim]);
502 make_step_back(dim, m_positions[dim] - next_positions[dim]);
504 for (size_type d = 0; d < dim; ++d)
506 make_step(d, next_positions[d] - m_positions[d]);
511 make_step(dim, steps_to_go);
515 template <
class S,
class R>
516 inline void xrepeat_stepper<S, R>::step_back(size_type dim, size_type steps_to_go)
518 if (m_positions[dim] < steps_to_go)
520 const auto next_positions = get_next_positions_back(dim, steps_to_go);
521 if (next_positions[dim] < m_positions[dim])
523 make_step_back(dim, m_positions[dim] - next_positions[dim]);
527 make_step(dim, next_positions[dim] - m_positions[dim]);
529 for (size_type d = 0; d < dim; ++d)
531 make_step_back(d, m_positions[d] - next_positions[d]);
536 make_step_back(dim, steps_to_go);
540 template <
class S,
class R>
541 inline void xrepeat_stepper<S, R>::reset(size_type dim)
543 m_substepper.reset(dim);
544 m_positions[dim] = 0;
545 if (dim == m_repeating_axis)
548 m_repeating_steps = 0;
552 template <
class S,
class R>
553 inline void xrepeat_stepper<S, R>::reset_back(size_type dim)
555 m_substepper.reset_back(dim);
556 m_positions[dim] = m_shape[dim] - 1;
557 if (dim == m_repeating_axis)
559 m_subposition = m_repeats.size() - 1;
560 m_repeating_steps =
static_cast<std::ptrdiff_t
>(m_repeats.back()) - 1;
564 template <
class S,
class R>
565 inline void xrepeat_stepper<S, R>::to_begin()
567 m_substepper.to_begin();
568 std::fill(m_positions.begin(), m_positions.end(), 0);
570 m_repeating_steps = 0;
573 template <
class S,
class R>
574 inline void xrepeat_stepper<S, R>::to_end(
layout_type l)
576 m_substepper.to_end(l);
588 ++m_positions.front();
592 ++m_positions.back();
594 m_subposition = m_repeats.size();
595 m_repeating_steps = 0;
598 template <
class S,
class R>
599 inline void xrepeat_stepper<S, R>::step_leading()
601 step(m_shape.size() - 1, 1);
604 template <
class S,
class R>
605 inline void xrepeat_stepper<S, R>::make_step(size_type dim, size_type steps_to_go)
609 if (dim == m_repeating_axis)
611 size_type subposition = m_subposition;
612 m_repeating_steps +=
static_cast<std::ptrdiff_t
>(steps_to_go);
613 while (m_repeating_steps >=
static_cast<ptrdiff_t
>(m_repeats[subposition]))
615 m_repeating_steps -=
static_cast<ptrdiff_t
>(m_repeats[subposition]);
618 m_substepper.step(dim, subposition - m_subposition);
619 m_subposition = subposition;
623 m_substepper.step(dim, steps_to_go);
625 m_positions[dim] += steps_to_go;
629 template <
class S,
class R>
630 inline void xrepeat_stepper<S, R>::make_step_back(size_type dim, size_type steps_to_go)
634 if (dim == m_repeating_axis)
636 size_type subposition = m_subposition;
637 m_repeating_steps -=
static_cast<std::ptrdiff_t
>(steps_to_go);
638 while (m_repeating_steps < 0)
641 m_repeating_steps +=
static_cast<ptrdiff_t
>(m_repeats[subposition]);
643 m_substepper.step_back(dim, m_subposition - subposition);
644 m_subposition = subposition;
648 m_substepper.step_back(dim, steps_to_go);
650 m_positions[dim] -= steps_to_go;
654 template <
class S,
class R>
655 inline auto xrepeat_stepper<S, R>::get_next_positions(size_type dim, size_type steps_to_go)
const
656 -> std::vector<size_type>
658 size_type next_position_for_dim = m_positions[dim] + steps_to_go;
661 size_type steps_in_previous_dim = 0;
662 while (next_position_for_dim >= m_shape[dim])
664 next_position_for_dim -= m_shape[dim];
665 ++steps_in_previous_dim;
667 if (steps_in_previous_dim > 0)
669 auto next_positions = get_next_positions(dim - 1, steps_in_previous_dim);
670 next_positions[dim] = next_position_for_dim;
671 return next_positions;
674 std::vector<size_type> next_positions = m_positions;
675 next_positions[dim] = next_position_for_dim;
676 return next_positions;
679 template <
class S,
class R>
680 inline auto xrepeat_stepper<S, R>::get_next_positions_back(size_type dim, size_type steps_to_go)
const
681 -> std::vector<size_type>
683 auto next_position_for_dim =
static_cast<std::ptrdiff_t
>(m_positions[dim] - steps_to_go);
686 size_type steps_in_previous_dim = 0;
687 while (next_position_for_dim < 0)
689 next_position_for_dim +=
static_cast<std::ptrdiff_t
>(m_shape[dim]);
690 ++steps_in_previous_dim;
692 if (steps_in_previous_dim > 0)
694 auto next_positions = get_next_positions_back(dim - 1, steps_in_previous_dim);
695 next_positions[dim] =
static_cast<size_type
>(next_position_for_dim);
696 return next_positions;
699 std::vector<size_type> next_positions = m_positions;
700 next_positions[dim] =
static_cast<size_type
>(next_position_for_dim);
701 return next_positions;
Base class for implementation of common expression constant access methods.
size_type size() const noexcept
Returns the size of the expression.
size_type shape(size_type index) const
Returns the i-th dimension of the expression.
Base class for multidimensional iterable constant expressions.
Base class for multidimensional iterable expressions.
Expression with repeated values along an axis.
bool has_linear_assign(const S &strides) const noexcept
Checks whether the xbroadcast can be linearly assigned to an expression with the specified strides.
bool broadcast_shape(S &shape, bool reuse_cache=false) const
Broadcast the shape of the function to the specified parameter.
layout_type layout() const noexcept
Returns the layout_type of the expression.
const xexpression_type & expression() const noexcept
Returns a constant reference to the underlying expression of the broadcast expression.
xrepeat(CTA &&e, R &&repeats, size_type axis)
Constructs an xrepeat expression repeating the element of the specified xexpression.
const shape_type & shape() const noexcept
Returns the shape of the expression.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions