xtensor
 
Loading...
Searching...
No Matches
xblockwise_reducer_functors.hpp
1#ifndef XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
2#define XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
3
4
5#include <tuple>
6
7#include "../chunk/xchunked_array.hpp"
8#include "../chunk/xchunked_assign.hpp"
9#include "../chunk/xchunked_view.hpp"
10#include "../containers/xarray.hpp"
11#include "../core/xexpression.hpp"
12#include "../core/xmath.hpp"
13#include "../generators/xbuilder.hpp"
14#include "../reducers/xnorm.hpp"
15#include "../reducers/xreducer.hpp"
16#include "../utils/xutils.hpp"
17#include "xtl/xclosure.hpp"
18#include "xtl/xsequence.hpp"
19
20namespace xt
21{
22 namespace detail
23 {
24 namespace blockwise
25 {
26
27 struct empty_reduction_variable
28 {
29 };
30
31 struct simple_functor_base
32 {
33 template <class E>
34 auto reduction_variable(const E&) const
35 {
36 return empty_reduction_variable();
37 }
38
39 template <class MR, class E, class R>
40 void finalize(const MR&, E&, const R&) const
41 {
42 }
43 };
44
45 template <class T_E, class T_I = void>
46 struct sum_functor : public simple_functor_base
47 {
48 using value_type = typename std::decay_t<decltype(xt::sum<T_I>(std::declval<xarray<T_E>>()))>::value_type;
49
50 template <class E, class A, class O>
51 auto compute(const E& input, const A& axes, const O& options) const
52 {
53 return xt::sum<value_type>(input, axes, options);
54 }
55
56 template <class BR, class E, class MR>
57 auto merge(const BR& block_result, bool first, E& result, MR&) const
58 {
59 if (first)
60 {
61 xt::noalias(result) = block_result;
62 }
63 else
64 {
65 xt::noalias(result) += block_result;
66 }
67 }
68 };
69
70 template <class T_E, class T_I = void>
71 struct prod_functor : public simple_functor_base
72 {
73 using value_type = typename std::decay_t<decltype(xt::sum<T_I>(std::declval<xarray<T_E>>()))>::value_type;
74
75 template <class E, class A, class O>
76 auto compute(const E& input, const A& axes, const O& options) const
77 {
78 return xt::prod<value_type>(input, axes, options);
79 }
80
81 template <class BR, class E, class MR>
82 auto merge(const BR& block_result, bool first, E& result, MR&) const
83 {
84 if (first)
85 {
86 xt::noalias(result) = block_result;
87 }
88 else
89 {
90 xt::noalias(result) *= block_result;
91 }
92 }
93 };
94
95 template <class T_E, class T_I = void>
96 struct amin_functor : public simple_functor_base
97 {
98 using value_type = typename std::decay_t<decltype(xt::amin<T_I>(std::declval<xarray<T_E>>()))>::value_type;
99
100 template <class E, class A, class O>
101 auto compute(const E& input, const A& axes, const O& options) const
102 {
103 return xt::amin(input, axes, options);
104 }
105
106 template <class BR, class E, class MR>
107 auto merge(const BR& block_result, bool first, E& result, MR&) const
108 {
109 if (first)
110 {
111 xt::noalias(result) = block_result;
112 }
113 else
114 {
115 xt::noalias(result) = xt::minimum(block_result, result);
116 }
117 }
118 };
119
120 template <class T_E, class T_I = void>
121 struct amax_functor : public simple_functor_base
122 {
123 using value_type = typename std::decay_t<decltype(xt::amax<T_I>(std::declval<xarray<T_E>>()))>::value_type;
124
125 template <class E, class A, class O>
126 auto compute(const E& input, const A& axes, const O& options) const
127 {
128 return xt::amax(input, axes, options);
129 }
130
131 template <class BR, class E, class MR>
132 auto merge(const BR& block_result, bool first, E& result, MR&) const
133 {
134 if (first)
135 {
136 xt::noalias(result) = block_result;
137 }
138 else
139 {
140 xt::noalias(result) = xt::maximum(block_result, result);
141 }
142 }
143 };
144
145 template <class T_E, class T_I = void>
146 struct mean_functor
147 {
148 using value_type = typename std::decay_t<decltype(xt::mean<T_I>(std::declval<xarray<T_E>>()))>::value_type;
149
150 template <class E, class A, class O>
151 auto compute(const E& input, const A& axes, const O& options) const
152 {
153 return xt::sum<value_type>(input, axes, options);
154 }
155
156 template <class E>
157 auto reduction_variable(const E&) const
158 {
159 return empty_reduction_variable();
160 }
161
162 template <class BR, class E>
163 auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
164 {
165 if (first)
166 {
167 xt::noalias(result) = block_result;
168 }
169 else
170 {
171 xt::noalias(result) += block_result;
172 }
173 }
174
175 template <class E, class R>
176 void finalize(const empty_reduction_variable&, E& results, const R& reducer) const
177 {
178 const auto& axes = reducer.axes();
179 std::decay_t<decltype(reducer.input_shape()[0])> factor = 1;
180 for (auto a : axes)
181 {
182 factor *= reducer.input_shape()[a];
183 }
184 xt::noalias(results) /= static_cast<typename E::value_type>(factor);
185 }
186 };
187
188 template <class T_E, class T_I = void>
189 struct variance_functor
190 {
191 using value_type = typename std::decay_t<decltype(xt::variance<T_I>(std::declval<xarray<T_E>>())
192 )>::value_type;
193
194 template <class E, class A, class O>
195 auto compute(const E& input, const A& axes, const O& options) const
196 {
197 double weight = 1.0;
198 for (auto a : axes)
199 {
200 weight *= static_cast<double>(input.shape()[a]);
201 }
202
203
204 return std::make_tuple(
205 xt::variance<value_type>(input, axes, options),
206 xt::mean<value_type>(input, axes, options),
207 weight
208 );
209 }
210
211 template <class E>
212 auto reduction_variable(const E&) const
213 {
214 return std::make_tuple(xarray<value_type>(), 0.0);
215 }
216
217 template <class BR, class E, class MR>
218 auto merge(const BR& block_result, bool first, E& variance_a, MR& mr) const
219 {
220 auto& mean_a = std::get<0>(mr);
221 auto& n_a = std::get<1>(mr);
222
223 const auto& variance_b = std::get<0>(block_result);
224 const auto& mean_b = std::get<1>(block_result);
225 const auto& n_b = std::get<2>(block_result);
226 if (first)
227 {
228 xt::noalias(variance_a) = variance_b;
229 xt::noalias(mean_a) = mean_b;
230 n_a += n_b;
231 }
232 else
233 {
234 auto new_mean = (n_a * mean_a + n_b * mean_b) / (n_a + n_b);
235 auto new_variance = (n_a * variance_a + n_b * variance_b
236 + n_a * xt::pow(mean_a - new_mean, 2)
237 + n_b * xt::pow(mean_b - new_mean, 2))
238 / (n_a + n_b);
239 xt::noalias(variance_a) = new_variance;
240 xt::noalias(mean_a) = new_mean;
241 n_a += n_b;
242 }
243 }
244
245 template <class MR, class E, class R>
246 void finalize(const MR&, E&, const R&) const
247 {
248 }
249 };
250
251 template <class T_E, class T_I = void>
252 struct stddev_functor : public variance_functor<T_E, T_I>
253 {
254 template <class MR, class E, class R>
255 void finalize(const MR&, E& results, const R&) const
256 {
257 xt::noalias(results) = xt::sqrt(results);
258 }
259 };
260
261 template <class T_E>
262 struct norm_l0_functor : public simple_functor_base
263 {
264 using value_type = typename std::decay_t<decltype(xt::norm_l0(std::declval<xarray<T_E>>()))>::value_type;
265
266 template <class E, class A, class O>
267 auto compute(const E& input, const A& axes, const O& options) const
268 {
269 return xt::sum<value_type>(xt::not_equal(input, xt::zeros<T_E>(input.shape())), axes, options);
270 }
271
272 template <class BR, class E, class MR>
273 auto merge(const BR& block_result, bool first, E& result, MR&) const
274 {
275 if (first)
276 {
277 xt::noalias(result) = block_result;
278 }
279 else
280 {
281 xt::noalias(result) += block_result;
282 }
283 }
284 };
285
286 template <class T_E>
287 struct norm_l1_functor : public simple_functor_base
288 {
289 using value_type = typename std::decay_t<decltype(xt::norm_l1(std::declval<xarray<T_E>>()))>::value_type;
290
291 template <class E, class A, class O>
292 auto compute(const E& input, const A& axes, const O& options) const
293 {
294 return xt::sum<value_type>(xt::abs(input), axes, options);
295 }
296
297 template <class BR, class E, class MR>
298 auto merge(const BR& block_result, bool first, E& result, MR&) const
299 {
300 if (first)
301 {
302 xt::noalias(result) = block_result;
303 }
304 else
305 {
306 xt::noalias(result) += block_result;
307 }
308 }
309 };
310
311 template <class T_E>
312 struct norm_l2_functor
313 {
314 using value_type = typename std::decay_t<decltype(xt::norm_l2(std::declval<xarray<T_E>>()))>::value_type;
315
316 template <class E, class A, class O>
317 auto compute(const E& input, const A& axes, const O& options) const
318 {
319 return xt::sum<value_type>(xt::square(input), axes, options);
320 }
321
322 template <class E>
323 auto reduction_variable(const E&) const
324 {
325 return empty_reduction_variable();
326 }
327
328 template <class BR, class E>
329 auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
330 {
331 if (first)
332 {
333 xt::noalias(result) = block_result;
334 }
335 else
336 {
337 xt::noalias(result) += block_result;
338 }
339 }
340
341 template <class E, class R>
342 void finalize(const empty_reduction_variable&, E& results, const R&) const
343 {
344 xt::noalias(results) = xt::sqrt(results);
345 }
346 };
347
348 template <class T_E>
349 struct norm_sq_functor : public simple_functor_base
350 {
351 using value_type = typename std::decay_t<decltype(xt::norm_sq(std::declval<xarray<T_E>>()))>::value_type;
352
353 template <class E, class A, class O>
354 auto compute(const E& input, const A& axes, const O& options) const
355 {
356 return xt::sum<value_type>(xt::square(input), axes, options);
357 }
358
359 template <class BR, class E, class MR>
360 auto merge(const BR& block_result, bool first, E& result, MR&) const
361 {
362 if (first)
363 {
364 xt::noalias(result) = block_result;
365 }
366 else
367 {
368 xt::noalias(result) += block_result;
369 }
370 }
371 };
372
373 template <class T_E>
374 struct norm_linf_functor : public simple_functor_base
375 {
376 using value_type = typename std::decay_t<decltype(xt::norm_linf(std::declval<xarray<T_E>>()))>::value_type;
377
378 template <class E, class A, class O>
379 auto compute(const E& input, const A& axes, const O& options) const
380 {
381 return xt::amax<value_type>(xt::abs(input), axes, options);
382 }
383
384 template <class BR, class E, class MR>
385 auto merge(const BR& block_result, bool first, E& result, MR&) const
386 {
387 if (first)
388 {
389 xt::noalias(result) = block_result;
390 }
391 else
392 {
393 xt::noalias(result) = xt::maximum(block_result, result);
394 }
395 }
396 };
397
398 template <class T_E>
399 class norm_lp_to_p_functor
400 {
401 public:
402
403 using value_type = typename std::decay_t<
404 decltype(xt::norm_lp_to_p(std::declval<xarray<T_E>>(), 1.0))>::value_type;
405
406 norm_lp_to_p_functor(double p)
407 : m_p(p)
408 {
409 }
410
411 template <class E, class A, class O>
412 auto compute(const E& input, const A& axes, const O& options) const
413 {
414 return xt::sum<value_type>(xt::pow(input, m_p), axes, options);
415 }
416
417 template <class E>
418 auto reduction_variable(const E&) const
419 {
420 return empty_reduction_variable();
421 }
422
423 template <class BR, class E>
424 auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
425 {
426 if (first)
427 {
428 xt::noalias(result) = block_result;
429 }
430 else
431 {
432 xt::noalias(result) += block_result;
433 }
434 }
435
436 template <class E, class R>
437 void finalize(const empty_reduction_variable&, E&, const R&) const
438 {
439 }
440
441 private:
442
443 double m_p;
444 };
445
446 template <class T_E>
447 class norm_lp_functor
448 {
449 public:
450
451 norm_lp_functor(double p)
452 : m_p(p)
453 {
454 }
455
456 using value_type = typename std::decay_t<decltype(xt::norm_lp(std::declval<xarray<T_E>>(), 1.0)
457 )>::value_type;
458
459 template <class E, class A, class O>
460 auto compute(const E& input, const A& axes, const O& options) const
461 {
462 return xt::sum<value_type>(xt::pow(input, m_p), axes, options);
463 }
464
465 template <class E>
466 auto reduction_variable(const E&) const
467 {
468 return empty_reduction_variable();
469 }
470
471 template <class BR, class E>
472 auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
473 {
474 if (first)
475 {
476 xt::noalias(result) = block_result;
477 }
478 else
479 {
480 xt::noalias(result) += block_result;
481 }
482 }
483
484 template <class E, class R>
485 void finalize(const empty_reduction_variable&, E& results, const R&) const
486 {
487 results = xt::pow(results, 1.0 / m_p);
488 }
489
490 private:
491
492 double m_p;
493 };
494
495
496 }
497 }
498}
499
500#endif
auto amax(E &&e, X &&axes, EVS es=EVS())
Maximum element along given axis.
Definition xmath.hpp:782
auto abs(E &&e) noexcept -> detail::xfunction_type_t< math::abs_fun, E >
Absolute value function.
Definition xmath.hpp:443
auto minimum(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::minimum< void >, E1, E2 >
Elementwise minimum.
Definition xmath.hpp:761
auto maximum(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::maximum< void >, E1, E2 >
Elementwise maximum.
Definition xmath.hpp:745
auto amin(E &&e, X &&axes, EVS es=EVS())
Minimum element along given axis.
Definition xmath.hpp:800
auto not_equal(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::not_equal_to, E1, E2 >
Element-wise inequality.
auto sqrt(E &&e) noexcept -> detail::xfunction_type_t< math::sqrt_fun, E >
Square root function.
Definition xmath.hpp:1201
auto square(E1 &&e1) noexcept
Square power function, equivalent to e1 * e1.
Definition xmath.hpp:1101
auto pow(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::pow_fun, E1, E2 >
Power function.
Definition xmath.hpp:1015
auto sum(E &&e, X &&axes, EVS es=EVS())
Sum of elements over given axes.
Definition xmath.hpp:1803
auto norm_sq(E &&e, X &&axes, EVS es) noexcept
Squared L2 norm of an array-like argument over given axes.
auto norm_lp(E &&e, double p, X &&axes, EVS es=EVS())
Lp norm of an array-like argument over given axes.
Definition xnorm.hpp:601
auto norm_l2(E &&e, EVS es=EVS()) noexcept
L2 norm of a scalar or array-like argument.
Definition xnorm.hpp:494
auto norm_l1(E &&e, X &&axes, EVS es) noexcept
L1 norm of an array-like argument over given axes.
auto prod(E &&e, X &&axes, EVS es=EVS())
Product of elements over given axes.
Definition xmath.hpp:1823
auto norm_lp_to_p(E &&e, double p, X &&axes, EVS es=EVS()) noexcept
p-th power of the Lp norm of an array-like argument over given axes.
Definition xnorm.hpp:557
auto mean(E &&e, X &&axes, EVS es=EVS())
Mean of elements over given axes.
Definition xmath.hpp:1897
auto norm_l0(E &&e, X &&axes, EVS es) noexcept
L0 (count) pseudo-norm of an array-like argument over given axes.
auto norm_linf(E &&e, X &&axes, EVS es) noexcept
Infinity (maximum) norm of an array-like argument over given axes.
standard mathematical functions for xexpressions
xarray_container< uvector< T, A >, L, xt::svector< typename uvector< T, A >::size_type, 4, SA, true > > xarray
Alias template on xarray_container with default parameters for data container type and shape / stride...
auto zeros(S shape) noexcept
Returns an xexpression containing zeros of the specified shape.
Definition xbuilder.hpp:66