1#ifndef XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
2#define XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
11#include "xbuilder.hpp"
12#include "xchunked_array.hpp"
13#include "xchunked_assign.hpp"
14#include "xchunked_view.hpp"
15#include "xexpression.hpp"
18#include "xreducer.hpp"
19#include "xtl/xclosure.hpp"
20#include "xtl/xsequence.hpp"
30 struct empty_reduction_variable
34 struct simple_functor_base
37 auto reduction_variable(
const E&)
const
39 return empty_reduction_variable();
42 template <
class MR,
class E,
class R>
43 void finalize(
const MR&, E&,
const R&)
const
48 template <
class T_E,
class T_I =
void>
49 struct sum_functor :
public simple_functor_base
51 using value_type =
typename std::decay_t<decltype(xt::sum<T_I>(std::declval<xarray<T_E>>()))>::value_type;
53 template <
class E,
class A,
class O>
54 auto compute(
const E& input,
const A& axes,
const O& options)
const
59 template <
class BR,
class E,
class MR>
60 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
64 xt::noalias(result) = block_result;
68 xt::noalias(result) += block_result;
73 template <
class T_E,
class T_I =
void>
74 struct prod_functor :
public simple_functor_base
76 using value_type =
typename std::decay_t<decltype(xt::sum<T_I>(std::declval<xarray<T_E>>()))>::value_type;
78 template <
class E,
class A,
class O>
79 auto compute(
const E& input,
const A& axes,
const O& options)
const
84 template <
class BR,
class E,
class MR>
85 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
89 xt::noalias(result) = block_result;
93 xt::noalias(result) *= block_result;
98 template <
class T_E,
class T_I =
void>
99 struct amin_functor :
public simple_functor_base
101 using value_type =
typename std::decay_t<decltype(xt::amin<T_I>(std::declval<xarray<T_E>>()))>::value_type;
103 template <
class E,
class A,
class O>
104 auto compute(
const E& input,
const A& axes,
const O& options)
const
106 return xt::amin(input, axes, options);
109 template <
class BR,
class E,
class MR>
110 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
114 xt::noalias(result) = block_result;
118 xt::noalias(result) =
xt::minimum(block_result, result);
123 template <
class T_E,
class T_I =
void>
124 struct amax_functor :
public simple_functor_base
126 using value_type =
typename std::decay_t<decltype(xt::amax<T_I>(std::declval<xarray<T_E>>()))>::value_type;
128 template <
class E,
class A,
class O>
129 auto compute(
const E& input,
const A& axes,
const O& options)
const
131 return xt::amax(input, axes, options);
134 template <
class BR,
class E,
class MR>
135 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
139 xt::noalias(result) = block_result;
143 xt::noalias(result) =
xt::maximum(block_result, result);
148 template <
class T_E,
class T_I =
void>
151 using value_type =
typename std::decay_t<decltype(xt::mean<T_I>(std::declval<xarray<T_E>>()))>::value_type;
153 template <
class E,
class A,
class O>
154 auto compute(
const E& input,
const A& axes,
const O& options)
const
160 auto reduction_variable(
const E&)
const
162 return empty_reduction_variable();
165 template <
class BR,
class E>
166 auto merge(
const BR& block_result,
bool first, E& result, empty_reduction_variable&)
const
170 xt::noalias(result) = block_result;
174 xt::noalias(result) += block_result;
178 template <
class E,
class R>
179 void finalize(
const empty_reduction_variable&, E& results,
const R& reducer)
const
181 const auto& axes = reducer.axes();
182 std::decay_t<
decltype(reducer.input_shape()[0])> factor = 1;
185 factor *= reducer.input_shape()[a];
187 xt::noalias(results) /=
static_cast<typename E::value_type
>(factor);
191 template <
class T_E,
class T_I =
void>
192 struct variance_functor
194 using value_type =
typename std::decay_t<decltype(xt::variance<T_I>(std::declval<xarray<T_E>>())
197 template <
class E,
class A,
class O>
198 auto compute(
const E& input,
const A& axes,
const O& options)
const
203 weight *=
static_cast<double>(input.shape()[a]);
207 return std::make_tuple(
215 auto reduction_variable(
const E&)
const
217 return std::make_tuple(xarray<value_type>(), 0.0);
220 template <
class BR,
class E,
class MR>
221 auto merge(
const BR& block_result,
bool first, E& variance_a, MR& mr)
const
223 auto& mean_a = std::get<0>(mr);
224 auto& n_a = std::get<1>(mr);
226 const auto& variance_b = std::get<0>(block_result);
227 const auto& mean_b = std::get<1>(block_result);
228 const auto& n_b = std::get<2>(block_result);
231 xt::noalias(variance_a) = variance_b;
232 xt::noalias(mean_a) = mean_b;
237 auto new_mean = (n_a * mean_a + n_b * mean_b) / (n_a + n_b);
238 auto new_variance = (n_a * variance_a + n_b * variance_b
239 + n_a *
xt::pow(mean_a - new_mean, 2)
240 + n_b *
xt::pow(mean_b - new_mean, 2))
242 xt::noalias(variance_a) = new_variance;
243 xt::noalias(mean_a) = new_mean;
248 template <
class MR,
class E,
class R>
249 void finalize(
const MR&, E&,
const R&)
const
254 template <
class T_E,
class T_I =
void>
255 struct stddev_functor :
public variance_functor<T_E, T_I>
257 template <
class MR,
class E,
class R>
258 void finalize(
const MR&, E& results,
const R&)
const
260 xt::noalias(results) =
xt::sqrt(results);
265 struct norm_l0_functor :
public simple_functor_base
267 using value_type =
typename std::decay_t<
decltype(
xt::norm_l0(std::declval<xarray<T_E>>()))>::value_type;
269 template <
class E,
class A,
class O>
270 auto compute(
const E& input,
const A& axes,
const O& options)
const
275 template <
class BR,
class E,
class MR>
276 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
280 xt::noalias(result) = block_result;
284 xt::noalias(result) += block_result;
290 struct norm_l1_functor :
public simple_functor_base
292 using value_type =
typename std::decay_t<
decltype(
xt::norm_l1(std::declval<xarray<T_E>>()))>::value_type;
294 template <
class E,
class A,
class O>
295 auto compute(
const E& input,
const A& axes,
const O& options)
const
300 template <
class BR,
class E,
class MR>
301 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
305 xt::noalias(result) = block_result;
309 xt::noalias(result) += block_result;
315 struct norm_l2_functor
317 using value_type =
typename std::decay_t<
decltype(
xt::norm_l2(std::declval<xarray<T_E>>()))>::value_type;
319 template <
class E,
class A,
class O>
320 auto compute(
const E& input,
const A& axes,
const O& options)
const
326 auto reduction_variable(
const E&)
const
328 return empty_reduction_variable();
331 template <
class BR,
class E>
332 auto merge(
const BR& block_result,
bool first, E& result, empty_reduction_variable&)
const
336 xt::noalias(result) = block_result;
340 xt::noalias(result) += block_result;
344 template <
class E,
class R>
345 void finalize(
const empty_reduction_variable&, E& results,
const R&)
const
347 xt::noalias(results) =
xt::sqrt(results);
352 struct norm_sq_functor :
public simple_functor_base
354 using value_type =
typename std::decay_t<
decltype(
xt::norm_sq(std::declval<xarray<T_E>>()))>::value_type;
356 template <
class E,
class A,
class O>
357 auto compute(
const E& input,
const A& axes,
const O& options)
const
362 template <
class BR,
class E,
class MR>
363 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
367 xt::noalias(result) = block_result;
371 xt::noalias(result) += block_result;
377 struct norm_linf_functor :
public simple_functor_base
379 using value_type =
typename std::decay_t<
decltype(
xt::norm_linf(std::declval<xarray<T_E>>()))>::value_type;
381 template <
class E,
class A,
class O>
382 auto compute(
const E& input,
const A& axes,
const O& options)
const
387 template <
class BR,
class E,
class MR>
388 auto merge(
const BR& block_result,
bool first, E& result, MR&)
const
392 xt::noalias(result) = block_result;
396 xt::noalias(result) =
xt::maximum(block_result, result);
402 class norm_lp_to_p_functor
406 using value_type =
typename std::decay_t<
409 norm_lp_to_p_functor(
double p)
414 template <
class E,
class A,
class O>
415 auto compute(
const E& input,
const A& axes,
const O& options)
const
421 auto reduction_variable(
const E&)
const
423 return empty_reduction_variable();
426 template <
class BR,
class E>
427 auto merge(
const BR& block_result,
bool first, E& result, empty_reduction_variable&)
const
431 xt::noalias(result) = block_result;
435 xt::noalias(result) += block_result;
439 template <
class E,
class R>
440 void finalize(
const empty_reduction_variable&, E&,
const R&)
const
450 class norm_lp_functor
454 norm_lp_functor(
double p)
459 using value_type =
typename std::decay_t<
decltype(
xt::norm_lp(std::declval<xarray<T_E>>(), 1.0)
462 template <
class E,
class A,
class O>
463 auto compute(
const E& input,
const A& axes,
const O& options)
const
469 auto reduction_variable(
const E&)
const
471 return empty_reduction_variable();
474 template <
class BR,
class E>
475 auto merge(
const BR& block_result,
bool first, E& result, empty_reduction_variable&)
const
479 xt::noalias(result) = block_result;
483 xt::noalias(result) += block_result;
487 template <
class E,
class R>
488 void finalize(
const empty_reduction_variable&, E& results,
const R&)
const
490 results =
xt::pow(results, 1.0 / m_p);
auto amax(E &&e, X &&axes, EVS es=EVS())
Maximum element along given axis.
auto abs(E &&e) noexcept -> detail::xfunction_type_t< math::abs_fun, E >
Absolute value function.
auto minimum(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::minimum< void >, E1, E2 >
Elementwise minimum.
auto maximum(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::maximum< void >, E1, E2 >
Elementwise maximum.
auto amin(E &&e, X &&axes, EVS es=EVS())
Minimum element along given axis.
auto not_equal(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::not_equal_to, E1, E2 >
Element-wise inequality.
auto sqrt(E &&e) noexcept -> detail::xfunction_type_t< math::sqrt_fun, E >
Square root function.
auto square(E1 &&e1) noexcept
Square power function, equivalent to e1 * e1.
auto pow(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::pow_fun, E1, E2 >
Power function.
auto norm_sq(E &&e, X &&axes, EVS es) noexcept
Squared L2 norm of an array-like argument over given axes.
auto norm_lp(E &&e, double p, X &&axes, EVS es=EVS())
Lp norm of an array-like argument over given axes.
auto norm_l2(E &&e, EVS es=EVS()) noexcept
L2 norm of a scalar or array-like argument.
auto norm_l1(E &&e, X &&axes, EVS es) noexcept
L1 norm of an array-like argument over given axes.
auto norm_lp_to_p(E &&e, double p, X &&axes, EVS es=EVS()) noexcept
p-th power of the Lp norm of an array-like argument over given axes.
auto norm_l0(E &&e, X &&axes, EVS es) noexcept
L0 (count) pseudo-norm of an array-like argument over given axes.
auto norm_linf(E &&e, X &&axes, EVS es) noexcept
Infinity (maximum) norm of an array-like argument over given axes.
standard mathematical functions for xexpressions