10#ifndef XTENSOR_OPERATION_HPP
11#define XTENSOR_OPERATION_HPP
16#include <xtl/xsequence.hpp>
18#include "../containers/xscalar.hpp"
19#include "../core/xfunction.hpp"
20#include "../core/xstrides.hpp"
21#include "../views/xstrided_view.hpp"
30#define UNARY_OPERATOR_FUNCTOR(NAME, OP) \
34 constexpr auto operator()(const A1& arg) const \
39 constexpr auto simd_apply(const B& arg) const \
45#define DEFINE_COMPLEX_OVERLOAD(OP) \
46 template <class T1, class T2, XTL_REQUIRES(std::negation<std::is_same<T1, T2>>)> \
47 constexpr auto operator OP(const std::complex<T1>& arg1, const std::complex<T2>& arg2) \
49 using result_type = typename xtl::promote_type_t<std::complex<T1>, std::complex<T2>>; \
50 return (result_type(arg1) OP result_type(arg2)); \
53 template <class T1, class T2, XTL_REQUIRES(std::negation<std::is_same<T1, T2>>)> \
54 constexpr auto operator OP(const T1& arg1, const std::complex<T2>& arg2) \
56 using result_type = typename xtl::promote_type_t<T1, std::complex<T2>>; \
57 return (result_type(arg1) OP result_type(arg2)); \
60 template <class T1, class T2, XTL_REQUIRES(std::negation<std::is_same<T1, T2>>)> \
61 constexpr auto operator OP(const std::complex<T1>& arg1, const T2& arg2) \
63 using result_type = typename xtl::promote_type_t<std::complex<T1>, T2>; \
64 return (result_type(arg1) OP result_type(arg2)); \
67#define BINARY_OPERATOR_FUNCTOR(NAME, OP) \
70 template <class T1, class T2> \
71 constexpr auto operator()(T1&& arg1, T2&& arg2) const \
73 using xt::detail::operator OP; \
74 return (std::forward<T1>(arg1) OP std::forward<T2>(arg2)); \
77 constexpr auto simd_apply(const B& arg1, const B& arg2) const \
79 return (arg1 OP arg2); \
85 DEFINE_COMPLEX_OVERLOAD(+);
86 DEFINE_COMPLEX_OVERLOAD(-);
87 DEFINE_COMPLEX_OVERLOAD(*);
88 DEFINE_COMPLEX_OVERLOAD(/);
89 DEFINE_COMPLEX_OVERLOAD(%);
90 DEFINE_COMPLEX_OVERLOAD(||);
91 DEFINE_COMPLEX_OVERLOAD(&&);
92 DEFINE_COMPLEX_OVERLOAD(|);
93 DEFINE_COMPLEX_OVERLOAD(&);
94 DEFINE_COMPLEX_OVERLOAD(^);
95 DEFINE_COMPLEX_OVERLOAD(<<);
96 DEFINE_COMPLEX_OVERLOAD(>>);
97 DEFINE_COMPLEX_OVERLOAD(<);
98 DEFINE_COMPLEX_OVERLOAD(<=);
99 DEFINE_COMPLEX_OVERLOAD(>);
100 DEFINE_COMPLEX_OVERLOAD(>=);
101 DEFINE_COMPLEX_OVERLOAD(==);
102 DEFINE_COMPLEX_OVERLOAD(!=);
104 UNARY_OPERATOR_FUNCTOR(identity, +);
105 UNARY_OPERATOR_FUNCTOR(negate, -);
106 BINARY_OPERATOR_FUNCTOR(plus, +);
107 BINARY_OPERATOR_FUNCTOR(minus, -);
108 BINARY_OPERATOR_FUNCTOR(multiplies, *);
109 BINARY_OPERATOR_FUNCTOR(divides, /);
110 BINARY_OPERATOR_FUNCTOR(modulus, %);
111 BINARY_OPERATOR_FUNCTOR(logical_or, ||);
112 BINARY_OPERATOR_FUNCTOR(logical_and, &&);
113 UNARY_OPERATOR_FUNCTOR(logical_not, !);
114 BINARY_OPERATOR_FUNCTOR(bitwise_or, |);
115 BINARY_OPERATOR_FUNCTOR(bitwise_and, &);
116 BINARY_OPERATOR_FUNCTOR(bitwise_xor, ^);
117 UNARY_OPERATOR_FUNCTOR(bitwise_not, ~);
118 BINARY_OPERATOR_FUNCTOR(left_shift, <<);
119 BINARY_OPERATOR_FUNCTOR(right_shift, >>);
120 BINARY_OPERATOR_FUNCTOR(less, <);
121 BINARY_OPERATOR_FUNCTOR(less_equal, <=);
122 BINARY_OPERATOR_FUNCTOR(greater, >);
123 BINARY_OPERATOR_FUNCTOR(greater_equal, >=);
124 BINARY_OPERATOR_FUNCTOR(equal_to, ==);
125 BINARY_OPERATOR_FUNCTOR(not_equal_to, !=);
127 struct conditional_ternary
130 using get_batch_bool =
typename xt_simd::simd_traits<typename xt_simd::revert_simd_traits<B>::type>::bool_type;
132 template <
class B,
class A1,
class A2>
133 constexpr auto operator()(
const B& cond,
const A1& v1,
const A2& v2)
const noexcept
135 return xtl::select(cond, v1, v2);
139 constexpr B simd_apply(
const get_batch_bool<B>& t1,
const B& t2,
const B& t3)
const noexcept
141 return xt_simd::select(t1, t2, t3);
150 using result_type = R;
153 constexpr result_type operator()(
const A1&
arg)
const
155 return static_cast<R
>(
arg);
167 template <
class Tag,
class F,
class... E>
168 struct select_xfunction_expression;
170 template <
class F,
class... E>
171 struct select_xfunction_expression<xtensor_expression_tag, F, E...>
173 using type = xfunction<F, E...>;
176 template <
class F,
class... E>
177 struct select_xfunction_expression<xoptional_expression_tag, F, E...>
179 using type = xfunction<F, E...>;
182 template <
class Tag,
class F,
class... E>
183 using select_xfunction_expression_t =
typename select_xfunction_expression<Tag, F, E...>::type;
185 template <
class F,
class... E>
186 struct xfunction_type
188 using expression_tag = xexpression_tag_t<E...>;
189 using functor_type = F;
190 using type = select_xfunction_expression_t<expression_tag, functor_type, const_xclosure_t<E>...>;
193 template <
class F,
class... E>
194 inline auto make_xfunction(E&&... e)
noexcept
196 using function_type = xfunction_type<F, E...>;
197 using functor_type =
typename function_type::functor_type;
198 using type =
typename function_type::type;
199 return type(functor_type(), std::forward<E>(e)...);
206 template <
class F,
class... E>
207 using xfunction_type_t =
typename std::
208 enable_if_t<has_xexpression<std::decay_t<E>...>::value, xfunction_type<F, E...>>::type;
211#undef UNARY_OPERATOR_FUNCTOR
212#undef BINARY_OPERATOR_FUNCTOR
232 inline auto operator+(E&& e)
noexcept -> detail::xfunction_type_t<detail::identity, E>
234 return detail::make_xfunction<detail::identity>(std::forward<E>(e));
247 inline auto operator-(E&& e)
noexcept -> detail::xfunction_type_t<detail::negate, E>
249 return detail::make_xfunction<detail::negate>(std::forward<E>(e));
262 template <
class E1,
class E2>
263 inline auto operator+(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::plus, E1, E2>
265 return detail::make_xfunction<detail::plus>(std::forward<E1>(e1), std::forward<E2>(e2));
278 template <
class E1,
class E2>
279 inline auto operator-(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::minus, E1, E2>
281 return detail::make_xfunction<detail::minus>(std::forward<E1>(e1), std::forward<E2>(e2));
294 template <
class E1,
class E2>
295 inline auto operator*(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::multiplies, E1, E2>
297 return detail::make_xfunction<detail::multiplies>(std::forward<E1>(e1), std::forward<E2>(e2));
310 template <
class E1,
class E2>
311 inline auto operator/(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::divides, E1, E2>
313 return detail::make_xfunction<detail::divides>(std::forward<E1>(e1), std::forward<E2>(e2));
326 template <
class E1,
class E2>
327 inline auto operator%(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::modulus, E1, E2>
329 return detail::make_xfunction<detail::modulus>(std::forward<E1>(e1), std::forward<E2>(e2));
346 template <
class E1,
class E2>
347 inline auto operator||(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::logical_or, E1, E2>
349 return detail::make_xfunction<detail::logical_or>(std::forward<E1>(e1), std::forward<E2>(e2));
362 template <
class E1,
class E2>
363 inline auto operator&&(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::logical_and, E1, E2>
365 return detail::make_xfunction<detail::logical_and>(std::forward<E1>(e1), std::forward<E2>(e2));
378 inline auto operator!(E&& e)
noexcept -> detail::xfunction_type_t<detail::logical_not, E>
380 return detail::make_xfunction<detail::logical_not>(std::forward<E>(e));
397 template <
class E1,
class E2>
398 inline auto operator&(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::bitwise_and, E1, E2>
400 return detail::make_xfunction<detail::bitwise_and>(std::forward<E1>(e1), std::forward<E2>(e2));
413 template <
class E1,
class E2>
414 inline auto operator|(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::bitwise_or, E1, E2>
416 return detail::make_xfunction<detail::bitwise_or>(std::forward<E1>(e1), std::forward<E2>(e2));
429 template <
class E1,
class E2>
430 inline auto operator^(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::bitwise_xor, E1, E2>
432 return detail::make_xfunction<detail::bitwise_xor>(std::forward<E1>(e1), std::forward<E2>(e2));
445 inline auto operator~(E&& e)
noexcept -> detail::xfunction_type_t<detail::bitwise_not, E>
447 return detail::make_xfunction<detail::bitwise_not>(std::forward<E>(e));
460 template <
class E1,
class E2>
461 inline auto left_shift(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::left_shift, E1, E2>
463 return detail::make_xfunction<detail::left_shift>(std::forward<E1>(e1), std::forward<E2>(e2));
476 template <
class E1,
class E2>
477 inline auto right_shift(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::right_shift, E1, E2>
479 return detail::make_xfunction<detail::right_shift>(std::forward<E1>(e1), std::forward<E2>(e2));
486 template <
class F,
class E1,
class E2>
487 struct shift_function_getter
489 using type = xfunction_type_t<F, E1, E2>;
492 template <
bool B,
class T>
493 struct eval_enable_if
495 using type =
typename T::type;
499 struct eval_enable_if<false, T>
503 template <
bool B,
class T>
504 using eval_enable_if_t =
typename eval_enable_if<B, T>::type;
506 template <
class F,
class E1,
class E2>
507 using shift_return_type_t = eval_enable_if_t<
508 is_xexpression<std::decay_t<E1>>::value,
509 shift_function_getter<F, E1, E2>>;
523 template <
class E1,
class E2>
525 -> detail::shift_return_type_t<detail::left_shift, E1, E2>
527 return left_shift(std::forward<E1>(e1), std::forward<E2>(e2));
541 template <
class E1,
class E2>
542 inline auto operator>>(E1&& e1, E2&& e2) -> detail::shift_return_type_t<detail::right_shift, E1, E2>
544 return right_shift(std::forward<E1>(e1), std::forward<E2>(e2));
561 template <
class E1,
class E2>
562 inline auto operator<(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::less, E1, E2>
564 return detail::make_xfunction<detail::less>(std::forward<E1>(e1), std::forward<E2>(e2));
577 template <
class E1,
class E2>
578 inline auto operator<=(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::less_equal, E1, E2>
580 return detail::make_xfunction<detail::less_equal>(std::forward<E1>(e1), std::forward<E2>(e2));
593 template <
class E1,
class E2>
594 inline auto operator>(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::greater, E1, E2>
596 return detail::make_xfunction<detail::greater>(std::forward<E1>(e1), std::forward<E2>(e2));
609 template <
class E1,
class E2>
610 inline auto operator>=(E1&& e1, E2&& e2)
noexcept
611 -> detail::xfunction_type_t<detail::greater_equal, E1, E2>
613 return detail::make_xfunction<detail::greater_equal>(std::forward<E1>(e1), std::forward<E2>(e2));
627 template <
class E1,
class E2>
628 inline std::enable_if_t<xoptional_comparable<E1, E2>::value,
bool>
633 bool res = de1.dimension() == de2.dimension()
634 && std::equal(de1.shape().begin(), de1.shape().end(), de2.shape().begin());
635 auto iter1 = de1.begin();
636 auto iter2 = de2.begin();
637 auto iter_end = de1.end();
638 while (res && iter1 != iter_end)
640 res = (*iter1++ == *iter2++);
656 template <
class E1,
class E2>
672 template <
class E1,
class E2>
673 inline auto equal(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::equal_to, E1, E2>
675 return detail::make_xfunction<detail::equal_to>(std::forward<E1>(e1), std::forward<E2>(e2));
688 template <
class E1,
class E2>
689 inline auto not_equal(E1&& e1, E2&& e2)
noexcept -> detail::xfunction_type_t<detail::not_equal_to, E1, E2>
691 return detail::make_xfunction<detail::not_equal_to>(std::forward<E1>(e1), std::forward<E2>(e2));
705 template <
class E1,
class E2>
706 inline auto less(E1&& e1, E2&& e2)
noexcept ->
decltype(std::forward<E1>(e1) < std::forward<E2>(e2))
708 return std::forward<E1>(e1) < std::forward<E2>(e2);
722 template <
class E1,
class E2>
723 inline auto less_equal(E1&& e1, E2&& e2)
noexcept ->
decltype(std::forward<E1>(e1) <= std::forward<E2>(e2))
725 return std::forward<E1>(e1) <= std::forward<E2>(e2);
739 template <
class E1,
class E2>
740 inline auto greater(E1&& e1, E2&& e2)
noexcept ->
decltype(std::forward<E1>(e1) > std::forward<E2>(e2))
742 return std::forward<E1>(e1) > std::forward<E2>(e2);
756 template <
class E1,
class E2>
758 ->
decltype(std::forward<E1>(e1) >= std::forward<E2>(e2))
760 return std::forward<E1>(e1) >= std::forward<E2>(e2);
775 template <
class E1,
class E2,
class E3>
776 inline auto where(E1&& e1, E2&& e2, E3&& e3)
noexcept
777 -> detail::xfunction_type_t<detail::conditional_ternary, E1, E2, E3>
779 return detail::make_xfunction<detail::conditional_ternary>(
780 std::forward<E1>(e1),
781 std::forward<E2>(e2),
788 template <layout_type L>
789 struct next_idx_impl;
794 template <
class S,
class I>
795 inline auto operator()(
const S& shape, I& idx)
797 for (std::size_t j = shape.size(); j > 0; --j)
799 std::size_t i = j - 1;
800 if (idx[i] >= shape[i] - 1)
818 template <
class S,
class I>
819 inline auto operator()(
const S& shape, I& idx)
821 for (std::size_t i = 0; i < shape.size(); ++i)
823 if (idx[i] >= shape[i] - 1)
838 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL,
class S,
class I>
839 inline auto next_idx(
const S& shape, I& idx)
841 next_idx_impl<L> nii;
842 return nii(shape, idx);
857 auto shape = arr.shape();
858 using index_type = xindex_type_t<typename T::shape_type>;
859 using size_type =
typename T::size_type;
861 auto idx = xtl::make_sequence<index_type>(arr.dimension(), 0);
862 std::vector<std::vector<size_type>> indices(arr.dimension());
864 size_type total_size = compute_size(shape);
865 for (size_type i = 0; i < total_size; i++, detail::next_idx(shape, idx))
867 if (arr.element(std::begin(idx), std::end(idx)))
869 for (std::size_t n = 0; n < indices.size(); ++n)
871 indices.at(n).push_back(idx[n]);
888 inline auto where(
const T& condition)
903 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL,
class T>
906 auto shape = arr.shape();
907 using index_type = xindex_type_t<typename T::shape_type>;
908 using size_type =
typename T::size_type;
910 auto idx = xtl::make_sequence<index_type>(arr.dimension(), 0);
911 std::vector<index_type> indices;
913 size_type total_size = compute_size(shape);
914 for (size_type i = 0; i < total_size; i++, detail::next_idx<L>(shape, idx))
916 if (arr.element(std::begin(idx), std::end(idx)))
918 indices.push_back(idx);
937 using xtype = std::decay_t<E>;
938 using value_type =
typename xtype::value_type;
942 [](
const value_type& el)
961 using xtype = std::decay_t<E>;
962 using value_type =
typename xtype::value_type;
966 [](
const value_type& el)
988 template <
class R,
class E>
989 inline auto cast(E&& e)
noexcept -> detail::xfunction_type_t<typename detail::cast<R>::functor, E>
991 return detail::make_xfunction<typename detail::cast<R>::functor>(std::forward<E>(e));
Base class for xexpressions.
derived_type & derived_cast() &noexcept
Returns a reference to the actual derived type of the xexpression.
auto operator+(E &&e) noexcept -> detail::xfunction_type_t< detail::identity, E >
Identity.
auto operator/(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::divides, E1, E2 >
Division.
auto operator%(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::modulus, E1, E2 >
Modulus.
auto operator-(E &&e) noexcept -> detail::xfunction_type_t< detail::negate, E >
Opposite.
auto operator*(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::multiplies, E1, E2 >
Multiplication.
auto operator&(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_and, E1, E2 >
Bitwise and.
auto left_shift(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::left_shift, E1, E2 >
Bitwise left shift.
auto right_shift(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::right_shift, E1, E2 >
Bitwise left shift.
auto operator~(E &&e) noexcept -> detail::xfunction_type_t< detail::bitwise_not, E >
Bitwise not.
auto operator^(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_xor, E1, E2 >
Bitwise xor.
auto operator<<(E1 &&e1, E2 &&e2) noexcept -> detail::shift_return_type_t< detail::left_shift, E1, E2 >
Bitwise left shift.
auto operator|(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::bitwise_or, E1, E2 >
Bitwise or.
auto operator>>(E1 &&e1, E2 &&e2) -> detail::shift_return_type_t< detail::right_shift, E1, E2 >
Bitwise right shift.
auto cast(E &&e) noexcept -> detail::xfunction_type_t< typename detail::cast< R >::functor, E >
Element-wise static_cast.
auto not_equal(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::not_equal_to, E1, E2 >
Element-wise inequality.
auto less(E1 &&e1, E2 &&e2) noexcept -> decltype(std::forward< E1 >(e1)< std::forward< E2 >(e2))
Lesser than.
auto equal(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::equal_to, E1, E2 >
Element-wise equality.
auto greater_equal(E1 &&e1, E2 &&e2) noexcept -> decltype(std::forward< E1 >(e1) >=std::forward< E2 >(e2))
Greater or equal.
auto greater(E1 &&e1, E2 &&e2) noexcept -> decltype(std::forward< E1 >(e1) > std::forward< E2 >(e2))
Greater than.
auto less_equal(E1 &&e1, E2 &&e2) noexcept -> decltype(std::forward< E1 >(e1)<=std::forward< E2 >(e2))
Lesser or equal.
auto operator!(E &&e) noexcept -> detail::xfunction_type_t< detail::logical_not, E >
Not.
auto argwhere(const T &arr)
return vector of indices where arr is not zero
auto operator&&(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::logical_and, E1, E2 >
And.
auto nonzero(const T &arr)
return vector of indices where T is not zero
auto operator||(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< detail::logical_or, E1, E2 >
Or.
auto where(E1 &&e1, E2 &&e2, E3 &&e3) noexcept -> detail::xfunction_type_t< detail::conditional_ternary, E1, E2, E3 >
Ternary selection.
auto arg(E &&e) noexcept
Calculates the phase angle (in radians) elementwise for the complex numbers in e.
standard mathematical functions for xexpressions
auto all() noexcept
Returns a slice representing a full dimension, to be used as an argument of view function.