xtensor
Loading...
Searching...
No Matches
xreducer.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_REDUCER_HPP
11#define XTENSOR_REDUCER_HPP
12
13#include <algorithm>
14#include <cstddef>
15#include <initializer_list>
16#include <iterator>
17#include <stdexcept>
18#include <tuple>
19#include <type_traits>
20#include <utility>
21
22#include <xtl/xfunctional.hpp>
23#include <xtl/xsequence.hpp>
24
25#include "xaccessible.hpp"
26#include "xbuilder.hpp"
27#include "xeval.hpp"
28#include "xexpression.hpp"
29#include "xgenerator.hpp"
30#include "xiterable.hpp"
31#include "xtensor_config.hpp"
32#include "xutils.hpp"
33
34namespace xt
35{
36 template <template <class...> class A, class... AX, class X, XTL_REQUIRES(is_evaluation_strategy<AX>..., is_evaluation_strategy<X>)>
37 auto operator|(const A<AX...>& args, const A<X>& rhs)
38 {
39 return std::tuple_cat(args, rhs);
40 }
41
42 struct keep_dims_type : xt::detail::option_base
43 {
44 };
45
46 constexpr auto keep_dims = std::tuple<keep_dims_type>{};
47
48 template <class T = double>
49 struct xinitial : xt::detail::option_base
50 {
51 constexpr xinitial(T val)
52 : m_val(val)
53 {
54 }
55
56 constexpr T value() const
57 {
58 return m_val;
59 }
60
61 T m_val;
62 };
63
64 template <class T>
65 constexpr auto initial(T val)
66 {
67 return std::make_tuple(xinitial<T>(val));
68 }
69
70 template <std::ptrdiff_t I, class T, class Tuple>
72
73 template <std::ptrdiff_t I, class T>
74 struct tuple_idx_of_impl<I, T, std::tuple<>>
75 {
76 static constexpr std::ptrdiff_t value = -1;
77 };
78
79 template <std::ptrdiff_t I, class T, class... Types>
80 struct tuple_idx_of_impl<I, T, std::tuple<T, Types...>>
81 {
82 static constexpr std::ptrdiff_t value = I;
83 };
84
85 template <std::ptrdiff_t I, class T, class U, class... Types>
86 struct tuple_idx_of_impl<I, T, std::tuple<U, Types...>>
87 {
88 static constexpr std::ptrdiff_t value = tuple_idx_of_impl<I + 1, T, std::tuple<Types...>>::value;
89 };
90
91 template <class S, class... X>
92 struct decay_all;
93
94 template <template <class...> class S, class... X>
95 struct decay_all<S<X...>>
96 {
97 using type = S<std::decay_t<X>...>;
98 };
99
100 template <class T, class Tuple>
102 {
103 static constexpr std::ptrdiff_t
105 };
106
107 template <class R, class T>
109 {
110 template <class X>
111 struct initial_tester : std::false_type
112 {
113 };
114
115 template <class X>
116 struct initial_tester<xinitial<X>> : std::true_type
117 {
118 };
119
120 // Workaround for Apple because tuple_cat is buggy!
121 template <class X>
122 struct initial_tester<const xinitial<X>> : std::true_type
123 {
124 };
125
126 using d_t = std::decay_t<T>;
127
128 static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
129 reducer_options() = default;
130
131 reducer_options(const T& tpl)
132 {
133 xtl::mpl::static_if<initial_val_idx != std::tuple_size<T>::value>(
134 [this, &tpl](auto no_compile)
135 {
136 // use no_compile to prevent compilation if initial_val_idx is out of bounds!
137 this->initial_value = no_compile(
138 std::get < initial_val_idx != std::tuple_size<T>::value
139 ? initial_val_idx
140 : 0 > (tpl)
141 )
142 .value();
143 },
144 [](auto /*np_compile*/) {}
145 );
146 }
147
148 using evaluation_strategy = std::conditional_t<
149 tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
152
153 using keep_dims = std::
154 conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1, std::true_type, std::false_type>;
155
156 static constexpr bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
157
158 R initial_value;
159
160 template <class NR>
161 using rebind_t = reducer_options<NR, T>;
162
163 template <class NR>
164 auto rebind(NR initial, const reducer_options<R, T>&) const
165 {
166 reducer_options<NR, T> res;
167 res.initial_value = initial;
168 return res;
169 }
170 };
171
172 template <class T>
173 struct is_reducer_options_impl : std::false_type
174 {
175 };
176
177 template <class... X>
178 struct is_reducer_options_impl<std::tuple<X...>> : std::true_type
179 {
180 };
181
182 template <class T>
184 {
185 };
186
187 /**********
188 * reduce *
189 **********/
190
191#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
192
193 template <class ST, class X, class KD = std::false_type>
194 struct xreducer_shape_type;
195
196 template <class S1, class S2>
198
199 namespace detail
200 {
201 template <class O, class RS, class R, class E, class AX>
202 inline void shape_computation(
204 R& result,
205 E& expr,
206 const AX& axes,
207 std::enable_if_t<!detail::is_fixed<RS>::value, int> = 0
208 )
209 {
210 if (typename O::keep_dims())
211 {
212 resize_container(result_shape, expr.dimension());
213 for (std::size_t i = 0; i < expr.dimension(); ++i)
214 {
215 if (std::find(axes.begin(), axes.end(), i) == axes.end())
216 {
217 // i not in axes!
218 result_shape[i] = expr.shape()[i];
219 }
220 else
221 {
222 result_shape[i] = 1;
223 }
224 }
225 }
226 else
227 {
228 resize_container(result_shape, expr.dimension() - axes.size());
229 for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
230 {
231 if (std::find(axes.begin(), axes.end(), i) == axes.end())
232 {
233 // i not in axes!
234 result_shape[idx] = expr.shape()[i];
235 ++idx;
236 }
237 }
238 }
239 result.resize(result_shape, expr.layout());
240 }
241
242 // skip shape computation if already done at compile time
243 template <class O, class RS, class R, class S, class AX>
244 inline void
245 shape_computation(RS&, R&, const S&, const AX&, std::enable_if_t<detail::is_fixed<RS>::value, int> = 0)
246 {
247 }
248 }
249
250 template <class F, class E, class R, XTL_REQUIRES(std::is_convertible<typename E::value_type, typename R::value_type>)>
251 inline void copy_to_reduced(F&, const E& e, R& result)
252 {
253 if (e.layout() == layout_type::row_major)
254 {
255 std::copy(
256 e.template cbegin<layout_type::row_major>(),
257 e.template cend<layout_type::row_major>(),
258 result.data()
259 );
260 }
261 else
262 {
263 std::copy(
264 e.template cbegin<layout_type::column_major>(),
265 e.template cend<layout_type::column_major>(),
266 result.data()
267 );
268 }
269 }
270
271 template <
272 class F,
273 class E,
274 class R,
275 XTL_REQUIRES(xtl::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
276 inline void copy_to_reduced(F& f, const E& e, R& result)
277 {
278 if (e.layout() == layout_type::row_major)
279 {
280 std::transform(
281 e.template cbegin<layout_type::row_major>(),
282 e.template cend<layout_type::row_major>(),
283 result.data(),
284 f
285 );
286 }
287 else
288 {
289 std::transform(
290 e.template cbegin<layout_type::column_major>(),
291 e.template cend<layout_type::column_major>(),
292 result.data(),
293 f
294 );
295 }
296 }
297
298 template <class F, class E, class X, class O>
299 inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
300 {
301 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
302 using init_functor_type = typename std::decay_t<F>::init_functor_type;
303 using expr_value_type = typename std::decay_t<E>::value_type;
304 using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
305 std::declval<init_functor_type>()(),
306 std::declval<expr_value_type>()
307 ))>;
308
309 using options_t = reducer_options<result_type, std::decay_t<O>>;
310 options_t options(raw_options);
311
312 using shape_type = typename xreducer_shape_type<
313 typename std::decay_t<E>::shape_type,
314 std::decay_t<X>,
315 typename options_t::keep_dims>::type;
316 using result_container_type = typename detail::xtype_for_shape<
317 shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
318 result_container_type result;
319
320 // retrieve functors from triple struct
321 auto reduce_fct = xt::get<0>(f);
322 auto init_fct = xt::get<1>(f);
323 auto merge_fct = xt::get<2>(f);
324
325 if (axes.size() == 0)
326 {
327 result.resize(e.shape(), e.layout());
328 auto cpf = [&reduce_fct, &init_fct](const auto& v)
329 {
330 return reduce_fct(static_cast<result_type>(init_fct()), v);
331 };
332 copy_to_reduced(cpf, e, result);
333 return result;
334 }
335
336 shape_type result_shape{};
337 dynamic_shape<std::size_t>
338 iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>, decltype(e.shape())>(e.shape());
339 dynamic_shape<std::size_t> iter_strides(e.dimension());
340
341 // std::less is used, because as the standard says (24.4.5):
342 // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the
343 // sequence and any non-negative integer n such that i + n is a valid iterator pointing to an element
344 // of the sequence, comp(*(i + n), *i) == false. Therefore less is required to detect duplicates.
345 if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less<>()))
346 {
347 XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted.");
348 }
349 if (std::adjacent_find(axes.cbegin(), axes.cend()) != axes.cend())
350 {
351 XTENSOR_THROW(std::runtime_error, "Reducing axes should not contain duplicates.");
352 }
353 if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
354 {
355 XTENSOR_THROW(
356 std::runtime_error,
357 "Axis " + std::to_string(axes[axes.size() - 1]) + " out of bounds for reduction."
358 );
359 }
360
361 detail::shape_computation<options_t>(result_shape, result, e, axes);
362
363 // Fast track for complete reduction
364 if (e.dimension() == axes.size())
365 {
366 result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
367 result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
368 return result;
369 }
370
371 std::size_t leading_ax = axes[(e.layout() == layout_type::row_major) ? axes.size() - 1 : 0];
372 auto strides_finder = e.strides().begin() + static_cast<std::ptrdiff_t>(leading_ax);
373 // The computed strides contain "0" where the shape is 1 -- therefore find the next none-zero number
374 std::size_t inner_stride = static_cast<std::size_t>(*strides_finder);
375 auto iter_bound = e.layout() == layout_type::row_major ? e.strides().begin() : (e.strides().end() - 1);
376 while (inner_stride == 0 && strides_finder != iter_bound)
377 {
378 (e.layout() == layout_type::row_major) ? --strides_finder : ++strides_finder;
379 inner_stride = static_cast<std::size_t>(*strides_finder);
380 }
381
382 if (inner_stride == 0)
383 {
384 auto cpf = [&reduce_fct, &init_fct](const auto& v)
385 {
386 return reduce_fct(static_cast<result_type>(init_fct()), v);
387 };
388 copy_to_reduced(cpf, e, result);
389 return result;
390 }
391
392 std::size_t inner_loop_size = static_cast<std::size_t>(inner_stride);
393 std::size_t outer_loop_size = e.shape()[leading_ax];
394
395 // The following code merges reduction axes "at the end" (or the beginning for col_major)
396 // together by increasing the size of the outer loop where appropriate
397 auto merge_loops = [&outer_loop_size, &e](auto it, auto end)
398 {
399 auto last_ax = *it;
400 ++it;
401 for (; it != end; ++it)
402 {
403 // note that we check is_sorted, so this condition is valid
404 if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
405 {
406 last_ax = *it;
407 outer_loop_size *= e.shape()[last_ax];
408 }
409 }
410 return last_ax;
411 };
412
413 for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
414 {
415 if (std::find(axes.begin(), axes.end(), i) == axes.end())
416 {
417 // i not in axes!
418 iter_strides[i] = static_cast<std::size_t>(result.strides(
419 )[typename options_t::keep_dims() ? i : idx]);
420 ++idx;
421 }
422 }
423
424 if (e.layout() == layout_type::row_major)
425 {
426 std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
427
428 iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
429 iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
430 }
431 else if (e.layout() == layout_type::column_major)
432 {
433 // we got column_major here
434 std::size_t last_ax = merge_loops(axes.begin(), axes.end());
435
436 // erasing the front vs the back
437 iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
438 iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
439
440 // and reversing, to make it work with the same next_idx function
441 std::reverse(iter_shape.begin(), iter_shape.end());
442 std::reverse(iter_strides.begin(), iter_strides.end());
443 }
444 else
445 {
446 XTENSOR_THROW(std::runtime_error, "Layout not supported in immediate reduction.");
447 }
448
449 xindex temp_idx(iter_shape.size());
450 auto next_idx = [&iter_shape, &iter_strides, &temp_idx]()
451 {
452 std::size_t i = iter_shape.size();
453 for (; i > 0; --i)
454 {
455 if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
456 {
457 temp_idx[i - 1] = 0;
458 }
459 else
460 {
461 temp_idx[i - 1]++;
462 break;
463 }
464 }
465
466 return std::make_pair(
467 i == 0,
468 std::inner_product(temp_idx.begin(), temp_idx.end(), iter_strides.begin(), std::ptrdiff_t(0))
469 );
470 };
471
472 auto begin = e.data();
473 auto out = result.data();
474 auto out_begin = result.data();
475
476 std::ptrdiff_t next_stride = 0;
477
478 std::pair<bool, std::ptrdiff_t> idx_res(false, 0);
479
480 // Remark: eventually some modifications here to make conditions faster where merge + accumulate is
481 // the same function (e.g. check std::is_same<decltype(merge_fct), decltype(reduce_fct)>::value) ...
482
483 auto merge_border = out;
484 bool merge = false;
485
486 // TODO there could be some performance gain by removing merge checking
487 // when axes.size() == 1 and even next_idx could be removed for something simpler (next_stride
488 // always the same) best way to do this would be to create a function that takes (begin, out,
489 // outer_loop_size, inner_loop_size, next_idx_lambda)
490 // Decide if going about it row-wise or col-wise
491 if (inner_stride == 1)
492 {
493 while (idx_res.first != true)
494 {
495 // for unknown reasons it's much faster to use a temporary variable and
496 // std::accumulate here -- probably some cache behavior
497 result_type tmp = init_fct();
498 tmp = std::accumulate(begin, begin + outer_loop_size, tmp, reduce_fct);
499
500 // use merge function if necessary
501 *out = merge ? merge_fct(*out, tmp) : tmp;
502
503 begin += outer_loop_size;
504
505 idx_res = next_idx();
506 next_stride = idx_res.second;
507 out = out_begin + next_stride;
508
509 if (out > merge_border)
510 {
511 // looped over once
512 merge = false;
513 merge_border = out;
514 }
515 else
516 {
517 merge = true;
518 }
519 };
520 }
521 else
522 {
523 while (idx_res.first != true)
524 {
525 std::transform(
526 out,
527 out + inner_loop_size,
528 begin,
529 out,
530 [merge, &init_fct, &reduce_fct](auto&& v1, auto&& v2)
531 {
532 return merge ? reduce_fct(v1, v2) :
533 // cast because return type of identity function is not upcasted
534 reduce_fct(static_cast<result_type>(init_fct()), v2);
535 }
536 );
537
538 begin += inner_stride;
539 for (std::size_t i = 1; i < outer_loop_size; ++i)
540 {
541 std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
542 begin += inner_stride;
543 }
544
545 idx_res = next_idx();
546 next_stride = idx_res.second;
547 out = out_begin + next_stride;
548
549 if (out > merge_border)
550 {
551 // looped over once
552 merge = false;
553 merge_border = out;
554 }
555 else
556 {
557 merge = true;
558 }
559 };
560 }
561 if (options_t::has_initial_value)
562 {
563 std::transform(
564 result.data(),
565 result.data() + result.size(),
566 result.data(),
567 [&merge_fct, &options](auto&& v)
568 {
569 return merge_fct(v, options.initial_value);
570 }
571 );
572 }
573 return result;
574 }
575
576 /*********************
577 * xreducer functors *
578 *********************/
579
580 template <class T>
582 {
583 using value_type = T;
584
585 constexpr const_value() = default;
586
587 constexpr const_value(T t)
588 : m_value(t)
589 {
590 }
591
592 constexpr T operator()() const
593 {
594 return m_value;
595 }
596
597 template <class NT>
599
600 template <class NT>
601 const_value<NT> rebind() const;
602
603 T m_value;
604 };
605
606 namespace detail
607 {
608 template <class T, bool B>
609 struct evaluated_value_type
610 {
611 using type = T;
612 };
613
614 template <class T>
615 struct evaluated_value_type<T, true>
616 {
617 using type = typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
618 };
619
620 template <class T, bool B>
621 using evaluated_value_type_t = typename evaluated_value_type<T, B>::type;
622 }
623
624 template <class REDUCE_FUNC, class INIT_FUNC = const_value<long int>, class MERGE_FUNC = REDUCE_FUNC>
625 struct xreducer_functors : public std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>
626 {
628 using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
632 using init_value_type = typename init_functor_type::value_type;
633
635 : base_type()
636 {
637 }
638
639 template <class RF>
641 : base_type(std::forward<RF>(reduce_func), INIT_FUNC(), reduce_func)
642 {
643 }
644
645 template <class RF, class IF>
647 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), reduce_func)
648 {
649 }
650
651 template <class RF, class IF, class MF>
653 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func))
654 {
655 }
656
657 reduce_functor_type get_reduce() const
658 {
659 return std::get<0>(upcast());
660 }
661
662 init_functor_type get_init() const
663 {
664 return std::get<1>(upcast());
665 }
666
667 merge_functor_type get_merge() const
668 {
669 return std::get<2>(upcast());
670 }
671
672 template <class NT>
674
675 template <class NT>
676 rebind_t<NT> rebind()
677 {
678 return make_xreducer_functor(get_reduce(), get_init().template rebind<NT>(), get_merge());
679 }
680
681 private:
682
683 // Workaround for clang-cl
684 const base_type& upcast() const
685 {
686 return static_cast<const base_type&>(*this);
687 }
688 };
689
690 template <class RF>
691 auto make_xreducer_functor(RF&& reduce_func)
692 {
694 return reducer_type(std::forward<RF>(reduce_func));
695 }
696
697 template <class RF, class IF>
698 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
699 {
700 using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
701 return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
702 }
703
704 template <class RF, class IF, class MF>
705 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
706 {
707 using reducer_type = xreducer_functors<
708 std::remove_reference_t<RF>,
709 std::remove_reference_t<IF>,
710 std::remove_reference_t<MF>>;
711 return reducer_type(
712 std::forward<RF>(reduce_func),
713 std::forward<IF>(init_func),
714 std::forward<MF>(merge_func)
715 );
716 }
717
718 /**********************
719 * xreducer extension *
720 **********************/
721
722 namespace extension
723 {
724 template <class Tag, class F, class CT, class X, class O>
726
727 template <class F, class CT, class X, class O>
732
733 template <class F, class CT, class X, class O>
734 struct xreducer_base : xreducer_base_impl<xexpression_tag_t<CT>, F, CT, X, O>
735 {
736 };
737
738 template <class F, class CT, class X, class O>
739 using xreducer_base_t = typename xreducer_base<F, CT, X, O>::type;
740 }
741
742 /************
743 * xreducer *
744 ************/
745
746 template <class F, class CT, class X, class O>
747 class xreducer;
748
749 template <class F, class CT, class X, class O>
750 class xreducer_stepper;
751
752 template <class F, class CT, class X, class O>
754 {
755 using xexpression_type = std::decay_t<CT>;
756 using inner_shape_type = typename xreducer_shape_type<
757 typename xexpression_type::shape_type,
758 std::decay_t<X>,
759 typename O::keep_dims>::type;
761 using stepper = const_stepper;
762 };
763
764 template <class F, class CT, class X, class O>
766 {
767 using xexpression_type = std::decay_t<CT>;
768 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
769 using init_functor_type = typename std::decay_t<F>::init_functor_type;
770 using merge_functor_type = typename std::decay_t<F>::merge_functor_type;
771 using substepper_type = typename xexpression_type::const_stepper;
772 using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
773 std::declval<init_functor_type>()(),
774 *std::declval<substepper_type>()
775 ))>;
776 using value_type = typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
777
778 using reference = value_type;
779 using const_reference = value_type;
780 using size_type = typename xexpression_type::size_type;
781 };
782
783 template <class T>
785 {
786 using type = T;
787 };
788
789 template <std::size_t... I>
791 {
792 using type = std::array<std::size_t, sizeof...(I)>;
793 };
794
812 template <class F, class CT, class X, class O>
813 class xreducer : public xsharable_expression<xreducer<F, CT, X, O>>,
814 public xconst_iterable<xreducer<F, CT, X, O>>,
815 public xaccessible<xreducer<F, CT, X, O>>,
816 public extension::xreducer_base_t<F, CT, X, O>
817 {
818 public:
819
822
823 using reduce_functor_type = typename inner_types::reduce_functor_type;
824 using init_functor_type = typename inner_types::init_functor_type;
825 using merge_functor_type = typename inner_types::merge_functor_type;
827
828 using xexpression_type = typename inner_types::xexpression_type;
829 using axes_type = X;
830
831 using extension_base = extension::xreducer_base_t<F, CT, X, O>;
832 using expression_tag = typename extension_base::expression_tag;
833
834 using substepper_type = typename inner_types::substepper_type;
835 using value_type = typename inner_types::value_type;
836 using reference = typename inner_types::reference;
837 using const_reference = typename inner_types::const_reference;
838 using pointer = value_type*;
839 using const_pointer = const value_type*;
840
841 using size_type = typename inner_types::size_type;
842 using difference_type = typename xexpression_type::difference_type;
843
845 using inner_shape_type = typename iterable_base::inner_shape_type;
846 using shape_type = inner_shape_type;
847
848 using dim_mapping_type = typename select_dim_mapping_type<inner_shape_type>::type;
849
850 using stepper = typename iterable_base::stepper;
851 using const_stepper = typename iterable_base::const_stepper;
852 using bool_load_type = typename xexpression_type::bool_load_type;
853
854 static constexpr layout_type static_layout = layout_type::dynamic;
855 static constexpr bool contiguous_layout = false;
856
857 template <class Func, class CTA, class AX, class OX>
858 xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options);
859
860 const inner_shape_type& shape() const noexcept;
861 layout_type layout() const noexcept;
862 bool is_contiguous() const noexcept;
863
864 template <class... Args>
865 const_reference operator()(Args... args) const;
866 template <class... Args>
867 const_reference unchecked(Args... args) const;
868
869 template <class It>
870 const_reference element(It first, It last) const;
871
872 const xexpression_type& expression() const noexcept;
873
874 template <class S>
875 bool broadcast_shape(S& shape, bool reuse_cache = false) const;
876
877 template <class S>
878 bool has_linear_assign(const S& strides) const noexcept;
879
880 template <class S>
881 const_stepper stepper_begin(const S& shape) const noexcept;
882 template <class S>
883 const_stepper stepper_end(const S& shape, layout_type) const noexcept;
884
885 template <class E, class Func = F, class Opts = O>
887
888 template <class E>
889 rebind_t<E> build_reducer(E&& e) const;
890
891 template <class E, class Func, class Opts>
892 rebind_t<E, Func, Opts> build_reducer(E&& e, Func&& func, Opts&& opts) const;
893
894 xreducer_functors_type functors() const
895 {
896 return xreducer_functors_type(m_reduce, m_init, m_merge); // TODO: understand why
897 // make_xreducer_functor is throwing an
898 // error
899 }
900
901 const O& options() const
902 {
903 return m_options;
904 }
905
906 private:
907
908 CT m_e;
909 reduce_functor_type m_reduce;
910 init_functor_type m_init;
911 merge_functor_type m_merge;
912 axes_type m_axes;
913 inner_shape_type m_shape;
914 dim_mapping_type m_dim_mapping;
915 O m_options;
916
917 friend class xreducer_stepper<F, CT, X, O>;
918 };
919
920 /*************************
921 * reduce implementation *
922 *************************/
923
924 namespace detail
925 {
926 template <class F, class E, class X, class O>
927 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::lazy_type, O&& options)
928 {
929 decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
930
931 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
932 using init_functor_type = typename std::decay_t<F>::init_functor_type;
933 using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
934 std::declval<init_functor_type>()(),
935 *std::declval<typename std::decay_t<E>::const_stepper>()
936 ))>;
938
939 using reducer_type = xreducer<
940 F,
942 xtl::const_closure_type_t<decltype(normalized_axes)>,
943 reducer_options<evaluated_value_type, std::decay_t<O>>>;
944 return reducer_type(
945 std::forward<F>(f),
946 std::forward<E>(e),
947 std::forward<decltype(normalized_axes)>(normalized_axes),
948 std::forward<O>(options)
949 );
950 }
951
952 template <class F, class E, class X, class O>
953 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
954 {
955 decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
956 return reduce_immediate(
957 std::forward<F>(f),
958 eval(std::forward<E>(e)),
959 std::forward<decltype(normalized_axes)>(normalized_axes),
960 std::forward<O>(options)
961 );
962 }
963 }
964
965#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
966
967 namespace detail
968 {
969 template <class T>
970 struct is_xreducer_functors_impl : std::false_type
971 {
972 };
973
974 template <class RF, class IF, class MF>
975 struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>> : std::true_type
976 {
977 };
978
979 template <class T>
980 using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
981 }
982
996 template <
997 class F,
998 class E,
999 class X,
1000 class EVS = DEFAULT_STRATEGY_REDUCERS,
1001 XTL_REQUIRES(xtl::negation<is_reducer_options<X>>, detail::is_xreducer_functors<F>)>
1002 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1003 {
1004 return detail::reduce_impl(
1005 std::forward<F>(f),
1006 std::forward<E>(e),
1007 std::forward<X>(axes),
1008 typename reducer_options<int, EVS>::evaluation_strategy{},
1009 std::forward<EVS>(options)
1010 );
1011 }
1012
1013 template <
1014 class F,
1015 class E,
1016 class X,
1017 class EVS = DEFAULT_STRATEGY_REDUCERS,
1018 XTL_REQUIRES(xtl::negation<is_reducer_options<X>>, xtl::negation<detail::is_xreducer_functors<F>>)>
1019 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1020 {
1021 return reduce(
1022 make_xreducer_functor(std::forward<F>(f)),
1023 std::forward<E>(e),
1024 std::forward<X>(axes),
1025 std::forward<EVS>(options)
1026 );
1027 }
1028
1029 template <
1030 class F,
1031 class E,
1032 class EVS = DEFAULT_STRATEGY_REDUCERS,
1033 XTL_REQUIRES(is_reducer_options<EVS>, detail::is_xreducer_functors<F>)>
1034 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1035 {
1036 xindex_type_t<typename std::decay_t<E>::shape_type> ar;
1037 resize_container(ar, e.dimension());
1038 std::iota(ar.begin(), ar.end(), 0);
1039 return detail::reduce_impl(
1040 std::forward<F>(f),
1041 std::forward<E>(e),
1042 std::move(ar),
1043 typename reducer_options<int, std::decay_t<EVS>>::evaluation_strategy{},
1044 std::forward<EVS>(options)
1045 );
1046 }
1047
1048 template <
1049 class F,
1050 class E,
1051 class EVS = DEFAULT_STRATEGY_REDUCERS,
1052 XTL_REQUIRES(is_reducer_options<EVS>, xtl::negation<detail::is_xreducer_functors<F>>)>
1053 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1054 {
1055 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
1056 }
1057
1058 template <
1059 class F,
1060 class E,
1061 class I,
1062 std::size_t N,
1063 class EVS = DEFAULT_STRATEGY_REDUCERS,
1064 XTL_REQUIRES(detail::is_xreducer_functors<F>)>
1065 inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
1066 {
1067 using axes_type = std::array<std::size_t, N>;
1068 auto ax = xt::forward_normalize<axes_type>(e, axes);
1069 return detail::reduce_impl(
1070 std::forward<F>(f),
1071 std::forward<E>(e),
1072 std::move(ax),
1073 typename reducer_options<int, EVS>::evaluation_strategy{},
1074 options
1075 );
1076 }
1077
1078 template <
1079 class F,
1080 class E,
1081 class I,
1082 std::size_t N,
1083 class EVS = DEFAULT_STRATEGY_REDUCERS,
1084 XTL_REQUIRES(xtl::negation<detail::is_xreducer_functors<F>>)>
1085 inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
1086 {
1087 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
1088 }
1089
1090 /********************
1091 * xreducer_stepper *
1092 ********************/
1093
1094 template <class F, class CT, class X, class O>
1096 {
1097 public:
1098
1101
1102 using value_type = typename xreducer_type::value_type;
1103 using reference = typename xreducer_type::value_type;
1104 using pointer = typename xreducer_type::const_pointer;
1105 using size_type = typename xreducer_type::size_type;
1106 using difference_type = typename xreducer_type::difference_type;
1107
1108 using xexpression_type = typename xreducer_type::xexpression_type;
1109 using substepper_type = typename xexpression_type::const_stepper;
1110 using shape_type = typename xreducer_type::shape_type;
1111
1113 const xreducer_type& red,
1114 size_type offset,
1115 bool end = false,
1116 layout_type l = default_assignable_layout(xexpression_type::static_layout)
1117 );
1118
1119 reference operator*() const;
1120
1121 void step(size_type dim);
1122 void step_back(size_type dim);
1123 void step(size_type dim, size_type n);
1124 void step_back(size_type dim, size_type n);
1125 void reset(size_type dim);
1126 void reset_back(size_type dim);
1127
1128 void to_begin();
1129 void to_end(layout_type l);
1130
1131 private:
1132
1133 reference initial_value() const;
1134 reference aggregate(size_type dim) const;
1135 reference aggregate_impl(size_type dim, /*keep_dims=*/std::false_type) const;
1136 reference aggregate_impl(size_type dim, /*keep_dims=*/std::true_type) const;
1137
1138 substepper_type get_substepper_begin() const;
1139 size_type get_dim(size_type dim) const noexcept;
1140 size_type shape(size_type i) const noexcept;
1141 size_type axis(size_type i) const noexcept;
1142
1143 const xreducer_type* m_reducer;
1144 size_type m_offset;
1145 mutable substepper_type m_stepper;
1146 };
1147
1148 /******************
1149 * xreducer utils *
1150 ******************/
1151
1152 namespace detail
1153 {
1154 template <std::size_t X, std::size_t... I>
1155 struct in
1156 {
1157 static constexpr bool value = xtl::disjunction<std::integral_constant<bool, X == I>...>::value;
1158 };
1159
1160 template <std::size_t Z, class S1, class S2, class R>
1161 struct fixed_xreducer_shape_type_impl;
1162
1163 template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
1164 struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1165 {
1166 using type = std::conditional_t<
1167 in<Z, J...>::value,
1168 typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>::type,
1169 typename fixed_xreducer_shape_type_impl<
1170 Z - 1,
1171 fixed_shape<I...>,
1172 fixed_shape<J...>,
1173 fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
1174 };
1175
1176 template <std::size_t... I, std::size_t... J, std::size_t... R>
1177 struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1178 {
1179 using type = std::
1180 conditional_t<in<0, J...>::value, fixed_shape<R...>, fixed_shape<detail::at<0, I...>::value, R...>>;
1181 };
1182
1183 /***************************
1184 * helper for return types *
1185 ***************************/
1186
1187 template <class T>
1188 struct xreducer_size_type
1189 {
1190 using type = std::size_t;
1191 };
1192
1193 template <class T>
1194 using xreducer_size_type_t = typename xreducer_size_type<T>::type;
1195
1196 template <class T>
1197 struct xreducer_temporary_type
1198 {
1199 using type = T;
1200 };
1201
1202 template <class T>
1203 using xreducer_temporary_type_t = typename xreducer_temporary_type<T>::type;
1204
1205 /********************************
1206 * Default const_value rebinder *
1207 ********************************/
1208
1209 template <class T, class U>
1210 struct const_value_rebinder
1211 {
1212 static const_value<U> run(const const_value<T>& t)
1213 {
1214 return const_value<U>(t.m_value);
1215 }
1216 };
1217 }
1218
1219 /*******************************************
1220 * Init functor const_value implementation *
1221 *******************************************/
1222
1223 template <class T>
1224 template <class NT>
1225 const_value<NT> const_value<T>::rebind() const
1226 {
1227 return detail::const_value_rebinder<T, NT>::run(*this);
1228 }
1229
1230 /*****************************
1231 * fixed_xreducer_shape_type *
1232 *****************************/
1233
1234 template <class S1, class S2>
1235 struct fixed_xreducer_shape_type;
1236
1237 template <std::size_t... I, std::size_t... J>
1239 {
1240 using type = typename detail::
1241 fixed_xreducer_shape_type_impl<sizeof...(I) - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<>>::type;
1242 };
1243
1244 // meta-function returning the shape type for an xreducer
1245 template <class ST, class X, class O>
1250
1251 template <class I1, std::size_t N1, class I2, std::size_t N2>
1252 struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::true_type>
1253 {
1254 using type = std::array<I2, N1>;
1255 };
1256
1257 template <class I1, std::size_t N1, class I2, std::size_t N2>
1258 struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::false_type>
1259 {
1260 using type = std::array<I2, N1 - N2>;
1261 };
1262
1263 template <std::size_t... I, class I2, std::size_t N2>
1264 struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::false_type>
1265 {
1266 using type = std::conditional_t<sizeof...(I) == N2, fixed_shape<>, std::array<I2, sizeof...(I) - N2>>;
1267 };
1268
1269 namespace detail
1270 {
1271 template <class S1, class S2>
1272 struct ixconcat;
1273
1274 template <class T, T... I1, T... I2>
1275 struct ixconcat<std::integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
1276 {
1277 using type = std::integer_sequence<T, I1..., I2...>;
1278 };
1279
1280 template <class T, T X, std::size_t N>
1281 struct repeat_integer_sequence
1282 {
1283 using type = typename ixconcat<
1284 std::integer_sequence<T, X>,
1285 typename repeat_integer_sequence<T, X, N - 1>::type>::type;
1286 };
1287
1288 template <class T, T X>
1289 struct repeat_integer_sequence<T, X, 0>
1290 {
1291 using type = std::integer_sequence<T>;
1292 };
1293
1294 template <class T, T X>
1295 struct repeat_integer_sequence<T, X, 2>
1296 {
1297 using type = std::integer_sequence<T, X, X>;
1298 };
1299
1300 template <class T, T X>
1301 struct repeat_integer_sequence<T, X, 1>
1302 {
1303 using type = std::integer_sequence<T, X>;
1304 };
1305 }
1306
1307 template <std::size_t... I, class I2, std::size_t N2>
1308 struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::true_type>
1309 {
1310 template <std::size_t... X>
1311 static constexpr auto get_type(std::index_sequence<X...>)
1312 {
1313 return fixed_shape<X...>{};
1314 }
1315
1316 // if all axes reduced
1317 using type = std::conditional_t<
1318 sizeof...(I) == N2,
1319 decltype(get_type(typename detail::repeat_integer_sequence<std::size_t, std::size_t(1), N2>::type{})),
1320 std::array<I2, sizeof...(I)>>;
1321 };
1322
1323 // Note adding "A" to prevent compilation in case nothing else matches
1324 template <std::size_t... I, std::size_t... J, class O>
1326 {
1327 using type = typename fixed_xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>>::type;
1328 };
1329
1330 namespace detail
1331 {
1332 template <class S, class E, class X, class M>
1333 inline void shape_and_mapping_computation(S& shape, E& e, const X& axes, M& mapping, std::false_type)
1334 {
1335 auto first = e.shape().begin();
1336 auto last = e.shape().end();
1337 auto exclude_it = axes.begin();
1338
1339 using value_type = typename S::value_type;
1340 using difference_type = typename S::difference_type;
1341 auto d_first = shape.begin();
1342 auto map_first = mapping.begin();
1343
1344 auto iter = first;
1345 while (iter != last && exclude_it != axes.end())
1346 {
1347 auto diff = std::distance(first, iter);
1348 if (diff != difference_type(*exclude_it))
1349 {
1350 *d_first++ = *iter++;
1351 *map_first++ = value_type(diff);
1352 }
1353 else
1354 {
1355 ++iter;
1356 ++exclude_it;
1357 }
1358 }
1359
1360 auto diff = std::distance(first, iter);
1361 auto end = std::distance(iter, last);
1362 std::iota(map_first, map_first + end, diff);
1363 std::copy(iter, last, d_first);
1364 }
1365
1366 template <class S, class E, class X, class M>
1367 inline void
1368 shape_and_mapping_computation_keep_dim(S& shape, E& e, const X& axes, M& mapping, std::false_type)
1369 {
1370 for (std::size_t i = 0; i < e.dimension(); ++i)
1371 {
1372 if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
1373 {
1374 // i not in axes!
1375 shape[i] = e.shape()[i];
1376 }
1377 else
1378 {
1379 shape[i] = 1;
1380 }
1381 }
1382 std::iota(mapping.begin(), mapping.end(), 0);
1383 }
1384
1385 template <class S, class E, class X, class M>
1386 inline void shape_and_mapping_computation(S&, E&, const X&, M&, std::true_type)
1387 {
1388 }
1389
1390 template <class S, class E, class X, class M>
1391 inline void shape_and_mapping_computation_keep_dim(S&, E&, const X&, M&, std::true_type)
1392 {
1393 }
1394 }
1395
1396 /***************************
1397 * xreducer implementation *
1398 ***************************/
1399
1412 template <class F, class CT, class X, class O>
1413 template <class Func, class CTA, class AX, class OX>
1414 inline xreducer<F, CT, X, O>::xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options)
1415 : m_e(std::forward<CTA>(e))
1416 , m_reduce(xt::get<0>(func))
1417 , m_init(xt::get<1>(func))
1418 , m_merge(xt::get<2>(func))
1419 , m_axes(std::forward<AX>(axes))
1420 , m_shape(xtl::make_sequence<inner_shape_type>(
1421 typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
1422 0
1423 ))
1424 , m_dim_mapping(xtl::make_sequence<dim_mapping_type>(
1425 typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
1426 0
1427 ))
1428 , m_options(std::forward<OX>(options))
1429 {
1430 // std::less is used, because as the standard says (24.4.5):
1431 // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the
1432 // sequence and any non-negative integer n such that i + n is a valid iterator pointing to an element
1433 // of the sequence, comp(*(i + n), *i) == false. Therefore less is required to detect duplicates.
1434 if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less<>()))
1435 {
1436 XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted.");
1437 }
1438 if (std::adjacent_find(m_axes.cbegin(), m_axes.cend()) != m_axes.cend())
1439 {
1440 XTENSOR_THROW(std::runtime_error, "Reducing axes should not contain duplicates.");
1441 }
1442 if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
1443 {
1444 XTENSOR_THROW(
1445 std::runtime_error,
1446 "Axis " + std::to_string(m_axes[m_axes.size() - 1]) + " out of bounds for reduction."
1447 );
1448 }
1449
1450 if (!typename O::keep_dims())
1451 {
1452 detail::shape_and_mapping_computation(
1453 m_shape,
1454 m_e,
1455 m_axes,
1456 m_dim_mapping,
1457 detail::is_fixed<shape_type>{}
1458 );
1459 }
1460 else
1461 {
1462 detail::shape_and_mapping_computation_keep_dim(
1463 m_shape,
1464 m_e,
1465 m_axes,
1466 m_dim_mapping,
1467 detail::is_fixed<shape_type>{}
1468 );
1469 }
1470 }
1471
1473
1481 template <class F, class CT, class X, class O>
1482 inline auto xreducer<F, CT, X, O>::shape() const noexcept -> const inner_shape_type&
1483 {
1484 return m_shape;
1485 }
1486
1490 template <class F, class CT, class X, class O>
1492 {
1493 return static_layout;
1494 }
1495
1496 template <class F, class CT, class X, class O>
1498 {
1499 return false;
1500 }
1501
1503
1514 template <class F, class CT, class X, class O>
1515 template <class... Args>
1516 inline auto xreducer<F, CT, X, O>::operator()(Args... args) const -> const_reference
1517 {
1518 XTENSOR_TRY(check_index(shape(), args...));
1519 XTENSOR_CHECK_DIMENSION(shape(), args...);
1520 std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
1521 return element(arg_array.cbegin(), arg_array.cend());
1522 }
1523
1543 template <class F, class CT, class X, class O>
1544 template <class... Args>
1545 inline auto xreducer<F, CT, X, O>::unchecked(Args... args) const -> const_reference
1546 {
1547 std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
1548 return element(arg_array.cbegin(), arg_array.cend());
1549 }
1550
1558 template <class F, class CT, class X, class O>
1559 template <class It>
1560 inline auto xreducer<F, CT, X, O>::element(It first, It last) const -> const_reference
1561 {
1562 XTENSOR_TRY(check_element_index(shape(), first, last));
1563 auto stepper = const_stepper(*this, 0);
1564 if (first != last)
1565 {
1566 size_type dim = 0;
1567 // drop left most elements
1568 auto size = std::ptrdiff_t(this->dimension()) - std::distance(first, last);
1569 auto begin = first - size;
1570 while (begin != last)
1571 {
1572 if (begin < first)
1573 {
1574 stepper.step(dim++, std::size_t(0));
1575 begin++;
1576 }
1577 else
1578 {
1579 stepper.step(dim++, std::size_t(*begin++));
1580 }
1581 }
1582 }
1583 return *stepper;
1584 }
1585
1589 template <class F, class CT, class X, class O>
1590 inline auto xreducer<F, CT, X, O>::expression() const noexcept -> const xexpression_type&
1591 {
1592 return m_e;
1593 }
1594
1596
1607 template <class F, class CT, class X, class O>
1608 template <class S>
1609 inline bool xreducer<F, CT, X, O>::broadcast_shape(S& shape, bool) const
1610 {
1611 return xt::broadcast_shape(m_shape, shape);
1612 }
1613
1619 template <class F, class CT, class X, class O>
1620 template <class S>
1621 inline bool xreducer<F, CT, X, O>::has_linear_assign(const S& /*strides*/) const noexcept
1622 {
1623 return false;
1624 }
1625
1627
1628 template <class F, class CT, class X, class O>
1629 template <class S>
1630 inline auto xreducer<F, CT, X, O>::stepper_begin(const S& shape) const noexcept -> const_stepper
1631 {
1632 size_type offset = shape.size() - this->dimension();
1633 return const_stepper(*this, offset);
1634 }
1635
1636 template <class F, class CT, class X, class O>
1637 template <class S>
1638 inline auto xreducer<F, CT, X, O>::stepper_end(const S& shape, layout_type l) const noexcept
1639 -> const_stepper
1640 {
1641 size_type offset = shape.size() - this->dimension();
1642 return const_stepper(*this, offset, true, l);
1643 }
1644
1645 template <class F, class CT, class X, class O>
1646 template <class E>
1647 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e) const -> rebind_t<E>
1648 {
1649 return rebind_t<E>(
1650 std::make_tuple(m_reduce, m_init, m_merge),
1651 std::forward<E>(e),
1652 axes_type(m_axes),
1653 m_options
1654 );
1655 }
1656
1657 template <class F, class CT, class X, class O>
1658 template <class E, class Func, class Opts>
1659 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts) const
1660 -> rebind_t<E, Func, Opts>
1661 {
1662 return rebind_t<E, Func, Opts>(
1663 std::forward<Func>(func),
1664 std::forward<E>(e),
1665 axes_type(m_axes),
1666 std::forward<Opts>(opts)
1667 );
1668 }
1669
1670 /***********************************
1671 * xreducer_stepper implementation *
1672 ***********************************/
1673
1674 template <class F, class CT, class X, class O>
1675 inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(
1676 const xreducer_type& red,
1677 size_type offset,
1678 bool end,
1679 layout_type l
1680 )
1681 : m_reducer(&red)
1682 , m_offset(offset)
1683 , m_stepper(get_substepper_begin())
1684 {
1685 if (end)
1686 {
1687 to_end(l);
1688 }
1689 }
1690
1691 template <class F, class CT, class X, class O>
1692 inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
1693 {
1694 reference r = aggregate(0);
1695 return r;
1696 }
1697
1698 template <class F, class CT, class X, class O>
1699 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
1700 {
1701 if (dim >= m_offset)
1702 {
1703 m_stepper.step(get_dim(dim - m_offset));
1704 }
1705 }
1706
1707 template <class F, class CT, class X, class O>
1708 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
1709 {
1710 if (dim >= m_offset)
1711 {
1712 m_stepper.step_back(get_dim(dim - m_offset));
1713 }
1714 }
1715
1716 template <class F, class CT, class X, class O>
1717 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
1718 {
1719 if (dim >= m_offset)
1720 {
1721 m_stepper.step(get_dim(dim - m_offset), n);
1722 }
1723 }
1724
1725 template <class F, class CT, class X, class O>
1726 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
1727 {
1728 if (dim >= m_offset)
1729 {
1730 m_stepper.step_back(get_dim(dim - m_offset), n);
1731 }
1732 }
1733
1734 template <class F, class CT, class X, class O>
1735 inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
1736 {
1737 if (dim >= m_offset)
1738 {
1739 // Because the reducer uses `reset` to reset the non-reducing axes,
1740 // we need to prevent that here for the KD case where.
1741 if (typename O::keep_dims()
1742 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1743 {
1744 // If keep dim activated, and dim is in the axes, do nothing!
1745 return;
1746 }
1747 m_stepper.reset(get_dim(dim - m_offset));
1748 }
1749 }
1750
1751 template <class F, class CT, class X, class O>
1752 inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
1753 {
1754 if (dim >= m_offset)
1755 {
1756 // Note that for *not* KD this is not going to do anything
1757 if (typename O::keep_dims()
1758 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1759 {
1760 // If keep dim activated, and dim is in the axes, do nothing!
1761 return;
1762 }
1763 m_stepper.reset_back(get_dim(dim - m_offset));
1764 }
1765 }
1766
1767 template <class F, class CT, class X, class O>
1768 inline void xreducer_stepper<F, CT, X, O>::to_begin()
1769 {
1770 m_stepper.to_begin();
1771 }
1772
1773 template <class F, class CT, class X, class O>
1774 inline void xreducer_stepper<F, CT, X, O>::to_end(layout_type l)
1775 {
1776 m_stepper.to_end(l);
1777 }
1778
1779 template <class F, class CT, class X, class O>
1780 inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
1781 {
1782 return O::has_initial_value ? m_reducer->m_options.initial_value
1783 : static_cast<reference>(m_reducer->m_init());
1784 }
1785
1786 template <class F, class CT, class X, class O>
1787 inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
1788 {
1789 reference res;
1790 if (m_reducer->m_e.size() == size_type(0))
1791 {
1792 res = initial_value();
1793 }
1794 else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
1795 {
1796 res = m_reducer->m_reduce(initial_value(), *m_stepper);
1797 }
1798 else
1799 {
1800 res = aggregate_impl(dim, typename O::keep_dims());
1801 if (O::has_initial_value && dim == 0)
1802 {
1803 res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1804 }
1805 }
1806 return res;
1807 }
1808
1809 template <class F, class CT, class X, class O>
1810 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
1811 {
1812 // reference can be std::array, hence the {} initializer
1813 reference res = {};
1814 size_type index = axis(dim);
1815 size_type size = shape(index);
1816 if (dim != m_reducer->m_axes.size() - 1)
1817 {
1818 res = aggregate_impl(dim + 1, typename O::keep_dims());
1819 for (size_type i = 1; i != size; ++i)
1820 {
1821 m_stepper.step(index);
1822 res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
1823 }
1824 }
1825 else
1826 {
1827 res = m_reducer->m_reduce(static_cast<reference>(m_reducer->m_init()), *m_stepper);
1828 for (size_type i = 1; i != size; ++i)
1829 {
1830 m_stepper.step(index);
1831 res = m_reducer->m_reduce(res, *m_stepper);
1832 }
1833 }
1834 m_stepper.reset(index);
1835 return res;
1836 }
1837
1838 template <class F, class CT, class X, class O>
1839 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type) const -> reference
1840 {
1841 // reference can be std::array, hence the {} initializer
1842 reference res = {};
1843 auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1844 if (ax_it != m_reducer->m_axes.end())
1845 {
1846 size_type index = dim;
1847 size_type size = m_reducer->m_e.shape()[index];
1848 if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
1849 {
1850 res = aggregate_impl(dim + 1, typename O::keep_dims());
1851 for (size_type i = 1; i != size; ++i)
1852 {
1853 m_stepper.step(index);
1854 res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
1855 }
1856 }
1857 else
1858 {
1859 res = m_reducer->m_reduce(static_cast<reference>(m_reducer->m_init()), *m_stepper);
1860 for (size_type i = 1; i != size; ++i)
1861 {
1862 m_stepper.step(index);
1863 res = m_reducer->m_reduce(res, *m_stepper);
1864 }
1865 }
1866 m_stepper.reset(index);
1867 }
1868 else
1869 {
1870 if (dim < m_reducer->m_e.dimension())
1871 {
1872 res = aggregate_impl(dim + 1, typename O::keep_dims());
1873 }
1874 }
1875 return res;
1876 }
1877
1878 template <class F, class CT, class X, class O>
1879 inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
1880 {
1881 return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
1882 }
1883
1884 template <class F, class CT, class X, class O>
1885 inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim) const noexcept -> size_type
1886 {
1887 return m_reducer->m_dim_mapping[dim];
1888 }
1889
1890 template <class F, class CT, class X, class O>
1891 inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i) const noexcept -> size_type
1892 {
1893 return m_reducer->m_e.shape()[i];
1894 }
1895
1896 template <class F, class CT, class X, class O>
1897 inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i) const noexcept -> size_type
1898 {
1899 return m_reducer->m_axes[i];
1900 }
1901}
1902
1903#endif
Fixed shape implementation for compile time defined arrays.
Base class for implementation of common expression access methods.
Base class for multidimensional iterable constant expressions.
Definition xiterable.hpp:37
Reducing function operating over specified axes.
Definition xreducer.hpp:817
bool broadcast_shape(S &shape, bool reuse_cache=false) const
Broadcast the shape of the reducer to the specified parameter.
const xexpression_type & expression() const noexcept
Returns a constant reference to the underlying expression of the reducer.
bool has_linear_assign(const S &strides) const noexcept
Checks whether the xreducer can be linearly assigned to an expression with the specified strides.
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
xreducer(Func &&func, CTA &&e, AX &&axes, OX &&options)
Constructs an xreducer expression applying the specified function to the given expression over the gi...
layout_type layout() const noexcept
Returns the shape of the expression.
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto diff(const xexpression< T > &a, std::size_t n=1, std::ptrdiff_t axis=-1)
Calculate the n-th discrete difference along the given axis.
Definition xmath.hpp:2911
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
Definition xeval.hpp:46
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto reduce(F &&f, E &&e, X &&axes, EVS &&options=EVS())
Returns an xexpression applying the specified reducing function to an expression over the given axes.
layout_type
Definition xlayout.hpp:24