xtensor
 
Loading...
Searching...
No Matches
xreducer.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_REDUCER_HPP
11#define XTENSOR_REDUCER_HPP
12
13#include <algorithm>
14#include <cstddef>
15#include <initializer_list>
16#include <iterator>
17#include <stdexcept>
18#include <tuple>
19#include <type_traits>
20#include <utility>
21
22#include <xtl/xfunctional.hpp>
23#include <xtl/xsequence.hpp>
24
25#include "../core/xaccessible.hpp"
26#include "../core/xeval.hpp"
27#include "../core/xexpression.hpp"
28#include "../core/xiterable.hpp"
29#include "../core/xtensor_config.hpp"
30#include "../generators/xbuilder.hpp"
31#include "../generators/xgenerator.hpp"
32#include "../utils/xutils.hpp"
33
34namespace xt
35{
36 template <template <class...> class A, class... AX, class X, XTL_REQUIRES(is_evaluation_strategy<AX>..., is_evaluation_strategy<X>)>
37 auto operator|(const A<AX...>& args, const A<X>& rhs)
38 {
39 return std::tuple_cat(args, rhs);
40 }
41
42 struct keep_dims_type : xt::detail::option_base
43 {
44 };
45
46 constexpr auto keep_dims = std::tuple<keep_dims_type>{};
47
48 template <class T = double>
49 struct xinitial : xt::detail::option_base
50 {
51 constexpr xinitial(T val)
52 : m_val(val)
53 {
54 }
55
56 constexpr T value() const
57 {
58 return m_val;
59 }
60
61 T m_val;
62 };
63
64 template <class T>
65 constexpr auto initial(T val)
66 {
67 return std::make_tuple(xinitial<T>(val));
68 }
69
70 template <std::ptrdiff_t I, class T, class Tuple>
72
73 template <std::ptrdiff_t I, class T>
74 struct tuple_idx_of_impl<I, T, std::tuple<>>
75 {
76 static constexpr std::ptrdiff_t value = -1;
77 };
78
79 template <std::ptrdiff_t I, class T, class... Types>
80 struct tuple_idx_of_impl<I, T, std::tuple<T, Types...>>
81 {
82 static constexpr std::ptrdiff_t value = I;
83 };
84
85 template <std::ptrdiff_t I, class T, class U, class... Types>
86 struct tuple_idx_of_impl<I, T, std::tuple<U, Types...>>
87 {
88 static constexpr std::ptrdiff_t value = tuple_idx_of_impl<I + 1, T, std::tuple<Types...>>::value;
89 };
90
91 template <class S, class... X>
92 struct decay_all;
93
94 template <template <class...> class S, class... X>
95 struct decay_all<S<X...>>
96 {
97 using type = S<std::decay_t<X>...>;
98 };
99
100 template <class T, class Tuple>
102 {
103 static constexpr std::ptrdiff_t
105 };
106
107 template <class R, class T>
108 struct reducer_options
109 {
110 template <class X>
111 struct initial_tester : std::false_type
112 {
113 };
114
115 template <class X>
116 struct initial_tester<xinitial<X>> : std::true_type
117 {
118 };
119
120 // Workaround for Apple because tuple_cat is buggy!
121 template <class X>
122 struct initial_tester<const xinitial<X>> : std::true_type
123 {
124 };
125
126 using d_t = std::decay_t<T>;
127
128 static constexpr std::size_t initial_val_idx = xtl::mpl::find_if<initial_tester, d_t>::value;
129 reducer_options() = default;
130
131 reducer_options(const T& tpl)
132 {
133 if constexpr (initial_val_idx != std::tuple_size<T>::value)
134 {
135 initial_value = std::get < initial_val_idx != std::tuple_size<T>::value ? initial_val_idx
136 : 0 > (tpl).value();
137 }
138 }
139
140 using evaluation_strategy = std::conditional_t<
141 tuple_idx_of<xt::evaluation_strategy::immediate_type, d_t>::value != -1,
144
145 using keep_dims = std::
146 conditional_t<tuple_idx_of<xt::keep_dims_type, d_t>::value != -1, std::true_type, std::false_type>;
147
148 static constexpr bool has_initial_value = initial_val_idx != std::tuple_size<d_t>::value;
149
150 R initial_value;
151
152 template <class NR>
153 using rebind_t = reducer_options<NR, T>;
154
155 template <class NR>
156 auto rebind(NR initial, const reducer_options<R, T>&) const
157 {
158 reducer_options<NR, T> res;
159 res.initial_value = initial;
160 return res;
161 }
162 };
163
164 template <class T>
165 struct is_reducer_options_impl : std::false_type
166 {
167 };
168
169 template <class... X>
170 struct is_reducer_options_impl<std::tuple<X...>> : std::true_type
171 {
172 };
173
174 template <class T>
176 {
177 };
178
179 /**********
180 * reduce *
181 **********/
182
183#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
184
185 template <class ST, class X, class KD = std::false_type>
186 struct xreducer_shape_type;
187
188 template <class S1, class S2>
190
191 namespace detail
192 {
193 template <class O, class RS, class R, class E, class AX>
194 inline void shape_computation(
195 RS& result_shape,
196 R& result,
197 E& expr,
198 const AX& axes,
199 std::enable_if_t<!detail::is_fixed<RS>::value, int> = 0
200 )
201 {
202 if (typename O::keep_dims())
203 {
204 resize_container(result_shape, expr.dimension());
205 for (std::size_t i = 0; i < expr.dimension(); ++i)
206 {
207 if (std::find(axes.begin(), axes.end(), i) == axes.end())
208 {
209 // i not in axes!
210 result_shape[i] = expr.shape()[i];
211 }
212 else
213 {
214 result_shape[i] = 1;
215 }
216 }
217 }
218 else
219 {
220 resize_container(result_shape, expr.dimension() - axes.size());
221 for (std::size_t i = 0, idx = 0; i < expr.dimension(); ++i)
222 {
223 if (std::find(axes.begin(), axes.end(), i) == axes.end())
224 {
225 // i not in axes!
226 result_shape[idx] = expr.shape()[i];
227 ++idx;
228 }
229 }
230 }
231 result.resize(result_shape, expr.layout());
232 }
233
234 // skip shape computation if already done at compile time
235 template <class O, class RS, class R, class S, class AX>
236 inline void
237 shape_computation(RS&, R&, const S&, const AX&, std::enable_if_t<detail::is_fixed<RS>::value, int> = 0)
238 {
239 }
240 }
241
242 template <class F, class E, class R, XTL_REQUIRES(std::is_convertible<typename E::value_type, typename R::value_type>)>
243 inline void copy_to_reduced(F&, const E& e, R& result)
244 {
245 if (e.layout() == layout_type::row_major)
246 {
247 std::copy(
248 e.template cbegin<layout_type::row_major>(),
249 e.template cend<layout_type::row_major>(),
250 result.data()
251 );
252 }
253 else
254 {
255 std::copy(
256 e.template cbegin<layout_type::column_major>(),
257 e.template cend<layout_type::column_major>(),
258 result.data()
259 );
260 }
261 }
262
263 template <
264 class F,
265 class E,
266 class R,
267 XTL_REQUIRES(std::negation<std::is_convertible<typename E::value_type, typename R::value_type>>)>
268 inline void copy_to_reduced(F& f, const E& e, R& result)
269 {
270 if (e.layout() == layout_type::row_major)
271 {
272 std::transform(
273 e.template cbegin<layout_type::row_major>(),
274 e.template cend<layout_type::row_major>(),
275 result.data(),
276 f
277 );
278 }
279 else
280 {
281 std::transform(
282 e.template cbegin<layout_type::column_major>(),
283 e.template cend<layout_type::column_major>(),
284 result.data(),
285 f
286 );
287 }
288 }
289
290 template <class F, class E, class X, class O>
291 inline auto reduce_immediate(F&& f, E&& e, X&& axes, O&& raw_options)
292 {
293 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
294 using init_functor_type = typename std::decay_t<F>::init_functor_type;
295 using expr_value_type = typename std::decay_t<E>::value_type;
296 using result_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
297 std::declval<init_functor_type>()(),
298 std::declval<expr_value_type>()
299 ))>;
300
302 options_t options(raw_options);
303
304 using shape_type = typename xreducer_shape_type<
305 typename std::decay_t<E>::shape_type,
306 std::decay_t<X>,
307 typename options_t::keep_dims>::type;
308 using result_container_type = typename detail::xtype_for_shape<
309 shape_type>::template type<result_type, std::decay_t<E>::static_layout>;
310 result_container_type result;
311
312 // retrieve functors from triple struct
313 auto reduce_fct = xt::get<0>(f);
314 auto init_fct = xt::get<1>(f);
315 auto merge_fct = xt::get<2>(f);
316
317 if (axes.size() == 0)
318 {
319 result.resize(e.shape(), e.layout());
320 auto cpf = [&reduce_fct, &init_fct](const auto& v)
321 {
322 return reduce_fct(static_cast<result_type>(init_fct()), v);
323 };
324 copy_to_reduced(cpf, e, result);
325 return result;
326 }
327
328 shape_type result_shape{};
329 dynamic_shape<std::size_t>
330 iter_shape = xtl::forward_sequence<dynamic_shape<std::size_t>, decltype(e.shape())>(e.shape());
331 dynamic_shape<std::size_t> iter_strides(e.dimension());
332
333 // std::less is used, because as the standard says (24.4.5):
334 // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the
335 // sequence and any non-negative integer n such that i + n is a valid iterator pointing to an element
336 // of the sequence, comp(*(i + n), *i) == false. Therefore less is required to detect duplicates.
337 if (!std::is_sorted(axes.cbegin(), axes.cend(), std::less<>()))
338 {
339 XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted.");
340 }
341 if (std::adjacent_find(axes.cbegin(), axes.cend()) != axes.cend())
342 {
343 XTENSOR_THROW(std::runtime_error, "Reducing axes should not contain duplicates.");
344 }
345 if (axes.size() != 0 && axes[axes.size() - 1] > e.dimension() - 1)
346 {
347 XTENSOR_THROW(
348 std::runtime_error,
349 "Axis " + std::to_string(axes[axes.size() - 1]) + " out of bounds for reduction."
350 );
351 }
352
353 detail::shape_computation<options_t>(result_shape, result, e, axes);
354
355 // Fast track for complete reduction
356 if (e.dimension() == axes.size())
357 {
358 result_type tmp = options_t::has_initial_value ? options.initial_value : init_fct();
359 result.data()[0] = std::accumulate(e.storage().begin(), e.storage().end(), tmp, reduce_fct);
360 return result;
361 }
362
363 std::size_t leading_ax = axes[(e.layout() == layout_type::row_major) ? axes.size() - 1 : 0];
364 auto strides_finder = e.strides().begin() + static_cast<std::ptrdiff_t>(leading_ax);
365 // The computed strides contain "0" where the shape is 1 -- therefore find the next none-zero number
366 std::size_t inner_stride = static_cast<std::size_t>(*strides_finder);
367 auto iter_bound = e.layout() == layout_type::row_major ? e.strides().begin() : (e.strides().end() - 1);
368 while (inner_stride == 0 && strides_finder != iter_bound)
369 {
370 (e.layout() == layout_type::row_major) ? --strides_finder : ++strides_finder;
371 inner_stride = static_cast<std::size_t>(*strides_finder);
372 }
373
374 if (inner_stride == 0)
375 {
376 auto cpf = [&reduce_fct, &init_fct](const auto& v)
377 {
378 return reduce_fct(static_cast<result_type>(init_fct()), v);
379 };
380 copy_to_reduced(cpf, e, result);
381 return result;
382 }
383
384 std::size_t inner_loop_size = static_cast<std::size_t>(inner_stride);
385 std::size_t outer_loop_size = e.shape()[leading_ax];
386
387 // The following code merges reduction axes "at the end" (or the beginning for col_major)
388 // together by increasing the size of the outer loop where appropriate
389 auto merge_loops = [&outer_loop_size, &e](auto it, auto end)
390 {
391 auto last_ax = *it;
392 ++it;
393 for (; it != end; ++it)
394 {
395 // note that we check is_sorted, so this condition is valid
396 if (std::abs(std::ptrdiff_t(*it) - std::ptrdiff_t(last_ax)) == 1)
397 {
398 last_ax = *it;
399 outer_loop_size *= e.shape()[last_ax];
400 }
401 }
402 return last_ax;
403 };
404
405 for (std::size_t i = 0, idx = 0; i < e.dimension(); ++i)
406 {
407 if (std::find(axes.begin(), axes.end(), i) == axes.end())
408 {
409 // i not in axes!
410 iter_strides[i] = static_cast<std::size_t>(result.strides(
411 )[typename options_t::keep_dims() ? i : idx]);
412 ++idx;
413 }
414 }
415
416 if (e.layout() == layout_type::row_major)
417 {
418 std::size_t last_ax = merge_loops(axes.rbegin(), axes.rend());
419
420 iter_shape.erase(iter_shape.begin() + std::ptrdiff_t(last_ax), iter_shape.end());
421 iter_strides.erase(iter_strides.begin() + std::ptrdiff_t(last_ax), iter_strides.end());
422 }
423 else if (e.layout() == layout_type::column_major)
424 {
425 // we got column_major here
426 std::size_t last_ax = merge_loops(axes.begin(), axes.end());
427
428 // erasing the front vs the back
429 iter_shape.erase(iter_shape.begin(), iter_shape.begin() + std::ptrdiff_t(last_ax + 1));
430 iter_strides.erase(iter_strides.begin(), iter_strides.begin() + std::ptrdiff_t(last_ax + 1));
431
432 // and reversing, to make it work with the same next_idx function
433 std::reverse(iter_shape.begin(), iter_shape.end());
434 std::reverse(iter_strides.begin(), iter_strides.end());
435 }
436 else
437 {
438 XTENSOR_THROW(std::runtime_error, "Layout not supported in immediate reduction.");
439 }
440
441 xindex temp_idx(iter_shape.size());
442 auto next_idx = [&iter_shape, &iter_strides, &temp_idx]()
443 {
444 std::size_t i = iter_shape.size();
445 for (; i > 0; --i)
446 {
447 if (std::ptrdiff_t(temp_idx[i - 1]) >= std::ptrdiff_t(iter_shape[i - 1]) - 1)
448 {
449 temp_idx[i - 1] = 0;
450 }
451 else
452 {
453 temp_idx[i - 1]++;
454 break;
455 }
456 }
457
458 return std::make_pair(
459 i == 0,
460 std::inner_product(temp_idx.begin(), temp_idx.end(), iter_strides.begin(), std::ptrdiff_t(0))
461 );
462 };
463
464 auto begin = e.data();
465 auto out = result.data();
466 auto out_begin = result.data();
467
468 std::ptrdiff_t next_stride = 0;
469
470 std::pair<bool, std::ptrdiff_t> idx_res(false, 0);
471
472 // Remark: eventually some modifications here to make conditions faster where merge + accumulate is
473 // the same function (e.g. check std::is_same<decltype(merge_fct), decltype(reduce_fct)>::value) ...
474
475 auto merge_border = out;
476 bool merge = false;
477
478 // TODO there could be some performance gain by removing merge checking
479 // when axes.size() == 1 and even next_idx could be removed for something simpler (next_stride
480 // always the same) best way to do this would be to create a function that takes (begin, out,
481 // outer_loop_size, inner_loop_size, next_idx_lambda)
482 // Decide if going about it row-wise or col-wise
483 if (inner_stride == 1)
484 {
485 while (idx_res.first != true)
486 {
487 // for unknown reasons it's much faster to use a temporary variable and
488 // std::accumulate here -- probably some cache behavior
489 result_type tmp = init_fct();
490 tmp = std::accumulate(begin, begin + outer_loop_size, tmp, reduce_fct);
491
492 // use merge function if necessary
493 *out = merge ? merge_fct(*out, tmp) : tmp;
494
495 begin += outer_loop_size;
496
497 idx_res = next_idx();
498 next_stride = idx_res.second;
499 out = out_begin + next_stride;
500
501 if (out > merge_border)
502 {
503 // looped over once
504 merge = false;
505 merge_border = out;
506 }
507 else
508 {
509 merge = true;
510 }
511 };
512 }
513 else
514 {
515 while (idx_res.first != true)
516 {
517 std::transform(
518 out,
519 out + inner_loop_size,
520 begin,
521 out,
522 [merge, &init_fct, &reduce_fct](auto&& v1, auto&& v2)
523 {
524 return merge ? reduce_fct(v1, v2) :
525 // cast because return type of identity function is not upcasted
526 reduce_fct(static_cast<result_type>(init_fct()), v2);
527 }
528 );
529
530 begin += inner_stride;
531 for (std::size_t i = 1; i < outer_loop_size; ++i)
532 {
533 std::transform(out, out + inner_loop_size, begin, out, reduce_fct);
534 begin += inner_stride;
535 }
536
537 idx_res = next_idx();
538 next_stride = idx_res.second;
539 out = out_begin + next_stride;
540
541 if (out > merge_border)
542 {
543 // looped over once
544 merge = false;
545 merge_border = out;
546 }
547 else
548 {
549 merge = true;
550 }
551 };
552 }
553 if (options_t::has_initial_value)
554 {
555 std::transform(
556 result.data(),
557 result.data() + result.size(),
558 result.data(),
559 [&merge_fct, &options](auto&& v)
560 {
561 return merge_fct(v, options.initial_value);
562 }
563 );
564 }
565 return result;
566 }
567
568 /*********************
569 * xreducer functors *
570 *********************/
571
572 template <class T>
573 struct const_value
574 {
575 using value_type = T;
576
577 constexpr const_value() = default;
578
579 constexpr const_value(T t)
580 : m_value(t)
581 {
582 }
583
584 constexpr T operator()() const
585 {
586 return m_value;
587 }
588
589 template <class NT>
590 using rebind_t = const_value<NT>;
591
592 template <class NT>
593 const_value<NT> rebind() const;
594
595 T m_value;
596 };
597
598 namespace detail
599 {
600 template <class T, bool B>
601 struct evaluated_value_type
602 {
603 using type = T;
604 };
605
606 template <class T>
607 struct evaluated_value_type<T, true>
608 {
609 using type = typename std::decay_t<decltype(xt::eval(std::declval<T>()))>;
610 };
611
612 template <class T, bool B>
613 using evaluated_value_type_t = typename evaluated_value_type<T, B>::type;
614 }
615
616 template <class REDUCE_FUNC, class INIT_FUNC = const_value<long int>, class MERGE_FUNC = REDUCE_FUNC>
617 struct xreducer_functors : public std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>
618 {
619 using self_type = xreducer_functors<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
620 using base_type = std::tuple<REDUCE_FUNC, INIT_FUNC, MERGE_FUNC>;
621 using reduce_functor_type = REDUCE_FUNC;
622 using init_functor_type = INIT_FUNC;
623 using merge_functor_type = MERGE_FUNC;
624 using init_value_type = typename init_functor_type::value_type;
625
626 xreducer_functors()
627 : base_type()
628 {
629 }
630
631 template <class RF>
632 xreducer_functors(RF&& reduce_func)
633 : base_type(std::forward<RF>(reduce_func), INIT_FUNC(), reduce_func)
634 {
635 }
636
637 template <class RF, class IF>
638 xreducer_functors(RF&& reduce_func, IF&& init_func)
639 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), reduce_func)
640 {
641 }
642
643 template <class RF, class IF, class MF>
644 xreducer_functors(RF&& reduce_func, IF&& init_func, MF&& merge_func)
645 : base_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func), std::forward<MF>(merge_func))
646 {
647 }
648
649 reduce_functor_type get_reduce() const
650 {
651 return std::get<0>(upcast());
652 }
653
654 init_functor_type get_init() const
655 {
656 return std::get<1>(upcast());
657 }
658
659 merge_functor_type get_merge() const
660 {
661 return std::get<2>(upcast());
662 }
663
664 template <class NT>
665 using rebind_t = xreducer_functors<REDUCE_FUNC, const_value<NT>, MERGE_FUNC>;
666
667 template <class NT>
668 rebind_t<NT> rebind()
669 {
670 return make_xreducer_functor(get_reduce(), get_init().template rebind<NT>(), get_merge());
671 }
672
673 private:
674
675 // Workaround for clang-cl
676 const base_type& upcast() const
677 {
678 return static_cast<const base_type&>(*this);
679 }
680 };
681
682 template <class RF>
683 auto make_xreducer_functor(RF&& reduce_func)
684 {
686 return reducer_type(std::forward<RF>(reduce_func));
687 }
688
689 template <class RF, class IF>
690 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func)
691 {
692 using reducer_type = xreducer_functors<std::remove_reference_t<RF>, std::remove_reference_t<IF>>;
693 return reducer_type(std::forward<RF>(reduce_func), std::forward<IF>(init_func));
694 }
695
696 template <class RF, class IF, class MF>
697 auto make_xreducer_functor(RF&& reduce_func, IF&& init_func, MF&& merge_func)
698 {
699 using reducer_type = xreducer_functors<
700 std::remove_reference_t<RF>,
701 std::remove_reference_t<IF>,
702 std::remove_reference_t<MF>>;
703 return reducer_type(
704 std::forward<RF>(reduce_func),
705 std::forward<IF>(init_func),
706 std::forward<MF>(merge_func)
707 );
708 }
709
710 /**********************
711 * xreducer extension *
712 **********************/
713
714 namespace extension
715 {
716 template <class Tag, class F, class CT, class X, class O>
718
719 template <class F, class CT, class X, class O>
721 {
722 using type = xtensor_empty_base;
723 };
724
725 template <class F, class CT, class X, class O>
726 struct xreducer_base : xreducer_base_impl<xexpression_tag_t<CT>, F, CT, X, O>
727 {
728 };
729
730 template <class F, class CT, class X, class O>
731 using xreducer_base_t = typename xreducer_base<F, CT, X, O>::type;
732 }
733
734 /************
735 * xreducer *
736 ************/
737
738 template <class F, class CT, class X, class O>
739 class xreducer;
740
741 template <class F, class CT, class X, class O>
742 class xreducer_stepper;
743
744 template <class F, class CT, class X, class O>
745 struct xiterable_inner_types<xreducer<F, CT, X, O>>
746 {
747 using xexpression_type = std::decay_t<CT>;
748 using inner_shape_type = typename xreducer_shape_type<
749 typename xexpression_type::shape_type,
750 std::decay_t<X>,
751 typename O::keep_dims>::type;
752 using const_stepper = xreducer_stepper<F, CT, X, O>;
753 using stepper = const_stepper;
754 };
755
756 template <class F, class CT, class X, class O>
757 struct xcontainer_inner_types<xreducer<F, CT, X, O>>
758 {
759 using xexpression_type = std::decay_t<CT>;
760 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
761 using init_functor_type = typename std::decay_t<F>::init_functor_type;
762 using merge_functor_type = typename std::decay_t<F>::merge_functor_type;
763 using substepper_type = typename xexpression_type::const_stepper;
764 using raw_value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
765 std::declval<init_functor_type>()(),
766 *std::declval<substepper_type>()
767 ))>;
768 using value_type = typename detail::evaluated_value_type_t<raw_value_type, is_xexpression<raw_value_type>::value>;
769
770 using reference = value_type;
771 using const_reference = value_type;
772 using size_type = typename xexpression_type::size_type;
773 };
774
775 template <class T>
777 {
778 using type = T;
779 };
780
781 template <std::size_t... I>
783 {
784 using type = std::array<std::size_t, sizeof...(I)>;
785 };
786
804 template <class F, class CT, class X, class O>
805 class xreducer : public xsharable_expression<xreducer<F, CT, X, O>>,
806 public xconst_iterable<xreducer<F, CT, X, O>>,
807 public xaccessible<xreducer<F, CT, X, O>>,
808 public extension::xreducer_base_t<F, CT, X, O>
809 {
810 public:
811
812 using self_type = xreducer<F, CT, X, O>;
813 using inner_types = xcontainer_inner_types<self_type>;
814
815 using reduce_functor_type = typename inner_types::reduce_functor_type;
816 using init_functor_type = typename inner_types::init_functor_type;
817 using merge_functor_type = typename inner_types::merge_functor_type;
819
820 using xexpression_type = typename inner_types::xexpression_type;
821 using axes_type = X;
822
823 using extension_base = extension::xreducer_base_t<F, CT, X, O>;
824 using expression_tag = typename extension_base::expression_tag;
825
826 using substepper_type = typename inner_types::substepper_type;
827 using value_type = typename inner_types::value_type;
828 using reference = typename inner_types::reference;
829 using const_reference = typename inner_types::const_reference;
830 using pointer = value_type*;
831 using const_pointer = const value_type*;
832
833 using size_type = typename inner_types::size_type;
834 using difference_type = typename xexpression_type::difference_type;
835
836 using iterable_base = xconst_iterable<self_type>;
837 using inner_shape_type = typename iterable_base::inner_shape_type;
838 using shape_type = inner_shape_type;
839
840 using dim_mapping_type = typename select_dim_mapping_type<inner_shape_type>::type;
841
842 using stepper = typename iterable_base::stepper;
843 using const_stepper = typename iterable_base::const_stepper;
844 using bool_load_type = typename xexpression_type::bool_load_type;
845
846 static constexpr layout_type static_layout = layout_type::dynamic;
847 static constexpr bool contiguous_layout = false;
848
849 template <class Func, class CTA, class AX, class OX>
850 xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options);
851
852 const inner_shape_type& shape() const noexcept;
853 layout_type layout() const noexcept;
854 bool is_contiguous() const noexcept;
855
856 template <class... Args>
857 const_reference operator()(Args... args) const;
858 template <class... Args>
859 const_reference unchecked(Args... args) const;
860
861 template <class It>
862 const_reference element(It first, It last) const;
863
864 const xexpression_type& expression() const noexcept;
865
866 template <class S>
867 bool broadcast_shape(S& shape, bool reuse_cache = false) const;
868
869 template <class S>
870 bool has_linear_assign(const S& strides) const noexcept;
871
872 template <class S>
873 const_stepper stepper_begin(const S& shape) const noexcept;
874 template <class S>
875 const_stepper stepper_end(const S& shape, layout_type) const noexcept;
876
877 template <class E, class Func = F, class Opts = O>
878 using rebind_t = xreducer<Func, E, X, Opts>;
879
880 template <class E>
881 rebind_t<E> build_reducer(E&& e) const;
882
883 template <class E, class Func, class Opts>
884 rebind_t<E, Func, Opts> build_reducer(E&& e, Func&& func, Opts&& opts) const;
885
886 xreducer_functors_type functors() const
887 {
888 return xreducer_functors_type(m_reduce, m_init, m_merge); // TODO: understand why
889 // make_xreducer_functor is throwing an
890 // error
891 }
892
893 const O& options() const
894 {
895 return m_options;
896 }
897
898 private:
899
900 CT m_e;
901 reduce_functor_type m_reduce;
902 init_functor_type m_init;
903 merge_functor_type m_merge;
904 axes_type m_axes;
905 inner_shape_type m_shape;
906 dim_mapping_type m_dim_mapping;
907 O m_options;
908
909 friend class xreducer_stepper<F, CT, X, O>;
910 };
911
912 /*************************
913 * reduce implementation *
914 *************************/
915
916 namespace detail
917 {
918 template <class F, class E, class X, class O>
919 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::lazy_type, O&& options)
920 {
921 decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
922
923 using reduce_functor_type = typename std::decay_t<F>::reduce_functor_type;
924 using init_functor_type = typename std::decay_t<F>::init_functor_type;
925 using value_type = std::decay_t<decltype(std::declval<reduce_functor_type>()(
926 std::declval<init_functor_type>()(),
927 *std::declval<typename std::decay_t<E>::const_stepper>()
928 ))>;
929 using evaluated_value_type = evaluated_value_type_t<value_type, is_xexpression<value_type>::value>;
930
931 using reducer_type = xreducer<
932 F,
933 const_xclosure_t<E>,
934 xtl::const_closure_type_t<decltype(normalized_axes)>,
935 reducer_options<evaluated_value_type, std::decay_t<O>>>;
936 return reducer_type(
937 std::forward<F>(f),
938 std::forward<E>(e),
939 std::forward<decltype(normalized_axes)>(normalized_axes),
940 std::forward<O>(options)
941 );
942 }
943
944 template <class F, class E, class X, class O>
945 inline auto reduce_impl(F&& f, E&& e, X&& axes, evaluation_strategy::immediate_type, O&& options)
946 {
947 decltype(auto) normalized_axes = normalize_axis(e, std::forward<X>(axes));
948 return reduce_immediate(
949 std::forward<F>(f),
950 eval(std::forward<E>(e)),
951 std::forward<decltype(normalized_axes)>(normalized_axes),
952 std::forward<O>(options)
953 );
954 }
955 }
956
957#define DEFAULT_STRATEGY_REDUCERS std::tuple<evaluation_strategy::lazy_type>
958
959 namespace detail
960 {
961 template <class T>
962 struct is_xreducer_functors_impl : std::false_type
963 {
964 };
965
966 template <class RF, class IF, class MF>
967 struct is_xreducer_functors_impl<xreducer_functors<RF, IF, MF>> : std::true_type
968 {
969 };
970
971 template <class T>
972 using is_xreducer_functors = is_xreducer_functors_impl<std::decay_t<T>>;
973 }
974
987
988 template <
989 class F,
990 class E,
991 class X,
992 class EVS = DEFAULT_STRATEGY_REDUCERS,
993 XTL_REQUIRES(std::negation<is_reducer_options<X>>, detail::is_xreducer_functors<F>)>
994 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
995 {
996 return detail::reduce_impl(
997 std::forward<F>(f),
998 std::forward<E>(e),
999 std::forward<X>(axes),
1000 typename reducer_options<int, EVS>::evaluation_strategy{},
1001 std::forward<EVS>(options)
1002 );
1003 }
1004
1005 template <
1006 class F,
1007 class E,
1008 class X,
1009 class EVS = DEFAULT_STRATEGY_REDUCERS,
1010 XTL_REQUIRES(std::negation<is_reducer_options<X>>, std::negation<detail::is_xreducer_functors<F>>)>
1011 inline auto reduce(F&& f, E&& e, X&& axes, EVS&& options = EVS())
1012 {
1013 return reduce(
1014 make_xreducer_functor(std::forward<F>(f)),
1015 std::forward<E>(e),
1016 std::forward<X>(axes),
1017 std::forward<EVS>(options)
1018 );
1019 }
1020
1021 template <
1022 class F,
1023 class E,
1024 class EVS = DEFAULT_STRATEGY_REDUCERS,
1025 XTL_REQUIRES(is_reducer_options<EVS>, detail::is_xreducer_functors<F>)>
1026 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1027 {
1028 xindex_type_t<typename std::decay_t<E>::shape_type> ar;
1029 resize_container(ar, e.dimension());
1030 std::iota(ar.begin(), ar.end(), 0);
1031 return detail::reduce_impl(
1032 std::forward<F>(f),
1033 std::forward<E>(e),
1034 std::move(ar),
1035 typename reducer_options<int, std::decay_t<EVS>>::evaluation_strategy{},
1036 std::forward<EVS>(options)
1037 );
1038 }
1039
1040 template <
1041 class F,
1042 class E,
1043 class EVS = DEFAULT_STRATEGY_REDUCERS,
1044 XTL_REQUIRES(is_reducer_options<EVS>, std::negation<detail::is_xreducer_functors<F>>)>
1045 inline auto reduce(F&& f, E&& e, EVS&& options = EVS())
1046 {
1047 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), std::forward<EVS>(options));
1048 }
1049
1050 template <
1051 class F,
1052 class E,
1053 class I,
1054 std::size_t N,
1055 class EVS = DEFAULT_STRATEGY_REDUCERS,
1056 XTL_REQUIRES(detail::is_xreducer_functors<F>)>
1057 inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
1058 {
1059 using axes_type = std::array<std::size_t, N>;
1060 auto ax = xt::forward_normalize<axes_type>(e, axes);
1061 return detail::reduce_impl(
1062 std::forward<F>(f),
1063 std::forward<E>(e),
1064 std::move(ax),
1065 typename reducer_options<int, EVS>::evaluation_strategy{},
1066 options
1067 );
1068 }
1069
1070 template <
1071 class F,
1072 class E,
1073 class I,
1074 std::size_t N,
1075 class EVS = DEFAULT_STRATEGY_REDUCERS,
1076 XTL_REQUIRES(std::negation<detail::is_xreducer_functors<F>>)>
1077 inline auto reduce(F&& f, E&& e, const I (&axes)[N], EVS options = EVS())
1078 {
1079 return reduce(make_xreducer_functor(std::forward<F>(f)), std::forward<E>(e), axes, options);
1080 }
1081
1082 /********************
1083 * xreducer_stepper *
1084 ********************/
1085
1086 template <class F, class CT, class X, class O>
1087 class xreducer_stepper
1088 {
1089 public:
1090
1091 using self_type = xreducer_stepper<F, CT, X, O>;
1092 using xreducer_type = xreducer<F, CT, X, O>;
1093
1094 using value_type = typename xreducer_type::value_type;
1095 using reference = typename xreducer_type::value_type;
1096 using pointer = typename xreducer_type::const_pointer;
1097 using size_type = typename xreducer_type::size_type;
1098 using difference_type = typename xreducer_type::difference_type;
1099
1100 using xexpression_type = typename xreducer_type::xexpression_type;
1101 using substepper_type = typename xexpression_type::const_stepper;
1102 using shape_type = typename xreducer_type::shape_type;
1103
1104 xreducer_stepper(
1105 const xreducer_type& red,
1106 size_type offset,
1107 bool end = false,
1108 layout_type l = default_assignable_layout(xexpression_type::static_layout)
1109 );
1110
1111 reference operator*() const;
1112
1113 void step(size_type dim);
1114 void step_back(size_type dim);
1115 void step(size_type dim, size_type n);
1116 void step_back(size_type dim, size_type n);
1117 void reset(size_type dim);
1118 void reset_back(size_type dim);
1119
1120 void to_begin();
1121 void to_end(layout_type l);
1122
1123 private:
1124
1125 reference initial_value() const;
1126 reference aggregate(size_type dim) const;
1127 reference aggregate_impl(size_type dim, /*keep_dims=*/std::false_type) const;
1128 reference aggregate_impl(size_type dim, /*keep_dims=*/std::true_type) const;
1129
1130 substepper_type get_substepper_begin() const;
1131 size_type get_dim(size_type dim) const noexcept;
1132 size_type shape(size_type i) const noexcept;
1133 size_type axis(size_type i) const noexcept;
1134
1135 const xreducer_type* m_reducer;
1136 size_type m_offset;
1137 mutable substepper_type m_stepper;
1138 };
1139
1140 /******************
1141 * xreducer utils *
1142 ******************/
1143
1144 namespace detail
1145 {
1146 template <std::size_t X, std::size_t... I>
1147 struct in
1148 {
1149 static constexpr bool value = std::disjunction<std::integral_constant<bool, X == I>...>::value;
1150 };
1151
1152 template <std::size_t Z, class S1, class S2, class R>
1153 struct fixed_xreducer_shape_type_impl;
1154
1155 template <std::size_t Z, std::size_t... I, std::size_t... J, std::size_t... R>
1156 struct fixed_xreducer_shape_type_impl<Z, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1157 {
1158 using type = std::conditional_t<
1159 in<Z, J...>::value,
1160 typename fixed_xreducer_shape_type_impl<Z - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>::type,
1161 typename fixed_xreducer_shape_type_impl<
1162 Z - 1,
1163 fixed_shape<I...>,
1164 fixed_shape<J...>,
1165 fixed_shape<detail::at<Z, I...>::value, R...>>::type>;
1166 };
1167
1168 template <std::size_t... I, std::size_t... J, std::size_t... R>
1169 struct fixed_xreducer_shape_type_impl<0, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<R...>>
1170 {
1171 using type = std::
1172 conditional_t<in<0, J...>::value, fixed_shape<R...>, fixed_shape<detail::at<0, I...>::value, R...>>;
1173 };
1174
1175 /***************************
1176 * helper for return types *
1177 ***************************/
1178
1179 template <class T>
1180 struct xreducer_size_type
1181 {
1182 using type = std::size_t;
1183 };
1184
1185 template <class T>
1186 using xreducer_size_type_t = typename xreducer_size_type<T>::type;
1187
1188 template <class T>
1189 struct xreducer_temporary_type
1190 {
1191 using type = T;
1192 };
1193
1194 template <class T>
1195 using xreducer_temporary_type_t = typename xreducer_temporary_type<T>::type;
1196
1197 /********************************
1198 * Default const_value rebinder *
1199 ********************************/
1200
1201 template <class T, class U>
1202 struct const_value_rebinder
1203 {
1204 static const_value<U> run(const const_value<T>& t)
1205 {
1206 return const_value<U>(t.m_value);
1207 }
1208 };
1209 }
1210
1211 /*******************************************
1212 * Init functor const_value implementation *
1213 *******************************************/
1214
1215 template <class T>
1216 template <class NT>
1217 const_value<NT> const_value<T>::rebind() const
1218 {
1219 return detail::const_value_rebinder<T, NT>::run(*this);
1220 }
1221
1222 /*****************************
1223 * fixed_xreducer_shape_type *
1224 *****************************/
1225
1226 template <class S1, class S2>
1228
1229 template <std::size_t... I, std::size_t... J>
1231 {
1232 using type = typename detail::
1233 fixed_xreducer_shape_type_impl<sizeof...(I) - 1, fixed_shape<I...>, fixed_shape<J...>, fixed_shape<>>::type;
1234 };
1235
1236 // meta-function returning the shape type for an xreducer
1237 template <class ST, class X, class O>
1239 {
1240 using type = promote_shape_t<ST, std::decay_t<X>>;
1241 };
1242
1243 template <class I1, std::size_t N1, class I2, std::size_t N2>
1244 struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::true_type>
1245 {
1246 using type = std::array<I2, N1>;
1247 };
1248
1249 template <class I1, std::size_t N1, class I2, std::size_t N2>
1250 struct xreducer_shape_type<std::array<I1, N1>, std::array<I2, N2>, std::false_type>
1251 {
1252 using type = std::array<I2, N1 - N2>;
1253 };
1254
1255 template <std::size_t... I, class I2, std::size_t N2>
1256 struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::false_type>
1257 {
1258 using type = std::conditional_t<sizeof...(I) == N2, fixed_shape<>, std::array<I2, sizeof...(I) - N2>>;
1259 };
1260
1261 namespace detail
1262 {
1263 template <class S1, class S2>
1264 struct ixconcat;
1265
1266 template <class T, T... I1, T... I2>
1267 struct ixconcat<std::integer_sequence<T, I1...>, std::integer_sequence<T, I2...>>
1268 {
1269 using type = std::integer_sequence<T, I1..., I2...>;
1270 };
1271
1272 template <class T, T X, std::size_t N>
1273 struct repeat_integer_sequence
1274 {
1275 using type = typename ixconcat<
1276 std::integer_sequence<T, X>,
1277 typename repeat_integer_sequence<T, X, N - 1>::type>::type;
1278 };
1279
1280 template <class T, T X>
1281 struct repeat_integer_sequence<T, X, 0>
1282 {
1283 using type = std::integer_sequence<T>;
1284 };
1285
1286 template <class T, T X>
1287 struct repeat_integer_sequence<T, X, 2>
1288 {
1289 using type = std::integer_sequence<T, X, X>;
1290 };
1291
1292 template <class T, T X>
1293 struct repeat_integer_sequence<T, X, 1>
1294 {
1295 using type = std::integer_sequence<T, X>;
1296 };
1297 }
1298
1299 template <std::size_t... I, class I2, std::size_t N2>
1300 struct xreducer_shape_type<fixed_shape<I...>, std::array<I2, N2>, std::true_type>
1301 {
1302 template <std::size_t... X>
1303 static constexpr auto get_type(std::index_sequence<X...>)
1304 {
1305 return fixed_shape<X...>{};
1306 }
1307
1308 // if all axes reduced
1309 using type = std::conditional_t<
1310 sizeof...(I) == N2,
1311 decltype(get_type(typename detail::repeat_integer_sequence<std::size_t, std::size_t(1), N2>::type{})),
1312 std::array<I2, sizeof...(I)>>;
1313 };
1314
1315 // Note adding "A" to prevent compilation in case nothing else matches
1316 template <std::size_t... I, std::size_t... J, class O>
1318 {
1319 using type = typename fixed_xreducer_shape_type<fixed_shape<I...>, fixed_shape<J...>>::type;
1320 };
1321
1322 namespace detail
1323 {
1324 template <class S, class E, class X, class M>
1325 inline void shape_and_mapping_computation(S& shape, E& e, const X& axes, M& mapping, std::false_type)
1326 {
1327 auto first = e.shape().begin();
1328 auto last = e.shape().end();
1329 auto exclude_it = axes.begin();
1330
1331 using value_type = typename S::value_type;
1332 using difference_type = typename S::difference_type;
1333 auto d_first = shape.begin();
1334 auto map_first = mapping.begin();
1335
1336 auto iter = first;
1337 while (iter != last && exclude_it != axes.end())
1338 {
1339 auto diff = std::distance(first, iter);
1340 if (diff != difference_type(*exclude_it))
1341 {
1342 *d_first++ = *iter++;
1343 *map_first++ = value_type(diff);
1344 }
1345 else
1346 {
1347 ++iter;
1348 ++exclude_it;
1349 }
1350 }
1351
1352 auto diff = std::distance(first, iter);
1353 auto end = std::distance(iter, last);
1354 std::iota(map_first, map_first + end, diff);
1355 std::copy(iter, last, d_first);
1356 }
1357
1358 template <class S, class E, class X, class M>
1359 inline void
1360 shape_and_mapping_computation_keep_dim(S& shape, E& e, const X& axes, M& mapping, std::false_type)
1361 {
1362 for (std::size_t i = 0; i < e.dimension(); ++i)
1363 {
1364 if (std::find(axes.cbegin(), axes.cend(), i) == axes.cend())
1365 {
1366 // i not in axes!
1367 shape[i] = e.shape()[i];
1368 }
1369 else
1370 {
1371 shape[i] = 1;
1372 }
1373 }
1374 std::iota(mapping.begin(), mapping.end(), 0);
1375 }
1376
1377 template <class S, class E, class X, class M>
1378 inline void shape_and_mapping_computation(S&, E&, const X&, M&, std::true_type)
1379 {
1380 }
1381
1382 template <class S, class E, class X, class M>
1383 inline void shape_and_mapping_computation_keep_dim(S&, E&, const X&, M&, std::true_type)
1384 {
1385 }
1386 }
1387
1388 /***************************
1389 * xreducer implementation *
1390 ***************************/
1391
1396
1404 template <class F, class CT, class X, class O>
1405 template <class Func, class CTA, class AX, class OX>
1406 inline xreducer<F, CT, X, O>::xreducer(Func&& func, CTA&& e, AX&& axes, OX&& options)
1407 : m_e(std::forward<CTA>(e))
1408 , m_reduce(xt::get<0>(func))
1409 , m_init(xt::get<1>(func))
1410 , m_merge(xt::get<2>(func))
1411 , m_axes(std::forward<AX>(axes))
1412 , m_shape(xtl::make_sequence<inner_shape_type>(
1413 typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
1414 0
1415 ))
1416 , m_dim_mapping(xtl::make_sequence<dim_mapping_type>(
1417 typename O::keep_dims() ? m_e.dimension() : m_e.dimension() - m_axes.size(),
1418 0
1419 ))
1420 , m_options(std::forward<OX>(options))
1421 {
1422 // std::less is used, because as the standard says (24.4.5):
1423 // A sequence is sorted with respect to a comparator comp if for any iterator i pointing to the
1424 // sequence and any non-negative integer n such that i + n is a valid iterator pointing to an element
1425 // of the sequence, comp(*(i + n), *i) == false. Therefore less is required to detect duplicates.
1426 if (!std::is_sorted(m_axes.cbegin(), m_axes.cend(), std::less<>()))
1427 {
1428 XTENSOR_THROW(std::runtime_error, "Reducing axes should be sorted.");
1429 }
1430 if (std::adjacent_find(m_axes.cbegin(), m_axes.cend()) != m_axes.cend())
1431 {
1432 XTENSOR_THROW(std::runtime_error, "Reducing axes should not contain duplicates.");
1433 }
1434 if (m_axes.size() != 0 && m_axes[m_axes.size() - 1] > m_e.dimension() - 1)
1435 {
1436 XTENSOR_THROW(
1437 std::runtime_error,
1438 "Axis " + std::to_string(m_axes[m_axes.size() - 1]) + " out of bounds for reduction."
1439 );
1440 }
1441
1442 if (!typename O::keep_dims())
1443 {
1444 detail::shape_and_mapping_computation(
1445 m_shape,
1446 m_e,
1447 m_axes,
1448 m_dim_mapping,
1449 detail::is_fixed<shape_type>{}
1450 );
1451 }
1452 else
1453 {
1454 detail::shape_and_mapping_computation_keep_dim(
1455 m_shape,
1456 m_e,
1457 m_axes,
1458 m_dim_mapping,
1459 detail::is_fixed<shape_type>{}
1460 );
1461 }
1462 }
1463
1465
1469
1473 template <class F, class CT, class X, class O>
1474 inline auto xreducer<F, CT, X, O>::shape() const noexcept -> const inner_shape_type&
1475 {
1476 return m_shape;
1477 }
1478
1482 template <class F, class CT, class X, class O>
1484 {
1485 return static_layout;
1486 }
1487
1488 template <class F, class CT, class X, class O>
1489 inline bool xreducer<F, CT, X, O>::is_contiguous() const noexcept
1490 {
1491 return false;
1492 }
1493
1495
1499
1506 template <class F, class CT, class X, class O>
1507 template <class... Args>
1508 inline auto xreducer<F, CT, X, O>::operator()(Args... args) const -> const_reference
1509 {
1510 XTENSOR_TRY(check_index(shape(), args...));
1511 XTENSOR_CHECK_DIMENSION(shape(), args...);
1512 std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
1513 return element(arg_array.cbegin(), arg_array.cend());
1514 }
1515
1535 template <class F, class CT, class X, class O>
1536 template <class... Args>
1537 inline auto xreducer<F, CT, X, O>::unchecked(Args... args) const -> const_reference
1538 {
1539 std::array<std::size_t, sizeof...(Args)> arg_array = {{static_cast<std::size_t>(args)...}};
1540 return element(arg_array.cbegin(), arg_array.cend());
1541 }
1542
1550 template <class F, class CT, class X, class O>
1551 template <class It>
1552 inline auto xreducer<F, CT, X, O>::element(It first, It last) const -> const_reference
1553 {
1554 XTENSOR_TRY(check_element_index(shape(), first, last));
1555 auto stepper = const_stepper(*this, 0);
1556 if (first != last)
1557 {
1558 size_type dim = 0;
1559 // drop left most elements
1560 auto size = std::ptrdiff_t(this->dimension()) - std::distance(first, last);
1561 auto begin = first - size;
1562 while (begin != last)
1563 {
1564 if (begin < first)
1565 {
1566 stepper.step(dim++, std::size_t(0));
1567 begin++;
1568 }
1569 else
1570 {
1571 stepper.step(dim++, std::size_t(*begin++));
1572 }
1573 }
1574 }
1575 return *stepper;
1576 }
1577
1581 template <class F, class CT, class X, class O>
1582 inline auto xreducer<F, CT, X, O>::expression() const noexcept -> const xexpression_type&
1583 {
1584 return m_e;
1585 }
1586
1588
1593
1599 template <class F, class CT, class X, class O>
1600 template <class S>
1602 {
1603 return xt::broadcast_shape(m_shape, shape);
1604 }
1605
1611 template <class F, class CT, class X, class O>
1612 template <class S>
1613 inline bool xreducer<F, CT, X, O>::has_linear_assign(const S& /*strides*/) const noexcept
1614 {
1615 return false;
1616 }
1617
1619
1620 template <class F, class CT, class X, class O>
1621 template <class S>
1622 inline auto xreducer<F, CT, X, O>::stepper_begin(const S& shape) const noexcept -> const_stepper
1623 {
1624 size_type offset = shape.size() - this->dimension();
1625 return const_stepper(*this, offset);
1626 }
1627
1628 template <class F, class CT, class X, class O>
1629 template <class S>
1630 inline auto xreducer<F, CT, X, O>::stepper_end(const S& shape, layout_type l) const noexcept
1631 -> const_stepper
1632 {
1633 size_type offset = shape.size() - this->dimension();
1634 return const_stepper(*this, offset, true, l);
1635 }
1636
1637 template <class F, class CT, class X, class O>
1638 template <class E>
1639 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e) const -> rebind_t<E>
1640 {
1641 return rebind_t<E>(
1642 std::make_tuple(m_reduce, m_init, m_merge),
1643 std::forward<E>(e),
1644 axes_type(m_axes),
1645 m_options
1646 );
1647 }
1648
1649 template <class F, class CT, class X, class O>
1650 template <class E, class Func, class Opts>
1651 inline auto xreducer<F, CT, X, O>::build_reducer(E&& e, Func&& func, Opts&& opts) const
1652 -> rebind_t<E, Func, Opts>
1653 {
1654 return rebind_t<E, Func, Opts>(
1655 std::forward<Func>(func),
1656 std::forward<E>(e),
1657 axes_type(m_axes),
1658 std::forward<Opts>(opts)
1659 );
1660 }
1661
1662 /***********************************
1663 * xreducer_stepper implementation *
1664 ***********************************/
1665
1666 template <class F, class CT, class X, class O>
1667 inline xreducer_stepper<F, CT, X, O>::xreducer_stepper(
1668 const xreducer_type& red,
1669 size_type offset,
1670 bool end,
1671 layout_type l
1672 )
1673 : m_reducer(&red)
1674 , m_offset(offset)
1675 , m_stepper(get_substepper_begin())
1676 {
1677 if (end)
1678 {
1679 to_end(l);
1680 }
1681 }
1682
1683 template <class F, class CT, class X, class O>
1684 inline auto xreducer_stepper<F, CT, X, O>::operator*() const -> reference
1685 {
1686 reference r = aggregate(0);
1687 return r;
1688 }
1689
1690 template <class F, class CT, class X, class O>
1691 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim)
1692 {
1693 if (dim >= m_offset)
1694 {
1695 m_stepper.step(get_dim(dim - m_offset));
1696 }
1697 }
1698
1699 template <class F, class CT, class X, class O>
1700 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim)
1701 {
1702 if (dim >= m_offset)
1703 {
1704 m_stepper.step_back(get_dim(dim - m_offset));
1705 }
1706 }
1707
1708 template <class F, class CT, class X, class O>
1709 inline void xreducer_stepper<F, CT, X, O>::step(size_type dim, size_type n)
1710 {
1711 if (dim >= m_offset)
1712 {
1713 m_stepper.step(get_dim(dim - m_offset), n);
1714 }
1715 }
1716
1717 template <class F, class CT, class X, class O>
1718 inline void xreducer_stepper<F, CT, X, O>::step_back(size_type dim, size_type n)
1719 {
1720 if (dim >= m_offset)
1721 {
1722 m_stepper.step_back(get_dim(dim - m_offset), n);
1723 }
1724 }
1725
1726 template <class F, class CT, class X, class O>
1727 inline void xreducer_stepper<F, CT, X, O>::reset(size_type dim)
1728 {
1729 if (dim >= m_offset)
1730 {
1731 // Because the reducer uses `reset` to reset the non-reducing axes,
1732 // we need to prevent that here for the KD case where.
1733 if (typename O::keep_dims()
1734 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1735 {
1736 // If keep dim activated, and dim is in the axes, do nothing!
1737 return;
1738 }
1739 m_stepper.reset(get_dim(dim - m_offset));
1740 }
1741 }
1742
1743 template <class F, class CT, class X, class O>
1744 inline void xreducer_stepper<F, CT, X, O>::reset_back(size_type dim)
1745 {
1746 if (dim >= m_offset)
1747 {
1748 // Note that for *not* KD this is not going to do anything
1749 if (typename O::keep_dims()
1750 && std::binary_search(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim - m_offset))
1751 {
1752 // If keep dim activated, and dim is in the axes, do nothing!
1753 return;
1754 }
1755 m_stepper.reset_back(get_dim(dim - m_offset));
1756 }
1757 }
1758
1759 template <class F, class CT, class X, class O>
1760 inline void xreducer_stepper<F, CT, X, O>::to_begin()
1761 {
1762 m_stepper.to_begin();
1763 }
1764
1765 template <class F, class CT, class X, class O>
1766 inline void xreducer_stepper<F, CT, X, O>::to_end(layout_type l)
1767 {
1768 m_stepper.to_end(l);
1769 }
1770
1771 template <class F, class CT, class X, class O>
1772 inline auto xreducer_stepper<F, CT, X, O>::initial_value() const -> reference
1773 {
1774 return O::has_initial_value ? m_reducer->m_options.initial_value
1775 : static_cast<reference>(m_reducer->m_init());
1776 }
1777
1778 template <class F, class CT, class X, class O>
1779 inline auto xreducer_stepper<F, CT, X, O>::aggregate(size_type dim) const -> reference
1780 {
1781 reference res;
1782 if (m_reducer->m_e.size() == size_type(0))
1783 {
1784 res = initial_value();
1785 }
1786 else if (m_reducer->m_e.shape().empty() || m_reducer->m_axes.size() == 0)
1787 {
1788 res = m_reducer->m_reduce(initial_value(), *m_stepper);
1789 }
1790 else
1791 {
1792 res = aggregate_impl(dim, typename O::keep_dims());
1793 if (O::has_initial_value && dim == 0)
1794 {
1795 res = m_reducer->m_merge(m_reducer->m_options.initial_value, res);
1796 }
1797 }
1798 return res;
1799 }
1800
1801 template <class F, class CT, class X, class O>
1802 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::false_type) const -> reference
1803 {
1804 // reference can be std::array, hence the {} initializer
1805 reference res = {};
1806 size_type index = axis(dim);
1807 size_type size = shape(index);
1808 if (dim != m_reducer->m_axes.size() - 1)
1809 {
1810 res = aggregate_impl(dim + 1, typename O::keep_dims());
1811 for (size_type i = 1; i != size; ++i)
1812 {
1813 m_stepper.step(index);
1814 res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
1815 }
1816 }
1817 else
1818 {
1819 res = m_reducer->m_reduce(static_cast<reference>(m_reducer->m_init()), *m_stepper);
1820 for (size_type i = 1; i != size; ++i)
1821 {
1822 m_stepper.step(index);
1823 res = m_reducer->m_reduce(res, *m_stepper);
1824 }
1825 }
1826 m_stepper.reset(index);
1827 return res;
1828 }
1829
1830 template <class F, class CT, class X, class O>
1831 inline auto xreducer_stepper<F, CT, X, O>::aggregate_impl(size_type dim, std::true_type) const -> reference
1832 {
1833 // reference can be std::array, hence the {} initializer
1834 reference res = {};
1835 auto ax_it = std::find(m_reducer->m_axes.begin(), m_reducer->m_axes.end(), dim);
1836 if (ax_it != m_reducer->m_axes.end())
1837 {
1838 size_type index = dim;
1839 size_type size = m_reducer->m_e.shape()[index];
1840 if (ax_it != m_reducer->m_axes.end() - 1 && size != 0)
1841 {
1842 res = aggregate_impl(dim + 1, typename O::keep_dims());
1843 for (size_type i = 1; i != size; ++i)
1844 {
1845 m_stepper.step(index);
1846 res = m_reducer->m_merge(res, aggregate_impl(dim + 1, typename O::keep_dims()));
1847 }
1848 }
1849 else
1850 {
1851 res = m_reducer->m_reduce(static_cast<reference>(m_reducer->m_init()), *m_stepper);
1852 for (size_type i = 1; i != size; ++i)
1853 {
1854 m_stepper.step(index);
1855 res = m_reducer->m_reduce(res, *m_stepper);
1856 }
1857 }
1858 m_stepper.reset(index);
1859 }
1860 else
1861 {
1862 if (dim < m_reducer->m_e.dimension())
1863 {
1864 res = aggregate_impl(dim + 1, typename O::keep_dims());
1865 }
1866 }
1867 return res;
1868 }
1869
1870 template <class F, class CT, class X, class O>
1871 inline auto xreducer_stepper<F, CT, X, O>::get_substepper_begin() const -> substepper_type
1872 {
1873 return m_reducer->m_e.stepper_begin(m_reducer->m_e.shape());
1874 }
1875
1876 template <class F, class CT, class X, class O>
1877 inline auto xreducer_stepper<F, CT, X, O>::get_dim(size_type dim) const noexcept -> size_type
1878 {
1879 return m_reducer->m_dim_mapping[dim];
1880 }
1881
1882 template <class F, class CT, class X, class O>
1883 inline auto xreducer_stepper<F, CT, X, O>::shape(size_type i) const noexcept -> size_type
1884 {
1885 return m_reducer->m_e.shape()[i];
1886 }
1887
1888 template <class F, class CT, class X, class O>
1889 inline auto xreducer_stepper<F, CT, X, O>::axis(size_type i) const noexcept -> size_type
1890 {
1891 return m_reducer->m_axes[i];
1892 }
1893}
1894
1895#endif
Fixed shape implementation for compile time defined arrays.
Base class for multidimensional iterable constant expressions.
Definition xiterable.hpp:37
auto end() const noexcept -> const_layout_iterator< L >
Reducing function operating over specified axes.
Definition xreducer.hpp:809
bool broadcast_shape(S &shape, bool reuse_cache=false) const
const xexpression_type & expression() const noexcept
bool has_linear_assign(const S &strides) const noexcept
const inner_shape_type & shape() const noexcept
Returns the shape of the expression.
xreducer(Func &&func, CTA &&e, AX &&axes, OX &&options)
Constructs an xreducer expression applying the specified function to the given expression over the gi...
layout_type layout() const noexcept
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto diff(const xexpression< T > &a, std::size_t n=1, std::ptrdiff_t axis=-1)
Calculate the n-th discrete difference along the given axis.
Definition xmath.hpp:2908
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
Definition xeval.hpp:46
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto reduce(F &&f, E &&e, X &&axes, EVS &&options=EVS())
Returns an xexpression applying the specified reducing function to an expression over the given axes.
Definition xreducer.hpp:994
layout_type
Definition xlayout.hpp:24