xtensor
Loading...
Searching...
No Matches
xfunction.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_FUNCTION_HPP
11#define XTENSOR_FUNCTION_HPP
12
13#include <algorithm>
14#include <cstddef>
15#include <iterator>
16#include <numeric>
17#include <tuple>
18#include <type_traits>
19#include <utility>
20
21#include <xtl/xsequence.hpp>
22#include <xtl/xtype_traits.hpp>
23
24#include "xaccessible.hpp"
25#include "xexpression_traits.hpp"
26#include "xiterable.hpp"
27#include "xiterator.hpp"
28#include "xlayout.hpp"
29#include "xscalar.hpp"
30#include "xshape.hpp"
31#include "xstrides.hpp"
32#include "xtensor_simd.hpp"
33#include "xutils.hpp"
34
35namespace xt
36{
37 namespace detail
38 {
39
40 template <bool... B>
42
43 /************************
44 * xfunction_cache_impl *
45 ************************/
46
47 template <class S, class is_shape_trivial>
48 struct xfunction_cache_impl
49 {
50 S shape;
51 bool is_trivial;
52 bool is_initialized;
53
54 xfunction_cache_impl()
55 : shape(xtl::make_sequence<S>(0, std::size_t(0)))
56 , is_trivial(false)
57 , is_initialized(false)
58 {
59 }
60 };
61
62 template <std::size_t... N, class is_shape_trivial>
63 struct xfunction_cache_impl<fixed_shape<N...>, is_shape_trivial>
64 {
65 XTENSOR_CONSTEXPR_ENHANCED_STATIC fixed_shape<N...> shape = fixed_shape<N...>();
66 XTENSOR_CONSTEXPR_ENHANCED_STATIC bool is_trivial = is_shape_trivial::value;
67 XTENSOR_CONSTEXPR_ENHANCED_STATIC bool is_initialized = true;
68 };
69
70#ifdef XTENSOR_HAS_CONSTEXPR_ENHANCED
71 // Out of line definitions to prevent linker errors prior to C++17
72 template <std::size_t... N, class is_shape_trivial>
73 constexpr fixed_shape<N...> xfunction_cache_impl<fixed_shape<N...>, is_shape_trivial>::shape;
74
75 template <std::size_t... N, class is_shape_trivial>
76 constexpr bool xfunction_cache_impl<fixed_shape<N...>, is_shape_trivial>::is_trivial;
77
78 template <std::size_t... N, class is_shape_trivial>
79 constexpr bool xfunction_cache_impl<fixed_shape<N...>, is_shape_trivial>::is_initialized;
80#endif
81
82 template <class... CT>
83 struct xfunction_bool_load_type
84 {
85 using type = xtl::promote_type_t<typename std::decay_t<CT>::bool_load_type...>;
86 };
87
88 template <class CT>
89 struct xfunction_bool_load_type<CT>
90 {
91 using type = typename std::decay_t<CT>::bool_load_type;
92 };
93
94 template <class... CT>
95 using xfunction_bool_load_type_t = typename xfunction_bool_load_type<CT...>::type;
96 }
97
98 /************************
99 * xfunction extensions *
100 ************************/
101
102 namespace extension
103 {
104
105 template <class Tag, class F, class... CT>
107
108 template <class F, class... CT>
113
114 template <class F, class... CT>
115 struct xfunction_base : xfunction_base_impl<xexpression_tag_t<CT...>, F, CT...>
116 {
117 };
118
119 template <class F, class... CT>
120 using xfunction_base_t = typename xfunction_base<F, CT...>::type;
121 }
122
123 template <class promote>
124 struct xfunction_cache : detail::xfunction_cache_impl<typename promote::type, promote>
125 {
126 };
127
128 template <class F, class... CT>
129 class xfunction_iterator;
130
131 template <class F, class... CT>
132 class xfunction_stepper;
133
134 template <class F, class... CT>
135 class xfunction;
136
137 template <class F, class... CT>
144
145 template <class F, class... CT>
147 {
148 // Added indirection for MSVC 2017 bug with the operator value_type()
149 using func_return_type = typename meta_identity<
150 decltype(std::declval<F>()(std::declval<xvalue_type_t<std::decay_t<CT>>>()...))>::type;
151 using value_type = std::decay_t<func_return_type>;
152 using reference = func_return_type;
153 using const_reference = reference;
155 };
156
157 template <class T, class F, class... CT>
159 has_simd_type<T>,
160 has_simd_apply<F, xt_simd::simd_type<T>>,
161 has_simd_interface<std::decay_t<CT>, T>...>
162 {
163 };
164
165 /*************************************
166 * overlapping_memory_checker_traits *
167 *************************************/
168
169 template <class E>
171 E,
172 std::enable_if_t<!has_memory_address<E>::value && is_specialization_of<xfunction, E>::value>>
173 {
174 template <std::size_t I = 0, class... T, std::enable_if_t<(I == sizeof...(T)), int> = 0>
175 static bool check_tuple(const std::tuple<T...>&, const memory_range&)
176 {
177 return false;
178 }
179
180 template <std::size_t I = 0, class... T, std::enable_if_t<(I < sizeof...(T)), int> = 0>
181 static bool check_tuple(const std::tuple<T...>& t, const memory_range& dst_range)
182 {
183 using ChildE = std::decay_t<decltype(std::get<I>(t))>;
186 }
187
188 static bool check_overlap(const E& expr, const memory_range& dst_range)
189 {
190 if (expr.size() == 0)
191 {
192 return false;
193 }
194 else
195 {
196 return check_tuple(expr.arguments(), dst_range);
197 }
198 }
199 };
200
201 /*************
202 * xfunction *
203 *************/
204
216 template <class F, class... CT>
217 class xfunction : private xconst_iterable<xfunction<F, CT...>>,
218 public xsharable_expression<xfunction<F, CT...>>,
219 private xconst_accessible<xfunction<F, CT...>>,
220 public extension::xfunction_base_t<F, CT...>
221 {
222 public:
223
224 using self_type = xfunction<F, CT...>;
226 using extension_base = extension::xfunction_base_t<F, CT...>;
227 using expression_tag = typename extension_base::expression_tag;
228 using only_scalar = all_xscalar<CT...>;
229 using functor_type = typename std::remove_reference<F>::type;
230 using tuple_type = std::tuple<CT...>;
231
233 using value_type = typename inner_types::value_type;
234 using reference = typename inner_types::reference;
235 using const_reference = typename inner_types::const_reference;
236 using pointer = value_type*;
237 using const_pointer = const value_type*;
238 using size_type = typename inner_types::size_type;
240
241 using simd_value_type = xt_simd::simd_type<value_type>;
242
243 // xtl::promote_type_t<typename std::decay_t<CT>::bool_load_type...>;
244 using bool_load_type = detail::xfunction_bool_load_type_t<CT...>;
245
246 template <class requested_type>
247 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
248
250 using inner_shape_type = typename iterable_base::inner_shape_type;
251 using shape_type = inner_shape_type;
252
253 using stepper = typename iterable_base::stepper;
254 using const_stepper = typename iterable_base::const_stepper;
255
256 static constexpr layout_type static_layout = compute_layout(std::decay_t<CT>::static_layout...);
257 static constexpr bool contiguous_layout = static_layout != layout_type::dynamic;
258
259 template <layout_type L>
260 using layout_iterator = typename iterable_base::template layout_iterator<L>;
261 template <layout_type L>
262 using const_layout_iterator = typename iterable_base::template const_layout_iterator<L>;
263 template <layout_type L>
264 using reverse_layout_iterator = typename iterable_base::template reverse_layout_iterator<L>;
265 template <layout_type L>
266 using const_reverse_layout_iterator = typename iterable_base::template const_reverse_layout_iterator<L>;
267
268 template <class S, layout_type L>
269 using broadcast_iterator = typename iterable_base::template broadcast_iterator<S, L>;
270 template <class S, layout_type L>
271 using const_broadcast_iterator = typename iterable_base::template const_broadcast_iterator<S, L>;
272 template <class S, layout_type L>
273 using reverse_broadcast_iterator = typename iterable_base::template reverse_broadcast_iterator<S, L>;
274 template <class S, layout_type L>
275 using const_reverse_broadcast_iterator = typename iterable_base::template const_reverse_broadcast_iterator<S, L>;
276
279 using const_reverse_linear_iterator = std::reverse_iterator<const_linear_iterator>;
280 using reverse_linear_iterator = std::reverse_iterator<linear_iterator>;
281
282 using iterator = typename iterable_base::iterator;
283 using const_iterator = typename iterable_base::const_iterator;
284 using reverse_iterator = typename iterable_base::reverse_iterator;
285 using const_reverse_iterator = typename iterable_base::const_reverse_iterator;
286
287 template <class Func, class... CTA, class U = std::enable_if_t<!std::is_base_of<std::decay_t<Func>, self_type>::value>>
288 xfunction(Func&& f, CTA&&... e) noexcept;
289
290 template <class FA, class... CTA>
292
293 ~xfunction() = default;
294
295 xfunction(const xfunction&) = default;
296 xfunction& operator=(const xfunction&) = default;
297
298 xfunction(xfunction&&) = default;
299 xfunction& operator=(xfunction&&) = default;
300
302 size_type dimension() const noexcept;
303 const inner_shape_type& shape() const;
304 layout_type layout() const noexcept;
305 bool is_contiguous() const noexcept;
307
308 template <class... Args>
309 const_reference operator()(Args... args) const;
310
311 template <class... Args>
312 const_reference unchecked(Args... args) const;
313
314 using accessible_base::at;
315 using accessible_base::operator[];
319 using accessible_base::periodic;
320
321 template <class It>
322 const_reference element(It first, It last) const;
323
324 template <class S>
325 bool broadcast_shape(S& shape, bool reuse_cache = false) const;
326
327 template <class S>
328 bool has_linear_assign(const S& strides) const noexcept;
329
330 using iterable_base::begin;
331 using iterable_base::cbegin;
332 using iterable_base::cend;
333 using iterable_base::crbegin;
334 using iterable_base::crend;
335 using iterable_base::end;
336 using iterable_base::rbegin;
337 using iterable_base::rend;
338
339 const_linear_iterator linear_begin() const noexcept;
340 const_linear_iterator linear_end() const noexcept;
341 const_linear_iterator linear_cbegin() const noexcept;
342 const_linear_iterator linear_cend() const noexcept;
343
344 const_reverse_linear_iterator linear_rbegin() const noexcept;
345 const_reverse_linear_iterator linear_rend() const noexcept;
346 const_reverse_linear_iterator linear_crbegin() const noexcept;
347 const_reverse_linear_iterator linear_crend() const noexcept;
348
349 template <class S>
350 const_stepper stepper_begin(const S& shape) const noexcept;
351 template <class S>
352 const_stepper stepper_end(const S& shape, layout_type l) const noexcept;
353
354 const_reference data_element(size_type i) const;
355
356 const_reference flat(size_type i) const;
357
359 operator value_type() const;
360
362 simd_return_type<requested_type> load_simd(size_type i) const;
363
364 const tuple_type& arguments() const noexcept;
365
366 const functor_type& functor() const noexcept;
367
368 private:
369
370 template <std::size_t... I>
371 layout_type layout_impl(std::index_sequence<I...>) const noexcept;
372
373 template <std::size_t... I, class... Args>
374 const_reference access_impl(std::index_sequence<I...>, Args... args) const;
375
376 template <std::size_t... I, class... Args>
377 const_reference unchecked_impl(std::index_sequence<I...>, Args... args) const;
378
379 template <std::size_t... I, class It>
380 const_reference element_access_impl(std::index_sequence<I...>, It first, It last) const;
381
382 template <std::size_t... I>
383 const_reference data_element_impl(std::index_sequence<I...>, size_type i) const;
384
385 template <class align, class requested_type, std::size_t N, std::size_t... I>
386 auto load_simd_impl(std::index_sequence<I...>, size_type i) const;
387
388 template <class Func, std::size_t... I>
389 const_stepper build_stepper(Func&& f, std::index_sequence<I...>) const noexcept;
390
391 template <class Func, std::size_t... I>
392 auto build_iterator(Func&& f, std::index_sequence<I...>) const noexcept;
393
394 size_type compute_dimension() const noexcept;
395
396 void compute_cached_shape() const;
397
398 tuple_type m_e;
399 functor_type m_f;
401
402 friend class xfunction_iterator<F, CT...>;
403 friend class xfunction_stepper<F, CT...>;
404 friend class xconst_iterable<self_type>;
405 friend class xconst_accessible<self_type>;
406 };
407
408 /**********************
409 * xfunction_iterator *
410 **********************/
411
412 template <class F, class... CT>
413 class xfunction_iterator : public xtl::xrandom_access_iterator_base<
414 xfunction_iterator<F, CT...>,
415 typename xfunction<F, CT...>::value_type,
416 typename xfunction<F, CT...>::difference_type,
417 typename xfunction<F, CT...>::pointer,
418 typename xfunction<F, CT...>::reference>
419 {
420 public:
421
422 using self_type = xfunction_iterator<F, CT...>;
423 using functor_type = typename std::remove_reference<F>::type;
424 using xfunction_type = xfunction<F, CT...>;
425
426 using value_type = typename xfunction_type::value_type;
427 using reference = typename xfunction_type::value_type;
428 using pointer = typename xfunction_type::const_pointer;
429 using difference_type = typename xfunction_type::difference_type;
430 using iterator_category = std::random_access_iterator_tag;
431
432 template <class... It>
433 xfunction_iterator(const xfunction_type* func, It&&... it) noexcept;
434
435 self_type& operator++();
436 self_type& operator--();
437
438 self_type& operator+=(difference_type n);
439 self_type& operator-=(difference_type n);
440
441 difference_type operator-(const self_type& rhs) const;
442
443 reference operator*() const;
444
445 bool equal(const self_type& rhs) const;
446 bool less_than(const self_type& rhs) const;
447
448 private:
449
450 using data_type = std::tuple<decltype(xt::linear_begin(std::declval<const std::decay_t<CT>>()))...>;
451
452 template <std::size_t... I>
453 reference deref_impl(std::index_sequence<I...>) const;
454
455 template <std::size_t... I>
456 difference_type
457 tuple_max_diff(std::index_sequence<I...>, const data_type& lhs, const data_type& rhs) const;
458
459 const xfunction_type* p_f;
460 data_type m_it;
461 };
462
463 template <class F, class... CT>
465
466 template <class F, class... CT>
468
469 /*********************
470 * xfunction_stepper *
471 *********************/
472
473 template <class F, class... CT>
475 {
476 public:
477
478 using self_type = xfunction_stepper<F, CT...>;
479 using functor_type = typename std::remove_reference<F>::type;
480 using xfunction_type = xfunction<F, CT...>;
481
482 using value_type = typename xfunction_type::value_type;
483 using reference = typename xfunction_type::reference;
484 using pointer = typename xfunction_type::const_pointer;
485 using size_type = typename xfunction_type::size_type;
486 using difference_type = typename xfunction_type::difference_type;
487
488 using shape_type = typename xfunction_type::shape_type;
489
490 template <class requested_type>
491 using simd_return_type = xt_simd::simd_return_type<value_type, requested_type>;
492
493 template <class... St>
494 xfunction_stepper(const xfunction_type* func, St&&... st) noexcept;
495
496 void step(size_type dim);
497 void step_back(size_type dim);
498 void step(size_type dim, size_type n);
499 void step_back(size_type dim, size_type n);
500 void reset(size_type dim);
501 void reset_back(size_type dim);
502
503 void to_begin();
504 void to_end(layout_type l);
505
506 reference operator*() const;
507
508 template <class T>
509 simd_return_type<T> step_simd();
510
511 void step_leading();
512
513 private:
514
515 template <std::size_t... I>
516 reference deref_impl(std::index_sequence<I...>) const;
517
518 template <class T, std::size_t... I>
519 simd_return_type<T> step_simd_impl(std::index_sequence<I...>);
520
521 const xfunction_type* p_f;
522 std::tuple<typename std::decay_t<CT>::const_stepper...> m_st;
523 };
524
525 /*********************************
526 * xfunction implementation *
527 *********************************/
528
539 template <class F, class... CT>
540 template <class Func, class... CTA, class U>
542 : m_e(std::forward<CTA>(e)...)
543 , m_f(std::forward<Func>(f))
544 {
545 }
546
552 template <class F, class... CT>
553 template <class FA, class... CTA>
555 : m_e(xf.arguments())
556 , m_f(xf.functor())
557 {
558 }
559
561
569 template <class F, class... CT>
571 {
572 size_type dimension = m_cache.is_initialized ? m_cache.shape.size() : compute_dimension();
573 return dimension;
574 }
575
576 template <class F, class... CT>
578 {
579 static_assert(!detail::is_fixed<shape_type>::value, "Calling compute_cached_shape on fixed!");
580
581 m_cache.shape = uninitialized_shape<xindex_type_t<inner_shape_type>>(compute_dimension());
582 m_cache.is_trivial = broadcast_shape(m_cache.shape, false);
583 m_cache.is_initialized = true;
584 }
585
589 template <class F, class... CT>
590 inline auto xfunction<F, CT...>::shape() const -> const inner_shape_type&
591 {
592 xtl::mpl::static_if<!detail::is_fixed<inner_shape_type>::value>(
593 [&](auto self)
594 {
595 if (!m_cache.is_initialized)
596 {
597 self(this)->compute_cached_shape();
598 }
599 },
600 [](auto /*self*/) {}
601 );
602 return m_cache.shape;
603 }
604
608 template <class F, class... CT>
610 {
611 return layout_impl(std::make_index_sequence<sizeof...(CT)>());
612 }
613
614 template <class F, class... CT>
616 {
617 return layout() != layout_type::dynamic
618 && accumulate(
619 [](bool r, const auto& exp)
620 {
621 return r && exp.is_contiguous();
622 },
623 true,
624 m_e
625 );
626 }
627
629
640 template <class F, class... CT>
641 template <class... Args>
642 inline auto xfunction<F, CT...>::operator()(Args... args) const -> const_reference
643 {
644 // The static cast prevents the compiler from instantiating the template methods with signed integers,
645 // leading to warning about signed/unsigned conversions in the deeper layers of the access methods
646 return access_impl(std::make_index_sequence<sizeof...(CT)>(), static_cast<size_type>(args)...);
647 }
648
658 template <class F, class... CT>
659 inline auto xfunction<F, CT...>::flat(size_type index) const -> const_reference
660 {
661 return data_element_impl(std::make_index_sequence<sizeof...(CT)>(), index);
662 }
663
683 template <class F, class... CT>
684 template <class... Args>
685 inline auto xfunction<F, CT...>::unchecked(Args... args) const -> const_reference
686 {
687 // The static cast prevents the compiler from instantiating the template methods with signed integers,
688 // leading to warning about signed/unsigned conversions in the deeper layers of the access methods
689 return unchecked_impl(std::make_index_sequence<sizeof...(CT)>(), static_cast<size_type>(args)...);
690 }
691
699 template <class F, class... CT>
700 template <class It>
701 inline auto xfunction<F, CT...>::element(It first, It last) const -> const_reference
702 {
703 return element_access_impl(std::make_index_sequence<sizeof...(CT)>(), first, last);
704 }
705
707
718 template <class F, class... CT>
719 template <class S>
720 inline bool xfunction<F, CT...>::broadcast_shape(S& shape, bool reuse_cache) const
721 {
722 if (m_cache.is_initialized && reuse_cache)
723 {
724 std::copy(m_cache.shape.cbegin(), m_cache.shape.cend(), shape.begin());
725 return m_cache.is_trivial;
726 }
727 else
728 {
729 // e.broadcast_shape must be evaluated even if b is false
730 auto func = [&shape](bool b, auto&& e)
731 {
732 return e.broadcast_shape(shape) && b;
733 };
734 return accumulate(func, true, m_e);
735 }
736 }
737
743 template <class F, class... CT>
744 template <class S>
745 inline bool xfunction<F, CT...>::has_linear_assign(const S& strides) const noexcept
746 {
747 auto func = [&strides](bool b, auto&& e)
748 {
749 return b && e.has_linear_assign(strides);
750 };
751 return accumulate(func, true, m_e);
752 }
753
755
756 template <class F, class... CT>
757 inline auto xfunction<F, CT...>::linear_begin() const noexcept -> const_linear_iterator
758 {
759 return linear_cbegin();
760 }
761
762 template <class F, class... CT>
763 inline auto xfunction<F, CT...>::linear_end() const noexcept -> const_linear_iterator
764 {
765 return linear_cend();
766 }
767
768 template <class F, class... CT>
769 inline auto xfunction<F, CT...>::linear_cbegin() const noexcept -> const_linear_iterator
770 {
771 auto f = [](const auto& e) noexcept
772 {
773 return xt::linear_begin(e);
774 };
775 return build_iterator(f, std::make_index_sequence<sizeof...(CT)>());
776 }
777
778 template <class F, class... CT>
779 inline auto xfunction<F, CT...>::linear_cend() const noexcept -> const_linear_iterator
780 {
781 auto f = [](const auto& e) noexcept
782 {
783 return xt::linear_end(e);
784 };
785 return build_iterator(f, std::make_index_sequence<sizeof...(CT)>());
786 }
787
788 template <class F, class... CT>
789 inline auto xfunction<F, CT...>::linear_rbegin() const noexcept -> const_reverse_linear_iterator
790 {
791 return linear_crbegin();
792 }
793
794 template <class F, class... CT>
795 inline auto xfunction<F, CT...>::linear_rend() const noexcept -> const_reverse_linear_iterator
796 {
797 return linear_crend();
798 }
799
800 template <class F, class... CT>
801 inline auto xfunction<F, CT...>::linear_crbegin() const noexcept -> const_reverse_linear_iterator
802 {
803 return const_reverse_linear_iterator(linear_cend());
804 }
805
806 template <class F, class... CT>
807 inline auto xfunction<F, CT...>::linear_crend() const noexcept -> const_reverse_linear_iterator
808 {
809 return const_reverse_linear_iterator(linear_cbegin());
810 }
811
812 template <class F, class... CT>
813 template <class S>
814 inline auto xfunction<F, CT...>::stepper_begin(const S& shape) const noexcept -> const_stepper
815 {
816 auto f = [&shape](const auto& e) noexcept
817 {
818 return e.stepper_begin(shape);
819 };
820 return build_stepper(f, std::make_index_sequence<sizeof...(CT)>());
821 }
822
823 template <class F, class... CT>
824 template <class S>
825 inline auto xfunction<F, CT...>::stepper_end(const S& shape, layout_type l) const noexcept -> const_stepper
826 {
827 auto f = [&shape, l](const auto& e) noexcept
828 {
829 return e.stepper_end(shape, l);
830 };
831 return build_stepper(f, std::make_index_sequence<sizeof...(CT)>());
832 }
833
834 template <class F, class... CT>
835 inline auto xfunction<F, CT...>::data_element(size_type i) const -> const_reference
836 {
837 return data_element_impl(std::make_index_sequence<sizeof...(CT)>(), i);
838 }
839
840 template <class F, class... CT>
841 template <class UT, class>
842 inline xfunction<F, CT...>::operator value_type() const
843 {
844 return operator()();
845 }
846
847 template <class F, class... CT>
848 template <class align, class requested_type, std::size_t N>
849 inline auto xfunction<F, CT...>::load_simd(size_type i) const -> simd_return_type<requested_type>
850 {
851 return load_simd_impl<align, requested_type, N>(std::make_index_sequence<sizeof...(CT)>(), i);
852 }
853
854 template <class F, class... CT>
855 inline auto xfunction<F, CT...>::arguments() const noexcept -> const tuple_type&
856 {
857 return m_e;
858 }
859
860 template <class F, class... CT>
861 inline auto xfunction<F, CT...>::functor() const noexcept -> const functor_type&
862 {
863 return m_f;
864 }
865
866 template <class F, class... CT>
867 template <std::size_t... I>
868 inline layout_type xfunction<F, CT...>::layout_impl(std::index_sequence<I...>) const noexcept
869 {
870 return compute_layout(std::get<I>(m_e).layout()...);
871 }
872
873 template <class F, class... CT>
874 template <std::size_t... I, class... Args>
875 inline auto xfunction<F, CT...>::access_impl(std::index_sequence<I...>, Args... args) const
876 -> const_reference
877 {
878 XTENSOR_TRY(check_index(shape(), args...));
879 XTENSOR_CHECK_DIMENSION(shape(), args...);
880 return m_f(std::get<I>(m_e)(args...)...);
881 }
882
883 template <class F, class... CT>
884 template <std::size_t... I, class... Args>
885 inline auto xfunction<F, CT...>::unchecked_impl(std::index_sequence<I...>, Args... args) const
886 -> const_reference
887 {
888 return m_f(std::get<I>(m_e).unchecked(args...)...);
889 }
890
891 template <class F, class... CT>
892 template <std::size_t... I, class It>
893 inline auto xfunction<F, CT...>::element_access_impl(std::index_sequence<I...>, It first, It last) const
894 -> const_reference
895 {
896 XTENSOR_TRY(check_element_index(shape(), first, last));
897 return m_f((std::get<I>(m_e).element(first, last))...);
898 }
899
900 template <class F, class... CT>
901 template <std::size_t... I>
902 inline auto xfunction<F, CT...>::data_element_impl(std::index_sequence<I...>, size_type i) const
903 -> const_reference
904 {
905 return m_f((std::get<I>(m_e).data_element(i))...);
906 }
907
908 template <class F, class... CT>
909 template <class align, class requested_type, std::size_t N, std::size_t... I>
910 inline auto xfunction<F, CT...>::load_simd_impl(std::index_sequence<I...>, size_type i) const
911 {
912 return m_f.simd_apply((std::get<I>(m_e).template load_simd<align, requested_type>(i))...);
913 }
914
915 template <class F, class... CT>
916 template <class Func, std::size_t... I>
917 inline auto xfunction<F, CT...>::build_stepper(Func&& f, std::index_sequence<I...>) const noexcept
918 -> const_stepper
919 {
920 return const_stepper(this, f(std::get<I>(m_e))...);
921 }
922
923 template <class F, class... CT>
924 template <class Func, std::size_t... I>
925 inline auto xfunction<F, CT...>::build_iterator(Func&& f, std::index_sequence<I...>) const noexcept
926 {
927 return const_linear_iterator(this, f(std::get<I>(m_e))...);
928 }
929
930 template <class F, class... CT>
931 inline auto xfunction<F, CT...>::compute_dimension() const noexcept -> size_type
932 {
933 auto func = [](size_type d, auto&& e) noexcept
934 {
935 return (std::max)(d, e.dimension());
936 };
937 return accumulate(func, size_type(0), m_e);
938 }
939
940 /*************************************
941 * xfunction_iterator implementation *
942 *************************************/
943
944 template <class F, class... CT>
945 template <class... It>
946 inline xfunction_iterator<F, CT...>::xfunction_iterator(const xfunction_type* func, It&&... it) noexcept
947 : p_f(func)
948 , m_it(std::forward<It>(it)...)
949 {
950 }
951
952 template <class F, class... CT>
953 inline auto xfunction_iterator<F, CT...>::operator++() -> self_type&
954 {
955 auto f = [](auto& it)
956 {
957 ++it;
958 };
959 for_each(f, m_it);
960 return *this;
961 }
962
963 template <class F, class... CT>
964 inline auto xfunction_iterator<F, CT...>::operator--() -> self_type&
965 {
966 auto f = [](auto& it)
967 {
968 return --it;
969 };
970 for_each(f, m_it);
971 return *this;
972 }
973
974 template <class F, class... CT>
975 inline auto xfunction_iterator<F, CT...>::operator+=(difference_type n) -> self_type&
976 {
977 auto f = [n](auto& it)
978 {
979 it += n;
980 };
981 for_each(f, m_it);
982 return *this;
983 }
984
985 template <class F, class... CT>
986 inline auto xfunction_iterator<F, CT...>::operator-=(difference_type n) -> self_type&
987 {
988 auto f = [n](auto& it)
989 {
990 it -= n;
991 };
992 for_each(f, m_it);
993 return *this;
994 }
995
996 template <class F, class... CT>
997 inline auto xfunction_iterator<F, CT...>::operator-(const self_type& rhs) const -> difference_type
998 {
999 return tuple_max_diff(std::make_index_sequence<sizeof...(CT)>(), m_it, rhs.m_it);
1000 }
1001
1002 template <class F, class... CT>
1003 inline auto xfunction_iterator<F, CT...>::operator*() const -> reference
1004 {
1005 return deref_impl(std::make_index_sequence<sizeof...(CT)>());
1006 }
1007
1008 template <class F, class... CT>
1009 inline bool xfunction_iterator<F, CT...>::equal(const self_type& rhs) const
1010 {
1011 // Optimization: no need to compare each subiterator since they all
1012 // are incremented decremented together.
1013 constexpr std::size_t temp = xtl::mpl::find_if<is_not_xdummy_iterator, data_type>::value;
1014 constexpr std::size_t index = (temp == std::tuple_size<data_type>::value) ? 0 : temp;
1015 return std::get<index>(m_it) == std::get<index>(rhs.m_it);
1016 }
1017
1018 template <class F, class... CT>
1019 inline bool xfunction_iterator<F, CT...>::less_than(const self_type& rhs) const
1020 {
1021 // Optimization: no need to compare each subiterator since they all
1022 // are incremented decremented together.
1023 constexpr std::size_t temp = xtl::mpl::find_if<is_not_xdummy_iterator, data_type>::value;
1024 constexpr std::size_t index = (temp == std::tuple_size<data_type>::value) ? 0 : temp;
1025 return std::get<index>(m_it) < std::get<index>(rhs.m_it);
1026 }
1027
1028 template <class F, class... CT>
1029 template <std::size_t... I>
1030 inline auto xfunction_iterator<F, CT...>::deref_impl(std::index_sequence<I...>) const -> reference
1031 {
1032 return (p_f->m_f)(*std::get<I>(m_it)...);
1033 }
1034
1035 template <class F, class... CT>
1036 template <std::size_t... I>
1037 inline auto xfunction_iterator<F, CT...>::tuple_max_diff(
1038 std::index_sequence<I...>,
1039 const data_type& lhs,
1040 const data_type& rhs
1041 ) const -> difference_type
1042 {
1043 auto diff = std::make_tuple((std::get<I>(lhs) - std::get<I>(rhs))...);
1044 auto func = [](difference_type n, auto&& v)
1045 {
1046 return (std::max)(n, v);
1047 };
1048 return accumulate(func, difference_type(0), diff);
1049 }
1050
1051 template <class F, class... CT>
1052 inline bool operator==(const xfunction_iterator<F, CT...>& it1, const xfunction_iterator<F, CT...>& it2)
1053 {
1054 return it1.equal(it2);
1055 }
1056
1057 template <class F, class... CT>
1058 inline bool operator<(const xfunction_iterator<F, CT...>& it1, const xfunction_iterator<F, CT...>& it2)
1059 {
1060 return it1.less_than(it2);
1061 }
1062
1063 /************************************
1064 * xfunction_stepper implementation *
1065 ************************************/
1066
1067 template <class F, class... CT>
1068 template <class... St>
1069 inline xfunction_stepper<F, CT...>::xfunction_stepper(const xfunction_type* func, St&&... st) noexcept
1070 : p_f(func)
1071 , m_st(std::forward<St>(st)...)
1072 {
1073 }
1074
1075 template <class F, class... CT>
1076 inline void xfunction_stepper<F, CT...>::step(size_type dim)
1077 {
1078 auto f = [dim](auto& st)
1079 {
1080 st.step(dim);
1081 };
1082 for_each(f, m_st);
1083 }
1084
1085 template <class F, class... CT>
1086 inline void xfunction_stepper<F, CT...>::step_back(size_type dim)
1087 {
1088 auto f = [dim](auto& st)
1089 {
1090 st.step_back(dim);
1091 };
1092 for_each(f, m_st);
1093 }
1094
1095 template <class F, class... CT>
1096 inline void xfunction_stepper<F, CT...>::step(size_type dim, size_type n)
1097 {
1098 auto f = [dim, n](auto& st)
1099 {
1100 st.step(dim, n);
1101 };
1102 for_each(f, m_st);
1103 }
1104
1105 template <class F, class... CT>
1106 inline void xfunction_stepper<F, CT...>::step_back(size_type dim, size_type n)
1107 {
1108 auto f = [dim, n](auto& st)
1109 {
1110 st.step_back(dim, n);
1111 };
1112 for_each(f, m_st);
1113 }
1114
1115 template <class F, class... CT>
1116 inline void xfunction_stepper<F, CT...>::reset(size_type dim)
1117 {
1118 auto f = [dim](auto& st)
1119 {
1120 st.reset(dim);
1121 };
1122 for_each(f, m_st);
1123 }
1124
1125 template <class F, class... CT>
1126 inline void xfunction_stepper<F, CT...>::reset_back(size_type dim)
1127 {
1128 auto f = [dim](auto& st)
1129 {
1130 st.reset_back(dim);
1131 };
1132 for_each(f, m_st);
1133 }
1134
1135 template <class F, class... CT>
1136 inline void xfunction_stepper<F, CT...>::to_begin()
1137 {
1138 auto f = [](auto& st)
1139 {
1140 st.to_begin();
1141 };
1142 for_each(f, m_st);
1143 }
1144
1145 template <class F, class... CT>
1146 inline void xfunction_stepper<F, CT...>::to_end(layout_type l)
1147 {
1148 auto f = [l](auto& st)
1149 {
1150 st.to_end(l);
1151 };
1152 for_each(f, m_st);
1153 }
1154
1155 template <class F, class... CT>
1156 inline auto xfunction_stepper<F, CT...>::operator*() const -> reference
1157 {
1158 return deref_impl(std::make_index_sequence<sizeof...(CT)>());
1159 }
1160
1161 template <class F, class... CT>
1162 template <std::size_t... I>
1163 inline auto xfunction_stepper<F, CT...>::deref_impl(std::index_sequence<I...>) const -> reference
1164 {
1165 return (p_f->m_f)(*std::get<I>(m_st)...);
1166 }
1167
1168 template <class F, class... CT>
1169 template <class T, std::size_t... I>
1170 inline auto xfunction_stepper<F, CT...>::step_simd_impl(std::index_sequence<I...>) -> simd_return_type<T>
1171 {
1172 return (p_f->m_f.simd_apply)(std::get<I>(m_st).template step_simd<T>()...);
1173 }
1174
1175 template <class F, class... CT>
1176 template <class T>
1177 inline auto xfunction_stepper<F, CT...>::step_simd() -> simd_return_type<T>
1178 {
1179 return step_simd_impl<T>(std::make_index_sequence<sizeof...(CT)>());
1180 }
1181
1182 template <class F, class... CT>
1183 inline void xfunction_stepper<F, CT...>::step_leading()
1184 {
1185 auto step_leading_lambda = [](auto&& st)
1186 {
1187 st.step_leading();
1188 };
1189 for_each(step_leading_lambda, m_st);
1190 }
1191}
1192
1193#endif
Base class for implementation of common expression constant access methods.
const_reference front() const
Returns a constant reference to first the element of the expression.
size_type size() const noexcept
Returns the size of the expression.
bool in_bounds(Args... args) const
Returns true only if the the specified position is a valid entry in the expression.
const_reference back() const
Returns a constant reference to last the element of the expression.
size_type shape(size_type index) const
Returns the i-th dimension of the expression.
Base class for multidimensional iterable constant expressions.
Definition xiterable.hpp:37
Multidimensional function operating on xtensor expressions.
bool broadcast_shape(S &shape, bool reuse_cache=false) const
Broadcast the shape of the function to the specified parameter.
layout_type layout() const noexcept
Returns the layout_type of the xfunction.
const inner_shape_type & shape() const
Returns the shape of the xfunction.
size_type dimension() const noexcept
Returns the number of dimensions of the function.
xfunction(Func &&f, CTA &&... e) noexcept
Constructs an xfunction applying the specified function to the given arguments.
bool has_linear_assign(const S &strides) const noexcept
Checks whether the xfunction can be linearly assigned to an expression with the specified strides.
const_reference flat(size_type i) const
Returns a constant reference to the element at the specified position of the underlying contiguous st...
auto diff(const xexpression< T > &a, std::size_t n=1, std::ptrdiff_t axis=-1)
Calculate the n-th discrete difference along the given axis.
Definition xmath.hpp:2911
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
constexpr layout_type compute_layout(Args... args) noexcept
Implementation of the following logical table:
Definition xlayout.hpp:88
bool operator==(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks equality of the iterators.
layout_type
Definition xlayout.hpp:24
auto accumulate(F &&f, E &&e, EVS evaluation_strategy=EVS())
Accumulate and flatten array NOTE This function is not lazy!