xtensor
 
Loading...
Searching...
No Matches
xreducer.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_REDUCER_HPP
11#define XTENSOR_REDUCER_HPP
12
13#include <algorithm>
14#include <cstddef>
15#include <iterator>
16#include <stdexcept>
17#include <tuple>
18#include <type_traits>
19#include <utility>
20
21#include <xtl/xfunctional.hpp>
22#include <xtl/xsequence.hpp>
23
24#include "../core/xaccessible.hpp"
25#include "../core/xeval.hpp"
26#include "../core/xexpression.hpp"
27#include "../core/xiterable.hpp"
28#include "../core/xtensor_config.hpp"
29#include "../generators/xbuilder.hpp"
30#include "../generators/xgenerator.hpp"
31#include "../utils/xutils.hpp"
32
33namespace xt
34{
35 template <template <class...> class A, class... AX, class X, XTL_REQUIRES(is_evaluation_strategy<AX>..., is_evaluation_strategy<X>)>
36 auto operator|(const A<AX...>& args, const A<X>& rhs)
37 {
38 return std::tuple_cat(args, rhs);
39 }
40
41 struct keep_dims_type : xt::detail::option_base
42 {
43 };
44
45 constexpr auto keep_dims = std::tuple<keep_dims_type>{};
46
47 template <class T = double>
48 struct xinitial : xt::detail::option_base
49 {
50 constexpr xinitial(T val)
51 : m_val(val)
52 {
53 }
54
55 constexpr T value() const
56 {
57 return m_val;
58 }
59
60 T m_val;
61 };
62
63 template <class T>
64 constexpr auto initial(T val)
65 {
66 return std::make_tuple(xinitial<T>(val));
67 }
68
69 template <std::ptrdiff_t I, class T, class Tuple>
71
72 template <std::ptrdiff_t I, class T>
73 struct tuple_idx_of_impl<I, T, std::tuple<>>
74 {
75 static constexpr std::ptrdiff_t value = -1;
76 };
77
78 template <std::ptrdiff_t I, class T, class... Types>
79 struct tuple_idx_of_impl<I, T, std::tuple<T, Types...>>
80 {
81 static constexpr std::ptrdiff_t value = I;
82 };
83
84 template <std::ptrdiff_t I, class T, class U, class... Types>
85 struct tuple_idx_of_impl<I, T, std::tuple<U, Types...>>
86 {
87 static constexpr std::ptrdiff_t value = tuple_idx_of_impl<I + 1, T, std::tuple<Types...>>::value;
88 };
89
90 template <class S, class... X>
91 struct decay_all;
92
93 template <template <class...> class S, class... X>
94 struct decay_all<S<X...>>
95 {
96 using type = S<std::decay_t<X>...>;
97 };
98
99 template <class T, class Tuple>
101 {
102 static constexpr std::ptrdiff_t
104 };
105
106 template <class R, class T>
107 struct reducer_options
108 {
109 template <class X>
110 struct initial_tester : std::false_type
111 {
112 };
113
114 template <class X>
115 struct initial_tester<xinitial<X>> : std::true_type
116 {
117 };
118
119 // Workaround for Apple because tuple_cat is buggy!
120 template <class X>
121 struct initial_tester<const xinitial<X>> : std::true_type
122 {
123 };
124
125 using d_t = std::decay_t<T>;
126
127 static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
128 reducer_options() = default;
129
130 reducer_options(const T& tpl)
131 {
132 if constexpr (initial_val_idx != std::tuple_size<T>::value)
133 {
134 initial_value = std::get < initial_val_idx != std::tuple_size<T>::value ? initial_val_idx
135 : 0 > (tpl).value();
136 }
137 }
138
139 using evaluation_strategy = std::conditional_t<
140 tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
143
144 using keep_dims = std::
145 conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1, std::true_type, std::false_type>;
146
147 static constexpr bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
148
149 R initial_value;
150
151 template <class NR>
152 using rebind_t = reducer_options<NR, T>;
153
154 template <class NR>
155 auto rebind(NR initial, const reducer_options<R, T>&) const
156 {
157 reducer_options<NR, T> res;
158 res.initial_value = initial;
159 return res;
160 }
161 };
162
163 template <class T>
164 struct is_reducer_options_impl : std::false_type
165 {
166 };
167
168 template <class... X>
169 struct is_reducer_options_impl<std::tuple<X...>> : std::true_type
170 {
171 };
172
173 template <class T>
175 {
176 };
177
178 /**********
179 * reduce *
180 **********/
181
182#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
183
184 template <class ST, class X, class KD = std::false_type>
185 struct xreducer_shape_type;
186
187 template <class S1, class S2>
189
190 namespace detail
191 {
192 template <class O, class RS, class R, class E, class AX>
193 inline void shape_computation(
194 RS& result_shape,
195 R& result,
196 E& expr,
197 const AX& axes,
198 std::enable_if_t<!detail::is_fixed<RS>::value, int> = 0
199 )
200 {
201 if (typename O::keep_dims())
202 {
203 resize_container(result_shape, expr.dimension());
204 for (std::size_t i = 0; i < expr.dimension(); ++i)
205 {
206 if (std::find(axes.begin(), axes.end(), i) == axes.end())
207 {
208 // i not in axes!
209 result_shape[i] = expr.shape()[i];
210 }
211 else
212 {
213 result_shape[i] = 1;
214 }
215 }
216 }
217 else
218 {
219 resize_container(result_shape, expr.dimension() - axes.size());
220 for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
221 {
222 if (std::find(axes.begin(), axes.end(), i) == axes.end())
223 {
224 // i not in axes!
225 result_shape[idx] = expr.shape()[i];
226 ++idx;
227 }
228 }
229 }
230 result.resize(result_shape, expr.layout());
231 }
232
233 // skip shape computation if already done at compile time
234 template <class O, class RS, class R, class S, class AX>
235 inline void
236 shape_computation(RS&, R&, const S&, const AX&, std::enable_if_t<detail::is_fixed<RS>::value, int> = 0)
237 {
238 }
239 }
240
241 template <class F, class E, class R, XTL_REQUIRES(std::is_convertible<typename E::value_type, typename R::value_type>)>
242 inline void copy_to_reduced(F&, const E& e, R& result)
243 {
244 if (e.layout() == layout_type::row_major)
245 {
246 std::copy(
247 e.template cbegin<layout_type::row_major>(),
248 e.template cend<layout_type::row_major>(),
249 result.data()
250 );
251 }
252 else
253 {
254 std::copy(
255 e.template cbegin<layout_type::column_major>(),
256 e.template cend<layout_type::column_major>(),
257 result.data()
258 );
259 }
260 }
261
262 template <
263 class F,
264 class E,
265 class R,
266 XTL_REQUIRES(std::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
267 inline void copy_to_reduced(F& f, const E& e, R& result)
268 {
269 if (e.layout() == layout_type::row_major)
270 {
271 std::transform(
272 e.template cbegin<layout_type::row_major>(),
273 e.template cend<layout_type::row_major>(),
274 result.data(),
275 f
276 );
277 }
278 else
279 {
280 std::transform(
281 e.template cbegin<layout_type::column_major>(),
282 e.template cend<layout_type::column_major>(),
283 result.data(),
284 f
285 );
286 }
287 }
288
289 template <class F, class E, class X, class O>
290 inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
291 {
292 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
293 using init_functor_type = typename std::decay_t<F>::init_functor_type;
294 using expr_value_type = typename std::decay_t<E>::value_type;
295 using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
296 std::declval<init_functor_type>()(),
297 std::declval<expr_value_type>()
298 ))>;
299
301 options_t options(raw_options);
302
303 using shape_type = typename xreducer_shape_type<
304 typename std::decay_t<E>::shape_type,
305 std::decay_t<X>,
306 typename options_t::keep_dims>::type;
307 using result_container_type = typename detail::xtype_for_shape<
308 shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
309 result_container_type result;
310
311 // retrieve functors from triple struct
312 auto reduce_fct = xt::get<0>(f);
313 auto init_fct = xt::get<1>(f);
314 auto merge_fct = xt::get<2>(f);
315
316 if (axes.size() == 0)
317 {
318 result.resize(e.shape(), e.layout());
319 auto cpf = [&reduce_fct, &init_fct](const auto& v)
320 {
321 return reduce_fct(static_cast<result_type>(init_fct()), v);
322 };
323 copy_to_reduced(cpf, e, result);
324 return result;
325 }
326
327 shape_type result_shape{};
328 dynamic_shape<std::size_t>
329 iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>, decltype(e.shape())>(e.shape());
330 dynamic_shape<std::size_t> iter_strides(e.dimension());
331
332 // std::less is used, because as the standard says (24.4.5):
333 // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the
334 // sequence and any non-negative integer n such that i + n is a valid iterator pointing to an element
335 // of the sequence, comp(*(i + n), *i) == false. Therefore less is required to detect duplicates.
336 if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less<>()))
337 {
338 XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted.");
339 }
340 if (std::adjacent_find(axes.cbegin(), axes.cend()) != axes.cend())
341 {
342 XTENSOR_THROW(std::runtime_error, "Reducing axes should not contain duplicates.");
343 }
344 if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
345 {
346 XTENSOR_THROW(
347 std::runtime_error,
348 "Axis " + std::to_string(axes[axes.size() - 1]) + " out of bounds for reduction."
349 );
350 }
351
352 detail::shape_computation<options_t>(result_shape, result, e, axes);
353
354 // Fast track for complete reduction
355 if (e.dimension() == axes.size())
356 {
357 result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
358 result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
359 return result;
360 }
361
362 std::size_t leading_ax = axes[(e.layout() == layout_type::row_major) ? axes.size() - 1 : 0];
363 auto strides_finder = e.strides().begin() + static_cast<std::ptrdiff_t>(leading_ax);
364 // The computed strides contain "0" where the shape is 1 -- therefore find the next none-zero number
365 std::size_t inner_stride = static_cast<std::size_t>(*strides_finder);
366 auto iter_bound = e.layout() == layout_type::row_major ? e.strides().begin() : (e.strides().end() - 1);
367 while (inner_stride == 0 && strides_finder != iter_bound)
368 {
369 (e.layout() == layout_type::row_major) ? --strides_finder : ++strides_finder;
370 inner_stride = static_cast<std::size_t>(*strides_finder);
371 }
372
373 if (inner_stride == 0)
374 {
375 auto cpf = [&reduce_fct, &init_fct](const auto& v)
376 {
377 return reduce_fct(static_cast<result_type>(init_fct()), v);
378 };
379 copy_to_reduced(cpf, e, result);
380 return result;
381 }
382
383 std::size_t inner_loop_size = static_cast<std::size_t>(inner_stride);
384 std::size_t outer_loop_size = e.shape()[leading_ax];
385
386 // The following code merges reduction axes "at the end" (or the beginning for col_major)
387 // together by increasing the size of the outer loop where appropriate
388 auto merge_loops = [&outer_loop_size, &e](auto it, auto end)
389 {
390 auto last_ax = *it;
391 ++it;
392 for (; it != end; ++it)
393 {
394 // note that we check is_sorted, so this condition is valid
395 if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
396 {
397 last_ax = *it;
398 outer_loop_size *= e.shape()[last_ax];
399 }
400 }
401 return last_ax;
402 };
403
404 for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
405 {
406 if (std::find(axes.begin(), axes.end(), i) == axes.end())
407 {
408 // i not in axes!
409 iter_strides[i] = static_cast<std::size_t>(result.strides(
410 )[typename options_t::keep_dims() ? i : idx]);
411 ++idx;
412 }
413 }
414
415 if (e.layout() == layout_type::row_major)
416 {
417 std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
418
419 iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
420 iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
421 }
422 else if (e.layout() == layout_type::column_major)
423 {
424 // we got column_major here
425 std::size_t last_ax = merge_loops(axes.begin(), axes.end());
426
427 // erasing the front vs the back
428 iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
429 iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
430
431 // and reversing, to make it work with the same next_idx function
432 std::reverse(iter_shape.begin(), iter_shape.end());
433 std::reverse(iter_strides.begin(), iter_strides.end());
434 }
435 else
436 {
437 XTENSOR_THROW(std::runtime_error, "Layout not supported in immediate reduction.");
438 }
439
440 xindex temp_idx(iter_shape.size());
441 auto next_idx = [&iter_shape, &iter_strides, &temp_idx]()
442 {
443 std::size_t i = iter_shape.size();
444 for (; i > 0; --i)
445 {
446 if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
447 {
448 temp_idx[i - 1] = 0;
449 }
450 else
451 {
452 temp_idx[i - 1]++;
453 break;
454 }
455 }
456
457 return std::make_pair(
458 i == 0,
459 std::inner_product(temp_idx.begin(), temp_idx.end(), iter_strides.begin(), std::ptrdiff_t(0))
460 );
461 };
462
463 auto begin = e.data();
464 auto out = result.data();
465 auto out_begin = result.data();
466
467 std::ptrdiff_t next_stride = 0;
468
469 std::pair<bool, std::ptrdiff_t> idx_res(false, 0);
470
471 // Remark: eventually some modifications here to make conditions faster where merge + accumulate is
472 // the same function (e.g. check std::is_same<decltype(merge_fct), decltype(reduce_fct)>::value) ...
473
474 auto merge_border = out;
475 bool merge = false;
476
477 // TODO there could be some performance gain by removing merge checking
478 // when axes.size() == 1 and even next_idx could be removed for something simpler (next_stride
479 // always the same) best way to do this would be to create a function that takes (begin, out,
480 // outer_loop_size, inner_loop_size, next_idx_lambda)
481 // Decide if going about it row-wise or col-wise
482 if (inner_stride == 1)
483 {
484 while (idx_res.first != true)
485 {
486 // for unknown reasons it's much faster to use a temporary variable and
487 // std::accumulate here -- probably some cache behavior
488 result_type tmp = init_fct();
489 tmp = std::accumulate(begin, begin + outer_loop_size, tmp, reduce_fct);
490
491 // use merge function if necessary
492 *out = merge ? merge_fct(*out, tmp) : tmp;
493
494 begin += outer_loop_size;
495
496 idx_res = next_idx();
497 next_stride = idx_res.second;
498 out = out_begin + next_stride;
499
500 if (out > merge_border)
501 {
502 // looped over once
503 merge = false;
504 merge_border = out;
505 }
506 else
507 {
508 merge = true;
509 }
510 };
511 }
512 else
513 {
514 while (idx_res.first != true)
515 {
516 std::transform(
517 out,
518 out + inner_loop_size,
519 begin,
520 out,
521 [merge, &init_fct, &reduce_fct](auto&& v1, auto&& v2)
522 {
523 return merge ? reduce_fct(v1, v2) :
524 // cast because return type of identity function is not upcasted
525 reduce_fct(static_cast<result_type>(init_fct()), v2);
526 }
527 );
528
529 begin += inner_stride;
530 for (std::size_t i = 1; i < outer_loop_size; ++i)
531 {
532 std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
533 begin += inner_stride;
534 }
535
536 idx_res = next_idx();
537 next_stride = idx_res.second;
538 out = out_begin + next_stride;
539
540 if (out > merge_border)
541 {
542 // looped over once
543 merge = false;
544 merge_border = out;
545 }
546 else
547 {
548 merge = true;
549 }
550 };
551 }
552 if (options_t::has_initial_value)
553 {
554 std::transform(
555 result.data(),
556 result.data() + result.size(),
557 result.data(),
558 [&merge_fct, &options](auto&& v)
559 {
560 return merge_fct(v, options.initial_value);
561 }
562 );
563 }
564 return result;
565 }
566
567 /*********************
568 * xreducer functors *
569 *********************/
570
571 template <class T>
572 struct const_value
573 {
574 using value_type = T;
575
576 constexpr const_value() = default;
577
578 constexpr const_value(T t)
579 : m_value(t)
580 {
581 }
582
583 constexpr T operator()() const
584 {
585 return m_value;
586 }
587
588 template <class NT>
589 using rebind_t = const_value<NT>;
590
591 template <class NT>
592 const_value<NT> rebind() const;
593
594 T m_value;
595 };
596
597 namespace detail
598 {
599 template <class T, bool B>
600 struct evaluated_value_type
601 {
602 using type = T;
603 };
604
605 template <class T>
606 struct evaluated_value_type<T, true>
607 {
608 using type = typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
609 };
610
611 template <class T, bool B>
612 using evaluated_value_type_t = typename evaluated_value_type<T, B>::type;
613 }
614
615 template <class REDUCE_FUNC, class INIT_FUNC = const_value<long int>, class MERGE_FUNC = REDUCE_FUNC>
616 struct xreducer_functors : public std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>
617 {
618 using self_type = xreducer_functors<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
619 using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
620 using reduce_functor_type = REDUCE_FUNC;
621 using init_functor_type = INIT_FUNC;
622 using merge_functor_type = MERGE_FUNC;
623 using init_value_type = typename init_functor_type::value_type;
624
625 xreducer_functors()
626 : base_type()
627 {
628 }
629
630 template <class RF>
631 xreducer_functors(RF&& reduce_func)
632 : base_type(std::forward<RF>(reduce_func), INIT_FUNC(), reduce_func)
633 {
634 }
635
636 template <class RF, class IF>
637 xreducer_functors(RF&& reduce_func, IF&& init_func)
638 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), reduce_func)
639 {
640 }
641
642 template <class RF, class IF, class MF>
643 xreducer_functors(RF&& reduce_func, IF&& init_func, MF&& merge_func)
644 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func))
645 {
646 }
647
648 reduce_functor_type get_reduce() const
649 {
650 return std::get<0>(upcast());
651 }
652
653 init_functor_type get_init() const
654 {
655 return std::get<1>(upcast());
656 }
657
658 merge_functor_type get_merge() const
659 {
660 return std::get<2>(upcast());
661 }
662
663 template <class NT>
664 using rebind_t = xreducer_functors<REDUCE_FUNC, const_value<NT>, MERGE_FUNC>;
665
666 template <class NT>
667 rebind_t<NT> rebind()
668 {
669 return make_xreducer_functor(get_reduce(), get_init().template rebind<NT>(), get_merge());
670 }
671
672 private:
673
674 // Workaround for clang-cl
675 const base_type& upcast() const
676 {
677 return static_cast<const base_type&>(*this);
678 }
679 };
680
681 template <class RF>
682 auto make_xreducer_functor(RF&& reduce_func)
683 {
685 return reducer_type(std::forward<RF>(reduce_func));
686 }
687
688 template <class RF, class IF>
689 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
690 {
691 using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
692 return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
693 }
694
695 template <class RF, class IF, class MF>
696 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
697 {
698 using reducer_type = xreducer_functors<
699 std::remove_reference_t<RF>,
700 std::remove_reference_t<IF>,
701 std::remove_reference_t<MF>>;
702 return reducer_type(
703 std::forward<RF>(reduce_func),
704 std::forward<IF>(init_func),
705 std::forward<MF>(merge_func)
706 );
707 }
708
709 /**********************
710 * xreducer extension *
711 **********************/
712
713 namespace extension
714 {
715 template <class Tag, class F, class CT, class X, class O>
717
718 template <class F, class CT, class X, class O>
720 {
721 using type = xtensor_empty_base;
722 };
723
724 template <class F, class CT, class X, class O>
725 struct xreducer_base : xreducer_base_impl<xexpression_tag_t<CT>, F, CT, X, O>
726 {
727 };
728
729 template <class F, class CT, class X, class O>
730 using xreducer_base_t = typename xreducer_base<F, CT, X, O>::type;
731 }
732
733 /************
734 * xreducer *
735 ************/
736
737 template <class F, class CT, class X, class O>
738 class xreducer;
739
740 template <class F, class CT, class X, class O>
741 class xreducer_stepper;
742
743 template <class F, class CT, class X, class O>
744 struct xiterable_inner_types<xreducer<F, CT, X, O>>
745 {
746 using xexpression_type = std::decay_t<CT>;
747 using inner_shape_type = typename xreducer_shape_type<
748 typename xexpression_type::shape_type,
749 std::decay_t<X>,
750 typename O::keep_dims>::type;
751 using const_stepper = xreducer_stepper<F, CT, X, O>;
752 using stepper = const_stepper;
753 };
754
755 template <class F, class CT, class X, class O>
756 struct xcontainer_inner_types<xreducer<F, CT, X, O>>
757 {
758 using xexpression_type = std::decay_t<CT>;
759 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
760 using init_functor_type = typename std::decay_t<F>::init_functor_type;
761 using merge_functor_type = typename std::decay_t<F>::merge_functor_type;
762 using substepper_type = typename xexpression_type::const_stepper;
763 using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
764 std::declval<init_functor_type>()(),
765 *std::declval<substepper_type>()
766 ))>;
767 using value_type = typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
768
769 using reference = value_type;
770 using const_reference = value_type;
771 using size_type = typename xexpression_type::size_type;
772 };
773
774 template <class T>
776 {
777 using type = T;
778 };
779
780 template <std::size_t... I>
782 {
783 using type = std::array<std::size_t, sizeof...(I)>;
784 };
785
803 template <class F, class CT, class X, class O>
804 class xreducer : public xsharable_expression<xreducer<F, CT, X, O>>,
805 public xconst_iterable<xreducer<F, CT, X, O>>,
806 public xaccessible<xreducer<F, CT, X, O>>,
807 public extension::xreducer_base_t<F, CT, X, O>
808 {
809 public:
810
811 using self_type = xreducer<F, CT, X, O>;
812 using inner_types = xcontainer_inner_types<self_type>;
813
814 using reduce_functor_type = typename inner_types::reduce_functor_type;
815 using init_functor_type = typename inner_types::init_functor_type;
816 using merge_functor_type = typename inner_types::merge_functor_type;
818
819 using xexpression_type = typename inner_types::xexpression_type;
820 using axes_type = X;
821
822 using extension_base = extension::xreducer_base_t<F, CT, X, O>;
823 using expression_tag = typename extension_base::expression_tag;
824
825 using substepper_type = typename inner_types::substepper_type;
826 using value_type = typename inner_types::value_type;
827 using reference = typename inner_types::reference;
828 using const_reference = typename inner_types::const_reference;
829 using pointer = value_type*;
830 using const_pointer = const value_type*;
831
832 using size_type = typename inner_types::size_type;
833 using difference_type = typename xexpression_type::difference_type;
834
835 using iterable_base = xconst_iterable<self_type>;
836 using inner_shape_type = typename iterable_base::inner_shape_type;
837 using shape_type = inner_shape_type;
838
839 using dim_mapping_type = typename select_dim_mapping_type<inner_shape_type>::type;
840
841 using stepper = typename iterable_base::stepper;
842 using const_stepper = typename iterable_base::const_stepper;
843 using bool_load_type = typename xexpression_type::bool_load_type;
844
845 static constexpr layout_type static_layout = layout_type::dynamic;
846 static constexpr bool contiguous_layout = false;
847
848 template <class Func, class CTA, class AX, class OX>
849 xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options);
850
851 const inner_shape_type& shape() const noexcept;
852 layout_type layout() const noexcept;
853 bool is_contiguous() const noexcept;
854
855 template <class... Args>
856 const_reference operator()(Args... args) const;
857 template <class... Args>
858 const_reference unchecked(Args... args) const;
859
860 template <class It>
861 const_reference element(It first, It last) const;
862
863 const xexpression_type& expression() const noexcept;
864
865 template <class S>
866 bool broadcast_shape(S& shape, bool reuse_cache = false) const;
867
868 template <class S>
869 bool has_linear_assign(const S& strides) const noexcept;
870
871 template <class S>
872 const_stepper stepper_begin(const S& shape) const noexcept;
873 template <class S>
874 const_stepper stepper_end(const S& shape, layout_type) const noexcept;
875
876 template <class E, class Func = F, class Opts = O>
877 using rebind_t = xreducer<Func, E, X, Opts>;
878
879 template <class E>
880 rebind_t<E> build_reducer(E&& e) const;
881
882 template <class E, class Func, class Opts>
883 rebind_t<E, Func, Opts> build_reducer(E&& e, Func&& func, Opts&& opts) const;
884
885 xreducer_functors_type functors() const
886 {
887 return xreducer_functors_type(m_reduce, m_init, m_merge); // TODO: understand why
888 // make_xreducer_functor is throwing an
889 // error
890 }
891
892 const O& options() const
893 {
894 return m_options;
895 }
896
897 private:
898
899 CT m_e;
900 reduce_functor_type m_reduce;
901 init_functor_type m_init;
902 merge_functor_type m_merge;
903 axes_type m_axes;
904 inner_shape_type m_shape;
905 dim_mapping_type m_dim_mapping;
906 O m_options;
907
908 friend class xreducer_stepper<F, CT, X, O>;
909 };
910
911 /*************************
912 * reduce implementation *
913 *************************/
914
915 namespace detail
916 {
917 template <class F, class E, class X, class O>
918 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::lazy_type, O&& options)
919 {
920 decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
921
922 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
923 using init_functor_type = typename std::decay_t<F>::init_functor_type;
924 using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
925 std::declval<init_functor_type>()(),
926 *std::declval<typename std::decay_t<E>::const_stepper>()
927 ))>;
928 using evaluated_value_type = evaluated_value_type_t<value_type, is_xexpression<value_type>::value>;
929
930 using reducer_type = xreducer<
931 F,
932 const_xclosure_t<E>,
933 xtl::const_closure_type_t<decltype(normalized_axes)>,
934 reducer_options<evaluated_value_type, std::decay_t<O>>>;
935 return reducer_type(
936 std::forward<F>(f),
937 std::forward<E>(e),
938 std::forward<decltype(normalized_axes)>(normalized_axes),
939 std::forward<O>(options)
940 );
941 }
942
943 template <class F, class E, class X, class O>
944 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
945 {
946 decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
947 return reduce_immediate(
948 std::forward<F>(f),
949 eval(std::forward<E>(e)),
950 std::forward<decltype(normalized_axes)>(normalized_axes),
951 std::forward<O>(options)
952 );
953 }
954 }
955
956#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
957
958 namespace detail
959 {
960 template <class T>
961 struct is_xreducer_functors_impl : std::false_type
962 {
963 };
964
965 template <class RF, class IF, class MF>
966 struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>> : std::true_type
967 {
968 };
969
970 template <class T>
971 using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
972 }
973
986
987 template <
988 class F,
989 class E,
990 class X,
991 class EVS = DEFAULT_STRATEGY_REDUCERS,
992 XTL_REQUIRES(std::negation<is_reducer_options<X>>, detail::is_xreducer_functors<F>)>
993 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
994 {
995 return detail::reduce_impl(
996 std::forward<F>(f),
997 std::forward<E>(e),
998 std::forward<X>(axes),
999 typename reducer_options<int, EVS>::evaluation_strategy{},
1000 std::forward<EVS>(options)
1001 );
1002 }
1003
1004 template <
1005 class F,
1006 class E,
1007 class X,
1008 class EVS = DEFAULT_STRATEGY_REDUCERS,
1009 XTL_REQUIRES(std::negation<is_reducer_options<X>>, std::negation<detail::is_xreducer_functors<F>>)>
1010 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1011 {
1012 return reduce(
1013 make_xreducer_functor(std::forward<F>(f)),
1014 std::forward<E>(e),
1015 std::forward<X>(axes),
1016 std::forward<EVS>(options)
1017 );
1018 }
1019
1020 template <
1021 class F,
1022 class E,
1023 class EVS = DEFAULT_STRATEGY_REDUCERS,
1024 XTL_REQUIRES(is_reducer_options<EVS>, detail::is_xreducer_functors<F>)>
1025 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1026 {
1027 xindex_type_t<typename std::decay_t<E>::shape_type> ar;
1028 resize_container(ar, e.dimension());
1029 std::iota(ar.begin(), ar.end(), 0);
1030 return detail::reduce_impl(
1031 std::forward<F>(f),
1032 std::forward<E>(e),
1033 std::move(ar),
1034 typename reducer_options<int, std::decay_t<EVS>>::evaluation_strategy{},
1035 std::forward<EVS>(options)
1036 );
1037 }
1038
1039 template <
1040 class F,
1041 class E,
1042 class EVS = DEFAULT_STRATEGY_REDUCERS,
1043 XTL_REQUIRES(is_reducer_options<EVS>, std::negation<detail::is_xreducer_functors<F>>)>
1044 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1045 {
1046 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
1047 }
1048
1049 template <
1050 class F,
1051 class E,
1052 class I,
1053 std::size_t N,
1054 class EVS = DEFAULT_STRATEGY_REDUCERS,
1055 XTL_REQUIRES(detail::is_xreducer_functors<F>)>
1056 inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
1057 {
1058 using axes_type = std::array<std::size_t, N>;
1059 auto ax = xt::forward_normalize<axes_type>(e, axes);
1060 return detail::reduce_impl(
1061 std::forward<F>(f),
1062 std::forward<E>(e),
1063 std::move(ax),
1064 typename reducer_options<int, EVS>::evaluation_strategy{},
1065 options
1066 );
1067 }
1068
1069 template <
1070 class F,
1071 class E,
1072 class I,
1073 std::size_t N,
1074 class EVS = DEFAULT_STRATEGY_REDUCERS,
1075 XTL_REQUIRES(std::negation<detail::is_xreducer_functors<F>>)>
1076 inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
1077 {
1078 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
1079 }
1080
1081 /********************
1082 * xreducer_stepper *
1083 ********************/
1084
1085 template <class F, class CT, class X, class O>
1086 class xreducer_stepper
1087 {
1088 public:
1089
1090 using self_type = xreducer_stepper<F, CT, X, O>;
1091 using xreducer_type = xreducer<F, CT, X, O>;
1092
1093 using value_type = typename xreducer_type::value_type;
1094 using reference = typename xreducer_type::value_type;
1095 using pointer = typename xreducer_type::const_pointer;
1096 using size_type = typename xreducer_type::size_type;
1097 using difference_type = typename xreducer_type::difference_type;
1098
1099 using xexpression_type = typename xreducer_type::xexpression_type;
1100 using substepper_type = typename xexpression_type::const_stepper;
1101 using shape_type = typename xreducer_type::shape_type;
1102
1103 xreducer_stepper(
1104 const xreducer_type& red,
1105 size_type offset,
1106 bool end = false,
1107 layout_type l = default_assignable_layout(xexpression_type::static_layout)
1108 );
1109
1110 reference operator*() const;
1111
1112 void step(size_type dim);
1113 void step_back(size_type dim);
1114 void step(size_type dim, size_type n);
1115 void step_back(size_type dim, size_type n);
1116 void reset(size_type dim);
1117 void reset_back(size_type dim);
1118
1119 void to_begin();
1120 void to_end(layout_type l);
1121
1122 private:
1123
1124 reference initial_value() const;
1125 reference aggregate(size_type dim) const;
1126 reference aggregate_impl(size_type dim, /*keep_dims=*/std::false_type) const;
1127 reference aggregate_impl(size_type dim, /*keep_dims=*/std::true_type) const;
1128
1129 substepper_type get_substepper_begin() const;
1130 size_type get_dim(size_type dim) const noexcept;
1131 size_type shape(size_type i) const noexcept;
1132 size_type axis(size_type i) const noexcept;
1133
1134 const xreducer_type* m_reducer;
1135 size_type m_offset;
1136 mutable substepper_type m_stepper;
1137 };
1138
1139 /******************
1140 * xreducer utils *
1141 ******************/
1142
1143 namespace detail
1144 {
1145 template <std::size_t X, std::size_t... I>
1146 struct in
1147 {
1148 static constexpr bool value = std::disjunction<std::integral_constant<bool, X == I>...>::value;
1149 };
1150
1151 template <std::size_t Z, class S1, class S2, class R>
1152 struct fixed_xreducer_shape_type_impl;
1153
1154 template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
1155 struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1156 {
1157 using type = std::conditional_t<
1158 in<Z, J...>::value,
1159 typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>::type,
1160 typename fixed_xreducer_shape_type_impl<
1161 Z - 1,
1162 fixed_shape<I...>,
1163 fixed_shape<J...>,
1164 fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
1165 };
1166
1167 template <std::size_t... I, std::size_t... J, std::size_t... R>
1168 struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1169 {
1170 using type = std::
1171 conditional_t<in<0, J...>::value, fixed_shape<R...>, fixed_shape<detail::at<0, I...>::value, R...>>;
1172 };
1173
1174 /***************************
1175 * helper for return types *
1176 ***************************/
1177
1178 template <class T>
1179 struct xreducer_size_type
1180 {
1181 using type = std::size_t;
1182 };
1183
1184 template <class T>
1185 using xreducer_size_type_t = typename xreducer_size_type<T>::type;
1186
1187 template <class T>
1188 struct xreducer_temporary_type
1189 {
1190 using type = T;
1191 };
1192
1193 template <class T>
1194 using xreducer_temporary_type_t = typename xreducer_temporary_type<T>::type;
1195
1196 /********************************
1197 * Default const_value rebinder *
1198 ********************************/
1199
1200 template <class T, class U>
1201 struct const_value_rebinder
1202 {
1203 static const_value<U> run(const const_value<T>& t)
1204 {
1205 return const_value<U>(t.m_value);
1206 }
1207 };
1208 }
1209
1210 /*******************************************
1211 * Init functor const_value implementation *
1212 *******************************************/
1213
1214 template <class T>
1215 template <class NT>
1216 const_value<NT> const_value<T>::rebind() const
1217 {
1218 return detail::const_value_rebinder<T, NT>::run(*this);
1219 }
1220
1221 /*****************************
1222 * fixed_xreducer_shape_type *
1223 *****************************/
1224
1225 template <class S1, class S2>
1227
1228 template <std::size_t... I, std::size_t... J>
1230 {
1231 using type = typename detail::
1232 fixed_xreducer_shape_type_impl<sizeof...(I) - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<>>::type;
1233 };
1234
1235 // meta-function returning the shape type for an xreducer
1236 template <class ST, class X, class O>
1238 {
1239 using type = promote_shape_t<ST, std::decay_t<X>>;
1240 };
1241
1242 template <class I1, std::size_t N1, class I2, std::size_t N2>
1243 struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::true_type>
1244 {
1245 using type = std::array<I2, N1>;
1246 };
1247
1248 template <class I1, std::size_t N1, class I2, std::size_t N2>
1249 struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::false_type>
1250 {
1251 using type = std::array<I2, N1 - N2>;
1252 };
1253
1254 template <std::size_t... I, class I2, std::size_t N2>
1255 struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::false_type>
1256 {
1257 using type = std::conditional_t<sizeof...(I) == N2, fixed_shape<>, std::array<I2, sizeof...(I) - N2>>;
1258 };
1259
1260 namespace detail
1261 {
1262 template <class S1, class S2>
1263 struct ixconcat;
1264
1265 template <class T, T... I1, T... I2>
1266 struct ixconcat<std::integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
1267 {
1268 using type = std::integer_sequence<T, I1..., I2...>;
1269 };
1270
1271 template <class T, T X, std::size_t N>
1272 struct repeat_integer_sequence
1273 {
1274 using type = typename ixconcat<
1275 std::integer_sequence<T, X>,
1276 typename repeat_integer_sequence<T, X, N - 1>::type>::type;
1277 };
1278
1279 template <class T, T X>
1280 struct repeat_integer_sequence<T, X, 0>
1281 {
1282 using type = std::integer_sequence<T>;
1283 };
1284
1285 template <class T, T X>
1286 struct repeat_integer_sequence<T, X, 2>
1287 {
1288 using type = std::integer_sequence<T, X, X>;
1289 };
1290
1291 template <class T, T X>
1292 struct repeat_integer_sequence<T, X, 1>
1293 {
1294 using type = std::integer_sequence<T, X>;
1295 };
1296 }
1297
1298 template <std::size_t... I, class I2, std::size_t N2>
1299 struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::true_type>
1300 {
1301 template <std::size_t... X>
1302 static constexpr auto get_type(std::index_sequence<X...>)
1303 {
1304 return fixed_shape<X...>{};
1305 }
1306
1307 // if all axes reduced
1308 using type = std::conditional_t<
1309 sizeof...(I) == N2,
1310 decltype(get_type(typename detail::repeat_integer_sequence<std::size_t, std::size_t(1), N2>::type{})),
1311 std::array<I2, sizeof...(I)>>;
1312 };
1313
1314 // Note adding "A" to prevent compilation in case nothing else matches
1315 template <std::size_t... I, std::size_t... J, class O>
1317 {
1318 using type = typename fixed_xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>>::type;
1319 };
1320
1321 namespace detail
1322 {
1323 template <class S, class E, class X, class M>
1324 inline void shape_and_mapping_computation(S& shape, E& e, const X& axes, M& mapping, std::false_type)
1325 {
1326 auto first = e.shape().begin();
1327 auto last = e.shape().end();
1328 auto exclude_it = axes.begin();
1329
1330 using value_type = typename S::value_type;
1331 using difference_type = typename S::difference_type;
1332 auto d_first = shape.begin();
1333 auto map_first = mapping.begin();
1334
1335 auto iter = first;
1336 while (iter != last && exclude_it != axes.end())
1337 {
1338 auto diff = std::distance(first, iter);
1339 if (diff != difference_type(*exclude_it))
1340 {
1341 *d_first++ = *iter++;
1342 *map_first++ = value_type(diff);
1343 }
1344 else
1345 {
1346 ++iter;
1347 ++exclude_it;
1348 }
1349 }
1350
1351 auto diff = std::distance(first, iter);
1352 auto end = std::distance(iter, last);
1353 std::iota(map_first, map_first + end, diff);
1354 std::copy(iter, last, d_first);
1355 }
1356
1357 template <class S, class E, class X, class M>
1358 inline void
1359 shape_and_mapping_computation_keep_dim(S& shape, E& e, const X& axes, M& mapping, std::false_type)
1360 {
1361 for (std::size_t i = 0; i < e.dimension(); ++i)
1362 {
1363 if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
1364 {
1365 // i not in axes!
1366 shape[i] = e.shape()[i];
1367 }
1368 else
1369 {
1370 shape[i] = 1;
1371 }
1372 }
1373 std::iota(mapping.begin(), mapping.end(), 0);
1374 }
1375
1376 template <class S, class E, class X, class M>
1377 inline void shape_and_mapping_computation(S&, E&, const X&, M&, std::true_type)
1378 {
1379 }
1380
1381 template <class S, class E, class X, class M>
1382 inline void shape_and_mapping_computation_keep_dim(S&, E&, const X&, M&, std::true_type)
1383 {
1384 }
1385 }
1386
1387 /***************************
1388 * xreducer implementation *
1389 ***************************/
1390
1395
1404 template <class F, class CT, class X, class O>
1405 template <class Func, class CTA, class AX, class OX>
1406 inline xreducer<F, CT, X, O>::xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options)
1407 : m_e(std::forward<CTA>(e))
1408 , m_reduce(xt::get<0>(func))
1409 , m_init(xt::get<1>(func))
1410 , m_merge(xt::get<2>(func))
1411 , m_axes(std::forward<AX>(axes))
1412 , m_shape(xtl::make_sequence<inner_shape_type>(
1413 typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
1414 0
1415 ))
1416 , m_dim_mapping(xtl::make_sequence<dim_mapping_type>(
1417 typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
1418 0
1419 ))
1420 , m_options(std::forward<OX>(options))
1421 {
1422 // std::less is used, because as the standard says (24.4.5):
1423 // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the
1424 // sequence and any non-negative integer n such that i + n is a valid iterator pointing to an element
1425 // of the sequence, comp(*(i + n), *i) == false. Therefore less is required to detect duplicates.
1426 if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less<>()))
1427 {
1428 XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted.");
1429 }
1430 if (std::adjacent_find(m_axes.cbegin(), m_axes.cend()) != m_axes.cend())
1431 {
1432 XTENSOR_THROW(std::runtime_error, "Reducing axes should not contain duplicates.");
1433 }
1434 if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
1435 {
1436 XTENSOR_THROW(
1437 std::runtime_error,
1438 "Axis " + std::to_string(m_axes[m_axes.size() - 1]) + " out of bounds for reduction."
1439 );
1440 }
1441
1442 if (!typename O::keep_dims())
1443 {
1444 detail::shape_and_mapping_computation(
1445 m_shape,
1446 m_e,
1447 m_axes,
1448 m_dim_mapping,
1449 detail::is_fixed<shape_type>{}
1450 );
1451 }
1452 else
1453 {
1454 detail::shape_and_mapping_computation_keep_dim(
1455 m_shape,
1456 m_e,
1457 m_axes,
1458 m_dim_mapping,
1459 detail::is_fixed<shape_type>{}
1460 );
1461 }
1462 }
1463
1465
1469
1473 template <class F, class CT, class X, class O>
1474 inline auto xreducer<F, CT, X, O>::shape() const noexcept -> const inner_shape_type&
1475 {
1476 return m_shape;
1477 }
1478
1482 template <class F, class CT, class X, class O>
1484 {
1485 return static_layout;
1486 }
1487
1488 template <class F, class CT, class X, class O>
1489 inline bool xreducer<F, CT, X, O>::is_contiguous() const noexcept
1490 {
1491 return false;
1492 }
1493
1495
1499
1506 template <class F, class CT, class X, class O>
1507 template <class... Args>
1508 inline auto xreducer<F, CT, X, O>::operator()(Args... args) const -> const_reference
1509 {
1510 XTENSOR_TRY(check_index(shape(), args...));
1511 XTENSOR_CHECK_DIMENSION(shape(), args...);
1512 std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
1513 return element(arg_array.cbegin(), arg_array.cend());
1514 }
1515
1535 template <class F, class CT, class X, class O>
1536 template <class... Args>
1537 inline auto xreducer<F, CT, X, O>::unchecked(Args... args) const -> const_reference
1538 {
1539 std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
1540 return element(arg_array.cbegin(), arg_array.cend());
1541 }
1542
1550 template <class F, class CT, class X, class O>
1551 template <class It>
1552 inline auto xreducer<F, CT, X, O>::element(It first, It last) const -> const_reference
1553 {
1554 XTENSOR_TRY(check_element_index(shape(), first, last));
1555 auto stepper = const_stepper(*this, 0);
1556 if (first != last)
1557 {
1558 size_type dim = 0;
1559 // drop left most elements
1560 auto size = std::ptrdiff_t(this->dimension()) - std::distance(first, last);
1561 auto begin = first - size;
1562 while (begin != last)
1563 {
1564 if (begin < first)
1565 {
1566 stepper.step(dim++, std::size_t(0));
1567 begin++;
1568 }
1569 else
1570 {
1571 stepper.step(dim++, std::size_t(*begin++));
1572 }
1573 }
1574 }
1575 return *stepper;
1576 }
1577
1581 template <class F, class CT, class X, class O>
1582 inline auto xreducer<F, CT, X, O>::expression() const noexcept -> const xexpression_type&
1583 {
1584 return m_e;
1585 }
1586
1588
1593
1599 template <class F, class CT, class X, class O>
1600 template <class S>
1602 {
1603 return xt::broadcast_shape(m_shape, shape);
1604 }
1605
1611 template <class F, class CT, class X, class O>
1612 template <class S>
1613 inline bool xreducer<F, CT, X, O>::has_linear_assign(const S& /*strides*/) const noexcept
1614 {
1615 return false;
1616 }
1617
1619
1620 template <class F, class CT, class X, class O>
1621 template <class S>
1622 inline auto xreducer<F, CT, X, O>::stepper_begin(const S& shape) const noexcept -> const_stepper
1623 {
1624 size_type offset = shape.size() - this->dimension();
1625 return const_stepper(*this, offset);
1626 }
1627
1628 template <class F, class CT, class X, class O>
1629 template <class S>
1630 inline auto xreducer<F, CT, X, O>::stepper_end(const S& shape, layout_type l) const noexcept
1631 -> const_stepper
1632 {
1633 size_type offset = shape.size() - this->dimension();
1634 return const_stepper(*this, offset, true, l);
1635 }
1636
1637 template <class F, class CT, class X, class O>
1638 template <class E>
1639 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e) const -> rebind_t<E>
1640 {
1641 return rebind_t<E>(
1642 std::make_tuple(m_reduce, m_init, m_merge),
1643 std::forward<E>(e),
1644 axes_type(m_axes),
1645 m_options
1646 );
1647 }
1648
1649 template <class F, class CT, class X, class O>
1650 template <class E, class Func, class Opts>
1651 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts) const
1652 -> rebind_t<E, Func, Opts>
1653 {
1654 return rebind_t<E, Func, Opts>(
1655 std::forward<Func>(func),
1656 std::forward<E>(e),
1657 axes_type(m_axes),
1658 std::forward<Opts>(opts)
1659 );
1660 }
1661
1662 /***********************************
1663 * xreducer_stepper implementation *
1664 ***********************************/
1665
1666 template <class F, class CT, class X, class O>
1667 inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(
1668 const xreducer_type& red,
1669 size_type offset,
1670 bool end,
1671 layout_type l
1672 )
1673 : m_reducer(&red)
1674 , m_offset(offset)
1675 , m_stepper(get_substepper_begin())
1676 {
1677 if (end)
1678 {
1679 to_end(l);
1680 }
1681 }
1682
1683 template <class F, class CT, class X, class O>
1684 inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
1685 {
1686 reference r = aggregate(0);
1687 return r;
1688 }
1689
1690 template <class F, class CT, class X, class O>
1691 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
1692 {
1693 if (dim >= m_offset)
1694 {
1695 m_stepper.step(get_dim(dim - m_offset));
1696 }
1697 }
1698
1699 template <class F, class CT, class X, class O>
1700 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
1701 {
1702 if (dim >= m_offset)
1703 {
1704 m_stepper.step_back(get_dim(dim - m_offset));
1705 }
1706 }
1707
1708 template <class F, class CT, class X, class O>
1709 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
1710 {
1711 if (dim >= m_offset)
1712 {
1713 m_stepper.step(get_dim(dim - m_offset), n);
1714 }
1715 }
1716
1717 template <class F, class CT, class X, class O>
1718 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
1719 {
1720 if (dim >= m_offset)
1721 {
1722 m_stepper.step_back(get_dim(dim - m_offset), n);
1723 }
1724 }
1725
1726 template <class F, class CT, class X, class O>
1727 inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
1728 {
1729 if (dim >= m_offset)
1730 {
1731 // Because the reducer uses `reset` to reset the non-reducing axes,
1732 // we need to prevent that here for the KD case where.
1733 if (typename O::keep_dims()
1734 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1735 {
1736 // If keep dim activated, and dim is in the axes, do nothing!
1737 return;
1738 }
1739 m_stepper.reset(get_dim(dim - m_offset));
1740 }
1741 }
1742
1743 template <class F, class CT, class X, class O>
1744 inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
1745 {
1746 if (dim >= m_offset)
1747 {
1748 // Note that for *not* KD this is not going to do anything
1749 if (typename O::keep_dims()
1750 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1751 {
1752 // If keep dim activated, and dim is in the axes, do nothing!
1753 return;
1754 }
1755 m_stepper.reset_back(get_dim(dim - m_offset));
1756 }
1757 }
1758
1759 template <class F, class CT, class X, class O>
1760 inline void xreducer_stepper<F, CT, X, O>::to_begin()
1761 {
1762 m_stepper.to_begin();
1763 }
1764
1765 template <class F, class CT, class X, class O>
1766 inline void xreducer_stepper<F, CT, X, O>::to_end(layout_type l)
1767 {
1768 m_stepper.to_end(l);
1769 }
1770
1771 template <class F, class CT, class X, class O>
1772 inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
1773 {
1774 return O::has_initial_value ? m_reducer->m_options.initial_value
1775 : static_cast<reference>(m_reducer->m_init());
1776 }
1777
1778 template <class F, class CT, class X, class O>
1779 inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
1780 {
1781 reference res;
1782 if (m_reducer->m_e.size() == size_type(0))
1783 {
1784 res = initial_value();
1785 }
1786 else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
1787 {
1788 res = m_reducer->m_reduce(initial_value(), *m_stepper);
1789 }
1790 else
1791 {
1792 res = aggregate_impl(dim, typename O::keep_dims());
1793 if (O::has_initial_value && dim == 0)
1794 {
1795 res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1796 }
1797 }
1798 return res;
1799 }
1800
1801 template <class F, class CT, class X, class O>
1802 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
1803 {
1804 // reference can be std::array, hence the {} initializer
1805 reference res = {};
1806 size_type index = axis(dim);
1807 size_type size = shape(index);
1808 if (dim != m_reducer->m_axes.size() - 1)
1809 {
1810 res = aggregate_impl(dim + 1, typename O::keep_dims());
1811 for (size_type i = 1; i != size; ++i)
1812 {
1813 m_stepper.step(index);
1814 res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
1815 }
1816 }
1817 else
1818 {
1819 res = m_reducer->m_reduce(static_cast<reference>(m_reducer->m_init()), *m_stepper);
1820 for (size_type i = 1; i != size; ++i)
1821 {
1822 m_stepper.step(index);
1823 res = m_reducer->m_reduce(res, *m_stepper);
1824 }
1825 }
1826 m_stepper.reset(index);
1827 return res;
1828 }
1829
1830 template <class F, class CT, class X, class O>
1831 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type) const -> reference
1832 {
1833 // reference can be std::array, hence the {} initializer
1834 reference res = {};
1835 auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1836 if (ax_it != m_reducer->m_axes.end())
1837 {
1838 size_type index = dim;
1839 size_type size = m_reducer->m_e.shape()[index];
1840 if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
1841 {
1842 res = aggregate_impl(dim + 1, typename O::keep_dims());
1843 for (size_type i = 1; i != size; ++i)
1844 {
1845 m_stepper.step(index);
1846 res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
1847 }
1848 }
1849 else
1850 {
1851 res = m_reducer->m_reduce(static_cast<reference>(m_reducer->m_init()), *m_stepper);
1852 for (size_type i = 1; i != size; ++i)
1853 {
1854 m_stepper.step(index);
1855 res = m_reducer->m_reduce(res, *m_stepper);
1856 }
1857 }
1858 m_stepper.reset(index);
1859 }
1860 else
1861 {
1862 if (dim < m_reducer->m_e.dimension())
1863 {
1864 res = aggregate_impl(dim + 1, typename O::keep_dims());
1865 }
1866 }
1867 return res;
1868 }
1869
1870 template <class F, class CT, class X, class O>
1871 inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
1872 {
1873 return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
1874 }
1875
1876 template <class F, class CT, class X, class O>
1877 inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim) const noexcept -> size_type
1878 {
1879 return m_reducer->m_dim_mapping[dim];
1880 }
1881
1882 template <class F, class CT, class X, class O>
1883 inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i) const noexcept -> size_type
1884 {
1885 return m_reducer->m_e.shape()[i];
1886 }
1887
1888 template <class F, class CT, class X, class O>
1889 inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i) const noexcept -> size_type
1890 {
1891 return m_reducer->m_axes[i];
1892 }
1893}
1894
1895#endif
Fixed shape implementation for compile time defined arrays.
size_type size() const noexcept(noexcept(derived_cast().shape()))
Base class for multidimensional iterable constant expressions.
Definition xiterable.hpp:37
auto end() const noexcept -> const_layout_iterator< L >
Reducing function operating over specified axes.
Definition xreducer.hpp:808
bool broadcast_shape(S &shape, bool reuse_cache=false) const
const xexpression_type & expression() const noexcept
bool has_linear_assign(const S &strides) const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
xreducer(Func &&func, CTA &&e, AX &&axes, OX &&options)
Constructs an xreducer expression applying the specified function to the given expression over the gi...
layout_type layout() const noexcept
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto diff(const xexpression< T > &a, std::size_t n=1, std::ptrdiff_t axis=-1)
Calculate the n-th discrete difference along the given axis.
Definition xmath.hpp:2908
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
Definition xeval.hpp:46
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:250
standard mathematical functions for xexpressions
auto reduce(F &&f, E &&e, X &&axes, EVS &&options=EVS())
Returns an xexpression applying the specified reducing function to an expression over the given axes.
Definition xreducer.hpp:993
layout_type
Definition xlayout.hpp:24