10#ifndef XTENSOR_REDUCER_HPP
11#define XTENSOR_REDUCER_HPP
15#include <initializer_list>
22#include <xtl/xfunctional.hpp>
23#include <xtl/xsequence.hpp>
25#include "xaccessible.hpp"
26#include "xbuilder.hpp"
28#include "xexpression.hpp"
29#include "xgenerator.hpp"
30#include "xiterable.hpp"
31#include "xtensor_config.hpp"
36 template <
template <
class...>
class A,
class... AX,
class X, XTL_REQUIRES(is_evaluation_strategy<AX>..., is_evaluation_strategy<X>)>
37 auto operator|(
const A<AX...>& args,
const A<X>& rhs)
39 return std::tuple_cat(args, rhs);
46 constexpr auto keep_dims = std::tuple<keep_dims_type>{};
48 template <
class T =
double>
56 constexpr T value()
const
65 constexpr auto initial(T
val)
70 template <std::ptrdiff_t I,
class T,
class Tuple>
73 template <std::ptrdiff_t I,
class T>
76 static constexpr std::ptrdiff_t value = -1;
79 template <std::ptrdiff_t I,
class T,
class...
Types>
82 static constexpr std::ptrdiff_t value = I;
85 template <std::ptrdiff_t I,
class T,
class U,
class...
Types>
91 template <
class S,
class...
X>
94 template <
template <
class...>
class S,
class...
X>
100 template <
class T,
class Tuple>
103 static constexpr std::ptrdiff_t
107 template <
class R,
class T>
126 using d_t = std::decay_t<T>;
128 static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
133 xtl::mpl::static_if<initial_val_idx != std::tuple_size<T>::value>(
138 std::get < initial_val_idx != std::tuple_size<T>::value
148 using evaluation_strategy = std::conditional_t<
149 tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
153 using keep_dims = std::
154 conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1, std::true_type, std::false_type>;
156 static constexpr bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
161 using rebind_t = reducer_options<NR, T>;
164 auto rebind(NR initial,
const reducer_options<R, T>&)
const
166 reducer_options<NR, T> res;
167 res.initial_value = initial;
177 template <
class...
X>
191#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
193 template <
class ST,
class X,
class KD = std::false_type>
196 template <
class S1,
class S2>
201 template <
class O,
class RS,
class R,
class E,
class AX>
202 inline void shape_computation(
207 std::enable_if_t<!detail::is_fixed<RS>::value,
int> = 0
210 if (
typename O::keep_dims())
213 for (std::size_t
i = 0;
i <
expr.dimension(); ++
i)
215 if (std::find(axes.begin(), axes.end(),
i) == axes.end())
228 resize_container(result_shape, expr.dimension() - axes.size());
229 for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
231 if (std::find(axes.begin(), axes.end(), i) == axes.end())
234 result_shape[idx] = expr.shape()[i];
239 result.resize(result_shape, expr.layout());
243 template <
class O,
class RS,
class R,
class S,
class AX>
245 shape_computation(RS&, R&,
const S&,
const AX&, std::enable_if_t<detail::is_fixed<RS>::value,
int> = 0)
250 template <
class F,
class E,
class R, XTL_REQUIRES(std::is_convertible<
typename E::value_type,
typename R::value_type>)>
251 inline void copy_to_reduced(F&,
const E& e, R& result)
256 e.template cbegin<layout_type::row_major>(),
257 e.template cend<layout_type::row_major>(),
264 e.template cbegin<layout_type::column_major>(),
265 e.template cend<layout_type::column_major>(),
275 XTL_REQUIRES(xtl::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
276 inline void copy_to_reduced(F& f,
const E& e, R& result)
281 e.template cbegin<layout_type::row_major>(),
282 e.template cend<layout_type::row_major>(),
290 e.template cbegin<layout_type::column_major>(),
291 e.template cend<layout_type::column_major>(),
298 template <
class F,
class E,
class X,
class O>
299 inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
301 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
302 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
303 using expr_value_type =
typename std::decay_t<E>::value_type;
304 using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
305 std::declval<init_functor_type>()(),
306 std::declval<expr_value_type>()
309 using options_t = reducer_options<result_type, std::decay_t<O>>;
310 options_t options(raw_options);
312 using shape_type =
typename xreducer_shape_type<
313 typename std::decay_t<E>::shape_type,
315 typename options_t::keep_dims>::type;
316 using result_container_type =
typename detail::xtype_for_shape<
317 shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
318 result_container_type result;
325 if (axes.size() == 0)
327 result.resize(e.shape(), e.layout());
328 auto cpf = [&reduce_fct, &init_fct](
const auto& v)
330 return reduce_fct(
static_cast<result_type
>(init_fct()), v);
332 copy_to_reduced(cpf, e, result);
336 shape_type result_shape{};
337 dynamic_shape<std::size_t>
338 iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>,
decltype(e.shape())>(e.shape());
339 dynamic_shape<std::size_t> iter_strides(e.dimension());
345 if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less<>()))
347 XTENSOR_THROW(std::runtime_error,
"Reducing axes should be sorted.");
349 if (std::adjacent_find(axes.cbegin(), axes.cend()) != axes.cend())
351 XTENSOR_THROW(std::runtime_error,
"Reducing axes should not contain duplicates.");
353 if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
357 "Axis " + std::to_string(axes[axes.size() - 1]) +
" out of bounds for reduction."
361 detail::shape_computation<options_t>(result_shape, result, e, axes);
364 if (e.dimension() == axes.size())
366 result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
367 result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
372 auto strides_finder = e.strides().begin() +
static_cast<std::ptrdiff_t
>(leading_ax);
374 std::size_t inner_stride =
static_cast<std::size_t
>(*strides_finder);
376 while (inner_stride == 0 && strides_finder != iter_bound)
379 inner_stride =
static_cast<std::size_t
>(*strides_finder);
382 if (inner_stride == 0)
384 auto cpf = [&reduce_fct, &init_fct](
const auto& v)
386 return reduce_fct(
static_cast<result_type
>(init_fct()), v);
388 copy_to_reduced(cpf, e, result);
392 std::size_t inner_loop_size =
static_cast<std::size_t
>(inner_stride);
393 std::size_t outer_loop_size = e.shape()[leading_ax];
397 auto merge_loops = [&outer_loop_size, &e](
auto it,
auto end)
401 for (; it != end; ++it)
404 if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
407 outer_loop_size *= e.shape()[last_ax];
413 for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
415 if (std::find(axes.begin(), axes.end(), i) == axes.end())
418 iter_strides[i] =
static_cast<std::size_t
>(result.strides(
419 )[
typename options_t::keep_dims() ? i : idx]);
426 std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
428 iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
429 iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
434 std::size_t last_ax = merge_loops(axes.begin(), axes.end());
437 iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
438 iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
441 std::reverse(iter_shape.begin(), iter_shape.end());
442 std::reverse(iter_strides.begin(), iter_strides.end());
446 XTENSOR_THROW(std::runtime_error,
"Layout not supported in immediate reduction.");
449 xindex temp_idx(iter_shape.size());
450 auto next_idx = [&iter_shape, &iter_strides, &temp_idx]()
452 std::size_t i = iter_shape.size();
455 if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
466 return std::make_pair(
468 std::inner_product(temp_idx.begin(), temp_idx.end(), iter_strides.begin(), std::ptrdiff_t(0))
472 auto begin = e.data();
473 auto out = result.data();
474 auto out_begin = result.data();
476 std::ptrdiff_t next_stride = 0;
478 std::pair<bool, std::ptrdiff_t> idx_res(
false, 0);
483 auto merge_border = out;
491 if (inner_stride == 1)
493 while (idx_res.first !=
true)
497 result_type tmp = init_fct();
498 tmp = std::accumulate(begin, begin + outer_loop_size, tmp, reduce_fct);
501 *out = merge ? merge_fct(*out, tmp) : tmp;
503 begin += outer_loop_size;
505 idx_res = next_idx();
506 next_stride = idx_res.second;
507 out = out_begin + next_stride;
509 if (out > merge_border)
523 while (idx_res.first !=
true)
527 out + inner_loop_size,
530 [merge, &init_fct, &reduce_fct](
auto&& v1,
auto&& v2)
532 return merge ? reduce_fct(v1, v2) :
534 reduce_fct(static_cast<result_type>(init_fct()), v2);
538 begin += inner_stride;
539 for (std::size_t i = 1; i < outer_loop_size; ++i)
541 std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
542 begin += inner_stride;
545 idx_res = next_idx();
546 next_stride = idx_res.second;
547 out = out_begin + next_stride;
549 if (out > merge_border)
561 if (options_t::has_initial_value)
565 result.data() + result.size(),
567 [&merge_fct, &options](
auto&& v)
569 return merge_fct(v, options.initial_value);
583 using value_type = T;
592 constexpr T operator()()
const
608 template <
class T,
bool B>
609 struct evaluated_value_type
615 struct evaluated_value_type<T, true>
617 using type =
typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
620 template <
class T,
bool B>
621 using evaluated_value_type_t =
typename evaluated_value_type<T, B>::type;
624 template <
class REDUCE_FUNC,
class INIT_FUNC = const_value<
long int>,
class MERGE_FUNC = REDUCE_FUNC>
628 using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
632 using init_value_type =
typename init_functor_type::value_type;
645 template <
class RF,
class IF>
651 template <
class RF,
class IF,
class MF>
659 return std::get<0>(upcast());
664 return std::get<1>(upcast());
669 return std::get<2>(upcast());
678 return make_xreducer_functor(get_reduce(), get_init().
template rebind<NT>(), get_merge());
684 const base_type& upcast()
const
686 return static_cast<const base_type&
>(*this);
697 template <
class RF,
class IF>
698 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
700 using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
701 return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
704 template <
class RF,
class IF,
class MF>
705 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
707 using reducer_type = xreducer_functors<
708 std::remove_reference_t<RF>,
709 std::remove_reference_t<IF>,
710 std::remove_reference_t<MF>>;
712 std::forward<RF>(reduce_func),
713 std::forward<IF>(init_func),
714 std::forward<MF>(merge_func)
724 template <
class Tag,
class F,
class CT,
class X,
class O>
727 template <
class F,
class CT,
class X,
class O>
733 template <
class F,
class CT,
class X,
class O>
738 template <
class F,
class CT,
class X,
class O>
746 template <
class F,
class CT,
class X,
class O>
749 template <
class F,
class CT,
class X,
class O>
752 template <
class F,
class CT,
class X,
class O>
755 using xexpression_type = std::decay_t<CT>;
757 typename xexpression_type::shape_type,
759 typename O::keep_dims>::type;
764 template <
class F,
class CT,
class X,
class O>
767 using xexpression_type = std::decay_t<CT>;
768 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
769 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
770 using merge_functor_type =
typename std::decay_t<F>::merge_functor_type;
771 using substepper_type =
typename xexpression_type::const_stepper;
772 using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
773 std::declval<init_functor_type>()(),
774 *std::declval<substepper_type>()
776 using value_type =
typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
778 using reference = value_type;
779 using const_reference = value_type;
780 using size_type =
typename xexpression_type::size_type;
789 template <std::size_t... I>
792 using type = std::array<std::size_t,
sizeof...(I)>;
812 template <
class F,
class CT,
class X,
class O>
816 public extension::xreducer_base_t<F, CT, X, O>
823 using reduce_functor_type =
typename inner_types::reduce_functor_type;
825 using merge_functor_type =
typename inner_types::merge_functor_type;
828 using xexpression_type =
typename inner_types::xexpression_type;
831 using extension_base = extension::xreducer_base_t<F, CT, X, O>;
832 using expression_tag =
typename extension_base::expression_tag;
834 using substepper_type =
typename inner_types::substepper_type;
835 using value_type =
typename inner_types::value_type;
836 using reference =
typename inner_types::reference;
837 using const_reference =
typename inner_types::const_reference;
838 using pointer = value_type*;
839 using const_pointer =
const value_type*;
841 using size_type =
typename inner_types::size_type;
842 using difference_type =
typename xexpression_type::difference_type;
845 using inner_shape_type =
typename iterable_base::inner_shape_type;
846 using shape_type = inner_shape_type;
848 using dim_mapping_type =
typename select_dim_mapping_type<inner_shape_type>::type;
850 using stepper =
typename iterable_base::stepper;
851 using const_stepper =
typename iterable_base::const_stepper;
852 using bool_load_type =
typename xexpression_type::bool_load_type;
855 static constexpr bool contiguous_layout =
false;
857 template <
class Func,
class CTA,
class AX,
class OX>
860 const inner_shape_type&
shape()
const noexcept;
862 bool is_contiguous()
const noexcept;
864 template <
class...
Args>
865 const_reference operator()(
Args...
args)
const;
866 template <
class...
Args>
867 const_reference unchecked(
Args...
args)
const;
872 const xexpression_type&
expression()
const noexcept;
881 const_stepper stepper_begin(
const S&
shape)
const noexcept;
885 template <
class E,
class Func = F,
class Opts = O>
891 template <
class E,
class Func,
class Opts>
901 const O& options()
const
909 reduce_functor_type m_reduce;
910 init_functor_type m_init;
911 merge_functor_type m_merge;
913 inner_shape_type m_shape;
914 dim_mapping_type m_dim_mapping;
926 template <
class F,
class E,
class X,
class O>
931 using reduce_functor_type =
typename std::decay_t<F>::reduce_functor_type;
932 using init_functor_type =
typename std::decay_t<F>::init_functor_type;
933 using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
934 std::declval<init_functor_type>()(),
935 *std::declval<
typename std::decay_t<E>::const_stepper>()
948 std::forward<O>(options)
952 template <
class F,
class E,
class X,
class O>
953 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
955 decltype(
auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
956 return reduce_immediate(
958 eval(std::forward<E>(e)),
959 std::forward<
decltype(normalized_axes)>(normalized_axes),
960 std::forward<O>(options)
965#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
970 struct is_xreducer_functors_impl : std::false_type
974 template <
class RF,
class IF,
class MF>
975 struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>> : std::true_type
980 using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
1000 class EVS = DEFAULT_STRATEGY_REDUCERS,
1001 XTL_REQUIRES(xtl::negation<is_reducer_options<X>>, detail::is_xreducer_functors<F>)>
1004 return detail::reduce_impl(
1007 std::forward<X>(axes),
1008 typename reducer_options<int, EVS>::evaluation_strategy{},
1009 std::forward<EVS>(options)
1017 class EVS = DEFAULT_STRATEGY_REDUCERS,
1018 XTL_REQUIRES(xtl::negation<is_reducer_options<X>>, xtl::negation<detail::is_xreducer_functors<F>>)>
1019 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1022 make_xreducer_functor(std::forward<F>(f)),
1024 std::forward<X>(axes),
1025 std::forward<EVS>(options)
1032 class EVS = DEFAULT_STRATEGY_REDUCERS,
1033 XTL_REQUIRES(is_reducer_options<EVS>, detail::is_xreducer_functors<F>)>
1034 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1036 xindex_type_t<typename std::decay_t<E>::shape_type> ar;
1037 resize_container(ar, e.dimension());
1038 std::iota(ar.begin(), ar.end(), 0);
1039 return detail::reduce_impl(
1043 typename reducer_options<
int, std::decay_t<EVS>>::evaluation_strategy{},
1044 std::forward<EVS>(options)
1051 class EVS = DEFAULT_STRATEGY_REDUCERS,
1052 XTL_REQUIRES(is_reducer_options<EVS>, xtl::negation<detail::is_xreducer_functors<F>>)>
1053 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1055 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
1063 class EVS = DEFAULT_STRATEGY_REDUCERS,
1064 XTL_REQUIRES(detail::is_xreducer_functors<F>)>
1065 inline auto reduce(F&& f, E&& e,
const I (&axes)[N], EVS options = EVS())
1067 using axes_type = std::array<std::size_t, N>;
1069 return detail::reduce_impl(
1073 typename reducer_options<int, EVS>::evaluation_strategy{},
1083 class EVS = DEFAULT_STRATEGY_REDUCERS,
1084 XTL_REQUIRES(xtl::negation<detail::is_xreducer_functors<F>>)>
1085 inline auto reduce(F&& f, E&& e,
const I (&axes)[N], EVS options = EVS())
1087 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
1094 template <
class F,
class CT,
class X,
class O>
1102 using value_type =
typename xreducer_type::value_type;
1103 using reference =
typename xreducer_type::value_type;
1104 using pointer =
typename xreducer_type::const_pointer;
1105 using size_type =
typename xreducer_type::size_type;
1106 using difference_type =
typename xreducer_type::difference_type;
1108 using xexpression_type =
typename xreducer_type::xexpression_type;
1109 using substepper_type =
typename xexpression_type::const_stepper;
1110 using shape_type =
typename xreducer_type::shape_type;
1116 layout_type l = default_assignable_layout(xexpression_type::static_layout)
1119 reference operator*()
const;
1121 void step(size_type
dim);
1122 void step_back(size_type
dim);
1123 void step(size_type
dim, size_type
n);
1124 void step_back(size_type
dim, size_type
n);
1125 void reset(size_type
dim);
1126 void reset_back(size_type
dim);
1133 reference initial_value()
const;
1134 reference aggregate(size_type
dim)
const;
1135 reference aggregate_impl(size_type
dim, std::false_type)
const;
1136 reference aggregate_impl(size_type
dim, std::true_type)
const;
1138 substepper_type get_substepper_begin()
const;
1139 size_type get_dim(size_type
dim)
const noexcept;
1140 size_type shape(size_type
i)
const noexcept;
1141 size_type axis(size_type
i)
const noexcept;
1145 mutable substepper_type m_stepper;
1154 template <std::size_t
X, std::size_t... I>
1157 static constexpr bool value = xtl::disjunction<std::integral_constant<bool, X == I>...>::value;
1160 template <std::
size_t Z,
class S1,
class S2,
class R>
1161 struct fixed_xreducer_shape_type_impl;
1163 template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
1164 struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1166 using type = std::conditional_t<
1168 typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>::type,
1169 typename fixed_xreducer_shape_type_impl<
1173 fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
1176 template <std::size_t... I, std::size_t... J, std::size_t... R>
1177 struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1180 conditional_t<in<0, J...>::value, fixed_shape<R...>, fixed_shape<detail::at<0, I...>::value, R...>>;
1188 struct xreducer_size_type
1190 using type = std::size_t;
1194 using xreducer_size_type_t =
typename xreducer_size_type<T>::type;
1197 struct xreducer_temporary_type
1203 using xreducer_temporary_type_t =
typename xreducer_temporary_type<T>::type;
1209 template <
class T,
class U>
1210 struct const_value_rebinder
1212 static const_value<U> run(
const const_value<T>& t)
1214 return const_value<U>(t.m_value);
1225 const_value<NT> const_value<T>::rebind()
const
1227 return detail::const_value_rebinder<T, NT>::run(*
this);
1234 template <
class S1,
class S2>
1235 struct fixed_xreducer_shape_type;
1237 template <std::size_t... I, std::size_t... J>
1240 using type =
typename detail::
1245 template <
class ST,
class X,
class O>
1251 template <
class I1, std::
size_t N1,
class I2, std::
size_t N2>
1254 using type = std::array<I2, N1>;
1257 template <
class I1, std::
size_t N1,
class I2, std::
size_t N2>
1260 using type = std::array<
I2,
N1 -
N2>;
1263 template <std::size_t... I,
class I2, std::size_t
N2>
1266 using type = std::conditional_t<
sizeof...(I) ==
N2,
fixed_shape<>, std::array<
I2,
sizeof...(I) -
N2>>;
1271 template <
class S1,
class S2>
1274 template <
class T, T...
I1, T...
I2>
1275 struct ixconcat<std::
integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
1277 using type = std::integer_sequence<T,
I1...,
I2...>;
1280 template <
class T, T X, std::
size_t N>
1281 struct repeat_integer_sequence
1283 using type =
typename ixconcat<
1284 std::integer_sequence<T, X>,
1285 typename repeat_integer_sequence<T, X, N - 1>::type>::type;
1288 template <
class T, T X>
1289 struct repeat_integer_sequence<T, X, 0>
1291 using type = std::integer_sequence<T>;
1294 template <
class T, T X>
1295 struct repeat_integer_sequence<T, X, 2>
1297 using type = std::integer_sequence<T, X, X>;
1300 template <
class T, T X>
1301 struct repeat_integer_sequence<T, X, 1>
1303 using type = std::integer_sequence<T, X>;
1307 template <std::size_t... I,
class I2, std::size_t N2>
1310 template <std::size_t...
X>
1311 static constexpr auto get_type(std::index_sequence<X...>)
1317 using type = std::conditional_t<
1319 decltype(get_type(
typename detail::repeat_integer_sequence<std::size_t, std::size_t(1),
N2>::type{})),
1320 std::array<
I2,
sizeof...(I)>>;
1324 template <std::size_t... I, std::size_t...
J,
class O>
1332 template <
class S,
class E,
class X,
class M>
1333 inline void shape_and_mapping_computation(
S& shape, E&
e,
const X& axes,
M&
mapping, std::false_type)
1335 auto first =
e.shape().begin();
1336 auto last =
e.shape().end();
1339 using value_type =
typename S::value_type;
1340 using difference_type =
typename S::difference_type;
1360 auto diff = std::distance(first, iter);
1361 auto end = std::distance(iter, last);
1362 std::iota(map_first, map_first + end,
diff);
1363 std::copy(iter, last, d_first);
1366 template <
class S,
class E,
class X,
class M>
1368 shape_and_mapping_computation_keep_dim(S& shape, E& e,
const X& axes, M& mapping, std::false_type)
1370 for (std::size_t i = 0; i < e.dimension(); ++i)
1372 if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
1375 shape[i] = e.shape()[i];
1382 std::iota(mapping.begin(), mapping.end(), 0);
1385 template <
class S,
class E,
class X,
class M>
1386 inline void shape_and_mapping_computation(S&, E&,
const X&, M&, std::true_type)
1390 template <
class S,
class E,
class X,
class M>
1391 inline void shape_and_mapping_computation_keep_dim(S&, E&,
const X&, M&, std::true_type)
1412 template <
class F,
class CT,
class X,
class O>
1413 template <
class Func,
class CTA,
class AX,
class OX>
1416 , m_reduce(
xt::get<0>(
func))
1417 , m_init(
xt::get<1>(
func))
1418 , m_merge(
xt::get<2>(
func))
1421 typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
1425 typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
1434 if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less<>()))
1436 XTENSOR_THROW(std::runtime_error,
"Reducing axes should be sorted.");
1438 if (std::adjacent_find(m_axes.cbegin(), m_axes.cend()) != m_axes.cend())
1440 XTENSOR_THROW(std::runtime_error,
"Reducing axes should not contain duplicates.");
1442 if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
1446 "Axis " + std::to_string(m_axes[m_axes.size() - 1]) +
" out of bounds for reduction."
1450 if (!
typename O::keep_dims())
1452 detail::shape_and_mapping_computation(
1457 detail::is_fixed<shape_type>{}
1462 detail::shape_and_mapping_computation_keep_dim(
1467 detail::is_fixed<shape_type>{}
1481 template <
class F,
class CT,
class X,
class O>
1490 template <
class F,
class CT,
class X,
class O>
1493 return static_layout;
1496 template <
class F,
class CT,
class X,
class O>
1514 template <
class F,
class CT,
class X,
class O>
1515 template <
class... Args>
1518 XTENSOR_TRY(check_index(shape(),
args...));
1519 XTENSOR_CHECK_DIMENSION(shape(),
args...);
1520 std::array<std::size_t,
sizeof...(Args)>
arg_array = {{
static_cast<std::size_t
>(
args)...}};
1543 template <
class F,
class CT,
class X,
class O>
1544 template <
class...
Args>
1547 std::array<std::size_t,
sizeof...(Args)>
arg_array = {{
static_cast<std::size_t
>(
args)...}};
1558 template <
class F,
class CT,
class X,
class O>
1562 XTENSOR_TRY(check_element_index(shape(),
first,
last));
1563 auto stepper = const_stepper(*
this, 0);
1568 auto size = std::ptrdiff_t(this->dimension()) - std::distance(
first,
last);
1569 auto begin =
first - size;
1570 while (begin !=
last)
1574 stepper.step(
dim++, std::size_t(0));
1579 stepper.step(
dim++, std::size_t(*begin++));
1589 template <
class F,
class CT,
class X,
class O>
1607 template <
class F,
class CT,
class X,
class O>
1611 return xt::broadcast_shape(m_shape, shape);
1619 template <
class F,
class CT,
class X,
class O>
1628 template <
class F,
class CT,
class X,
class O>
1632 size_type
offset = shape.size() - this->dimension();
1633 return const_stepper(*
this,
offset);
1636 template <
class F,
class CT,
class X,
class O>
1638 inline auto xreducer<F, CT, X, O>::stepper_end(
const S& shape,
layout_type l)
const noexcept
1641 size_type offset = shape.size() - this->dimension();
1642 return const_stepper(*
this, offset,
true, l);
1645 template <
class F,
class CT,
class X,
class O>
1647 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e)
const -> rebind_t<E>
1650 std::make_tuple(m_reduce, m_init, m_merge),
1657 template <
class F,
class CT,
class X,
class O>
1658 template <
class E,
class Func,
class Opts>
1659 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts)
const
1660 -> rebind_t<E, Func, Opts>
1662 return rebind_t<E, Func, Opts>(
1663 std::forward<Func>(func),
1666 std::forward<Opts>(opts)
1674 template <
class F,
class CT,
class X,
class O>
1675 inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(
1676 const xreducer_type& red,
1683 , m_stepper(get_substepper_begin())
1691 template <
class F,
class CT,
class X,
class O>
1692 inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
1694 reference r = aggregate(0);
1698 template <
class F,
class CT,
class X,
class O>
1699 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
1701 if (dim >= m_offset)
1703 m_stepper.step(get_dim(dim - m_offset));
1707 template <
class F,
class CT,
class X,
class O>
1708 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
1710 if (dim >= m_offset)
1712 m_stepper.step_back(get_dim(dim - m_offset));
1716 template <
class F,
class CT,
class X,
class O>
1717 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
1719 if (dim >= m_offset)
1721 m_stepper.step(get_dim(dim - m_offset), n);
1725 template <
class F,
class CT,
class X,
class O>
1726 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
1728 if (dim >= m_offset)
1730 m_stepper.step_back(get_dim(dim - m_offset), n);
1734 template <
class F,
class CT,
class X,
class O>
1735 inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
1737 if (dim >= m_offset)
1741 if (
typename O::keep_dims()
1742 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1747 m_stepper.reset(get_dim(dim - m_offset));
1751 template <
class F,
class CT,
class X,
class O>
1752 inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
1754 if (dim >= m_offset)
1757 if (
typename O::keep_dims()
1758 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1763 m_stepper.reset_back(get_dim(dim - m_offset));
1767 template <
class F,
class CT,
class X,
class O>
1768 inline void xreducer_stepper<F, CT, X, O>::to_begin()
1770 m_stepper.to_begin();
1773 template <
class F,
class CT,
class X,
class O>
1774 inline void xreducer_stepper<F, CT, X, O>::to_end(
layout_type l)
1776 m_stepper.to_end(l);
1779 template <
class F,
class CT,
class X,
class O>
1780 inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
1782 return O::has_initial_value ? m_reducer->m_options.initial_value
1783 :
static_cast<reference
>(m_reducer->m_init());
1786 template <
class F,
class CT,
class X,
class O>
1787 inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim)
const -> reference
1790 if (m_reducer->m_e.size() == size_type(0))
1792 res = initial_value();
1794 else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
1796 res = m_reducer->m_reduce(initial_value(), *m_stepper);
1800 res = aggregate_impl(dim,
typename O::keep_dims());
1801 if (O::has_initial_value && dim == 0)
1803 res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1809 template <
class F,
class CT,
class X,
class O>
1810 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type)
const -> reference
1814 size_type index = axis(dim);
1815 size_type size = shape(index);
1816 if (dim != m_reducer->m_axes.size() - 1)
1818 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1819 for (size_type i = 1; i != size; ++i)
1821 m_stepper.step(index);
1822 res = m_reducer->m_merge(res, aggregate_impl(dim + 1,
typename O::keep_dims()));
1827 res = m_reducer->m_reduce(
static_cast<reference
>(m_reducer->m_init()), *m_stepper);
1828 for (size_type i = 1; i != size; ++i)
1830 m_stepper.step(index);
1831 res = m_reducer->m_reduce(res, *m_stepper);
1834 m_stepper.reset(index);
1838 template <
class F,
class CT,
class X,
class O>
1839 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type)
const -> reference
1843 auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1844 if (ax_it != m_reducer->m_axes.end())
1846 size_type index = dim;
1847 size_type size = m_reducer->m_e.shape()[index];
1848 if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
1850 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1851 for (size_type i = 1; i != size; ++i)
1853 m_stepper.step(index);
1854 res = m_reducer->m_merge(res, aggregate_impl(dim + 1,
typename O::keep_dims()));
1859 res = m_reducer->m_reduce(
static_cast<reference
>(m_reducer->m_init()), *m_stepper);
1860 for (size_type i = 1; i != size; ++i)
1862 m_stepper.step(index);
1863 res = m_reducer->m_reduce(res, *m_stepper);
1866 m_stepper.reset(index);
1870 if (dim < m_reducer->m_e.dimension())
1872 res = aggregate_impl(dim + 1,
typename O::keep_dims());
1878 template <
class F,
class CT,
class X,
class O>
1879 inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
1881 return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
1884 template <
class F,
class CT,
class X,
class O>
1885 inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim)
const noexcept -> size_type
1887 return m_reducer->m_dim_mapping[dim];
1890 template <
class F,
class CT,
class X,
class O>
1891 inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i)
const noexcept -> size_type
1893 return m_reducer->m_e.shape()[i];
1896 template <
class F,
class CT,
class X,
class O>
1897 inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i)
const noexcept -> size_type
1899 return m_reducer->m_axes[i];
Fixed shape implementation for compile time defined arrays.
Base class for implementation of common expression access methods.
Base class for multidimensional iterable constant expressions.
Reducing function operating over specified axes.
bool broadcast_shape(S &shape, bool reuse_cache=false) const
Broadcast the shape of the reducer to the specified parameter.
const xexpression_type & expression() const noexcept
Returns a constant reference to the underlying expression of the reducer.
bool has_linear_assign(const S &strides) const noexcept
Checks whether the xreducer can be linearly assigned to an expression with the specified strides.
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
xreducer(Func &&func, CTA &&e, AX &&axes, OX &&options)
Constructs an xreducer expression applying the specified function to the given expression over the gi...
layout_type layout() const noexcept
Returns the shape of the expression.
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto diff(const xexpression< T > &a, std::size_t n=1, std::ptrdiff_t axis=-1)
Calculate the n-th discrete difference along the given axis.
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto reduce(F &&f, E &&e, X &&axes, EVS &&options=EVS())
Returns an xexpression applying the specified reducing function to an expression over the given axes.