10#ifndef XTENSOR_MANIPULATION_HPP
11#define XTENSOR_MANIPULATION_HPP
16#include <xtl/xcompare.hpp>
17#include <xtl/xsequence.hpp>
19#include "../core/xtensor_config.hpp"
20#include "../generators/xbuilder.hpp"
21#include "../utils/xexception.hpp"
22#include "../utils/xutils.hpp"
23#include "../views/xrepeat.hpp"
24#include "../views/xstrided_view.hpp"
25#include "xtl_concepts.hpp"
33 namespace check_policy
47 template <
class E,
class S,
class Tag = check_policy::none>
48 auto transpose(E&& e, S&& permutation, Tag check_policy = Tag());
51 auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2);
53 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL,
class E>
56 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL,
class E>
59 template <layout_type L,
class T>
63 auto trim_zeros(E&& e,
const std::string& direction =
"fb");
68 template <
class E, xtl::non_
integral_concept S,
class Tag = check_policy::none>
69 auto squeeze(E&& e, S&& axis, Tag check_policy = Tag());
74 template <std::
size_t N,
class E>
87 auto split(E& e, std::size_t n, std::size_t axis = 0);
90 auto hsplit(E& e, std::size_t n);
93 auto vsplit(E& e, std::size_t n);
99 auto flip(E&& e, std::size_t axis);
101 template <std::ptrdiff_t N = 1,
class E>
102 auto rot90(E&& e,
const std::array<std::ptrdiff_t, 2>& axes = {0, 1});
105 auto roll(E&& e, std::ptrdiff_t shift);
108 auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis);
111 auto repeat(E&& e, std::size_t repeats, std::size_t axis);
114 auto repeat(E&& e,
const std::vector<std::size_t>& repeats, std::size_t axis);
117 auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis);
143 XTENSOR_THROW(transpose_error,
"cannot compute transposed layout of dynamic layout");
145 return transpose_layout_noexcept(l);
148 template <
class E,
class S>
149 inline auto transpose_impl(E&& e, S&& permutation, check_policy::none)
151 if (std::size(permutation) != e.dimension())
153 XTENSOR_THROW(transpose_error,
"Permutation does not have the same size as shape");
157 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
158 shape_type temp_shape;
159 resize_container(temp_shape, e.shape().size());
161 using strides_type = get_strides_t<shape_type>;
162 strides_type temp_strides;
163 resize_container(temp_strides, e.strides().size());
165 using size_type =
typename std::decay_t<E>::size_type;
166 for (std::size_t i = 0; i < e.shape().size(); ++i)
168 if (std::size_t(permutation[i]) >= e.dimension())
170 XTENSOR_THROW(transpose_error,
"Permutation contains wrong axis");
172 size_type perm =
static_cast<size_type
>(permutation[i]);
173 temp_shape[i] = e.shape()[perm];
174 temp_strides[i] = e.strides()[perm];
178 if (std::is_sorted(std::begin(permutation), std::end(permutation)))
181 new_layout = e.layout();
183 else if (std::is_sorted(std::begin(permutation), std::end(permutation), std::greater<>()))
185 new_layout = transpose_layout_noexcept(e.layout());
190 std::move(temp_shape),
191 std::move(temp_strides),
192 get_offset<XTENSOR_DEFAULT_LAYOUT>(e),
197 template <
class E,
class S>
198 inline auto transpose_impl(E&& e, S&& permutation, check_policy::full)
201 for (std::size_t i = 0; i < std::size(permutation); ++i)
203 for (std::size_t j = i + 1; j < std::size(permutation); ++j)
205 if (permutation[i] == permutation[j])
207 XTENSOR_THROW(transpose_error,
"Permutation contains axis more than once");
211 return transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy::none());
214 template <
class E,
class S,
class X>
215 inline void compute_transposed_strides(E&& e,
const S& shape, X&
strides)
217 if constexpr (has_data_interface<std::decay_t<E>>::value)
219 std::copy(e.strides().crbegin(), e.strides().crend(),
strides.begin());
226 layout_type l = transpose_layout(XTENSOR_DEFAULT_TRAVERSAL);
227 compute_strides(shape, l,
strides);
241 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
243 resize_container(shape, e.shape().size());
244 std::copy(e.shape().crbegin(), e.shape().crend(), shape.begin());
246 get_strides_t<shape_type>
strides;
247 resize_container(
strides, e.shape().size());
248 detail::compute_transposed_strides(e, shape,
strides);
250 layout_type new_layout = detail::transpose_layout_noexcept(e.layout());
256 detail::get_offset<XTENSOR_DEFAULT_TRAVERSAL>(e),
270 template <
class E,
class S,
class Tag>
271 inline auto transpose(E&& e, S&& permutation, Tag check_policy)
273 return detail::transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy);
277 template <
class E,
class I, std::
size_t N,
class Tag = check_policy::none>
278 inline auto transpose(E&& e,
const I (&permutation)[N], Tag check_policy = Tag())
280 return detail::transpose_impl(std::forward<E>(e), permutation, check_policy);
292 inline S swapaxes_perm(std::size_t dim, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
294 const std::size_t ax1 = normalize_axis(dim, axis1);
295 const std::size_t ax2 = normalize_axis(dim, axis2);
296 auto perm = xtl::make_sequence<S>(dim, 0);
297 using id_t =
typename S::value_type;
298 std::iota(perm.begin(), perm.end(), id_t(0));
316 inline auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
318 const auto dim = e.dimension();
319 check_axis_in_dim(axis1, dim,
"Parameter axis1");
320 check_axis_in_dim(axis2, dim,
"Parameter axis2");
322 using strides_t = get_strides_t<typename std::decay_t<E>::shape_type>;
323 return transpose(std::forward<E>(e), detail::swapaxes_perm<strides_t>(dim, axis1, axis2));
333 inline S moveaxis_perm(std::size_t dim, std::ptrdiff_t src, std::ptrdiff_t dest)
335 using id_t =
typename S::value_type;
337 const std::size_t src_norm = normalize_axis(dim, src);
338 const std::size_t dest_norm = normalize_axis(dim, dest);
342 auto perm = xtl::make_sequence<S>(dim, src_norm);
344 for (id_t i = 0; xtl::cmp_less(i, dim); ++i)
346 if (xtl::cmp_equal(perm_idx, dest_norm))
348 perm[perm_idx] = src_norm;
351 if (xtl::cmp_not_equal(i, src_norm))
370 inline auto moveaxis(E&& e, std::ptrdiff_t src, std::ptrdiff_t dest)
372 const auto dim = e.dimension();
373 check_axis_in_dim(src, dim,
"Parameter src");
374 check_axis_in_dim(dest, dim,
"Parameter dest");
376 using strides_t = get_strides_t<typename std::decay_t<E>::shape_type>;
377 return xt::transpose(std::forward<E>(e), detail::moveaxis_perm<strides_t>(e.dimension(), src, dest));
386 template <
class E, layout_type L>
387 struct expression_iterator_getter
389 using iterator =
decltype(std::declval<E>().template begin<L>());
390 using const_iterator =
decltype(std::declval<E>().template cbegin<L>());
392 inline static iterator begin(E& e)
394 return e.template begin<L>();
397 inline static const_iterator cbegin(E& e)
399 return e.template cbegin<L>();
402 inline static auto size(E& e)
418 template <layout_type L,
class E>
421 using iterator =
decltype(e.template begin<L>());
422 using iterator_getter = detail::expression_iterator_getter<std::remove_reference_t<E>, L>;
423 auto size = e.size();
424 auto adaptor = make_xiterator_adaptor(std::forward<E>(e), iterator_getter());
426 using type =
xtensor_view<
decltype(adaptor), 1, layout, extension::get_expression_tag_t<E>>;
427 return type(std::move(adaptor), {size});
443 template <layout_type L,
class E>
446 return ravel<L>(std::forward<E>(e));
457 template <layout_type L,
class T>
479 XTENSOR_ASSERT_MSG(e.dimension() == 1,
"Dimension for trim_zeros has to be 1.");
481 std::ptrdiff_t begin = 0, end =
static_cast<std::ptrdiff_t
>(e.size());
483 auto find_fun = [](
const auto& i)
488 if (direction.find(
"f") != std::string::npos)
490 begin = std::find_if(e.cbegin(), e.cend(), find_fun) - e.cbegin();
493 if (direction.find(
"b") != std::string::npos && begin != end)
495 end -= std::find_if(e.crbegin(), e.crend(), find_fun) - e.crbegin();
517 dynamic_shape<std::size_t> new_shape;
518 dynamic_shape<std::ptrdiff_t> new_strides;
522 std::back_inserter(new_shape),
528 decltype(
auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
530 old_strides.cbegin(),
532 std::back_inserter(new_strides),
539 return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
544 template <
class E,
class S>
545 inline auto squeeze_impl(E&& e, S&& axis, check_policy::none)
547 std::size_t new_dim = e.dimension() - axis.size();
548 dynamic_shape<std::size_t> new_shape(new_dim);
549 dynamic_shape<std::ptrdiff_t> new_strides(new_dim);
551 decltype(
auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
553 for (std::size_t i = 0, ix = 0; i < e.dimension(); ++i)
555 if (axis.cend() == std::find(axis.cbegin(), axis.cend(), i))
557 new_shape[ix] = e.shape()[i];
558 new_strides[ix++] = old_strides[i];
562 return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
565 template <
class E,
class S>
566 inline auto squeeze_impl(E&& e, S&& axis, check_policy::full)
570 if (
static_cast<std::size_t
>(ix) > e.dimension())
572 XTENSOR_THROW(std::runtime_error,
"Axis argument to squeeze > dimension of expression");
574 if (e.shape()[
static_cast<std::size_t
>(ix)] != 1)
576 XTENSOR_THROW(std::runtime_error,
"Trying to squeeze axis != 1");
579 return squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy::none());
593 template <
class E, xtl::non_
integral_concept S,
class Tag>
594 inline auto squeeze(E&& e, S&& axis, Tag check_policy)
596 return detail::squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy);
600 template <
class E,
class I, std::
size_t N,
class Tag = check_policy::none>
601 inline auto squeeze(E&& e,
const I (&axis)[N], Tag check_policy = Tag())
603 using arr_t = std::array<I, N>;
604 return detail::squeeze_impl(
606 xtl::forward_sequence<arr_t,
decltype(axis)>(axis),
611 template <
class E,
class Tag = check_policy::none>
612 inline auto squeeze(E&& e, std::size_t axis, Tag check_policy = Tag())
614 return squeeze(std::forward<E>(e), std::array<std::size_t, 1>{axis}, check_policy);
659 template <std::
size_t N,
class E>
663 if (e.dimension() < N)
666 std::size_t end =
static_cast<std::size_t
>(std::round(
double(N - e.dimension()) /
double(N)));
734 inline auto split(E& e, std::size_t n, std::size_t axis)
736 if (axis >= e.dimension())
738 XTENSOR_THROW(std::runtime_error,
"Split along axis > dimension.");
741 std::size_t ax_sz = e.shape()[axis];
743 std::size_t step = ax_sz / n;
744 std::size_t rest = ax_sz % n;
748 XTENSOR_THROW(std::runtime_error,
"Split does not result in equal division.");
752 for (std::size_t i = 0; i < n; ++i)
754 sv[axis] =
range(i * step, (i + 1) * step);
772 return split(e, n, std::size_t(1));
787 return split(e, n, std::size_t(0));
804 using size_type =
typename std::decay_t<E>::size_type;
806 for (size_type d = 1; d < e.dimension(); ++d)
825 inline auto flip(E&& e, std::size_t axis)
827 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
830 resize_container(shape, e.shape().size());
831 std::copy(e.shape().cbegin(), e.shape().cend(), shape.begin());
833 get_strides_t<shape_type>
strides;
834 decltype(
auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
835 resize_container(
strides, old_strides.size());
836 std::copy(old_strides.cbegin(), old_strides.cend(),
strides.begin());
839 std::size_t offset =
static_cast<std::size_t
>(
840 static_cast<std::ptrdiff_t
>(e.data_offset())
841 + old_strides[axis] * (
static_cast<std::ptrdiff_t
>(e.shape()[axis]) - 1)
853 template <std::ptrdiff_t N>
860 inline auto operator()(E&& e,
const std::array<std::size_t, 2>& )
862 return std::forward<E>(e);
870 inline auto operator()(E&& e,
const std::array<std::size_t, 2>& axes)
874 dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
875 std::iota(axes_list.begin(), axes_list.end(), 0);
876 swap(axes_list[axes[0]], axes_list[axes[1]]);
878 return transpose(
flip(std::forward<E>(e), axes[1]), std::move(axes_list));
886 inline auto operator()(E&& e,
const std::array<std::size_t, 2>& axes)
888 return flip(
flip(std::forward<E>(e), axes[0]), axes[1]);
896 inline auto operator()(E&& e,
const std::array<std::size_t, 2>& axes)
900 dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
901 std::iota(axes_list.begin(), axes_list.end(), 0);
902 swap(axes_list[axes[0]], axes_list[axes[1]]);
904 return flip(
transpose(std::forward<E>(e), std::move(axes_list)), axes[1]);
920 template <std::ptrdiff_t N,
class E>
921 inline auto rot90(E&& e,
const std::array<std::ptrdiff_t, 2>& axes)
923 auto ndim =
static_cast<std::ptrdiff_t
>(e.shape().size());
925 if (axes[0] == axes[1] || std::abs(axes[0] - axes[1]) == ndim)
927 XTENSOR_THROW(std::runtime_error,
"Axes must be different");
930 auto norm_axes = forward_normalize<std::array<std::size_t, 2>>(e, axes);
931 constexpr std::ptrdiff_t n = (4 + (N % 4)) % 4;
933 return detail::rot90_impl<n>()(std::forward<E>(e), norm_axes);
954 inline auto roll(E&& e, std::ptrdiff_t shift)
957 auto flat_size = std::accumulate(
961 std::multiplies<std::size_t>()
969 std::copy(e.begin(), e.end() - shift, std::copy(e.end() - shift, e.end(), cpy.begin()));
980 template <
class To,
class From,
class S>
981 To roll(To to, From from, std::ptrdiff_t shift, std::size_t axis,
const S& shape, std::size_t M)
983 std::ptrdiff_t dim = std::ptrdiff_t(shape[M]);
984 std::ptrdiff_t offset = std::accumulate(
985 shape.begin() + M + 1,
988 std::multiplies<std::ptrdiff_t>()
990 if (shape.size() == M + 1)
994 const auto split = from + (dim - shift) * offset;
995 for (
auto iter =
split, end = from + dim * offset; iter != end; iter += offset, ++to)
999 for (
auto iter = from, end =
split; iter != end; iter += offset, ++to)
1006 for (
auto iter = from, end = from + dim * offset; iter != end; iter += offset, ++to)
1016 const auto split = from + (dim - shift) * offset;
1017 for (
auto iter =
split, end = from + dim * offset; iter != end; iter += offset)
1019 to = roll(to, iter, shift, axis, shape, M + 1);
1021 for (
auto iter = from, end =
split; iter != end; iter += offset)
1023 to = roll(to, iter, shift, axis, shape, M + 1);
1028 for (
auto iter = from, end = from + dim * offset; iter != end; iter += offset)
1030 to = roll(to, iter, shift, axis, shape, M + 1);
1051 inline auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis)
1054 const auto& shape = cpy.shape();
1055 std::size_t saxis =
static_cast<std::size_t
>(axis);
1058 axis += std::ptrdiff_t(cpy.dimension());
1061 if (saxis >= cpy.dimension() || axis < 0)
1063 XTENSOR_THROW(std::runtime_error,
"axis is no within shape dimension.");
1066 const auto axis_dim =
static_cast<std::ptrdiff_t
>(shape[saxis]);
1072 detail::roll(cpy.begin(), e.begin(), shift, saxis, shape, 0);
1082 template <
class E,
class R>
1083 inline auto make_xrepeat(E&& e, R&& r,
typename std::decay_t<E>::size_type axis)
1085 const auto casted_axis =
static_cast<typename std::decay_t<E>::size_type
>(axis);
1086 if (r.size() != e.shape(casted_axis))
1088 XTENSOR_THROW(std::invalid_argument,
"repeats must have the same size as the specified axis");
1090 return xrepeat<const_xclosure_t<E>, R>(std::forward<E>(e), std::forward<R>(r), axis);
1105 inline auto repeat(E&& e, std::size_t repeats, std::size_t axis)
1107 const auto casted_axis =
static_cast<typename std::decay_t<E>::size_type
>(axis);
1108 std::vector<std::size_t> broadcasted_repeats(e.shape(casted_axis));
1109 std::fill(broadcasted_repeats.begin(), broadcasted_repeats.end(), repeats);
1110 return repeat(std::forward<E>(e), std::move(broadcasted_repeats), axis);
1125 inline auto repeat(E&& e,
const std::vector<std::size_t>& repeats, std::size_t axis)
1127 return detail::make_xrepeat(std::forward<E>(e), repeats, axis);
1141 inline auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis)
1143 return detail::make_xrepeat(std::forward<E>(e), std::move(repeats), axis);
Dense multidimensional container adaptor with view semantics and fixed dimension.
auto nonzero(const T &arr)
return vector of indices where T is not zero
auto flatten(E &&e)
Return a flatten view of the given expression.
auto atleast_1d(E &&e)
Expand to at least 1D.
auto roll(E &&e, std::ptrdiff_t shift)
Roll an expression.
auto squeeze(E &&e)
Returns a squeeze view of the given expression.
auto moveaxis(E &&e, std::ptrdiff_t src, std::ptrdiff_t dest)
Return a new expression with an axis move to a new position.
auto ravel(E &&e)
Return a flatten view of the given expression.
auto transpose(E &&e) noexcept
Returns a transpose view by reversing the dimensions of xexpression e.
auto atleast_Nd(E &&e)
Expand dimensions of xexpression to at least N
auto repeat(E &&e, std::size_t repeats, std::size_t axis)
Repeat elements of an expression along a given axis.
auto trim_zeros(E &&e, const std::string &direction="fb")
Trim zeros at beginning, end or both of 1D sequence.
auto split(E &e, std::size_t n, std::size_t axis=0)
Split xexpression along axis into subexpressions.
auto rot90(E &&e, const std::array< std::ptrdiff_t, 2 > &axes={0, 1})
Rotate an array by 90 degrees in the plane specified by axes.
auto expand_dims(E &&e, std::size_t axis)
Expand the shape of an xexpression.
auto vsplit(E &e, std::size_t n)
Split an xexpression into subexpressions vertically (row-wise)
auto flip(E &&e)
Reverse the order of elements in an xexpression along every axis.
auto atleast_2d(E &&e)
Expand to at least 2D.
auto swapaxes(E &&e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
Return a new expression with two axes interchanged.
auto hsplit(E &e, std::size_t n)
Split an xexpression into subexpressions horizontally (column-wise)
auto atleast_3d(E &&e)
Expand to at least 3D.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto newaxis() noexcept
Returns a slice representing a new axis of length one, to be used as an argument of view function.
auto empty_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of the same shape,...
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto flatnonzero(const T &arr)
Return indices that are non-zero in the flattened version of arr.