10#ifndef XTENSOR_OPERATION_HPP
11#define XTENSOR_OPERATION_HPP
17#include <xtl/xsequence.hpp>
19#include "xfunction.hpp"
21#include "xstrided_view.hpp"
22#include "xstrides.hpp"
31#define UNARY_OPERATOR_FUNCTOR(NAME, OP) \
35 constexpr auto operator()(const A1& arg) const \
40 constexpr auto simd_apply(const B& arg) const \
46#define DEFINE_COMPLEX_OVERLOAD(OP) \
47 template <class T1, class T2, XTL_REQUIRES(xtl::negation<std::is_same<T1, T2>>)> \
48 constexpr auto operator OP(const std::complex<T1>& arg1, const std::complex<T2>& arg2) \
50 using result_type = typename xtl::promote_type_t<std::complex<T1>, std::complex<T2>>; \
51 return (result_type(arg1) OP result_type(arg2)); \
54 template <class T1, class T2, XTL_REQUIRES(xtl::negation<std::is_same<T1, T2>>)> \
55 constexpr auto operator OP(const T1& arg1, const std::complex<T2>& arg2) \
57 using result_type = typename xtl::promote_type_t<T1, std::complex<T2>>; \
58 return (result_type(arg1) OP result_type(arg2)); \
61 template <class T1, class T2, XTL_REQUIRES(xtl::negation<std::is_same<T1, T2>>)> \
62 constexpr auto operator OP(const std::complex<T1>& arg1, const T2& arg2) \
64 using result_type = typename xtl::promote_type_t<std::complex<T1>, T2>; \
65 return (result_type(arg1) OP result_type(arg2)); \
68#define BINARY_OPERATOR_FUNCTOR(NAME, OP) \
71 template <class T1, class T2> \
72 constexpr auto operator()(T1&& arg1, T2&& arg2) const \
74 using xt::detail::operator OP; \
75 return (std::forward<T1>(arg1) OP std::forward<T2>(arg2)); \
78 constexpr auto simd_apply(const B& arg1, const B& arg2) const \
80 return (arg1 OP arg2); \
86 DEFINE_COMPLEX_OVERLOAD(+);
87 DEFINE_COMPLEX_OVERLOAD(-);
88 DEFINE_COMPLEX_OVERLOAD(*);
89 DEFINE_COMPLEX_OVERLOAD(/);
90 DEFINE_COMPLEX_OVERLOAD(%);
91 DEFINE_COMPLEX_OVERLOAD(||);
92 DEFINE_COMPLEX_OVERLOAD(&&);
93 DEFINE_COMPLEX_OVERLOAD(|);
94 DEFINE_COMPLEX_OVERLOAD(&);
95 DEFINE_COMPLEX_OVERLOAD(^);
96 DEFINE_COMPLEX_OVERLOAD(<<);
97 DEFINE_COMPLEX_OVERLOAD(>>);
98 DEFINE_COMPLEX_OVERLOAD(<);
99 DEFINE_COMPLEX_OVERLOAD(<=);
100 DEFINE_COMPLEX_OVERLOAD(>);
101 DEFINE_COMPLEX_OVERLOAD(>=);
102 DEFINE_COMPLEX_OVERLOAD(==);
103 DEFINE_COMPLEX_OVERLOAD(!=);
105 UNARY_OPERATOR_FUNCTOR(identity, +);
106 UNARY_OPERATOR_FUNCTOR(negate, -);
107 BINARY_OPERATOR_FUNCTOR(plus, +);
108 BINARY_OPERATOR_FUNCTOR(minus, -);
109 BINARY_OPERATOR_FUNCTOR(multiplies, *);
110 BINARY_OPERATOR_FUNCTOR(divides, /);
111 BINARY_OPERATOR_FUNCTOR(modulus, %);
112 BINARY_OPERATOR_FUNCTOR(logical_or, ||);
113 BINARY_OPERATOR_FUNCTOR(logical_and, &&);
114 UNARY_OPERATOR_FUNCTOR(logical_not, !);
115 BINARY_OPERATOR_FUNCTOR(bitwise_or, |);
116 BINARY_OPERATOR_FUNCTOR(bitwise_and, &);
117 BINARY_OPERATOR_FUNCTOR(bitwise_xor, ^);
118 UNARY_OPERATOR_FUNCTOR(bitwise_not, ~);
121 BINARY_OPERATOR_FUNCTOR(
less, <);
123 BINARY_OPERATOR_FUNCTOR(
greater, >);
125 BINARY_OPERATOR_FUNCTOR(equal_to, ==);
126 BINARY_OPERATOR_FUNCTOR(not_equal_to, !=);
128 struct conditional_ternary
133 template <
class B,
class A1,
class A2>
134 constexpr auto operator()(
const B& cond,
const A1& v1,
const A2& v2)
const noexcept
136 return xtl::select(cond, v1, v2);
140 constexpr B simd_apply(
const get_batch_bool<B>& t1,
const B& t2,
const B& t3)
const noexcept
142 return xt_simd::select(t1, t2, t3);
151 using result_type = R;
154 constexpr result_type operator()(
const A1& arg)
const
156 return static_cast<R
>(
arg);
168 template <
class Tag,
class F,
class... E>
169 struct select_xfunction_expression;
171 template <
class F,
class... E>
172 struct select_xfunction_expression<xtensor_expression_tag, F, E...>
174 using type = xfunction<F, E...>;
177 template <
class F,
class... E>
178 struct select_xfunction_expression<xoptional_expression_tag, F, E...>
180 using type = xfunction<F, E...>;
183 template <
class Tag,
class F,
class... E>
184 using select_xfunction_expression_t =
typename select_xfunction_expression<Tag, F, E...>::type;
186 template <
class F,
class... E>
187 struct xfunction_type
189 using expression_tag = xexpression_tag_t<E...>;
190 using functor_type = F;
191 using type = select_xfunction_expression_t<expression_tag, functor_type, const_xclosure_t<E>...>;
194 template <
class F,
class... E>
195 inline auto make_xfunction(E&&... e)
noexcept
197 using function_type = xfunction_type<F, E...>;
198 using functor_type =
typename function_type::functor_type;
199 using type =
typename function_type::type;
200 return type(functor_type(), std::forward<E>(e)...);
207 template <
class F,
class... E>
208 using xfunction_type_t =
typename std::
209 enable_if_t<has_xexpression<std::decay_t<E>...>::value, xfunction_type<F, E...>>::type;
212#undef UNARY_OPERATOR_FUNCTOR
213#undef BINARY_OPERATOR_FUNCTOR
233 inline auto operator+(E&&
e)
noexcept -> detail::xfunction_type_t<detail::identity, E>
235 return detail::make_xfunction<detail::identity>(std::forward<E>(
e));
248 inline auto operator-(E&&
e)
noexcept -> detail::xfunction_type_t<detail::negate, E>
250 return detail::make_xfunction<detail::negate>(std::forward<E>(
e));
263 template <
class E1,
class E2>
264 inline auto operator+(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::plus, E1, E2>
266 return detail::make_xfunction<detail::plus>(std::forward<E1>(
e1), std::forward<E2>(
e2));
279 template <
class E1,
class E2>
280 inline auto operator-(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::minus, E1, E2>
282 return detail::make_xfunction<detail::minus>(std::forward<E1>(
e1), std::forward<E2>(
e2));
295 template <
class E1,
class E2>
296 inline auto operator*(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::multiplies, E1, E2>
298 return detail::make_xfunction<detail::multiplies>(std::forward<E1>(
e1), std::forward<E2>(
e2));
311 template <
class E1,
class E2>
312 inline auto operator/(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::divides, E1, E2>
314 return detail::make_xfunction<detail::divides>(std::forward<E1>(
e1), std::forward<E2>(
e2));
327 template <
class E1,
class E2>
328 inline auto operator%(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::modulus, E1, E2>
330 return detail::make_xfunction<detail::modulus>(std::forward<E1>(
e1), std::forward<E2>(
e2));
347 template <
class E1,
class E2>
348 inline auto operator||(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::logical_or, E1, E2>
350 return detail::make_xfunction<detail::logical_or>(std::forward<E1>(
e1), std::forward<E2>(
e2));
363 template <
class E1,
class E2>
364 inline auto operator&&(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::logical_and, E1, E2>
366 return detail::make_xfunction<detail::logical_and>(std::forward<E1>(
e1), std::forward<E2>(
e2));
379 inline auto operator!(E&&
e)
noexcept -> detail::xfunction_type_t<detail::logical_not, E>
381 return detail::make_xfunction<detail::logical_not>(std::forward<E>(
e));
398 template <
class E1,
class E2>
399 inline auto operator&(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::bitwise_and, E1, E2>
401 return detail::make_xfunction<detail::bitwise_and>(std::forward<E1>(
e1), std::forward<E2>(
e2));
414 template <
class E1,
class E2>
415 inline auto operator|(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::bitwise_or, E1, E2>
417 return detail::make_xfunction<detail::bitwise_or>(std::forward<E1>(
e1), std::forward<E2>(
e2));
430 template <
class E1,
class E2>
431 inline auto operator^(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::bitwise_xor, E1, E2>
433 return detail::make_xfunction<detail::bitwise_xor>(std::forward<E1>(
e1), std::forward<E2>(
e2));
446 inline auto operator~(E&&
e)
noexcept -> detail::xfunction_type_t<detail::bitwise_not, E>
448 return detail::make_xfunction<detail::bitwise_not>(std::forward<E>(
e));
461 template <
class E1,
class E2>
462 inline auto left_shift(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::left_shift, E1, E2>
464 return detail::make_xfunction<detail::left_shift>(std::forward<E1>(
e1), std::forward<E2>(
e2));
477 template <
class E1,
class E2>
478 inline auto right_shift(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::right_shift, E1, E2>
480 return detail::make_xfunction<detail::right_shift>(std::forward<E1>(
e1), std::forward<E2>(
e2));
487 template <
class F,
class E1,
class E2>
488 struct shift_function_getter
490 using type = xfunction_type_t<F, E1, E2>;
493 template <
bool B,
class T>
494 struct eval_enable_if
496 using type =
typename T::type;
500 struct eval_enable_if<false, T>
504 template <
bool B,
class T>
505 using eval_enable_if_t =
typename eval_enable_if<B, T>::type;
507 template <
class F,
class E1,
class E2>
508 using shift_return_type_t = eval_enable_if_t<
509 is_xexpression<std::decay_t<E1>>::value,
510 shift_function_getter<F, E1, E2>>;
524 template <
class E1,
class E2>
542 template <
class E1,
class E2>
562 template <
class E1,
class E2>
563 inline auto operator<(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::less, E1, E2>
565 return detail::make_xfunction<detail::less>(std::forward<E1>(
e1), std::forward<E2>(
e2));
578 template <
class E1,
class E2>
579 inline auto operator<=(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::less_equal, E1, E2>
581 return detail::make_xfunction<detail::less_equal>(std::forward<E1>(
e1), std::forward<E2>(
e2));
594 template <
class E1,
class E2>
595 inline auto operator>(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::greater, E1, E2>
597 return detail::make_xfunction<detail::greater>(std::forward<E1>(
e1), std::forward<E2>(
e2));
610 template <
class E1,
class E2>
612 -> detail::xfunction_type_t<detail::greater_equal, E1, E2>
614 return detail::make_xfunction<detail::greater_equal>(std::forward<E1>(
e1), std::forward<E2>(
e2));
628 template <
class E1,
class E2>
629 inline std::enable_if_t<xoptional_comparable<E1, E2>::value,
bool>
632 const E1&
de1 =
e1.derived_cast();
633 const E2&
de2 =
e2.derived_cast();
634 bool res =
de1.dimension() ==
de2.dimension()
635 && std::equal(
de1.shape().begin(),
de1.shape().end(),
de2.shape().begin());
657 template <
class E1,
class E2>
673 template <
class E1,
class E2>
674 inline auto equal(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::equal_to, E1, E2>
676 return detail::make_xfunction<detail::equal_to>(std::forward<E1>(
e1), std::forward<E2>(
e2));
689 template <
class E1,
class E2>
690 inline auto not_equal(
E1&&
e1,
E2&&
e2)
noexcept -> detail::xfunction_type_t<detail::not_equal_to, E1, E2>
692 return detail::make_xfunction<detail::not_equal_to>(std::forward<E1>(
e1), std::forward<E2>(
e2));
706 template <
class E1,
class E2>
707 inline auto less(
E1&&
e1,
E2&&
e2)
noexcept ->
decltype(std::forward<E1>(
e1) < std::forward<E2>(
e2))
709 return std::forward<E1>(
e1) < std::forward<E2>(
e2);
723 template <
class E1,
class E2>
726 return std::forward<E1>(
e1) <= std::forward<E2>(
e2);
740 template <
class E1,
class E2>
743 return std::forward<E1>(
e1) > std::forward<E2>(
e2);
757 template <
class E1,
class E2>
759 ->
decltype(std::forward<E1>(
e1) >= std::forward<E2>(
e2))
761 return std::forward<E1>(
e1) >= std::forward<E2>(
e2);
776 template <
class E1,
class E2,
class E3>
778 -> detail::xfunction_type_t<detail::conditional_ternary, E1, E2, E3>
780 return detail::make_xfunction<detail::conditional_ternary>(
781 std::forward<E1>(
e1),
782 std::forward<E2>(
e2),
789 template <layout_type L>
790 struct next_idx_impl;
795 template <
class S,
class I>
796 inline auto operator()(
const S& shape, I& idx)
798 for (std::size_t j = shape.size(); j > 0; --j)
800 std::size_t i = j - 1;
801 if (idx[i] >= shape[i] - 1)
819 template <
class S,
class I>
820 inline auto operator()(
const S& shape, I& idx)
822 for (std::size_t i = 0; i < shape.size(); ++i)
824 if (idx[i] >= shape[i] - 1)
839 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL,
class S,
class I>
840 inline auto next_idx(
const S& shape, I& idx)
842 next_idx_impl<L> nii;
843 return nii(shape, idx);
858 auto shape = arr.shape();
860 using size_type =
typename T::size_type;
862 auto idx = xtl::make_sequence<index_type>(arr.dimension(), 0);
863 std::vector<std::vector<size_type>>
indices(arr.dimension());
866 for (size_type
i = 0;
i <
total_size;
i++, detail::next_idx(shape, idx))
868 if (arr.element(std::begin(idx), std::end(idx)))
870 for (std::size_t
n = 0;
n <
indices.size(); ++
n)
904 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL,
class T>
907 auto shape = arr.shape();
909 using size_type =
typename T::size_type;
911 auto idx = xtl::make_sequence<index_type>(arr.dimension(), 0);
912 std::vector<index_type>
indices;
917 if (arr.element(std::begin(idx), std::end(idx)))
938 using xtype = std::decay_t<E>;
939 using value_type =
typename xtype::value_type;
943 [](
const value_type&
el)
962 using xtype = std::decay_t<E>;
963 using value_type =
typename xtype::value_type;
967 [](
const value_type&
el)
989 template <
class R,
class E>
990 inline auto cast(E&&
e)
noexcept -> detail::xfunction_type_t<typename detail::cast<R>::functor, E>
992 return detail::make_xfunction<typename detail::cast<R>::functor>(std::forward<E>(
e));
auto operator+(E &&e) noexcept -> detail::xfunction_type_t< detail::identity, E >
Identity.
auto operator/(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::divides, E1, E2 >
Division.
auto operator%(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::modulus, E1, E2 >
Modulus.
auto operator-(E &&e) noexcept -> detail::xfunction_type_t< detail::negate, E >
Opposite.
auto operator*(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::multiplies, E1, E2 >
Multiplication.
auto operator&(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_and, E1, E2 >
Bitwise and.
auto left_shift(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::left_shift, E1, E2 >
Bitwise left shift.
auto right_shift(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::right_shift, E1, E2 >
Bitwise left shift.
auto operator~(E &&e) noexcept -> detail::xfunction_type_t< detail::bitwise_not, E >
Bitwise not.
auto operator^(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_xor, E1, E2 >
Bitwise xor.
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto operator>>(E1 &&e1, E2 &&e2) -> detail::shift_return_type_t< detail::right_shift, E1, E2 >
Bitwise right shift.
auto cast(E &&e) noexcept -> detail::xfunction_type_t< typename detail::cast< R >::functor, E >
Element-wise static_cast.
auto not_equal(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::not_equal_to, E1, E2 >
Element-wise inequality.
auto less(E1 &&e1, E2 &&e2) noexcept -> decltype(std::forward< E1 >(e1)< std::forward< E2 >(e2))
Lesser than.
auto equal(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::equal_to, E1, E2 >
Element-wise equality.
auto greater_equal(E1 &&e1, E2 &&e2) noexcept -> decltype(std::forward< E1 >(e1) >=std::forward< E2 >(e2))
Greater or equal.
auto greater(E1 &&e1, E2 &&e2) noexcept -> decltype(std::forward< E1 >(e1) > std::forward< E2 >(e2))
Greater than.
auto less_equal(E1 &&e1, E2 &&e2) noexcept -> decltype(std::forward< E1 >(e1)<=std::forward< E2 >(e2))
Lesser or equal.
auto operator!(E &&e) noexcept -> detail::xfunction_type_t< detail::logical_not, E >
Not.
auto argwhere(const T &arr)
return vector of indices where arr is not zero
auto operator&&(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::logical_and, E1, E2 >
And.
auto nonzero(const T &arr)
return vector of indices where T is not zero
auto operator||(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::logical_or, E1, E2 >
Or.
auto where(E1 &&e1, E2 &&e2, E3 &&e3) noexcept -> detail::xfunction_type_t< detail::conditional_ternary, E1, E2, E3 >
Ternary selection.
auto arg(E &&e) noexcept
Calculates the phase angle (in radians) elementwise for the complex numbers in e.
standard mathematical functions for xexpressions
auto all() noexcept
Returns a slice representing a full dimension, to be used as an argument of view function.
bool operator==(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks equality of the iterators.
bool operator!=(const xaxis_iterator< CT > &lhs, const xaxis_iterator< CT > &rhs)
Checks inequality of the iterators.