1#ifndef XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
2#define XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
7#include "../chunk/xchunked_array.hpp"
8#include "../chunk/xchunked_assign.hpp"
9#include "../chunk/xchunked_view.hpp"
10#include "../containers/xarray.hpp"
11#include "../core/xexpression.hpp"
12#include "../core/xmath.hpp"
13#include "../generators/xbuilder.hpp"
14#include "../reducers/xnorm.hpp"
15#include "../reducers/xreducer.hpp"
16#include "../utils/xutils.hpp"
17#include "xtl/xclosure.hpp"
18#include "xtl/xsequence.hpp"
27 struct empty_reduction_variable
31 struct simple_functor_base
34 auto reduction_variable(
const E&)
const
36 return empty_reduction_variable();
39 template <
class MR,
class E,
class R>
40 void finalize(
const MR&, E&,
const R&)
const
45 template <
class T_E,
class T_I =
void>
46 struct sum_functor :
public simple_functor_base
48 using value_type =
typename std::decay_t<decltype(xt::sum<T_I>(std::declval<
xarray<T_E>>()))>::value_type;
50 template <
class E,
class A,
class O>
51 auto compute(
const E& input,
const A& axes,
const O& options)
const
56 template <
class BR,
class E,
class MR>
57 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
61 xt::noalias(result) = block_result;
65 xt::noalias(result) += block_result;
70 template <
class T_E,
class T_I =
void>
71 struct prod_functor :
public simple_functor_base
73 using value_type =
typename std::decay_t<decltype(xt::sum<T_I>(std::declval<
xarray<T_E>>()))>::value_type;
75 template <
class E,
class A,
class O>
76 auto compute(
const E& input,
const A& axes,
const O& options)
const
81 template <
class BR,
class E,
class MR>
82 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
86 xt::noalias(result) = block_result;
90 xt::noalias(result) *= block_result;
95 template <
class T_E,
class T_I =
void>
96 struct amin_functor :
public simple_functor_base
98 using value_type =
typename std::decay_t<decltype(xt::amin<T_I>(std::declval<
xarray<T_E>>()))>::value_type;
100 template <
class E,
class A,
class O>
101 auto compute(
const E& input,
const A& axes,
const O& options)
const
103 return xt::amin(input, axes, options);
106 template <
class BR,
class E,
class MR>
107 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
111 xt::noalias(result) = block_result;
115 xt::noalias(result) =
xt::minimum(block_result, result);
120 template <
class T_E,
class T_I =
void>
121 struct amax_functor :
public simple_functor_base
123 using value_type =
typename std::decay_t<decltype(xt::amax<T_I>(std::declval<
xarray<T_E>>()))>::value_type;
125 template <
class E,
class A,
class O>
126 auto compute(
const E& input,
const A& axes,
const O& options)
const
128 return xt::amax(input, axes, options);
131 template <
class BR,
class E,
class MR>
132 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
136 xt::noalias(result) = block_result;
140 xt::noalias(result) =
xt::maximum(block_result, result);
145 template <
class T_E,
class T_I =
void>
148 using value_type =
typename std::decay_t<decltype(xt::mean<T_I>(std::declval<
xarray<T_E>>()))>::value_type;
150 template <
class E,
class A,
class O>
151 auto compute(
const E& input,
const A& axes,
const O& options)
const
157 auto reduction_variable(
const E&)
const
159 return empty_reduction_variable();
162 template <
class BR,
class E>
163 auto merge(
const BR& block_result,
bool first, E& result, empty_reduction_variable&)
const
167 xt::noalias(result) = block_result;
171 xt::noalias(result) += block_result;
175 template <
class E,
class R>
176 void finalize(
const empty_reduction_variable&, E& results,
const R& reducer)
const
178 const auto& axes = reducer.axes();
179 std::decay_t<
decltype(reducer.input_shape()[0])> factor = 1;
182 factor *= reducer.input_shape()[a];
184 xt::noalias(results) /=
static_cast<typename E::value_type
>(factor);
188 template <
class T_E,
class T_I =
void>
189 struct variance_functor
191 using value_type =
typename std::decay_t<decltype(xt::variance<T_I>(std::declval<
xarray<T_E>>())
194 template <
class E,
class A,
class O>
195 auto compute(
const E& input,
const A& axes,
const O& options)
const
200 weight *=
static_cast<double>(input.shape()[a]);
204 return std::make_tuple(
205 xt::variance<value_type>(input, axes, options),
212 auto reduction_variable(
const E&)
const
217 template <
class BR,
class E,
class MR>
218 auto merge(
const BR& block_result,
bool first, E& variance_a, MR& mr)
const
220 auto& mean_a = std::get<0>(mr);
221 auto& n_a = std::get<1>(mr);
223 const auto& variance_b = std::get<0>(block_result);
224 const auto& mean_b = std::get<1>(block_result);
225 const auto& n_b = std::get<2>(block_result);
228 xt::noalias(variance_a) = variance_b;
229 xt::noalias(mean_a) = mean_b;
234 auto new_mean = (n_a * mean_a + n_b * mean_b) / (n_a + n_b);
235 auto new_variance = (n_a * variance_a + n_b * variance_b
236 + n_a *
xt::pow(mean_a - new_mean, 2)
237 + n_b *
xt::pow(mean_b - new_mean, 2))
239 xt::noalias(variance_a) = new_variance;
240 xt::noalias(mean_a) = new_mean;
245 template <
class MR,
class E,
class R>
246 void finalize(
const MR&, E&,
const R&)
const
251 template <
class T_E,
class T_I =
void>
252 struct stddev_functor :
public variance_functor<T_E, T_I>
254 template <
class MR,
class E,
class R>
255 void finalize(
const MR&, E& results,
const R&)
const
257 xt::noalias(results) =
xt::sqrt(results);
262 struct norm_l0_functor :
public simple_functor_base
266 template <
class E,
class A,
class O>
267 auto compute(
const E& input,
const A& axes,
const O& options)
const
272 template <
class BR,
class E,
class MR>
273 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
277 xt::noalias(result) = block_result;
281 xt::noalias(result) += block_result;
287 struct norm_l1_functor :
public simple_functor_base
291 template <
class E,
class A,
class O>
292 auto compute(
const E& input,
const A& axes,
const O& options)
const
297 template <
class BR,
class E,
class MR>
298 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
302 xt::noalias(result) = block_result;
306 xt::noalias(result) += block_result;
312 struct norm_l2_functor
316 template <
class E,
class A,
class O>
317 auto compute(
const E& input,
const A& axes,
const O& options)
const
323 auto reduction_variable(
const E&)
const
325 return empty_reduction_variable();
328 template <
class BR,
class E>
329 auto merge(
const BR& block_result,
bool first, E& result, empty_reduction_variable&)
const
333 xt::noalias(result) = block_result;
337 xt::noalias(result) += block_result;
341 template <
class E,
class R>
342 void finalize(
const empty_reduction_variable&, E& results,
const R&)
const
344 xt::noalias(results) =
xt::sqrt(results);
349 struct norm_sq_functor :
public simple_functor_base
353 template <
class E,
class A,
class O>
354 auto compute(
const E& input,
const A& axes,
const O& options)
const
359 template <
class BR,
class E,
class MR>
360 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
364 xt::noalias(result) = block_result;
368 xt::noalias(result) += block_result;
374 struct norm_linf_functor :
public simple_functor_base
378 template <
class E,
class A,
class O>
379 auto compute(
const E& input,
const A& axes,
const O& options)
const
384 template <
class BR,
class E,
class MR>
385 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
389 xt::noalias(result) = block_result;
393 xt::noalias(result) =
xt::maximum(block_result, result);
399 class norm_lp_to_p_functor
403 using value_type =
typename std::decay_t<
406 norm_lp_to_p_functor(
double p)
411 template <
class E,
class A,
class O>
412 auto compute(
const E& input,
const A& axes,
const O& options)
const
418 auto reduction_variable(
const E&)
const
420 return empty_reduction_variable();
423 template <
class BR,
class E>
424 auto merge(
const BR& block_result,
bool first, E& result, empty_reduction_variable&)
const
428 xt::noalias(result) = block_result;
432 xt::noalias(result) += block_result;
436 template <
class E,
class R>
437 void finalize(
const empty_reduction_variable&, E&,
const R&)
const
447 class norm_lp_functor
451 norm_lp_functor(
double p)
459 template <
class E,
class A,
class O>
460 auto compute(
const E& input,
const A& axes,
const O& options)
const
466 auto reduction_variable(
const E&)
const
468 return empty_reduction_variable();
471 template <
class BR,
class E>
472 auto merge(
const BR& block_result,
bool first, E& result, empty_reduction_variable&)
const
476 xt::noalias(result) = block_result;
480 xt::noalias(result) += block_result;
484 template <
class E,
class R>
485 void finalize(
const empty_reduction_variable&, E& results,
const R&)
const
487 results =
xt::pow(results, 1.0 / m_p);
auto amax(E &&e, X &&axes, EVS es=EVS())
Maximum element along given axis.
auto abs(E &&e) noexcept -> detail::xfunction_type_t< math::abs_fun, E >
Absolute value function.
auto minimum(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::minimum< void >, E1, E2 >
Elementwise minimum.
auto maximum(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::maximum< void >, E1, E2 >
Elementwise maximum.
auto amin(E &&e, X &&axes, EVS es=EVS())
Minimum element along given axis.
auto not_equal(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::not_equal_to, E1, E2 >
Element-wise inequality.
auto sqrt(E &&e) noexcept -> detail::xfunction_type_t< math::sqrt_fun, E >
Square root function.
auto square(E1 &&e1) noexcept
Square power function, equivalent to e1 * e1.
auto pow(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::pow_fun, E1, E2 >
Power function.
auto sum(E &&e, X &&axes, EVS es=EVS())
Sum of elements over given axes.
auto norm_sq(E &&e, X &&axes, EVS es) noexcept
Squared L2 norm of an array-like argument over given axes.
auto norm_lp(E &&e, double p, X &&axes, EVS es=EVS())
Lp norm of an array-like argument over given axes.
auto norm_l2(E &&e, EVS es=EVS()) noexcept
L2 norm of a scalar or array-like argument.
auto norm_l1(E &&e, X &&axes, EVS es) noexcept
L1 norm of an array-like argument over given axes.
auto prod(E &&e, X &&axes, EVS es=EVS())
Product of elements over given axes.
auto norm_lp_to_p(E &&e, double p, X &&axes, EVS es=EVS()) noexcept
p-th power of the Lp norm of an array-like argument over given axes.
auto mean(E &&e, X &&axes, EVS es=EVS())
Mean of elements over given axes.
auto norm_l0(E &&e, X &&axes, EVS es) noexcept
L0 (count) pseudo-norm of an array-like argument over given axes.
auto norm_linf(E &&e, X &&axes, EVS es) noexcept
Infinity (maximum) norm of an array-like argument over given axes.
standard mathematical functions for xexpressions
xarray_container< uvector< T, A >, L, xt::svector< typename uvector< T, A >::size_type, 4, SA, true > > xarray
Alias template on xarray_container with default parameters for data container type and shape / stride...
auto zeros(S shape) noexcept
Returns an xexpression containing zeros of the specified shape.