10#ifndef XTENSOR_REDUCER_HPP
11#define XTENSOR_REDUCER_HPP
15#include <initializer_list>
22#include <xtl/xfunctional.hpp>
23#include <xtl/xsequence.hpp>
25#include "xaccessible.hpp"
26#include "xbuilder.hpp"
28#include "xexpression.hpp"
29#include "xgenerator.hpp"
30#include "xiterable.hpp"
31#include "xtensor_config.hpp"
37 auto operator|(
const A<AX...>& args,
const A<X>& rhs)
39 return std::tuple_cat(args, rhs);
46 constexpr auto keep_dims = std::tuple<keep_dims_type>{};
48 template <
class T =
double>
49 struct xinitial : xt::detail::option_base
51 constexpr xinitial(T val)
56 constexpr T value()
const
65 constexpr auto initial(T val)
70 template <std::ptrdiff_t I,
class T,
class Tuple>
73 template <std::ptrdiff_t I,
class T>
76 static constexpr std::ptrdiff_t value = -1;
79 template <std::ptrdiff_t I,
class T,
class... Types>
82 static constexpr std::ptrdiff_t value = I;
85 template <std::ptrdiff_t I,
class T,
class U,
class... Types>
88 static constexpr std::ptrdiff_t value =
tuple_idx_of_impl<I + 1, T, std::tuple<Types...>>::value;
91 template <
class S,
class... X>
94 template <
template <
class...>
class S,
class... X>
97 using type = S<std::decay_t<X>...>;
100 template <
class T,
class Tuple>
103 static constexpr std::ptrdiff_t
107 template <
class R,
class T>
108 struct reducer_options
126 using d_t = std::decay_t<T>;
128 static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
129 reducer_options() =
default;
131 reducer_options(
const T& tpl)
133 xtl::mpl::static_if<initial_val_idx != std::tuple_size<T>::value>(
134 [
this, &tpl](
auto no_compile)
137 this->initial_value = no_compile(
138 std::get < initial_val_idx != std::tuple_size<T>::value
148 using evaluation_strategy = std::conditional_t<
149 tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
150 xt::evaluation_strategy::immediate_type,
151 xt::evaluation_strategy::lazy_type>;
153 using keep_dims = std::
154 conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1, std::true_type, std::false_type>;
156 static constexpr bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
161 using rebind_t = reducer_options<NR, T>;
164 auto rebind(NR initial,
const reducer_options<R, T>&)
const
166 reducer_options<NR, T> res;
167 res.initial_value = initial;
177 template <
class... X>
191#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
193 template <
class ST,
class X,
class KD = std::false_type>
196 template <
class S1,
class S2>
201 template <
class O,
class RS,
class R,
class E,
class AX>
202 inline void shape_computation(
207 std::enable_if_t<!detail::is_fixed<RS>::value,
int> = 0
210 if (
typename O::keep_dims())
212 resize_container(result_shape, expr.dimension());
213 for (std::size_t i = 0; i < expr.dimension(); ++i)
215 if (std::find(axes.begin(), axes.end(), i) == axes.end())
218 result_shape[i] = expr.shape()[i];
228 resize_container(result_shape, expr.dimension() - axes.size());
229 for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
231 if (std::find(axes.begin(), axes.end(), i) == axes.end())
234 result_shape[idx] = expr.shape()[i];
239 result.resize(result_shape, expr.layout());
243 template <
class O,
class RS,
class R,
class S,
class AX>
245 shape_computation(RS&, R&,
const S&,
const AX&, std::enable_if_t<detail::is_fixed<RS>::value,
int> = 0)
250 template <
class F,
class E,
class R, XTL_REQUIRES(std::is_convertible<
typename E::value_type,
typename R::value_type>)>
251 inline void copy_to_reduced(F&,
const E& e, R& result)
256 e.template cbegin<layout_type::row_major>(),
257 e.template cend<layout_type::row_major>(),
264 e.template cbegin<layout_type::column_major>(),
265 e.template cend<layout_type::column_major>(),
275 XTL_REQUIRES(xtl::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
276 inline void copy_to_reduced(F& f,
const E& e, R& result)
281 e.template cbegin<layout_type::row_major>(),
282 e.template cend<layout_type::row_major>(),
290 e.template cbegin<layout_type::column_major>(),
291 e.template cend<layout_type::column_major>(),
298 template <
class F,
class E,
class X,
class O>
299 inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
301 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
302 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
303 using expr_value_type =
typename std::decay_t<E>::value_type;
304 using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
305 std::declval<init_functor_type>()(),
306 std::declval<expr_value_type>()
310 options_t options(raw_options);
313 typename std::decay_t<E>::shape_type,
315 typename options_t::keep_dims>::type;
316 using result_container_type =
typename detail::xtype_for_shape<
317 shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
318 result_container_type result;
321 auto reduce_fct = xt::get<0>(f);
322 auto init_fct = xt::get<1>(f);
323 auto merge_fct = xt::get<2>(f);
325 if (axes.size() == 0)
327 result.resize(e.shape(), e.layout());
328 auto cpf = [&reduce_fct, &init_fct](
const auto& v)
330 return reduce_fct(
static_cast<result_type
>(init_fct()), v);
332 copy_to_reduced(cpf, e, result);
336 shape_type result_shape{};
337 dynamic_shape<std::size_t>
338 iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>,
decltype(e.shape())>(e.shape());
339 dynamic_shape<std::size_t> iter_strides(e.dimension());
345 if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less<>()))
347 XTENSOR_THROW(std::runtime_error,
"Reducing axes should be sorted.");
349 if (std::adjacent_find(axes.cbegin(), axes.cend()) != axes.cend())
351 XTENSOR_THROW(std::runtime_error,
"Reducing axes should not contain duplicates.");
353 if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
357 "Axis " + std::to_string(axes[axes.size() - 1]) +
" out of bounds for reduction."
361 detail::shape_computation<options_t>(result_shape, result, e, axes);
364 if (e.dimension() == axes.size())
366 result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
367 result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
372 auto strides_finder = e.strides().begin() +
static_cast<std::ptrdiff_t
>(leading_ax);
374 std::size_t inner_stride =
static_cast<std::size_t
>(*strides_finder);
376 while (inner_stride == 0 && strides_finder != iter_bound)
379 inner_stride =
static_cast<std::size_t
>(*strides_finder);
382 if (inner_stride == 0)
384 auto cpf = [&reduce_fct, &init_fct](
const auto& v)
386 return reduce_fct(
static_cast<result_type
>(init_fct()), v);
388 copy_to_reduced(cpf, e, result);
392 std::size_t inner_loop_size =
static_cast<std::size_t
>(inner_stride);
393 std::size_t outer_loop_size = e.shape()[leading_ax];
397 auto merge_loops = [&outer_loop_size, &e](
auto it,
auto end)
401 for (; it != end; ++it)
404 if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
407 outer_loop_size *= e.shape()[last_ax];
413 for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
415 if (std::find(axes.begin(), axes.end(), i) == axes.end())
418 iter_strides[i] =
static_cast<std::size_t
>(result.strides(
419 )[
typename options_t::keep_dims() ? i : idx]);
426 std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
428 iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
429 iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
434 std::size_t last_ax = merge_loops(axes.begin(), axes.end());
437 iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
438 iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
441 std::reverse(iter_shape.begin(), iter_shape.end());
442 std::reverse(iter_strides.begin(), iter_strides.end());
446 XTENSOR_THROW(std::runtime_error,
"Layout not supported in immediate reduction.");
449 xindex temp_idx(iter_shape.size());
450 auto next_idx = [&iter_shape, &iter_strides, &temp_idx]()
452 std::size_t i = iter_shape.size();
455 if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
466 return std::make_pair(
468 std::inner_product(temp_idx.begin(), temp_idx.end(), iter_strides.begin(), std::ptrdiff_t(0))
472 auto begin = e.data();
473 auto out = result.data();
474 auto out_begin = result.data();
476 std::ptrdiff_t next_stride = 0;
478 std::pair<bool, std::ptrdiff_t> idx_res(
false, 0);
483 auto merge_border = out;
491 if (inner_stride == 1)
493 while (idx_res.first !=
true)
497 result_type tmp = init_fct();
498 tmp = std::accumulate(begin, begin + outer_loop_size, tmp, reduce_fct);
501 *out = merge ? merge_fct(*out, tmp) : tmp;
503 begin += outer_loop_size;
505 idx_res = next_idx();
506 next_stride = idx_res.second;
507 out = out_begin + next_stride;
509 if (out > merge_border)
523 while (idx_res.first !=
true)
527 out + inner_loop_size,
530 [merge, &init_fct, &reduce_fct](
auto&& v1,
auto&& v2)
532 return merge ? reduce_fct(v1, v2) :
534 reduce_fct(static_cast<result_type>(init_fct()), v2);
538 begin += inner_stride;
539 for (std::size_t i = 1; i < outer_loop_size; ++i)
541 std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
542 begin += inner_stride;
545 idx_res = next_idx();
546 next_stride = idx_res.second;
547 out = out_begin + next_stride;
549 if (out > merge_border)
561 if (options_t::has_initial_value)
565 result.data() + result.size(),
567 [&merge_fct, &options](
auto&& v)
569 return merge_fct(v, options.initial_value);
583 using value_type = T;
585 constexpr const_value() =
default;
587 constexpr const_value(T t)
592 constexpr T operator()()
const
598 using rebind_t = const_value<NT>;
601 const_value<NT> rebind()
const;
608 template <
class T,
bool B>
609 struct evaluated_value_type
615 struct evaluated_value_type<T, true>
617 using type =
typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
620 template <
class T,
bool B>
621 using evaluated_value_type_t =
typename evaluated_value_type<T, B>::type;
624 template <
class REDUCE_FUNC,
class INIT_FUNC = const_value<
long int>,
class MERGE_FUNC = REDUCE_FUNC>
625 struct xreducer_functors :
public std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>
627 using self_type = xreducer_functors<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
628 using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
629 using reduce_functor_type = REDUCE_FUNC;
630 using init_functor_type = INIT_FUNC;
631 using merge_functor_type = MERGE_FUNC;
632 using init_value_type =
typename init_functor_type::value_type;
640 xreducer_functors(RF&& reduce_func)
641 : base_type(std::forward<RF>(reduce_func), INIT_FUNC(), reduce_func)
645 template <
class RF,
class IF>
646 xreducer_functors(RF&& reduce_func, IF&& init_func)
647 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), reduce_func)
651 template <
class RF,
class IF,
class MF>
652 xreducer_functors(RF&& reduce_func, IF&& init_func, MF&& merge_func)
653 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func))
657 reduce_functor_type get_reduce()
const
659 return std::get<0>(upcast());
662 init_functor_type get_init()
const
664 return std::get<1>(upcast());
667 merge_functor_type get_merge()
const
669 return std::get<2>(upcast());
673 using rebind_t = xreducer_functors<REDUCE_FUNC, const_value<NT>, MERGE_FUNC>;
676 rebind_t<NT> rebind()
678 return make_xreducer_functor(get_reduce(), get_init().
template rebind<NT>(), get_merge());
684 const base_type& upcast()
const
686 return static_cast<const base_type&
>(*this);
691 auto make_xreducer_functor(RF&& reduce_func)
694 return reducer_type(std::forward<RF>(reduce_func));
697 template <
class RF,
class IF>
698 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
700 using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
701 return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
704 template <
class RF,
class IF,
class MF>
705 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
708 std::remove_reference_t<RF>,
709 std::remove_reference_t<IF>,
710 std::remove_reference_t<MF>>;
712 std::forward<RF>(reduce_func),
713 std::forward<IF>(init_func),
714 std::forward<MF>(merge_func)
724 template <
class Tag,
class F,
class CT,
class X,
class O>
727 template <
class F,
class CT,
class X,
class O>
733 template <
class F,
class CT,
class X,
class O>
738 template <
class F,
class CT,
class X,
class O>
746 template <
class F,
class CT,
class X,
class O>
749 template <
class F,
class CT,
class X,
class O>
752 template <
class F,
class CT,
class X,
class O>
755 using xexpression_type = std::decay_t<CT>;
757 typename xexpression_type::shape_type,
759 typename O::keep_dims>::type;
761 using stepper = const_stepper;
764 template <
class F,
class CT,
class X,
class O>
767 using xexpression_type = std::decay_t<CT>;
768 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
769 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
770 using merge_functor_type =
typename std::decay_t<F>::merge_functor_type;
771 using substepper_type =
typename xexpression_type::const_stepper;
772 using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
773 std::declval<init_functor_type>()(),
774 *std::declval<substepper_type>()
776 using value_type =
typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
778 using reference = value_type;
779 using const_reference = value_type;
780 using size_type =
typename xexpression_type::size_type;
789 template <std::size_t... I>
792 using type = std::array<std::size_t,
sizeof...(I)>;
812 template <
class F,
class CT,
class X,
class O>
813 class xreducer :
public xsharable_expression<xreducer<F, CT, X, O>>,
815 public xaccessible<xreducer<F, CT, X, O>>,
816 public extension::xreducer_base_t<F, CT, X, O>
823 using reduce_functor_type =
typename inner_types::reduce_functor_type;
824 using init_functor_type =
typename inner_types::init_functor_type;
825 using merge_functor_type =
typename inner_types::merge_functor_type;
828 using xexpression_type =
typename inner_types::xexpression_type;
831 using extension_base = extension::xreducer_base_t<F, CT, X, O>;
832 using expression_tag =
typename extension_base::expression_tag;
834 using substepper_type =
typename inner_types::substepper_type;
835 using value_type =
typename inner_types::value_type;
836 using reference =
typename inner_types::reference;
837 using const_reference =
typename inner_types::const_reference;
838 using pointer = value_type*;
839 using const_pointer =
const value_type*;
841 using size_type =
typename inner_types::size_type;
842 using difference_type =
typename xexpression_type::difference_type;
845 using inner_shape_type =
typename iterable_base::inner_shape_type;
846 using shape_type = inner_shape_type;
848 using dim_mapping_type =
typename select_dim_mapping_type<inner_shape_type>::type;
850 using stepper =
typename iterable_base::stepper;
851 using const_stepper =
typename iterable_base::const_stepper;
852 using bool_load_type =
typename xexpression_type::bool_load_type;
855 static constexpr bool contiguous_layout =
false;
857 template <
class Func,
class CTA,
class AX,
class OX>
858 xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options);
860 const inner_shape_type&
shape() const noexcept;
862 bool is_contiguous() const noexcept;
864 template <class... Args>
865 const_reference operator()(Args... args) const;
866 template <class... Args>
867 const_reference unchecked(Args... args) const;
870 const_reference element(It first, It last) const;
881 const_stepper stepper_begin(const S&
shape) const noexcept;
885 template <class E, class Func = F, class Opts = O>
886 using rebind_t =
xreducer<Func, E, X, Opts>;
889 rebind_t<E> build_reducer(E&& e) const;
891 template <class E, class Func, class Opts>
892 rebind_t<E, Func, Opts> build_reducer(E&& e, Func&& func, Opts&& opts) const;
894 xreducer_functors_type functors()
const
896 return xreducer_functors_type(m_reduce, m_init, m_merge);
901 const O& options()
const
909 reduce_functor_type m_reduce;
910 init_functor_type m_init;
911 merge_functor_type m_merge;
913 inner_shape_type m_shape;
914 dim_mapping_type m_dim_mapping;
917 friend class xreducer_stepper<F, CT, X, O>;
926 template <
class F,
class E,
class X,
class O>
927 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::lazy_type, O&& options)
929 decltype(
auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
931 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
932 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
933 using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
934 std::declval<init_functor_type>()(),
935 *std::declval<
typename std::decay_t<E>::const_stepper>()
937 using evaluated_value_type = evaluated_value_type_t<value_type, is_xexpression<value_type>::value>;
939 using reducer_type = xreducer<
942 xtl::const_closure_type_t<
decltype(normalized_axes)>,
943 reducer_options<evaluated_value_type, std::decay_t<O>>>;
947 std::forward<
decltype(normalized_axes)>(normalized_axes),
948 std::forward<O>(options)
952 template <
class F,
class E,
class X,
class O>
953 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
955 decltype(
auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
956 return reduce_immediate(
958 eval(std::forward<E>(e)),
959 std::forward<
decltype(normalized_axes)>(normalized_axes),
960 std::forward<O>(options)
965#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
970 struct is_xreducer_functors_impl : std::false_type
974 template <
class RF,
class IF,
class MF>
975 struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>> : std::true_type
980 using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
1000 class EVS = DEFAULT_STRATEGY_REDUCERS,
1002 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1004 return detail::reduce_impl(
1007 std::forward<X>(axes),
1008 typename reducer_options<int, EVS>::evaluation_strategy{},
1009 std::forward<EVS>(options)
1017 class EVS = DEFAULT_STRATEGY_REDUCERS,
1018 XTL_REQUIRES(xtl::negation<is_reducer_options<X>>, xtl::negation<detail::is_xreducer_functors<F>>)>
1019 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1022 make_xreducer_functor(std::forward<F>(f)),
1024 std::forward<X>(axes),
1025 std::forward<EVS>(options)
1032 class EVS = DEFAULT_STRATEGY_REDUCERS,
1034 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1036 xindex_type_t<typename std::decay_t<E>::shape_type> ar;
1037 resize_container(ar, e.dimension());
1038 std::iota(ar.begin(), ar.end(), 0);
1039 return detail::reduce_impl(
1043 typename reducer_options<
int, std::decay_t<EVS>>::evaluation_strategy{},
1044 std::forward<EVS>(options)
1051 class EVS = DEFAULT_STRATEGY_REDUCERS,
1053 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1055 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
1063 class EVS = DEFAULT_STRATEGY_REDUCERS,
1064 XTL_REQUIRES(detail::is_xreducer_functors<F>)>
1065 inline auto reduce(F&& f, E&& e,
const I (&axes)[N], EVS options = EVS())
1067 using axes_type = std::array<std::size_t, N>;
1068 auto ax = xt::forward_normalize<axes_type>(e, axes);
1069 return detail::reduce_impl(
1073 typename reducer_options<int, EVS>::evaluation_strategy{},
1083 class EVS = DEFAULT_STRATEGY_REDUCERS,
1084 XTL_REQUIRES(xtl::negation<detail::is_xreducer_functors<F>>)>
1085 inline auto reduce(F&& f, E&& e,
const I (&axes)[N], EVS options = EVS())
1087 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
1094 template <
class F,
class CT,
class X,
class O>
1095 class xreducer_stepper
1099 using self_type = xreducer_stepper<F, CT, X, O>;
1102 using value_type =
typename xreducer_type::value_type;
1103 using reference =
typename xreducer_type::value_type;
1104 using pointer =
typename xreducer_type::const_pointer;
1105 using size_type =
typename xreducer_type::size_type;
1106 using difference_type =
typename xreducer_type::difference_type;
1108 using xexpression_type =
typename xreducer_type::xexpression_type;
1109 using substepper_type =
typename xexpression_type::const_stepper;
1110 using shape_type =
typename xreducer_type::shape_type;
1113 const xreducer_type& red,
1116 layout_type l = default_assignable_layout(xexpression_type::static_layout)
1119 reference operator*()
const;
1121 void step(size_type dim);
1122 void step_back(size_type dim);
1123 void step(size_type dim, size_type n);
1124 void step_back(size_type dim, size_type n);
1125 void reset(size_type dim);
1126 void reset_back(size_type dim);
1133 reference initial_value()
const;
1134 reference aggregate(size_type dim)
const;
1135 reference aggregate_impl(size_type dim, std::false_type)
const;
1136 reference aggregate_impl(size_type dim, std::true_type)
const;
1138 substepper_type get_substepper_begin()
const;
1139 size_type get_dim(size_type dim)
const noexcept;
1140 size_type shape(size_type i)
const noexcept;
1141 size_type axis(size_type i)
const noexcept;
1143 const xreducer_type* m_reducer;
1145 mutable substepper_type m_stepper;
1154 template <std::size_t X, std::size_t... I>
1157 static constexpr bool value = xtl::disjunction<std::integral_constant<bool, X == I>...>::value;
1160 template <std::
size_t Z,
class S1,
class S2,
class R>
1161 struct fixed_xreducer_shape_type_impl;
1163 template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
1164 struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1166 using type = std::conditional_t<
1168 typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>::type,
1169 typename fixed_xreducer_shape_type_impl<
1173 fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
1176 template <std::size_t... I, std::size_t... J, std::size_t... R>
1177 struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1180 conditional_t<in<0, J...>::value, fixed_shape<R...>, fixed_shape<detail::at<0, I...>::value, R...>>;
1188 struct xreducer_size_type
1190 using type = std::size_t;
1194 using xreducer_size_type_t =
typename xreducer_size_type<T>::type;
1197 struct xreducer_temporary_type
1203 using xreducer_temporary_type_t =
typename xreducer_temporary_type<T>::type;
1209 template <
class T,
class U>
1210 struct const_value_rebinder
1212 static const_value<U> run(
const const_value<T>& t)
1214 return const_value<U>(t.m_value);
1227 return detail::const_value_rebinder<T, NT>::run(*
this);
1234 template <
class S1,
class S2>
1237 template <std::size_t... I, std::size_t... J>
1240 using type =
typename detail::
1245 template <
class ST,
class X,
class O>
1248 using type = promote_shape_t<ST, std::decay_t<X>>;
1251 template <
class I1, std::
size_t N1,
class I2, std::
size_t N2>
1254 using type = std::array<I2, N1>;
1257 template <
class I1, std::
size_t N1,
class I2, std::
size_t N2>
1260 using type = std::array<I2, N1 - N2>;
1263 template <std::size_t... I,
class I2, std::size_t N2>
1266 using type = std::conditional_t<
sizeof...(I) == N2,
fixed_shape<>, std::array<I2,
sizeof...(I) - N2>>;
1271 template <
class S1,
class S2>
1274 template <
class T, T... I1, T... I2>
1275 struct ixconcat<std::integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
1277 using type = std::integer_sequence<T, I1..., I2...>;
1280 template <
class T, T X, std::
size_t N>
1281 struct repeat_integer_sequence
1283 using type =
typename ixconcat<
1284 std::integer_sequence<T, X>,
1285 typename repeat_integer_sequence<T, X, N - 1>::type>::type;
1288 template <
class T, T X>
1289 struct repeat_integer_sequence<T, X, 0>
1291 using type = std::integer_sequence<T>;
1294 template <
class T, T X>
1295 struct repeat_integer_sequence<T, X, 2>
1297 using type = std::integer_sequence<T, X, X>;
1300 template <
class T, T X>
1301 struct repeat_integer_sequence<T, X, 1>
1303 using type = std::integer_sequence<T, X>;
1307 template <std::size_t... I,
class I2, std::size_t N2>
1310 template <std::size_t... X>
1311 static constexpr auto get_type(std::index_sequence<X...>)
1317 using type = std::conditional_t<
1319 decltype(get_type(
typename detail::repeat_integer_sequence<std::size_t, std::size_t(1), N2>::type{})),
1320 std::array<I2,
sizeof...(I)>>;
1324 template <std::size_t... I, std::size_t... J,
class O>
1332 template <
class S,
class E,
class X,
class M>
1333 inline void shape_and_mapping_computation(S& shape, E& e,
const X& axes, M& mapping, std::false_type)
1335 auto first = e.shape().begin();
1336 auto last = e.shape().end();
1337 auto exclude_it = axes.begin();
1339 using value_type =
typename S::value_type;
1340 using difference_type =
typename S::difference_type;
1341 auto d_first = shape.begin();
1342 auto map_first = mapping.begin();
1345 while (iter != last && exclude_it != axes.end())
1347 auto diff = std::distance(first, iter);
1348 if (
diff != difference_type(*exclude_it))
1350 *d_first++ = *iter++;
1351 *map_first++ = value_type(
diff);
1360 auto diff = std::distance(first, iter);
1361 auto end = std::distance(iter, last);
1362 std::iota(map_first, map_first + end,
diff);
1363 std::copy(iter, last, d_first);
1366 template <
class S,
class E,
class X,
class M>
1368 shape_and_mapping_computation_keep_dim(S& shape, E& e,
const X& axes, M& mapping, std::false_type)
1370 for (std::size_t i = 0; i < e.dimension(); ++i)
1372 if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
1375 shape[i] = e.shape()[i];
1382 std::iota(mapping.begin(), mapping.end(), 0);
1385 template <
class S,
class E,
class X,
class M>
1386 inline void shape_and_mapping_computation(S&, E&,
const X&, M&, std::true_type)
1390 template <
class S,
class E,
class X,
class M>
1391 inline void shape_and_mapping_computation_keep_dim(S&, E&,
const X&, M&, std::true_type)
1412 template <
class F,
class CT,
class X,
class O>
1413 template <
class Func,
class CTA,
class AX,
class OX>
1415 : m_e(std::forward<CTA>(e))
1416 , m_reduce(
xt::get<0>(func))
1417 , m_init(
xt::get<1>(func))
1418 , m_merge(
xt::get<2>(func))
1419 , m_axes(std::forward<AX>(axes))
1420 , m_shape(xtl::make_sequence<inner_shape_type>(
1424 , m_dim_mapping(xtl::make_sequence<dim_mapping_type>(
1428 , m_options(std::forward<OX>(options))
1434 if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less<>()))
1436 XTENSOR_THROW(std::runtime_error,
"Reducing axes should be sorted.");
1438 if (std::adjacent_find(m_axes.cbegin(), m_axes.cend()) != m_axes.cend())
1440 XTENSOR_THROW(std::runtime_error,
"Reducing axes should not contain duplicates.");
1442 if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
1446 "Axis " + std::to_string(m_axes[m_axes.size() - 1]) +
" out of bounds for reduction."
1450 if (!
typename O::keep_dims())
1452 detail::shape_and_mapping_computation(
1457 detail::is_fixed<shape_type>{}
1462 detail::shape_and_mapping_computation_keep_dim(
1467 detail::is_fixed<shape_type>{}
1481 template <
class F,
class CT,
class X,
class O>
1490 template <
class F,
class CT,
class X,
class O>
1493 return static_layout;
1496 template <
class F,
class CT,
class X,
class O>
1497 inline bool xreducer<F, CT, X, O>::is_contiguous() const noexcept
1514 template <
class F,
class CT,
class X,
class O>
1515 template <
class... Args>
1516 inline auto xreducer<F, CT, X, O>::operator()(Args... args)
const -> const_reference
1518 XTENSOR_TRY(check_index(
shape(), args...));
1519 XTENSOR_CHECK_DIMENSION(
shape(), args...);
1520 std::array<std::size_t,
sizeof...(Args)> arg_array = {{
static_cast<std::size_t
>(args)...}};
1521 return element(arg_array.cbegin(), arg_array.cend());
1543 template <
class F,
class CT,
class X,
class O>
1544 template <
class... Args>
1545 inline auto xreducer<F, CT, X, O>::unchecked(Args... args)
const -> const_reference
1547 std::array<std::size_t,
sizeof...(Args)> arg_array = {{
static_cast<std::size_t
>(args)...}};
1548 return element(arg_array.cbegin(), arg_array.cend());
1558 template <
class F,
class CT,
class X,
class O>
1560 inline auto xreducer<F, CT, X, O>::element(It first, It last)
const -> const_reference
1562 XTENSOR_TRY(check_element_index(
shape(), first, last));
1563 auto stepper = const_stepper(*
this, 0);
1568 auto size = std::ptrdiff_t(this->
dimension()) - std::distance(first, last);
1569 auto begin = first -
size;
1570 while (begin != last)
1574 stepper.step(dim++, std::size_t(0));
1579 stepper.step(dim++, std::size_t(*begin++));
1589 template <
class F,
class CT,
class X,
class O>
1607 template <
class F,
class CT,
class X,
class O>
1611 return xt::broadcast_shape(m_shape,
shape);
1619 template <
class F,
class CT,
class X,
class O>
1628 template <
class F,
class CT,
class X,
class O>
1630 inline auto xreducer<F, CT, X, O>::stepper_begin(
const S& shape)
const noexcept -> const_stepper
1632 size_type offset = shape.size() - this->dimension();
1633 return const_stepper(*
this, offset);
1636 template <
class F,
class CT,
class X,
class O>
1638 inline auto xreducer<F, CT, X, O>::stepper_end(
const S& shape,
layout_type l)
const noexcept
1641 size_type offset = shape.size() - this->dimension();
1642 return const_stepper(*
this, offset,
true, l);
1645 template <
class F,
class CT,
class X,
class O>
1647 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e)
const -> rebind_t<E>
1650 std::make_tuple(m_reduce, m_init, m_merge),
1657 template <
class F,
class CT,
class X,
class O>
1658 template <
class E,
class Func,
class Opts>
1659 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts)
const
1660 -> rebind_t<E, Func, Opts>
1662 return rebind_t<E, Func, Opts>(
1663 std::forward<Func>(func),
1666 std::forward<Opts>(opts)
1674 template <
class F,
class CT,
class X,
class O>
1675 inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(
1676 const xreducer_type& red,
1683 , m_stepper(get_substepper_begin())
1691 template <
class F,
class CT,
class X,
class O>
1692 inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
1694 reference r = aggregate(0);
1698 template <
class F,
class CT,
class X,
class O>
1699 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
1701 if (dim >= m_offset)
1703 m_stepper.step(get_dim(dim - m_offset));
1707 template <
class F,
class CT,
class X,
class O>
1708 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
1710 if (dim >= m_offset)
1712 m_stepper.step_back(get_dim(dim - m_offset));
1716 template <
class F,
class CT,
class X,
class O>
1717 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
1719 if (dim >= m_offset)
1721 m_stepper.step(get_dim(dim - m_offset), n);
1725 template <
class F,
class CT,
class X,
class O>
1726 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
1728 if (dim >= m_offset)
1730 m_stepper.step_back(get_dim(dim - m_offset), n);
1734 template <
class F,
class CT,
class X,
class O>
1735 inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
1737 if (dim >= m_offset)
1741 if (
typename O::keep_dims()
1742 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1747 m_stepper.reset(get_dim(dim - m_offset));
1751 template <
class F,
class CT,
class X,
class O>
1752 inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
1754 if (dim >= m_offset)
1757 if (
typename O::keep_dims()
1758 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1763 m_stepper.reset_back(get_dim(dim - m_offset));
1767 template <
class F,
class CT,
class X,
class O>
1768 inline void xreducer_stepper<F, CT, X, O>::to_begin()
1770 m_stepper.to_begin();
1773 template <
class F,
class CT,
class X,
class O>
1774 inline void xreducer_stepper<F, CT, X, O>::to_end(
layout_type l)
1776 m_stepper.to_end(l);
1779 template <
class F,
class CT,
class X,
class O>
1780 inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
1782 return O::has_initial_value ? m_reducer->m_options.initial_value
1783 :
static_cast<reference
>(m_reducer->m_init());
1786 template <
class F,
class CT,
class X,
class O>
1787 inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim)
const -> reference
1790 if (m_reducer->m_e.size() == size_type(0))
1792 res = initial_value();
1794 else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
1796 res = m_reducer->m_reduce(initial_value(), *m_stepper);
1800 res = aggregate_impl(dim,
typename O::keep_dims());
1801 if (O::has_initial_value && dim == 0)
1803 res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1809 template <
class F,
class CT,
class X,
class O>
1810 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type)
const -> reference
1814 size_type index = axis(dim);
1815 size_type size = shape(index);
1816 if (dim != m_reducer->m_axes.size() - 1)
1818 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1819 for (size_type i = 1; i != size; ++i)
1821 m_stepper.step(index);
1822 res = m_reducer->m_merge(res, aggregate_impl(dim + 1,
typename O::keep_dims()));
1827 res = m_reducer->m_reduce(
static_cast<reference
>(m_reducer->m_init()), *m_stepper);
1828 for (size_type i = 1; i != size; ++i)
1830 m_stepper.step(index);
1831 res = m_reducer->m_reduce(res, *m_stepper);
1834 m_stepper.reset(index);
1838 template <
class F,
class CT,
class X,
class O>
1839 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type)
const -> reference
1843 auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1844 if (ax_it != m_reducer->m_axes.end())
1846 size_type index = dim;
1847 size_type size = m_reducer->m_e.shape()[index];
1848 if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
1850 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1851 for (size_type i = 1; i != size; ++i)
1853 m_stepper.step(index);
1854 res = m_reducer->m_merge(res, aggregate_impl(dim + 1,
typename O::keep_dims()));
1859 res = m_reducer->m_reduce(
static_cast<reference
>(m_reducer->m_init()), *m_stepper);
1860 for (size_type i = 1; i != size; ++i)
1862 m_stepper.step(index);
1863 res = m_reducer->m_reduce(res, *m_stepper);
1866 m_stepper.reset(index);
1870 if (dim < m_reducer->m_e.dimension())
1872 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1878 template <
class F,
class CT,
class X,
class O>
1879 inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
1881 return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
1884 template <
class F,
class CT,
class X,
class O>
1885 inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim)
const noexcept -> size_type
1887 return m_reducer->m_dim_mapping[dim];
1890 template <
class F,
class CT,
class X,
class O>
1891 inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i)
const noexcept -> size_type
1893 return m_reducer->m_e.shape()[i];
1896 template <
class F,
class CT,
class X,
class O>
1897 inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i)
const noexcept -> size_type
1899 return m_reducer->m_axes[i];
Fixed shape implementation for compile time defined arrays.
size_type size() const noexcept
size_type dimension() const noexcept
Base class for multidimensional iterable constant expressions.
auto end() const noexcept -> const_layout_iterator< L >
Reducing function operating over specified axes.
bool broadcast_shape(S &shape, bool reuse_cache=false) const
const xexpression_type & expression() const noexcept
bool has_linear_assign(const S &strides) const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
xreducer(Func &&func, CTA &&e, AX &&axes, OX &&options)
Constructs an xreducer expression applying the specified function to the given expression over the gi...
layout_type layout() const noexcept
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto diff(const xexpression< T > &a, std::size_t n=1, std::ptrdiff_t axis=-1)
Calculate the n-th discrete difference along the given axis.
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto reduce(F &&f, E &&e, X &&axes, EVS &&options=EVS())
Returns an xexpression applying the specified reducing function to an expression over the given axes.