10#ifndef XTENSOR_REDUCER_HPP
11#define XTENSOR_REDUCER_HPP
15#include <initializer_list>
22#include <xtl/xfunctional.hpp>
23#include <xtl/xsequence.hpp>
25#include "../core/xaccessible.hpp"
26#include "../core/xeval.hpp"
27#include "../core/xexpression.hpp"
28#include "../core/xiterable.hpp"
29#include "../core/xtensor_config.hpp"
30#include "../generators/xbuilder.hpp"
31#include "../generators/xgenerator.hpp"
32#include "../utils/xutils.hpp"
37 auto operator|(
const A<AX...>& args,
const A<X>& rhs)
39 return std::tuple_cat(args, rhs);
46 constexpr auto keep_dims = std::tuple<keep_dims_type>{};
48 template <
class T =
double>
49 struct xinitial : xt::detail::option_base
51 constexpr xinitial(T val)
56 constexpr T value()
const
65 constexpr auto initial(T val)
70 template <std::ptrdiff_t I,
class T,
class Tuple>
73 template <std::ptrdiff_t I,
class T>
76 static constexpr std::ptrdiff_t value = -1;
79 template <std::ptrdiff_t I,
class T,
class... Types>
82 static constexpr std::ptrdiff_t value = I;
85 template <std::ptrdiff_t I,
class T,
class U,
class... Types>
88 static constexpr std::ptrdiff_t value =
tuple_idx_of_impl<I + 1, T, std::tuple<Types...>>::value;
91 template <
class S,
class... X>
94 template <
template <
class...>
class S,
class... X>
97 using type = S<std::decay_t<X>...>;
100 template <
class T,
class Tuple>
103 static constexpr std::ptrdiff_t
107 template <
class R,
class T>
108 struct reducer_options
126 using d_t = std::decay_t<T>;
128 static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
129 reducer_options() =
default;
131 reducer_options(
const T& tpl)
133 if constexpr (initial_val_idx != std::tuple_size<T>::value)
135 initial_value = std::get < initial_val_idx != std::tuple_size<T>::value ? initial_val_idx
140 using evaluation_strategy = std::conditional_t<
141 tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
145 using keep_dims = std::
146 conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1, std::true_type, std::false_type>;
148 static constexpr bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
153 using rebind_t = reducer_options<NR, T>;
156 auto rebind(NR initial,
const reducer_options<R, T>&)
const
158 reducer_options<NR, T> res;
159 res.initial_value = initial;
169 template <
class... X>
183#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
185 template <
class ST,
class X,
class KD = std::false_type>
188 template <
class S1,
class S2>
193 template <
class O,
class RS,
class R,
class E,
class AX>
194 inline void shape_computation(
199 std::enable_if_t<!detail::is_fixed<RS>::value,
int> = 0
202 if (
typename O::keep_dims())
204 resize_container(result_shape, expr.dimension());
205 for (std::size_t i = 0; i < expr.dimension(); ++i)
207 if (std::find(axes.begin(), axes.end(), i) == axes.end())
210 result_shape[i] = expr.shape()[i];
220 resize_container(result_shape, expr.dimension() - axes.size());
221 for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
223 if (std::find(axes.begin(), axes.end(), i) == axes.end())
226 result_shape[idx] = expr.shape()[i];
231 result.resize(result_shape, expr.layout());
235 template <
class O,
class RS,
class R,
class S,
class AX>
237 shape_computation(RS&, R&,
const S&,
const AX&, std::enable_if_t<detail::is_fixed<RS>::value,
int> = 0)
242 template <
class F,
class E,
class R, XTL_REQUIRES(std::is_convertible<
typename E::value_type,
typename R::value_type>)>
243 inline void copy_to_reduced(F&,
const E& e, R& result)
248 e.template cbegin<layout_type::row_major>(),
249 e.template cend<layout_type::row_major>(),
256 e.template cbegin<layout_type::column_major>(),
257 e.template cend<layout_type::column_major>(),
267 XTL_REQUIRES(std::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
268 inline void copy_to_reduced(F& f,
const E& e, R& result)
273 e.template cbegin<layout_type::row_major>(),
274 e.template cend<layout_type::row_major>(),
282 e.template cbegin<layout_type::column_major>(),
283 e.template cend<layout_type::column_major>(),
290 template <
class F,
class E,
class X,
class O>
291 inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
293 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
294 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
295 using expr_value_type =
typename std::decay_t<E>::value_type;
296 using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
297 std::declval<init_functor_type>()(),
298 std::declval<expr_value_type>()
302 options_t options(raw_options);
305 typename std::decay_t<E>::shape_type,
307 typename options_t::keep_dims>::type;
308 using result_container_type =
typename detail::xtype_for_shape<
309 shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
310 result_container_type result;
313 auto reduce_fct = xt::get<0>(f);
314 auto init_fct = xt::get<1>(f);
315 auto merge_fct = xt::get<2>(f);
317 if (axes.size() == 0)
319 result.resize(e.shape(), e.layout());
320 auto cpf = [&reduce_fct, &init_fct](
const auto& v)
322 return reduce_fct(
static_cast<result_type
>(init_fct()), v);
324 copy_to_reduced(cpf, e, result);
328 shape_type result_shape{};
329 dynamic_shape<std::size_t>
330 iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>,
decltype(e.shape())>(e.shape());
331 dynamic_shape<std::size_t> iter_strides(e.dimension());
337 if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less<>()))
339 XTENSOR_THROW(std::runtime_error,
"Reducing axes should be sorted.");
341 if (std::adjacent_find(axes.cbegin(), axes.cend()) != axes.cend())
343 XTENSOR_THROW(std::runtime_error,
"Reducing axes should not contain duplicates.");
345 if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
349 "Axis " + std::to_string(axes[axes.size() - 1]) +
" out of bounds for reduction."
353 detail::shape_computation<options_t>(result_shape, result, e, axes);
356 if (e.dimension() == axes.size())
358 result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
359 result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
364 auto strides_finder = e.strides().begin() +
static_cast<std::ptrdiff_t
>(leading_ax);
366 std::size_t inner_stride =
static_cast<std::size_t
>(*strides_finder);
368 while (inner_stride == 0 && strides_finder != iter_bound)
371 inner_stride =
static_cast<std::size_t
>(*strides_finder);
374 if (inner_stride == 0)
376 auto cpf = [&reduce_fct, &init_fct](
const auto& v)
378 return reduce_fct(
static_cast<result_type
>(init_fct()), v);
380 copy_to_reduced(cpf, e, result);
384 std::size_t inner_loop_size =
static_cast<std::size_t
>(inner_stride);
385 std::size_t outer_loop_size = e.shape()[leading_ax];
389 auto merge_loops = [&outer_loop_size, &e](
auto it,
auto end)
393 for (; it != end; ++it)
396 if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
399 outer_loop_size *= e.shape()[last_ax];
405 for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
407 if (std::find(axes.begin(), axes.end(), i) == axes.end())
410 iter_strides[i] =
static_cast<std::size_t
>(result.strides(
411 )[
typename options_t::keep_dims() ? i : idx]);
418 std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
420 iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
421 iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
426 std::size_t last_ax = merge_loops(axes.begin(), axes.end());
429 iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
430 iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
433 std::reverse(iter_shape.begin(), iter_shape.end());
434 std::reverse(iter_strides.begin(), iter_strides.end());
438 XTENSOR_THROW(std::runtime_error,
"Layout not supported in immediate reduction.");
441 xindex temp_idx(iter_shape.size());
442 auto next_idx = [&iter_shape, &iter_strides, &temp_idx]()
444 std::size_t i = iter_shape.size();
447 if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
458 return std::make_pair(
460 std::inner_product(temp_idx.begin(), temp_idx.end(), iter_strides.begin(), std::ptrdiff_t(0))
464 auto begin = e.data();
465 auto out = result.data();
466 auto out_begin = result.data();
468 std::ptrdiff_t next_stride = 0;
470 std::pair<bool, std::ptrdiff_t> idx_res(
false, 0);
475 auto merge_border = out;
483 if (inner_stride == 1)
485 while (idx_res.first !=
true)
489 result_type tmp = init_fct();
490 tmp = std::accumulate(begin, begin + outer_loop_size, tmp, reduce_fct);
493 *out = merge ? merge_fct(*out, tmp) : tmp;
495 begin += outer_loop_size;
497 idx_res = next_idx();
498 next_stride = idx_res.second;
499 out = out_begin + next_stride;
501 if (out > merge_border)
515 while (idx_res.first !=
true)
519 out + inner_loop_size,
522 [merge, &init_fct, &reduce_fct](
auto&& v1,
auto&& v2)
524 return merge ? reduce_fct(v1, v2) :
526 reduce_fct(static_cast<result_type>(init_fct()), v2);
530 begin += inner_stride;
531 for (std::size_t i = 1; i < outer_loop_size; ++i)
533 std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
534 begin += inner_stride;
537 idx_res = next_idx();
538 next_stride = idx_res.second;
539 out = out_begin + next_stride;
541 if (out > merge_border)
553 if (options_t::has_initial_value)
557 result.data() + result.size(),
559 [&merge_fct, &options](
auto&& v)
561 return merge_fct(v, options.initial_value);
575 using value_type = T;
577 constexpr const_value() =
default;
579 constexpr const_value(T t)
584 constexpr T operator()()
const
590 using rebind_t = const_value<NT>;
593 const_value<NT> rebind()
const;
600 template <
class T,
bool B>
601 struct evaluated_value_type
607 struct evaluated_value_type<T, true>
609 using type =
typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
612 template <
class T,
bool B>
613 using evaluated_value_type_t =
typename evaluated_value_type<T, B>::type;
616 template <
class REDUCE_FUNC,
class INIT_FUNC = const_value<
long int>,
class MERGE_FUNC = REDUCE_FUNC>
617 struct xreducer_functors :
public std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>
619 using self_type = xreducer_functors<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
620 using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
621 using reduce_functor_type = REDUCE_FUNC;
622 using init_functor_type = INIT_FUNC;
623 using merge_functor_type = MERGE_FUNC;
624 using init_value_type =
typename init_functor_type::value_type;
632 xreducer_functors(RF&& reduce_func)
633 : base_type(std::forward<RF>(reduce_func), INIT_FUNC(), reduce_func)
637 template <
class RF,
class IF>
638 xreducer_functors(RF&& reduce_func, IF&& init_func)
639 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), reduce_func)
643 template <
class RF,
class IF,
class MF>
644 xreducer_functors(RF&& reduce_func, IF&& init_func, MF&& merge_func)
645 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func))
649 reduce_functor_type get_reduce()
const
651 return std::get<0>(upcast());
654 init_functor_type get_init()
const
656 return std::get<1>(upcast());
659 merge_functor_type get_merge()
const
661 return std::get<2>(upcast());
665 using rebind_t = xreducer_functors<REDUCE_FUNC, const_value<NT>, MERGE_FUNC>;
668 rebind_t<NT> rebind()
670 return make_xreducer_functor(get_reduce(), get_init().
template rebind<NT>(), get_merge());
676 const base_type& upcast()
const
678 return static_cast<const base_type&
>(*this);
683 auto make_xreducer_functor(RF&& reduce_func)
686 return reducer_type(std::forward<RF>(reduce_func));
689 template <
class RF,
class IF>
690 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
692 using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
693 return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
696 template <
class RF,
class IF,
class MF>
697 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
700 std::remove_reference_t<RF>,
701 std::remove_reference_t<IF>,
702 std::remove_reference_t<MF>>;
704 std::forward<RF>(reduce_func),
705 std::forward<IF>(init_func),
706 std::forward<MF>(merge_func)
716 template <
class Tag,
class F,
class CT,
class X,
class O>
719 template <
class F,
class CT,
class X,
class O>
725 template <
class F,
class CT,
class X,
class O>
730 template <
class F,
class CT,
class X,
class O>
738 template <
class F,
class CT,
class X,
class O>
741 template <
class F,
class CT,
class X,
class O>
744 template <
class F,
class CT,
class X,
class O>
747 using xexpression_type = std::decay_t<CT>;
749 typename xexpression_type::shape_type,
751 typename O::keep_dims>::type;
753 using stepper = const_stepper;
756 template <
class F,
class CT,
class X,
class O>
759 using xexpression_type = std::decay_t<CT>;
760 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
761 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
762 using merge_functor_type =
typename std::decay_t<F>::merge_functor_type;
763 using substepper_type =
typename xexpression_type::const_stepper;
764 using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
765 std::declval<init_functor_type>()(),
766 *std::declval<substepper_type>()
768 using value_type =
typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
770 using reference = value_type;
771 using const_reference = value_type;
772 using size_type =
typename xexpression_type::size_type;
781 template <std::size_t... I>
784 using type = std::array<std::size_t,
sizeof...(I)>;
804 template <
class F,
class CT,
class X,
class O>
805 class xreducer :
public xsharable_expression<xreducer<F, CT, X, O>>,
807 public xaccessible<xreducer<F, CT, X, O>>,
808 public extension::xreducer_base_t<F, CT, X, O>
815 using reduce_functor_type =
typename inner_types::reduce_functor_type;
816 using init_functor_type =
typename inner_types::init_functor_type;
817 using merge_functor_type =
typename inner_types::merge_functor_type;
820 using xexpression_type =
typename inner_types::xexpression_type;
823 using extension_base = extension::xreducer_base_t<F, CT, X, O>;
824 using expression_tag =
typename extension_base::expression_tag;
826 using substepper_type =
typename inner_types::substepper_type;
827 using value_type =
typename inner_types::value_type;
828 using reference =
typename inner_types::reference;
829 using const_reference =
typename inner_types::const_reference;
830 using pointer = value_type*;
831 using const_pointer =
const value_type*;
833 using size_type =
typename inner_types::size_type;
834 using difference_type =
typename xexpression_type::difference_type;
837 using inner_shape_type =
typename iterable_base::inner_shape_type;
838 using shape_type = inner_shape_type;
840 using dim_mapping_type =
typename select_dim_mapping_type<inner_shape_type>::type;
842 using stepper =
typename iterable_base::stepper;
843 using const_stepper =
typename iterable_base::const_stepper;
844 using bool_load_type =
typename xexpression_type::bool_load_type;
847 static constexpr bool contiguous_layout =
false;
849 template <
class Func,
class CTA,
class AX,
class OX>
850 xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options);
852 const inner_shape_type&
shape() const noexcept;
854 bool is_contiguous() const noexcept;
856 template <class... Args>
857 const_reference operator()(Args... args) const;
858 template <class... Args>
859 const_reference unchecked(Args... args) const;
862 const_reference element(It first, It last) const;
873 const_stepper stepper_begin(const S&
shape) const noexcept;
877 template <class E, class Func = F, class Opts = O>
878 using rebind_t =
xreducer<Func, E, X, Opts>;
881 rebind_t<E> build_reducer(E&& e) const;
883 template <class E, class Func, class Opts>
884 rebind_t<E, Func, Opts> build_reducer(E&& e, Func&& func, Opts&& opts) const;
886 xreducer_functors_type functors()
const
888 return xreducer_functors_type(m_reduce, m_init, m_merge);
893 const O& options()
const
901 reduce_functor_type m_reduce;
902 init_functor_type m_init;
903 merge_functor_type m_merge;
905 inner_shape_type m_shape;
906 dim_mapping_type m_dim_mapping;
909 friend class xreducer_stepper<F, CT, X, O>;
918 template <
class F,
class E,
class X,
class O>
919 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::lazy_type, O&& options)
921 decltype(
auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
923 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
924 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
925 using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
926 std::declval<init_functor_type>()(),
927 *std::declval<
typename std::decay_t<E>::const_stepper>()
929 using evaluated_value_type = evaluated_value_type_t<value_type, is_xexpression<value_type>::value>;
931 using reducer_type = xreducer<
934 xtl::const_closure_type_t<
decltype(normalized_axes)>,
935 reducer_options<evaluated_value_type, std::decay_t<O>>>;
939 std::forward<
decltype(normalized_axes)>(normalized_axes),
940 std::forward<O>(options)
944 template <
class F,
class E,
class X,
class O>
945 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
947 decltype(
auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
948 return reduce_immediate(
950 eval(std::forward<E>(e)),
951 std::forward<
decltype(normalized_axes)>(normalized_axes),
952 std::forward<O>(options)
957#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
962 struct is_xreducer_functors_impl : std::false_type
966 template <
class RF,
class IF,
class MF>
967 struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>> : std::true_type
972 using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
992 class EVS = DEFAULT_STRATEGY_REDUCERS,
994 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
996 return detail::reduce_impl(
999 std::forward<X>(axes),
1000 typename reducer_options<int, EVS>::evaluation_strategy{},
1001 std::forward<EVS>(options)
1009 class EVS = DEFAULT_STRATEGY_REDUCERS,
1010 XTL_REQUIRES(std::negation<is_reducer_options<X>>, std::negation<detail::is_xreducer_functors<F>>)>
1011 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1014 make_xreducer_functor(std::forward<F>(f)),
1016 std::forward<X>(axes),
1017 std::forward<EVS>(options)
1024 class EVS = DEFAULT_STRATEGY_REDUCERS,
1026 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1028 xindex_type_t<typename std::decay_t<E>::shape_type> ar;
1029 resize_container(ar, e.dimension());
1030 std::iota(ar.begin(), ar.end(), 0);
1031 return detail::reduce_impl(
1035 typename reducer_options<
int, std::decay_t<EVS>>::evaluation_strategy{},
1036 std::forward<EVS>(options)
1043 class EVS = DEFAULT_STRATEGY_REDUCERS,
1045 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1047 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
1055 class EVS = DEFAULT_STRATEGY_REDUCERS,
1056 XTL_REQUIRES(detail::is_xreducer_functors<F>)>
1057 inline auto reduce(F&& f, E&& e,
const I (&axes)[N], EVS options = EVS())
1059 using axes_type = std::array<std::size_t, N>;
1060 auto ax = xt::forward_normalize<axes_type>(e, axes);
1061 return detail::reduce_impl(
1065 typename reducer_options<int, EVS>::evaluation_strategy{},
1075 class EVS = DEFAULT_STRATEGY_REDUCERS,
1076 XTL_REQUIRES(std::negation<detail::is_xreducer_functors<F>>)>
1077 inline auto reduce(F&& f, E&& e,
const I (&axes)[N], EVS options = EVS())
1079 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
1086 template <
class F,
class CT,
class X,
class O>
1087 class xreducer_stepper
1091 using self_type = xreducer_stepper<F, CT, X, O>;
1094 using value_type =
typename xreducer_type::value_type;
1095 using reference =
typename xreducer_type::value_type;
1096 using pointer =
typename xreducer_type::const_pointer;
1097 using size_type =
typename xreducer_type::size_type;
1098 using difference_type =
typename xreducer_type::difference_type;
1100 using xexpression_type =
typename xreducer_type::xexpression_type;
1101 using substepper_type =
typename xexpression_type::const_stepper;
1102 using shape_type =
typename xreducer_type::shape_type;
1105 const xreducer_type& red,
1108 layout_type l = default_assignable_layout(xexpression_type::static_layout)
1111 reference operator*()
const;
1113 void step(size_type dim);
1114 void step_back(size_type dim);
1115 void step(size_type dim, size_type n);
1116 void step_back(size_type dim, size_type n);
1117 void reset(size_type dim);
1118 void reset_back(size_type dim);
1125 reference initial_value()
const;
1126 reference aggregate(size_type dim)
const;
1127 reference aggregate_impl(size_type dim, std::false_type)
const;
1128 reference aggregate_impl(size_type dim, std::true_type)
const;
1130 substepper_type get_substepper_begin()
const;
1131 size_type get_dim(size_type dim)
const noexcept;
1132 size_type shape(size_type i)
const noexcept;
1133 size_type axis(size_type i)
const noexcept;
1135 const xreducer_type* m_reducer;
1137 mutable substepper_type m_stepper;
1146 template <std::size_t X, std::size_t... I>
1149 static constexpr bool value = std::disjunction<std::integral_constant<bool, X == I>...>::value;
1152 template <std::
size_t Z,
class S1,
class S2,
class R>
1153 struct fixed_xreducer_shape_type_impl;
1155 template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
1156 struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1158 using type = std::conditional_t<
1160 typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>::type,
1161 typename fixed_xreducer_shape_type_impl<
1165 fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
1168 template <std::size_t... I, std::size_t... J, std::size_t... R>
1169 struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1172 conditional_t<in<0, J...>::value, fixed_shape<R...>, fixed_shape<detail::at<0, I...>::value, R...>>;
1180 struct xreducer_size_type
1182 using type = std::size_t;
1186 using xreducer_size_type_t =
typename xreducer_size_type<T>::type;
1189 struct xreducer_temporary_type
1195 using xreducer_temporary_type_t =
typename xreducer_temporary_type<T>::type;
1201 template <
class T,
class U>
1202 struct const_value_rebinder
1204 static const_value<U> run(
const const_value<T>& t)
1206 return const_value<U>(t.m_value);
1219 return detail::const_value_rebinder<T, NT>::run(*
this);
1226 template <
class S1,
class S2>
1229 template <std::size_t... I, std::size_t... J>
1232 using type =
typename detail::
1237 template <
class ST,
class X,
class O>
1240 using type = promote_shape_t<ST, std::decay_t<X>>;
1243 template <
class I1, std::
size_t N1,
class I2, std::
size_t N2>
1246 using type = std::array<I2, N1>;
1249 template <
class I1, std::
size_t N1,
class I2, std::
size_t N2>
1252 using type = std::array<I2, N1 - N2>;
1255 template <std::size_t... I,
class I2, std::size_t N2>
1258 using type = std::conditional_t<
sizeof...(I) == N2,
fixed_shape<>, std::array<I2,
sizeof...(I) - N2>>;
1263 template <
class S1,
class S2>
1266 template <
class T, T... I1, T... I2>
1267 struct ixconcat<std::integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
1269 using type = std::integer_sequence<T, I1..., I2...>;
1272 template <
class T, T X, std::
size_t N>
1273 struct repeat_integer_sequence
1275 using type =
typename ixconcat<
1276 std::integer_sequence<T, X>,
1277 typename repeat_integer_sequence<T, X, N - 1>::type>::type;
1280 template <
class T, T X>
1281 struct repeat_integer_sequence<T, X, 0>
1283 using type = std::integer_sequence<T>;
1286 template <
class T, T X>
1287 struct repeat_integer_sequence<T, X, 2>
1289 using type = std::integer_sequence<T, X, X>;
1292 template <
class T, T X>
1293 struct repeat_integer_sequence<T, X, 1>
1295 using type = std::integer_sequence<T, X>;
1299 template <std::size_t... I,
class I2, std::size_t N2>
1302 template <std::size_t... X>
1303 static constexpr auto get_type(std::index_sequence<X...>)
1309 using type = std::conditional_t<
1311 decltype(get_type(
typename detail::repeat_integer_sequence<std::size_t, std::size_t(1), N2>::type{})),
1312 std::array<I2,
sizeof...(I)>>;
1316 template <std::size_t... I, std::size_t... J,
class O>
1324 template <
class S,
class E,
class X,
class M>
1325 inline void shape_and_mapping_computation(S& shape, E& e,
const X& axes, M& mapping, std::false_type)
1327 auto first = e.shape().begin();
1328 auto last = e.shape().end();
1329 auto exclude_it = axes.begin();
1331 using value_type =
typename S::value_type;
1332 using difference_type =
typename S::difference_type;
1333 auto d_first = shape.begin();
1334 auto map_first = mapping.begin();
1337 while (iter != last && exclude_it != axes.end())
1339 auto diff = std::distance(first, iter);
1340 if (
diff != difference_type(*exclude_it))
1342 *d_first++ = *iter++;
1343 *map_first++ = value_type(
diff);
1352 auto diff = std::distance(first, iter);
1353 auto end = std::distance(iter, last);
1354 std::iota(map_first, map_first + end,
diff);
1355 std::copy(iter, last, d_first);
1358 template <
class S,
class E,
class X,
class M>
1360 shape_and_mapping_computation_keep_dim(S& shape, E& e,
const X& axes, M& mapping, std::false_type)
1362 for (std::size_t i = 0; i < e.dimension(); ++i)
1364 if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
1367 shape[i] = e.shape()[i];
1374 std::iota(mapping.begin(), mapping.end(), 0);
1377 template <
class S,
class E,
class X,
class M>
1378 inline void shape_and_mapping_computation(S&, E&,
const X&, M&, std::true_type)
1382 template <
class S,
class E,
class X,
class M>
1383 inline void shape_and_mapping_computation_keep_dim(S&, E&,
const X&, M&, std::true_type)
1404 template <
class F,
class CT,
class X,
class O>
1405 template <
class Func,
class CTA,
class AX,
class OX>
1407 : m_e(std::forward<CTA>(e))
1408 , m_reduce(
xt::get<0>(func))
1409 , m_init(
xt::get<1>(func))
1410 , m_merge(
xt::get<2>(func))
1411 , m_axes(std::forward<AX>(axes))
1412 , m_shape(xtl::make_sequence<inner_shape_type>(
1416 , m_dim_mapping(xtl::make_sequence<dim_mapping_type>(
1420 , m_options(std::forward<OX>(options))
1426 if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less<>()))
1428 XTENSOR_THROW(std::runtime_error,
"Reducing axes should be sorted.");
1430 if (std::adjacent_find(m_axes.cbegin(), m_axes.cend()) != m_axes.cend())
1432 XTENSOR_THROW(std::runtime_error,
"Reducing axes should not contain duplicates.");
1434 if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
1438 "Axis " + std::to_string(m_axes[m_axes.size() - 1]) +
" out of bounds for reduction."
1442 if (!
typename O::keep_dims())
1444 detail::shape_and_mapping_computation(
1449 detail::is_fixed<shape_type>{}
1454 detail::shape_and_mapping_computation_keep_dim(
1459 detail::is_fixed<shape_type>{}
1473 template <
class F,
class CT,
class X,
class O>
1482 template <
class F,
class CT,
class X,
class O>
1485 return static_layout;
1488 template <
class F,
class CT,
class X,
class O>
1489 inline bool xreducer<F, CT, X, O>::is_contiguous() const noexcept
1506 template <
class F,
class CT,
class X,
class O>
1507 template <
class... Args>
1508 inline auto xreducer<F, CT, X, O>::operator()(Args... args)
const -> const_reference
1510 XTENSOR_TRY(check_index(
shape(), args...));
1511 XTENSOR_CHECK_DIMENSION(
shape(), args...);
1512 std::array<std::size_t,
sizeof...(Args)> arg_array = {{
static_cast<std::size_t
>(args)...}};
1513 return element(arg_array.cbegin(), arg_array.cend());
1535 template <
class F,
class CT,
class X,
class O>
1536 template <
class... Args>
1537 inline auto xreducer<F, CT, X, O>::unchecked(Args... args)
const -> const_reference
1539 std::array<std::size_t,
sizeof...(Args)> arg_array = {{
static_cast<std::size_t
>(args)...}};
1540 return element(arg_array.cbegin(), arg_array.cend());
1550 template <
class F,
class CT,
class X,
class O>
1552 inline auto xreducer<F, CT, X, O>::element(It first, It last)
const -> const_reference
1554 XTENSOR_TRY(check_element_index(
shape(), first, last));
1555 auto stepper = const_stepper(*
this, 0);
1560 auto size = std::ptrdiff_t(this->
dimension()) - std::distance(first, last);
1561 auto begin = first -
size;
1562 while (begin != last)
1566 stepper.step(dim++, std::size_t(0));
1571 stepper.step(dim++, std::size_t(*begin++));
1581 template <
class F,
class CT,
class X,
class O>
1599 template <
class F,
class CT,
class X,
class O>
1603 return xt::broadcast_shape(m_shape,
shape);
1611 template <
class F,
class CT,
class X,
class O>
1620 template <
class F,
class CT,
class X,
class O>
1622 inline auto xreducer<F, CT, X, O>::stepper_begin(
const S& shape)
const noexcept -> const_stepper
1624 size_type offset = shape.size() - this->dimension();
1625 return const_stepper(*
this, offset);
1628 template <
class F,
class CT,
class X,
class O>
1630 inline auto xreducer<F, CT, X, O>::stepper_end(
const S& shape,
layout_type l)
const noexcept
1633 size_type offset = shape.size() - this->dimension();
1634 return const_stepper(*
this, offset,
true, l);
1637 template <
class F,
class CT,
class X,
class O>
1639 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e)
const -> rebind_t<E>
1642 std::make_tuple(m_reduce, m_init, m_merge),
1649 template <
class F,
class CT,
class X,
class O>
1650 template <
class E,
class Func,
class Opts>
1651 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts)
const
1652 -> rebind_t<E, Func, Opts>
1654 return rebind_t<E, Func, Opts>(
1655 std::forward<Func>(func),
1658 std::forward<Opts>(opts)
1666 template <
class F,
class CT,
class X,
class O>
1667 inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(
1668 const xreducer_type& red,
1675 , m_stepper(get_substepper_begin())
1683 template <
class F,
class CT,
class X,
class O>
1684 inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
1686 reference r = aggregate(0);
1690 template <
class F,
class CT,
class X,
class O>
1691 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
1693 if (dim >= m_offset)
1695 m_stepper.step(get_dim(dim - m_offset));
1699 template <
class F,
class CT,
class X,
class O>
1700 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
1702 if (dim >= m_offset)
1704 m_stepper.step_back(get_dim(dim - m_offset));
1708 template <
class F,
class CT,
class X,
class O>
1709 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
1711 if (dim >= m_offset)
1713 m_stepper.step(get_dim(dim - m_offset), n);
1717 template <
class F,
class CT,
class X,
class O>
1718 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
1720 if (dim >= m_offset)
1722 m_stepper.step_back(get_dim(dim - m_offset), n);
1726 template <
class F,
class CT,
class X,
class O>
1727 inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
1729 if (dim >= m_offset)
1733 if (
typename O::keep_dims()
1734 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1739 m_stepper.reset(get_dim(dim - m_offset));
1743 template <
class F,
class CT,
class X,
class O>
1744 inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
1746 if (dim >= m_offset)
1749 if (
typename O::keep_dims()
1750 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1755 m_stepper.reset_back(get_dim(dim - m_offset));
1759 template <
class F,
class CT,
class X,
class O>
1760 inline void xreducer_stepper<F, CT, X, O>::to_begin()
1762 m_stepper.to_begin();
1765 template <
class F,
class CT,
class X,
class O>
1766 inline void xreducer_stepper<F, CT, X, O>::to_end(
layout_type l)
1768 m_stepper.to_end(l);
1771 template <
class F,
class CT,
class X,
class O>
1772 inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
1774 return O::has_initial_value ? m_reducer->m_options.initial_value
1775 :
static_cast<reference
>(m_reducer->m_init());
1778 template <
class F,
class CT,
class X,
class O>
1779 inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim)
const -> reference
1782 if (m_reducer->m_e.size() == size_type(0))
1784 res = initial_value();
1786 else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
1788 res = m_reducer->m_reduce(initial_value(), *m_stepper);
1792 res = aggregate_impl(dim,
typename O::keep_dims());
1793 if (O::has_initial_value && dim == 0)
1795 res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1801 template <
class F,
class CT,
class X,
class O>
1802 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type)
const -> reference
1806 size_type index = axis(dim);
1807 size_type size = shape(index);
1808 if (dim != m_reducer->m_axes.size() - 1)
1810 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1811 for (size_type i = 1; i != size; ++i)
1813 m_stepper.step(index);
1814 res = m_reducer->m_merge(res, aggregate_impl(dim + 1,
typename O::keep_dims()));
1819 res = m_reducer->m_reduce(
static_cast<reference
>(m_reducer->m_init()), *m_stepper);
1820 for (size_type i = 1; i != size; ++i)
1822 m_stepper.step(index);
1823 res = m_reducer->m_reduce(res, *m_stepper);
1826 m_stepper.reset(index);
1830 template <
class F,
class CT,
class X,
class O>
1831 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type)
const -> reference
1835 auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1836 if (ax_it != m_reducer->m_axes.end())
1838 size_type index = dim;
1839 size_type size = m_reducer->m_e.shape()[index];
1840 if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
1842 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1843 for (size_type i = 1; i != size; ++i)
1845 m_stepper.step(index);
1846 res = m_reducer->m_merge(res, aggregate_impl(dim + 1,
typename O::keep_dims()));
1851 res = m_reducer->m_reduce(
static_cast<reference
>(m_reducer->m_init()), *m_stepper);
1852 for (size_type i = 1; i != size; ++i)
1854 m_stepper.step(index);
1855 res = m_reducer->m_reduce(res, *m_stepper);
1858 m_stepper.reset(index);
1862 if (dim < m_reducer->m_e.dimension())
1864 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1870 template <
class F,
class CT,
class X,
class O>
1871 inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
1873 return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
1876 template <
class F,
class CT,
class X,
class O>
1877 inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim)
const noexcept -> size_type
1879 return m_reducer->m_dim_mapping[dim];
1882 template <
class F,
class CT,
class X,
class O>
1883 inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i)
const noexcept -> size_type
1885 return m_reducer->m_e.shape()[i];
1888 template <
class F,
class CT,
class X,
class O>
1889 inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i)
const noexcept -> size_type
1891 return m_reducer->m_axes[i];
Fixed shape implementation for compile time defined arrays.
size_type size() const noexcept
size_type dimension() const noexcept
Base class for multidimensional iterable constant expressions.
auto end() const noexcept -> const_layout_iterator< L >
Reducing function operating over specified axes.
bool broadcast_shape(S &shape, bool reuse_cache=false) const
const xexpression_type & expression() const noexcept
bool has_linear_assign(const S &strides) const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
xreducer(Func &&func, CTA &&e, AX &&axes, OX &&options)
Constructs an xreducer expression applying the specified function to the given expression over the gi...
layout_type layout() const noexcept
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto diff(const xexpression< T > &a, std::size_t n=1, std::ptrdiff_t axis=-1)
Calculate the n-th discrete difference along the given axis.
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto reduce(F &&f, E &&e, X &&axes, EVS &&options=EVS())
Returns an xexpression applying the specified reducing function to an expression over the given axes.