10#ifndef XTENSOR_REDUCER_HPP
11#define XTENSOR_REDUCER_HPP
21#include <xtl/xfunctional.hpp>
22#include <xtl/xsequence.hpp>
24#include "../core/xaccessible.hpp"
25#include "../core/xeval.hpp"
26#include "../core/xexpression.hpp"
27#include "../core/xiterable.hpp"
28#include "../core/xtensor_config.hpp"
29#include "../generators/xbuilder.hpp"
30#include "../generators/xgenerator.hpp"
31#include "../utils/xutils.hpp"
36 auto operator|(
const A<AX...>& args,
const A<X>& rhs)
38 return std::tuple_cat(args, rhs);
45 constexpr auto keep_dims = std::tuple<keep_dims_type>{};
47 template <
class T =
double>
48 struct xinitial : xt::detail::option_base
50 constexpr xinitial(T val)
55 constexpr T value()
const
64 constexpr auto initial(T val)
69 template <std::ptrdiff_t I,
class T,
class Tuple>
72 template <std::ptrdiff_t I,
class T>
75 static constexpr std::ptrdiff_t value = -1;
78 template <std::ptrdiff_t I,
class T,
class... Types>
81 static constexpr std::ptrdiff_t value = I;
84 template <std::ptrdiff_t I,
class T,
class U,
class... Types>
87 static constexpr std::ptrdiff_t value =
tuple_idx_of_impl<I + 1, T, std::tuple<Types...>>::value;
90 template <
class S,
class... X>
93 template <
template <
class...>
class S,
class... X>
96 using type = S<std::decay_t<X>...>;
99 template <
class T,
class Tuple>
102 static constexpr std::ptrdiff_t
106 template <
class R,
class T>
107 struct reducer_options
125 using d_t = std::decay_t<T>;
127 static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
128 reducer_options() =
default;
130 reducer_options(
const T& tpl)
132 if constexpr (initial_val_idx != std::tuple_size<T>::value)
134 initial_value = std::get < initial_val_idx != std::tuple_size<T>::value ? initial_val_idx
139 using evaluation_strategy = std::conditional_t<
140 tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
144 using keep_dims = std::
145 conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1, std::true_type, std::false_type>;
147 static constexpr bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
152 using rebind_t = reducer_options<NR, T>;
155 auto rebind(NR initial,
const reducer_options<R, T>&)
const
157 reducer_options<NR, T> res;
158 res.initial_value = initial;
168 template <
class... X>
182#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
184 template <
class ST,
class X,
class KD = std::false_type>
187 template <
class S1,
class S2>
192 template <
class O,
class RS,
class R,
class E,
class AX>
193 inline void shape_computation(
198 std::enable_if_t<!detail::is_fixed<RS>::value,
int> = 0
201 if (
typename O::keep_dims())
203 resize_container(result_shape, expr.dimension());
204 for (std::size_t i = 0; i < expr.dimension(); ++i)
206 if (std::find(axes.begin(), axes.end(), i) == axes.end())
209 result_shape[i] = expr.shape()[i];
219 resize_container(result_shape, expr.dimension() - axes.size());
220 for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
222 if (std::find(axes.begin(), axes.end(), i) == axes.end())
225 result_shape[idx] = expr.shape()[i];
230 result.resize(result_shape, expr.layout());
234 template <
class O,
class RS,
class R,
class S,
class AX>
236 shape_computation(RS&, R&,
const S&,
const AX&, std::enable_if_t<detail::is_fixed<RS>::value,
int> = 0)
241 template <
class F,
class E,
class R, XTL_REQUIRES(std::is_convertible<
typename E::value_type,
typename R::value_type>)>
242 inline void copy_to_reduced(F&,
const E& e, R& result)
247 e.template cbegin<layout_type::row_major>(),
248 e.template cend<layout_type::row_major>(),
255 e.template cbegin<layout_type::column_major>(),
256 e.template cend<layout_type::column_major>(),
266 XTL_REQUIRES(std::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
267 inline void copy_to_reduced(F& f,
const E& e, R& result)
272 e.template cbegin<layout_type::row_major>(),
273 e.template cend<layout_type::row_major>(),
281 e.template cbegin<layout_type::column_major>(),
282 e.template cend<layout_type::column_major>(),
289 template <
class F,
class E,
class X,
class O>
290 inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
292 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
293 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
294 using expr_value_type =
typename std::decay_t<E>::value_type;
295 using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
296 std::declval<init_functor_type>()(),
297 std::declval<expr_value_type>()
301 options_t options(raw_options);
304 typename std::decay_t<E>::shape_type,
306 typename options_t::keep_dims>::type;
307 using result_container_type =
typename detail::xtype_for_shape<
308 shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
309 result_container_type result;
312 auto reduce_fct = xt::get<0>(f);
313 auto init_fct = xt::get<1>(f);
314 auto merge_fct = xt::get<2>(f);
316 if (axes.size() == 0)
318 result.resize(e.shape(), e.layout());
319 auto cpf = [&reduce_fct, &init_fct](
const auto& v)
321 return reduce_fct(
static_cast<result_type
>(init_fct()), v);
323 copy_to_reduced(cpf, e, result);
327 shape_type result_shape{};
328 dynamic_shape<std::size_t>
329 iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>,
decltype(e.shape())>(e.shape());
330 dynamic_shape<std::size_t> iter_strides(e.dimension());
336 if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less<>()))
338 XTENSOR_THROW(std::runtime_error,
"Reducing axes should be sorted.");
340 if (std::adjacent_find(axes.cbegin(), axes.cend()) != axes.cend())
342 XTENSOR_THROW(std::runtime_error,
"Reducing axes should not contain duplicates.");
344 if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
348 "Axis " + std::to_string(axes[axes.size() - 1]) +
" out of bounds for reduction."
352 detail::shape_computation<options_t>(result_shape, result, e, axes);
355 if (e.dimension() == axes.size())
357 result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
358 result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
363 auto strides_finder = e.strides().begin() +
static_cast<std::ptrdiff_t
>(leading_ax);
365 std::size_t inner_stride =
static_cast<std::size_t
>(*strides_finder);
367 while (inner_stride == 0 && strides_finder != iter_bound)
370 inner_stride =
static_cast<std::size_t
>(*strides_finder);
373 if (inner_stride == 0)
375 auto cpf = [&reduce_fct, &init_fct](
const auto& v)
377 return reduce_fct(
static_cast<result_type
>(init_fct()), v);
379 copy_to_reduced(cpf, e, result);
383 std::size_t inner_loop_size =
static_cast<std::size_t
>(inner_stride);
384 std::size_t outer_loop_size = e.shape()[leading_ax];
388 auto merge_loops = [&outer_loop_size, &e](
auto it,
auto end)
392 for (; it != end; ++it)
395 if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
398 outer_loop_size *= e.shape()[last_ax];
404 for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
406 if (std::find(axes.begin(), axes.end(), i) == axes.end())
409 iter_strides[i] =
static_cast<std::size_t
>(result.strides(
410 )[
typename options_t::keep_dims() ? i : idx]);
417 std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
419 iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
420 iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
425 std::size_t last_ax = merge_loops(axes.begin(), axes.end());
428 iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
429 iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
432 std::reverse(iter_shape.begin(), iter_shape.end());
433 std::reverse(iter_strides.begin(), iter_strides.end());
437 XTENSOR_THROW(std::runtime_error,
"Layout not supported in immediate reduction.");
440 xindex temp_idx(iter_shape.size());
441 auto next_idx = [&iter_shape, &iter_strides, &temp_idx]()
443 std::size_t i = iter_shape.size();
446 if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
457 return std::make_pair(
459 std::inner_product(temp_idx.begin(), temp_idx.end(), iter_strides.begin(), std::ptrdiff_t(0))
463 auto begin = e.data();
464 auto out = result.data();
465 auto out_begin = result.data();
467 std::ptrdiff_t next_stride = 0;
469 std::pair<bool, std::ptrdiff_t> idx_res(
false, 0);
474 auto merge_border = out;
482 if (inner_stride == 1)
484 while (idx_res.first !=
true)
488 result_type tmp = init_fct();
489 tmp = std::accumulate(begin, begin + outer_loop_size, tmp, reduce_fct);
492 *out = merge ? merge_fct(*out, tmp) : tmp;
494 begin += outer_loop_size;
496 idx_res = next_idx();
497 next_stride = idx_res.second;
498 out = out_begin + next_stride;
500 if (out > merge_border)
514 while (idx_res.first !=
true)
518 out + inner_loop_size,
521 [merge, &init_fct, &reduce_fct](
auto&& v1,
auto&& v2)
523 return merge ? reduce_fct(v1, v2) :
525 reduce_fct(static_cast<result_type>(init_fct()), v2);
529 begin += inner_stride;
530 for (std::size_t i = 1; i < outer_loop_size; ++i)
532 std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
533 begin += inner_stride;
536 idx_res = next_idx();
537 next_stride = idx_res.second;
538 out = out_begin + next_stride;
540 if (out > merge_border)
552 if (options_t::has_initial_value)
556 result.data() + result.size(),
558 [&merge_fct, &options](
auto&& v)
560 return merge_fct(v, options.initial_value);
574 using value_type = T;
576 constexpr const_value() =
default;
578 constexpr const_value(T t)
583 constexpr T operator()()
const
589 using rebind_t = const_value<NT>;
592 const_value<NT> rebind()
const;
599 template <
class T,
bool B>
600 struct evaluated_value_type
606 struct evaluated_value_type<T, true>
608 using type =
typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
611 template <
class T,
bool B>
612 using evaluated_value_type_t =
typename evaluated_value_type<T, B>::type;
615 template <
class REDUCE_FUNC,
class INIT_FUNC = const_value<
long int>,
class MERGE_FUNC = REDUCE_FUNC>
616 struct xreducer_functors :
public std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>
618 using self_type = xreducer_functors<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
619 using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
620 using reduce_functor_type = REDUCE_FUNC;
621 using init_functor_type = INIT_FUNC;
622 using merge_functor_type = MERGE_FUNC;
623 using init_value_type =
typename init_functor_type::value_type;
631 xreducer_functors(RF&& reduce_func)
632 : base_type(std::forward<RF>(reduce_func), INIT_FUNC(), reduce_func)
636 template <
class RF,
class IF>
637 xreducer_functors(RF&& reduce_func, IF&& init_func)
638 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), reduce_func)
642 template <
class RF,
class IF,
class MF>
643 xreducer_functors(RF&& reduce_func, IF&& init_func, MF&& merge_func)
644 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func))
648 reduce_functor_type get_reduce()
const
650 return std::get<0>(upcast());
653 init_functor_type get_init()
const
655 return std::get<1>(upcast());
658 merge_functor_type get_merge()
const
660 return std::get<2>(upcast());
664 using rebind_t = xreducer_functors<REDUCE_FUNC, const_value<NT>, MERGE_FUNC>;
667 rebind_t<NT> rebind()
669 return make_xreducer_functor(get_reduce(), get_init().
template rebind<NT>(), get_merge());
675 const base_type& upcast()
const
677 return static_cast<const base_type&
>(*this);
682 auto make_xreducer_functor(RF&& reduce_func)
685 return reducer_type(std::forward<RF>(reduce_func));
688 template <
class RF,
class IF>
689 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
691 using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
692 return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
695 template <
class RF,
class IF,
class MF>
696 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
699 std::remove_reference_t<RF>,
700 std::remove_reference_t<IF>,
701 std::remove_reference_t<MF>>;
703 std::forward<RF>(reduce_func),
704 std::forward<IF>(init_func),
705 std::forward<MF>(merge_func)
715 template <
class Tag,
class F,
class CT,
class X,
class O>
718 template <
class F,
class CT,
class X,
class O>
724 template <
class F,
class CT,
class X,
class O>
729 template <
class F,
class CT,
class X,
class O>
737 template <
class F,
class CT,
class X,
class O>
740 template <
class F,
class CT,
class X,
class O>
743 template <
class F,
class CT,
class X,
class O>
746 using xexpression_type = std::decay_t<CT>;
748 typename xexpression_type::shape_type,
750 typename O::keep_dims>::type;
752 using stepper = const_stepper;
755 template <
class F,
class CT,
class X,
class O>
758 using xexpression_type = std::decay_t<CT>;
759 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
760 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
761 using merge_functor_type =
typename std::decay_t<F>::merge_functor_type;
762 using substepper_type =
typename xexpression_type::const_stepper;
763 using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
764 std::declval<init_functor_type>()(),
765 *std::declval<substepper_type>()
767 using value_type =
typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
769 using reference = value_type;
770 using const_reference = value_type;
771 using size_type =
typename xexpression_type::size_type;
780 template <std::size_t... I>
783 using type = std::array<std::size_t,
sizeof...(I)>;
803 template <
class F,
class CT,
class X,
class O>
804 class xreducer :
public xsharable_expression<xreducer<F, CT, X, O>>,
806 public xaccessible<xreducer<F, CT, X, O>>,
807 public extension::xreducer_base_t<F, CT, X, O>
814 using reduce_functor_type =
typename inner_types::reduce_functor_type;
815 using init_functor_type =
typename inner_types::init_functor_type;
816 using merge_functor_type =
typename inner_types::merge_functor_type;
819 using xexpression_type =
typename inner_types::xexpression_type;
822 using extension_base = extension::xreducer_base_t<F, CT, X, O>;
823 using expression_tag =
typename extension_base::expression_tag;
825 using substepper_type =
typename inner_types::substepper_type;
826 using value_type =
typename inner_types::value_type;
827 using reference =
typename inner_types::reference;
828 using const_reference =
typename inner_types::const_reference;
829 using pointer = value_type*;
830 using const_pointer =
const value_type*;
832 using size_type =
typename inner_types::size_type;
833 using difference_type =
typename xexpression_type::difference_type;
836 using inner_shape_type =
typename iterable_base::inner_shape_type;
837 using shape_type = inner_shape_type;
839 using dim_mapping_type =
typename select_dim_mapping_type<inner_shape_type>::type;
841 using stepper =
typename iterable_base::stepper;
842 using const_stepper =
typename iterable_base::const_stepper;
843 using bool_load_type =
typename xexpression_type::bool_load_type;
846 static constexpr bool contiguous_layout =
false;
848 template <
class Func,
class CTA,
class AX,
class OX>
849 xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options);
851 const inner_shape_type&
shape() const noexcept;
853 bool is_contiguous() const noexcept;
855 template <class... Args>
856 const_reference operator()(Args... args) const;
857 template <class... Args>
858 const_reference unchecked(Args... args) const;
861 const_reference element(It first, It last) const;
872 const_stepper stepper_begin(const S&
shape) const noexcept;
876 template <class E, class Func = F, class Opts = O>
877 using rebind_t =
xreducer<Func, E, X, Opts>;
880 rebind_t<E> build_reducer(E&& e) const;
882 template <class E, class Func, class Opts>
883 rebind_t<E, Func, Opts> build_reducer(E&& e, Func&& func, Opts&& opts) const;
885 xreducer_functors_type functors()
const
887 return xreducer_functors_type(m_reduce, m_init, m_merge);
892 const O& options()
const
900 reduce_functor_type m_reduce;
901 init_functor_type m_init;
902 merge_functor_type m_merge;
904 inner_shape_type m_shape;
905 dim_mapping_type m_dim_mapping;
908 friend class xreducer_stepper<F, CT, X, O>;
917 template <
class F,
class E,
class X,
class O>
918 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::lazy_type, O&& options)
920 decltype(
auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
922 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
923 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
924 using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
925 std::declval<init_functor_type>()(),
926 *std::declval<
typename std::decay_t<E>::const_stepper>()
928 using evaluated_value_type = evaluated_value_type_t<value_type, is_xexpression<value_type>::value>;
930 using reducer_type = xreducer<
933 xtl::const_closure_type_t<
decltype(normalized_axes)>,
934 reducer_options<evaluated_value_type, std::decay_t<O>>>;
938 std::forward<
decltype(normalized_axes)>(normalized_axes),
939 std::forward<O>(options)
943 template <
class F,
class E,
class X,
class O>
944 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
946 decltype(
auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
947 return reduce_immediate(
949 eval(std::forward<E>(e)),
950 std::forward<
decltype(normalized_axes)>(normalized_axes),
951 std::forward<O>(options)
956#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
961 struct is_xreducer_functors_impl : std::false_type
965 template <
class RF,
class IF,
class MF>
966 struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>> : std::true_type
971 using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
991 class EVS = DEFAULT_STRATEGY_REDUCERS,
993 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
995 return detail::reduce_impl(
998 std::forward<X>(axes),
999 typename reducer_options<int, EVS>::evaluation_strategy{},
1000 std::forward<EVS>(options)
1008 class EVS = DEFAULT_STRATEGY_REDUCERS,
1009 XTL_REQUIRES(std::negation<is_reducer_options<X>>, std::negation<detail::is_xreducer_functors<F>>)>
1010 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1013 make_xreducer_functor(std::forward<F>(f)),
1015 std::forward<X>(axes),
1016 std::forward<EVS>(options)
1023 class EVS = DEFAULT_STRATEGY_REDUCERS,
1025 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1027 xindex_type_t<typename std::decay_t<E>::shape_type> ar;
1028 resize_container(ar, e.dimension());
1029 std::iota(ar.begin(), ar.end(), 0);
1030 return detail::reduce_impl(
1034 typename reducer_options<
int, std::decay_t<EVS>>::evaluation_strategy{},
1035 std::forward<EVS>(options)
1042 class EVS = DEFAULT_STRATEGY_REDUCERS,
1044 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1046 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
1054 class EVS = DEFAULT_STRATEGY_REDUCERS,
1055 XTL_REQUIRES(detail::is_xreducer_functors<F>)>
1056 inline auto reduce(F&& f, E&& e,
const I (&axes)[N], EVS options = EVS())
1058 using axes_type = std::array<std::size_t, N>;
1059 auto ax = xt::forward_normalize<axes_type>(e, axes);
1060 return detail::reduce_impl(
1064 typename reducer_options<int, EVS>::evaluation_strategy{},
1074 class EVS = DEFAULT_STRATEGY_REDUCERS,
1075 XTL_REQUIRES(std::negation<detail::is_xreducer_functors<F>>)>
1076 inline auto reduce(F&& f, E&& e,
const I (&axes)[N], EVS options = EVS())
1078 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
1085 template <
class F,
class CT,
class X,
class O>
1086 class xreducer_stepper
1090 using self_type = xreducer_stepper<F, CT, X, O>;
1093 using value_type =
typename xreducer_type::value_type;
1094 using reference =
typename xreducer_type::value_type;
1095 using pointer =
typename xreducer_type::const_pointer;
1096 using size_type =
typename xreducer_type::size_type;
1097 using difference_type =
typename xreducer_type::difference_type;
1099 using xexpression_type =
typename xreducer_type::xexpression_type;
1100 using substepper_type =
typename xexpression_type::const_stepper;
1101 using shape_type =
typename xreducer_type::shape_type;
1104 const xreducer_type& red,
1107 layout_type l = default_assignable_layout(xexpression_type::static_layout)
1110 reference operator*()
const;
1112 void step(size_type dim);
1113 void step_back(size_type dim);
1114 void step(size_type dim, size_type n);
1115 void step_back(size_type dim, size_type n);
1116 void reset(size_type dim);
1117 void reset_back(size_type dim);
1124 reference initial_value()
const;
1125 reference aggregate(size_type dim)
const;
1126 reference aggregate_impl(size_type dim, std::false_type)
const;
1127 reference aggregate_impl(size_type dim, std::true_type)
const;
1129 substepper_type get_substepper_begin()
const;
1130 size_type get_dim(size_type dim)
const noexcept;
1131 size_type shape(size_type i)
const noexcept;
1132 size_type axis(size_type i)
const noexcept;
1134 const xreducer_type* m_reducer;
1136 mutable substepper_type m_stepper;
1145 template <std::size_t X, std::size_t... I>
1148 static constexpr bool value = std::disjunction<std::integral_constant<bool, X == I>...>::value;
1151 template <std::
size_t Z,
class S1,
class S2,
class R>
1152 struct fixed_xreducer_shape_type_impl;
1154 template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
1155 struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1157 using type = std::conditional_t<
1159 typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>::type,
1160 typename fixed_xreducer_shape_type_impl<
1164 fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
1167 template <std::size_t... I, std::size_t... J, std::size_t... R>
1168 struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1171 conditional_t<in<0, J...>::value, fixed_shape<R...>, fixed_shape<detail::at<0, I...>::value, R...>>;
1179 struct xreducer_size_type
1181 using type = std::size_t;
1185 using xreducer_size_type_t =
typename xreducer_size_type<T>::type;
1188 struct xreducer_temporary_type
1194 using xreducer_temporary_type_t =
typename xreducer_temporary_type<T>::type;
1200 template <
class T,
class U>
1201 struct const_value_rebinder
1203 static const_value<U> run(
const const_value<T>& t)
1205 return const_value<U>(t.m_value);
1218 return detail::const_value_rebinder<T, NT>::run(*
this);
1225 template <
class S1,
class S2>
1228 template <std::size_t... I, std::size_t... J>
1231 using type =
typename detail::
1236 template <
class ST,
class X,
class O>
1239 using type = promote_shape_t<ST, std::decay_t<X>>;
1242 template <
class I1, std::
size_t N1,
class I2, std::
size_t N2>
1245 using type = std::array<I2, N1>;
1248 template <
class I1, std::
size_t N1,
class I2, std::
size_t N2>
1251 using type = std::array<I2, N1 - N2>;
1254 template <std::size_t... I,
class I2, std::size_t N2>
1257 using type = std::conditional_t<
sizeof...(I) == N2,
fixed_shape<>, std::array<I2,
sizeof...(I) - N2>>;
1262 template <
class S1,
class S2>
1265 template <
class T, T... I1, T... I2>
1266 struct ixconcat<std::integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
1268 using type = std::integer_sequence<T, I1..., I2...>;
1271 template <
class T, T X, std::
size_t N>
1272 struct repeat_integer_sequence
1274 using type =
typename ixconcat<
1275 std::integer_sequence<T, X>,
1276 typename repeat_integer_sequence<T, X, N - 1>::type>::type;
1279 template <
class T, T X>
1280 struct repeat_integer_sequence<T, X, 0>
1282 using type = std::integer_sequence<T>;
1285 template <
class T, T X>
1286 struct repeat_integer_sequence<T, X, 2>
1288 using type = std::integer_sequence<T, X, X>;
1291 template <
class T, T X>
1292 struct repeat_integer_sequence<T, X, 1>
1294 using type = std::integer_sequence<T, X>;
1298 template <std::size_t... I,
class I2, std::size_t N2>
1301 template <std::size_t... X>
1302 static constexpr auto get_type(std::index_sequence<X...>)
1308 using type = std::conditional_t<
1310 decltype(get_type(
typename detail::repeat_integer_sequence<std::size_t, std::size_t(1), N2>::type{})),
1311 std::array<I2,
sizeof...(I)>>;
1315 template <std::size_t... I, std::size_t... J,
class O>
1323 template <
class S,
class E,
class X,
class M>
1324 inline void shape_and_mapping_computation(S& shape, E& e,
const X& axes, M& mapping, std::false_type)
1326 auto first = e.shape().begin();
1327 auto last = e.shape().end();
1328 auto exclude_it = axes.begin();
1330 using value_type =
typename S::value_type;
1331 using difference_type =
typename S::difference_type;
1332 auto d_first = shape.begin();
1333 auto map_first = mapping.begin();
1336 while (iter != last && exclude_it != axes.end())
1338 auto diff = std::distance(first, iter);
1339 if (
diff != difference_type(*exclude_it))
1341 *d_first++ = *iter++;
1342 *map_first++ = value_type(
diff);
1351 auto diff = std::distance(first, iter);
1352 auto end = std::distance(iter, last);
1353 std::iota(map_first, map_first + end,
diff);
1354 std::copy(iter, last, d_first);
1357 template <
class S,
class E,
class X,
class M>
1359 shape_and_mapping_computation_keep_dim(S& shape, E& e,
const X& axes, M& mapping, std::false_type)
1361 for (std::size_t i = 0; i < e.dimension(); ++i)
1363 if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
1366 shape[i] = e.shape()[i];
1373 std::iota(mapping.begin(), mapping.end(), 0);
1376 template <
class S,
class E,
class X,
class M>
1377 inline void shape_and_mapping_computation(S&, E&,
const X&, M&, std::true_type)
1381 template <
class S,
class E,
class X,
class M>
1382 inline void shape_and_mapping_computation_keep_dim(S&, E&,
const X&, M&, std::true_type)
1404 template <
class F,
class CT,
class X,
class O>
1405 template <
class Func,
class CTA,
class AX,
class OX>
1407 : m_e(std::forward<CTA>(e))
1408 , m_reduce(
xt::get<0>(func))
1409 , m_init(
xt::get<1>(func))
1410 , m_merge(
xt::get<2>(func))
1411 , m_axes(std::forward<AX>(axes))
1412 , m_shape(xtl::make_sequence<inner_shape_type>(
1416 , m_dim_mapping(xtl::make_sequence<dim_mapping_type>(
1420 , m_options(std::forward<OX>(options))
1426 if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less<>()))
1428 XTENSOR_THROW(std::runtime_error,
"Reducing axes should be sorted.");
1430 if (std::adjacent_find(m_axes.cbegin(), m_axes.cend()) != m_axes.cend())
1432 XTENSOR_THROW(std::runtime_error,
"Reducing axes should not contain duplicates.");
1434 if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
1438 "Axis " + std::to_string(m_axes[m_axes.size() - 1]) +
" out of bounds for reduction."
1442 if (!
typename O::keep_dims())
1444 detail::shape_and_mapping_computation(
1449 detail::is_fixed<shape_type>{}
1454 detail::shape_and_mapping_computation_keep_dim(
1459 detail::is_fixed<shape_type>{}
1473 template <
class F,
class CT,
class X,
class O>
1482 template <
class F,
class CT,
class X,
class O>
1485 return static_layout;
1488 template <
class F,
class CT,
class X,
class O>
1489 inline bool xreducer<F, CT, X, O>::is_contiguous() const noexcept
1506 template <
class F,
class CT,
class X,
class O>
1507 template <
class... Args>
1508 inline auto xreducer<F, CT, X, O>::operator()(Args... args)
const -> const_reference
1510 XTENSOR_TRY(check_index(
shape(), args...));
1511 XTENSOR_CHECK_DIMENSION(
shape(), args...);
1512 std::array<std::size_t,
sizeof...(Args)> arg_array = {{
static_cast<std::size_t
>(args)...}};
1513 return element(arg_array.cbegin(), arg_array.cend());
1535 template <
class F,
class CT,
class X,
class O>
1536 template <
class... Args>
1537 inline auto xreducer<F, CT, X, O>::unchecked(Args... args)
const -> const_reference
1539 std::array<std::size_t,
sizeof...(Args)> arg_array = {{
static_cast<std::size_t
>(args)...}};
1540 return element(arg_array.cbegin(), arg_array.cend());
1550 template <
class F,
class CT,
class X,
class O>
1552 inline auto xreducer<F, CT, X, O>::element(It first, It last)
const -> const_reference
1554 XTENSOR_TRY(check_element_index(
shape(), first, last));
1555 auto stepper = const_stepper(*
this, 0);
1560 auto size = std::ptrdiff_t(this->
dimension()) - std::distance(first, last);
1561 auto begin = first -
size;
1562 while (begin != last)
1566 stepper.step(dim++, std::size_t(0));
1571 stepper.step(dim++, std::size_t(*begin++));
1581 template <
class F,
class CT,
class X,
class O>
1599 template <
class F,
class CT,
class X,
class O>
1603 return xt::broadcast_shape(m_shape,
shape);
1611 template <
class F,
class CT,
class X,
class O>
1620 template <
class F,
class CT,
class X,
class O>
1622 inline auto xreducer<F, CT, X, O>::stepper_begin(
const S& shape)
const noexcept -> const_stepper
1624 size_type offset = shape.size() - this->dimension();
1625 return const_stepper(*
this, offset);
1628 template <
class F,
class CT,
class X,
class O>
1630 inline auto xreducer<F, CT, X, O>::stepper_end(
const S& shape,
layout_type l)
const noexcept
1633 size_type offset = shape.size() - this->dimension();
1634 return const_stepper(*
this, offset,
true, l);
1637 template <
class F,
class CT,
class X,
class O>
1639 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e)
const -> rebind_t<E>
1642 std::make_tuple(m_reduce, m_init, m_merge),
1649 template <
class F,
class CT,
class X,
class O>
1650 template <
class E,
class Func,
class Opts>
1651 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts)
const
1652 -> rebind_t<E, Func, Opts>
1654 return rebind_t<E, Func, Opts>(
1655 std::forward<Func>(func),
1658 std::forward<Opts>(opts)
1666 template <
class F,
class CT,
class X,
class O>
1667 inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(
1668 const xreducer_type& red,
1675 , m_stepper(get_substepper_begin())
1683 template <
class F,
class CT,
class X,
class O>
1684 inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
1686 reference r = aggregate(0);
1690 template <
class F,
class CT,
class X,
class O>
1691 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
1693 if (dim >= m_offset)
1695 m_stepper.step(get_dim(dim - m_offset));
1699 template <
class F,
class CT,
class X,
class O>
1700 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
1702 if (dim >= m_offset)
1704 m_stepper.step_back(get_dim(dim - m_offset));
1708 template <
class F,
class CT,
class X,
class O>
1709 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
1711 if (dim >= m_offset)
1713 m_stepper.step(get_dim(dim - m_offset), n);
1717 template <
class F,
class CT,
class X,
class O>
1718 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
1720 if (dim >= m_offset)
1722 m_stepper.step_back(get_dim(dim - m_offset), n);
1726 template <
class F,
class CT,
class X,
class O>
1727 inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
1729 if (dim >= m_offset)
1733 if (
typename O::keep_dims()
1734 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1739 m_stepper.reset(get_dim(dim - m_offset));
1743 template <
class F,
class CT,
class X,
class O>
1744 inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
1746 if (dim >= m_offset)
1749 if (
typename O::keep_dims()
1750 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1755 m_stepper.reset_back(get_dim(dim - m_offset));
1759 template <
class F,
class CT,
class X,
class O>
1760 inline void xreducer_stepper<F, CT, X, O>::to_begin()
1762 m_stepper.to_begin();
1765 template <
class F,
class CT,
class X,
class O>
1766 inline void xreducer_stepper<F, CT, X, O>::to_end(
layout_type l)
1768 m_stepper.to_end(l);
1771 template <
class F,
class CT,
class X,
class O>
1772 inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
1774 return O::has_initial_value ? m_reducer->m_options.initial_value
1775 :
static_cast<reference
>(m_reducer->m_init());
1778 template <
class F,
class CT,
class X,
class O>
1779 inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim)
const -> reference
1782 if (m_reducer->m_e.size() == size_type(0))
1784 res = initial_value();
1786 else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
1788 res = m_reducer->m_reduce(initial_value(), *m_stepper);
1792 res = aggregate_impl(dim,
typename O::keep_dims());
1793 if (O::has_initial_value && dim == 0)
1795 res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1801 template <
class F,
class CT,
class X,
class O>
1802 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type)
const -> reference
1806 size_type index = axis(dim);
1807 size_type size = shape(index);
1808 if (dim != m_reducer->m_axes.size() - 1)
1810 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1811 for (size_type i = 1; i != size; ++i)
1813 m_stepper.step(index);
1814 res = m_reducer->m_merge(res, aggregate_impl(dim + 1,
typename O::keep_dims()));
1819 res = m_reducer->m_reduce(
static_cast<reference
>(m_reducer->m_init()), *m_stepper);
1820 for (size_type i = 1; i != size; ++i)
1822 m_stepper.step(index);
1823 res = m_reducer->m_reduce(res, *m_stepper);
1826 m_stepper.reset(index);
1830 template <
class F,
class CT,
class X,
class O>
1831 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type)
const -> reference
1835 auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1836 if (ax_it != m_reducer->m_axes.end())
1838 size_type index = dim;
1839 size_type size = m_reducer->m_e.shape()[index];
1840 if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
1842 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1843 for (size_type i = 1; i != size; ++i)
1845 m_stepper.step(index);
1846 res = m_reducer->m_merge(res, aggregate_impl(dim + 1,
typename O::keep_dims()));
1851 res = m_reducer->m_reduce(
static_cast<reference
>(m_reducer->m_init()), *m_stepper);
1852 for (size_type i = 1; i != size; ++i)
1854 m_stepper.step(index);
1855 res = m_reducer->m_reduce(res, *m_stepper);
1858 m_stepper.reset(index);
1862 if (dim < m_reducer->m_e.dimension())
1864 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1870 template <
class F,
class CT,
class X,
class O>
1871 inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
1873 return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
1876 template <
class F,
class CT,
class X,
class O>
1877 inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim)
const noexcept -> size_type
1879 return m_reducer->m_dim_mapping[dim];
1882 template <
class F,
class CT,
class X,
class O>
1883 inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i)
const noexcept -> size_type
1885 return m_reducer->m_e.shape()[i];
1888 template <
class F,
class CT,
class X,
class O>
1889 inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i)
const noexcept -> size_type
1891 return m_reducer->m_axes[i];
Fixed shape implementation for compile time defined arrays.
size_type size() const noexcept(noexcept(derived_cast().shape()))
size_type dimension() const noexcept
Base class for multidimensional iterable constant expressions.
auto end() const noexcept -> const_layout_iterator< L >
Reducing function operating over specified axes.
bool broadcast_shape(S &shape, bool reuse_cache=false) const
const xexpression_type & expression() const noexcept
bool has_linear_assign(const S &strides) const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
xreducer(Func &&func, CTA &&e, AX &&axes, OX &&options)
Constructs an xreducer expression applying the specified function to the given expression over the gi...
layout_type layout() const noexcept
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto diff(const xexpression< T > &a, std::size_t n=1, std::ptrdiff_t axis=-1)
Calculate the n-th discrete difference along the given axis.
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto reduce(F &&f, E &&e, X &&axes, EVS &&options=EVS())
Returns an xexpression applying the specified reducing function to an expression over the given axes.