xtensor
Loading...
Searching...
No Matches
xblockwise_reducer_functors.hpp
1#ifndef XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
2#define XTENSOR_XBLOCKWISE_REDUCER_FUNCTORS_HPP
3
4
5#include <sstream>
6#include <string>
7#include <tuple>
8#include <typeinfo>
9
10#include "xarray.hpp"
11#include "xbuilder.hpp"
12#include "xchunked_array.hpp"
13#include "xchunked_assign.hpp"
14#include "xchunked_view.hpp"
15#include "xexpression.hpp"
16#include "xmath.hpp"
17#include "xnorm.hpp"
18#include "xreducer.hpp"
19#include "xtl/xclosure.hpp"
20#include "xtl/xsequence.hpp"
21#include "xutils.hpp"
22
23namespace xt
24{
25 namespace detail
26 {
27 namespace blockwise
28 {
29
30 struct empty_reduction_variable
31 {
32 };
33
34 struct simple_functor_base
35 {
36 template <class E>
37 auto reduction_variable(const E&) const
38 {
39 return empty_reduction_variable();
40 }
41
42 template <class MR, class E, class R>
43 void finalize(const MR&, E&, const R&) const
44 {
45 }
46 };
47
48 template <class T_E, class T_I = void>
49 struct sum_functor : public simple_functor_base
50 {
51 using value_type = typename std::decay_t<decltype(xt::sum<T_I>(std::declval<xarray<T_E>>()))>::value_type;
52
53 template <class E, class A, class O>
54 auto compute(const E& input, const A& axes, const O& options) const
55 {
56 return xt::sum<value_type>(input, axes, options);
57 }
58
59 template <class BR, class E, class MR>
60 auto merge(const BR& block_result, bool first, E& result, MR&) const
61 {
62 if (first)
63 {
64 xt::noalias(result) = block_result;
65 }
66 else
67 {
68 xt::noalias(result) += block_result;
69 }
70 }
71 };
72
73 template <class T_E, class T_I = void>
74 struct prod_functor : public simple_functor_base
75 {
76 using value_type = typename std::decay_t<decltype(xt::sum<T_I>(std::declval<xarray<T_E>>()))>::value_type;
77
78 template <class E, class A, class O>
79 auto compute(const E& input, const A& axes, const O& options) const
80 {
81 return xt::prod<value_type>(input, axes, options);
82 }
83
84 template <class BR, class E, class MR>
85 auto merge(const BR& block_result, bool first, E& result, MR&) const
86 {
87 if (first)
88 {
89 xt::noalias(result) = block_result;
90 }
91 else
92 {
93 xt::noalias(result) *= block_result;
94 }
95 }
96 };
97
98 template <class T_E, class T_I = void>
99 struct amin_functor : public simple_functor_base
100 {
101 using value_type = typename std::decay_t<decltype(xt::amin<T_I>(std::declval<xarray<T_E>>()))>::value_type;
102
103 template <class E, class A, class O>
104 auto compute(const E& input, const A& axes, const O& options) const
105 {
106 return xt::amin(input, axes, options);
107 }
108
109 template <class BR, class E, class MR>
110 auto merge(const BR& block_result, bool first, E& result, MR&) const
111 {
112 if (first)
113 {
114 xt::noalias(result) = block_result;
115 }
116 else
117 {
118 xt::noalias(result) = xt::minimum(block_result, result);
119 }
120 }
121 };
122
123 template <class T_E, class T_I = void>
124 struct amax_functor : public simple_functor_base
125 {
126 using value_type = typename std::decay_t<decltype(xt::amax<T_I>(std::declval<xarray<T_E>>()))>::value_type;
127
128 template <class E, class A, class O>
129 auto compute(const E& input, const A& axes, const O& options) const
130 {
131 return xt::amax(input, axes, options);
132 }
133
134 template <class BR, class E, class MR>
135 auto merge(const BR& block_result, bool first, E& result, MR&) const
136 {
137 if (first)
138 {
139 xt::noalias(result) = block_result;
140 }
141 else
142 {
143 xt::noalias(result) = xt::maximum(block_result, result);
144 }
145 }
146 };
147
148 template <class T_E, class T_I = void>
149 struct mean_functor
150 {
151 using value_type = typename std::decay_t<decltype(xt::mean<T_I>(std::declval<xarray<T_E>>()))>::value_type;
152
153 template <class E, class A, class O>
154 auto compute(const E& input, const A& axes, const O& options) const
155 {
156 return xt::sum<value_type>(input, axes, options);
157 }
158
159 template <class E>
160 auto reduction_variable(const E&) const
161 {
162 return empty_reduction_variable();
163 }
164
165 template <class BR, class E>
166 auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
167 {
168 if (first)
169 {
170 xt::noalias(result) = block_result;
171 }
172 else
173 {
174 xt::noalias(result) += block_result;
175 }
176 }
177
178 template <class E, class R>
179 void finalize(const empty_reduction_variable&, E& results, const R& reducer) const
180 {
181 const auto& axes = reducer.axes();
182 std::decay_t<decltype(reducer.input_shape()[0])> factor = 1;
183 for (auto a : axes)
184 {
185 factor *= reducer.input_shape()[a];
186 }
187 xt::noalias(results) /= static_cast<typename E::value_type>(factor);
188 }
189 };
190
191 template <class T_E, class T_I = void>
192 struct variance_functor
193 {
194 using value_type = typename std::decay_t<decltype(xt::variance<T_I>(std::declval<xarray<T_E>>())
195 )>::value_type;
196
197 template <class E, class A, class O>
198 auto compute(const E& input, const A& axes, const O& options) const
199 {
200 double weight = 1.0;
201 for (auto a : axes)
202 {
203 weight *= static_cast<double>(input.shape()[a]);
204 }
205
206
207 return std::make_tuple(
208 xt::variance<value_type>(input, axes, options),
209 xt::mean<value_type>(input, axes, options),
210 weight
211 );
212 }
213
214 template <class E>
215 auto reduction_variable(const E&) const
216 {
217 return std::make_tuple(xarray<value_type>(), 0.0);
218 }
219
220 template <class BR, class E, class MR>
221 auto merge(const BR& block_result, bool first, E& variance_a, MR& mr) const
222 {
223 auto& mean_a = std::get<0>(mr);
224 auto& n_a = std::get<1>(mr);
225
226 const auto& variance_b = std::get<0>(block_result);
227 const auto& mean_b = std::get<1>(block_result);
228 const auto& n_b = std::get<2>(block_result);
229 if (first)
230 {
231 xt::noalias(variance_a) = variance_b;
232 xt::noalias(mean_a) = mean_b;
233 n_a += n_b;
234 }
235 else
236 {
237 auto new_mean = (n_a * mean_a + n_b * mean_b) / (n_a + n_b);
238 auto new_variance = (n_a * variance_a + n_b * variance_b
239 + n_a * xt::pow(mean_a - new_mean, 2)
240 + n_b * xt::pow(mean_b - new_mean, 2))
241 / (n_a + n_b);
242 xt::noalias(variance_a) = new_variance;
243 xt::noalias(mean_a) = new_mean;
244 n_a += n_b;
245 }
246 }
247
248 template <class MR, class E, class R>
249 void finalize(const MR&, E&, const R&) const
250 {
251 }
252 };
253
254 template <class T_E, class T_I = void>
255 struct stddev_functor : public variance_functor<T_E, T_I>
256 {
257 template <class MR, class E, class R>
258 void finalize(const MR&, E& results, const R&) const
259 {
260 xt::noalias(results) = xt::sqrt(results);
261 }
262 };
263
264 template <class T_E>
265 struct norm_l0_functor : public simple_functor_base
266 {
267 using value_type = typename std::decay_t<decltype(xt::norm_l0(std::declval<xarray<T_E>>()))>::value_type;
268
269 template <class E, class A, class O>
270 auto compute(const E& input, const A& axes, const O& options) const
271 {
272 return xt::sum<value_type>(xt::not_equal(input, xt::zeros<T_E>(input.shape())), axes, options);
273 }
274
275 template <class BR, class E, class MR>
276 auto merge(const BR& block_result, bool first, E& result, MR&) const
277 {
278 if (first)
279 {
280 xt::noalias(result) = block_result;
281 }
282 else
283 {
284 xt::noalias(result) += block_result;
285 }
286 }
287 };
288
289 template <class T_E>
290 struct norm_l1_functor : public simple_functor_base
291 {
292 using value_type = typename std::decay_t<decltype(xt::norm_l1(std::declval<xarray<T_E>>()))>::value_type;
293
294 template <class E, class A, class O>
295 auto compute(const E& input, const A& axes, const O& options) const
296 {
297 return xt::sum<value_type>(xt::abs(input), axes, options);
298 }
299
300 template <class BR, class E, class MR>
301 auto merge(const BR& block_result, bool first, E& result, MR&) const
302 {
303 if (first)
304 {
305 xt::noalias(result) = block_result;
306 }
307 else
308 {
309 xt::noalias(result) += block_result;
310 }
311 }
312 };
313
314 template <class T_E>
315 struct norm_l2_functor
316 {
317 using value_type = typename std::decay_t<decltype(xt::norm_l2(std::declval<xarray<T_E>>()))>::value_type;
318
319 template <class E, class A, class O>
320 auto compute(const E& input, const A& axes, const O& options) const
321 {
322 return xt::sum<value_type>(xt::square(input), axes, options);
323 }
324
325 template <class E>
326 auto reduction_variable(const E&) const
327 {
328 return empty_reduction_variable();
329 }
330
331 template <class BR, class E>
332 auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
333 {
334 if (first)
335 {
336 xt::noalias(result) = block_result;
337 }
338 else
339 {
340 xt::noalias(result) += block_result;
341 }
342 }
343
344 template <class E, class R>
345 void finalize(const empty_reduction_variable&, E& results, const R&) const
346 {
347 xt::noalias(results) = xt::sqrt(results);
348 }
349 };
350
351 template <class T_E>
352 struct norm_sq_functor : public simple_functor_base
353 {
354 using value_type = typename std::decay_t<decltype(xt::norm_sq(std::declval<xarray<T_E>>()))>::value_type;
355
356 template <class E, class A, class O>
357 auto compute(const E& input, const A& axes, const O& options) const
358 {
359 return xt::sum<value_type>(xt::square(input), axes, options);
360 }
361
362 template <class BR, class E, class MR>
363 auto merge(const BR& block_result, bool first, E& result, MR&) const
364 {
365 if (first)
366 {
367 xt::noalias(result) = block_result;
368 }
369 else
370 {
371 xt::noalias(result) += block_result;
372 }
373 }
374 };
375
376 template <class T_E>
377 struct norm_linf_functor : public simple_functor_base
378 {
379 using value_type = typename std::decay_t<decltype(xt::norm_linf(std::declval<xarray<T_E>>()))>::value_type;
380
381 template <class E, class A, class O>
382 auto compute(const E& input, const A& axes, const O& options) const
383 {
384 return xt::amax<value_type>(xt::abs(input), axes, options);
385 }
386
387 template <class BR, class E, class MR>
388 auto merge(const BR& block_result, bool first, E& result, MR&) const
389 {
390 if (first)
391 {
392 xt::noalias(result) = block_result;
393 }
394 else
395 {
396 xt::noalias(result) = xt::maximum(block_result, result);
397 }
398 }
399 };
400
401 template <class T_E>
402 class norm_lp_to_p_functor
403 {
404 public:
405
406 using value_type = typename std::decay_t<
407 decltype(xt::norm_lp_to_p(std::declval<xarray<T_E>>(), 1.0))>::value_type;
408
409 norm_lp_to_p_functor(double p)
410 : m_p(p)
411 {
412 }
413
414 template <class E, class A, class O>
415 auto compute(const E& input, const A& axes, const O& options) const
416 {
417 return xt::sum<value_type>(xt::pow(input, m_p), axes, options);
418 }
419
420 template <class E>
421 auto reduction_variable(const E&) const
422 {
423 return empty_reduction_variable();
424 }
425
426 template <class BR, class E>
427 auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
428 {
429 if (first)
430 {
431 xt::noalias(result) = block_result;
432 }
433 else
434 {
435 xt::noalias(result) += block_result;
436 }
437 }
438
439 template <class E, class R>
440 void finalize(const empty_reduction_variable&, E&, const R&) const
441 {
442 }
443
444 private:
445
446 double m_p;
447 };
448
449 template <class T_E>
450 class norm_lp_functor
451 {
452 public:
453
454 norm_lp_functor(double p)
455 : m_p(p)
456 {
457 }
458
459 using value_type = typename std::decay_t<decltype(xt::norm_lp(std::declval<xarray<T_E>>(), 1.0)
460 )>::value_type;
461
462 template <class E, class A, class O>
463 auto compute(const E& input, const A& axes, const O& options) const
464 {
465 return xt::sum<value_type>(xt::pow(input, m_p), axes, options);
466 }
467
468 template <class E>
469 auto reduction_variable(const E&) const
470 {
471 return empty_reduction_variable();
472 }
473
474 template <class BR, class E>
475 auto merge(const BR& block_result, bool first, E& result, empty_reduction_variable&) const
476 {
477 if (first)
478 {
479 xt::noalias(result) = block_result;
480 }
481 else
482 {
483 xt::noalias(result) += block_result;
484 }
485 }
486
487 template <class E, class R>
488 void finalize(const empty_reduction_variable&, E& results, const R&) const
489 {
490 results = xt::pow(results, 1.0 / m_p);
491 }
492
493 private:
494
495 double m_p;
496 };
497
498
499 }
500 }
501}
502
503#endif
auto amax(E &&e, X &&axes, EVS es=EVS())
Maximum element along given axis.
Definition xmath.hpp:782
auto abs(E &&e) noexcept -> detail::xfunction_type_t< math::abs_fun, E >
Absolute value function.
Definition xmath.hpp:443
auto minimum(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::minimum< void >, E1, E2 >
Elementwise minimum.
Definition xmath.hpp:761
auto maximum(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::maximum< void >, E1, E2 >
Elementwise maximum.
Definition xmath.hpp:745
auto amin(E &&e, X &&axes, EVS es=EVS())
Minimum element along given axis.
Definition xmath.hpp:800
auto not_equal(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::not_equal_to, E1, E2 >
Element-wise inequality.
auto sqrt(E &&e) noexcept -> detail::xfunction_type_t< math::sqrt_fun, E >
Square root function.
Definition xmath.hpp:1239
auto square(E1 &&e1) noexcept
Square power function, equivalent to e1 * e1.
Definition xmath.hpp:1128
auto pow(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::pow_fun, E1, E2 >
Power function.
Definition xmath.hpp:1015
auto norm_sq(E &&e, X &&axes, EVS es) noexcept
Squared L2 norm of an array-like argument over given axes.
auto norm_lp(E &&e, double p, X &&axes, EVS es=EVS())
Lp norm of an array-like argument over given axes.
Definition xnorm.hpp:601
auto norm_l2(E &&e, EVS es=EVS()) noexcept
L2 norm of a scalar or array-like argument.
Definition xnorm.hpp:494
auto norm_l1(E &&e, X &&axes, EVS es) noexcept
L1 norm of an array-like argument over given axes.
auto norm_lp_to_p(E &&e, double p, X &&axes, EVS es=EVS()) noexcept
p-th power of the Lp norm of an array-like argument over given axes.
Definition xnorm.hpp:557
auto norm_l0(E &&e, X &&axes, EVS es) noexcept
L0 (count) pseudo-norm of an array-like argument over given axes.
auto norm_linf(E &&e, X &&axes, EVS es) noexcept
Infinity (maximum) norm of an array-like argument over given axes.
standard mathematical functions for xexpressions