10#ifndef XTENSOR_MANIPULATION_HPP
11#define XTENSOR_MANIPULATION_HPP
16#include <xtl/xcompare.hpp>
17#include <xtl/xsequence.hpp>
19#include "xbuilder.hpp"
20#include "xexception.hpp"
22#include "xstrided_view.hpp"
23#include "xtensor_config.hpp"
32 namespace check_policy
46 template <
class E,
class S,
class Tag = check_policy::none>
47 auto transpose(E&& e, S&& permutation, Tag check_policy = Tag());
50 auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2);
52 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL,
class E>
55 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL,
class E>
58 template <layout_type L,
class T>
62 auto trim_zeros(E&& e,
const std::string& direction =
"fb");
67 template <class E, class S, class Tag = check_policy::none, std::enable_if_t<!xtl::is_integral<S>::value,
int> = 0>
68 auto squeeze(E&& e, S&& axis, Tag check_policy = Tag());
73 template <std::
size_t N,
class E>
86 auto split(E& e, std::size_t n, std::size_t axis = 0);
89 auto hsplit(E& e, std::size_t n);
92 auto vsplit(E& e, std::size_t n);
98 auto flip(E&& e, std::size_t axis);
100 template <std::ptrdiff_t N = 1,
class E>
101 auto rot90(E&& e,
const std::array<std::ptrdiff_t, 2>& axes = {0, 1});
104 auto roll(E&& e, std::ptrdiff_t shift);
107 auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis);
110 auto repeat(E&& e, std::size_t repeats, std::size_t axis);
113 auto repeat(E&& e,
const std::vector<std::size_t>& repeats, std::size_t axis);
116 auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis);
142 XTENSOR_THROW(transpose_error,
"cannot compute transposed layout of dynamic layout");
144 return transpose_layout_noexcept(l);
147 template <
class E,
class S>
148 inline auto transpose_impl(E&& e, S&& permutation, check_policy::none)
150 if (sequence_size(permutation) != e.dimension())
152 XTENSOR_THROW(transpose_error,
"Permutation does not have the same size as shape");
156 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
157 shape_type temp_shape;
158 resize_container(temp_shape, e.shape().size());
160 using strides_type = get_strides_t<shape_type>;
161 strides_type temp_strides;
162 resize_container(temp_strides, e.strides().size());
164 using size_type =
typename std::decay_t<E>::size_type;
165 for (std::size_t i = 0; i < e.shape().size(); ++i)
167 if (std::size_t(permutation[i]) >= e.dimension())
169 XTENSOR_THROW(transpose_error,
"Permutation contains wrong axis");
171 size_type perm =
static_cast<size_type
>(permutation[i]);
172 temp_shape[i] = e.shape()[perm];
173 temp_strides[i] = e.strides()[perm];
177 if (std::is_sorted(std::begin(permutation), std::end(permutation)))
180 new_layout = e.layout();
182 else if (std::is_sorted(std::begin(permutation), std::end(permutation), std::greater<>()))
184 new_layout = transpose_layout_noexcept(e.layout());
189 std::move(temp_shape),
190 std::move(temp_strides),
191 get_offset<XTENSOR_DEFAULT_LAYOUT>(e),
196 template <
class E,
class S>
197 inline auto transpose_impl(E&& e, S&& permutation, check_policy::full)
200 for (std::size_t i = 0; i < sequence_size(permutation); ++i)
202 for (std::size_t j = i + 1; j < sequence_size(permutation); ++j)
204 if (permutation[i] == permutation[j])
206 XTENSOR_THROW(transpose_error,
"Permutation contains axis more than once");
210 return transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy::none());
213 template <
class E,
class S,
class X, std::enable_if_t<has_data_
interface<std::decay_t<E>>::value>* =
nullptr>
214 inline void compute_transposed_strides(E&& e,
const S&, X&
strides)
216 std::copy(e.strides().crbegin(), e.strides().crend(),
strides.begin());
219 template <
class E,
class S,
class X, std::enable_if_t<!has_data_
interface<std::decay_t<E>>::value>* =
nullptr>
220 inline void compute_transposed_strides(E&&,
const S& shape, X&
strides)
225 layout_type l = transpose_layout(XTENSOR_DEFAULT_TRAVERSAL);
226 compute_strides(shape, l,
strides);
239 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
241 resize_container(shape, e.shape().size());
242 std::copy(e.shape().crbegin(), e.shape().crend(), shape.begin());
244 get_strides_t<shape_type>
strides;
245 resize_container(
strides, e.shape().size());
246 detail::compute_transposed_strides(e, shape,
strides);
248 layout_type new_layout = detail::transpose_layout_noexcept(e.layout());
254 detail::get_offset<XTENSOR_DEFAULT_TRAVERSAL>(e),
268 template <
class E,
class S,
class Tag>
269 inline auto transpose(E&& e, S&& permutation, Tag check_policy)
271 return detail::transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy);
275 template <
class E,
class I, std::
size_t N,
class Tag = check_policy::none>
276 inline auto transpose(E&& e,
const I (&permutation)[N], Tag check_policy = Tag())
278 return detail::transpose_impl(std::forward<E>(e), permutation, check_policy);
290 inline S swapaxes_perm(std::size_t dim, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
292 const std::size_t ax1 = normalize_axis(dim, axis1);
293 const std::size_t ax2 = normalize_axis(dim, axis2);
294 auto perm = xtl::make_sequence<S>(dim, 0);
295 using id_t =
typename S::value_type;
296 std::iota(perm.begin(), perm.end(), id_t(0));
314 inline auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
316 const auto dim = e.dimension();
317 check_axis_in_dim(axis1, dim,
"Parameter axis1");
318 check_axis_in_dim(axis2, dim,
"Parameter axis2");
320 using strides_t = get_strides_t<typename std::decay_t<E>::shape_type>;
321 return transpose(std::forward<E>(e), detail::swapaxes_perm<strides_t>(dim, axis1, axis2));
331 inline S moveaxis_perm(std::size_t dim, std::ptrdiff_t src, std::ptrdiff_t dest)
333 using id_t =
typename S::value_type;
335 const std::size_t src_norm = normalize_axis(dim, src);
336 const std::size_t dest_norm = normalize_axis(dim, dest);
340 auto perm = xtl::make_sequence<S>(dim, src_norm);
342 for (id_t i = 0; xtl::cmp_less(i, dim); ++i)
344 if (xtl::cmp_equal(perm_idx, dest_norm))
346 perm[perm_idx] = src_norm;
349 if (xtl::cmp_not_equal(i, src_norm))
368 inline auto moveaxis(E&& e, std::ptrdiff_t src, std::ptrdiff_t dest)
370 const auto dim = e.dimension();
371 check_axis_in_dim(src, dim,
"Parameter src");
372 check_axis_in_dim(dest, dim,
"Parameter dest");
374 using strides_t = get_strides_t<typename std::decay_t<E>::shape_type>;
375 return xt::transpose(std::forward<E>(e), detail::moveaxis_perm<strides_t>(e.dimension(), src, dest));
384 template <
class E, layout_type L>
385 struct expression_iterator_getter
387 using iterator =
decltype(std::declval<E>().template begin<L>());
388 using const_iterator =
decltype(std::declval<E>().template cbegin<L>());
390 inline static iterator begin(E& e)
392 return e.template begin<L>();
395 inline static const_iterator cbegin(E& e)
397 return e.template cbegin<L>();
400 inline static auto size(E& e)
416 template <layout_type L,
class E>
419 using iterator =
decltype(e.template begin<L>());
420 using iterator_getter = detail::expression_iterator_getter<std::remove_reference_t<E>, L>;
421 auto size = e.size();
422 auto adaptor = make_xiterator_adaptor(std::forward<E>(e), iterator_getter());
424 using type =
xtensor_view<
decltype(adaptor), 1, layout, extension::get_expression_tag_t<E>>;
425 return type(std::move(adaptor), {size});
441 template <layout_type L,
class E>
444 return ravel<L>(std::forward<E>(e));
455 template <layout_type L,
class T>
477 XTENSOR_ASSERT_MSG(e.dimension() == 1,
"Dimension for trim_zeros has to be 1.");
479 std::ptrdiff_t begin = 0, end =
static_cast<std::ptrdiff_t
>(e.size());
481 auto find_fun = [](
const auto& i)
486 if (direction.find(
"f") != std::string::npos)
488 begin = std::find_if(e.cbegin(), e.cend(), find_fun) - e.cbegin();
491 if (direction.find(
"b") != std::string::npos && begin != end)
493 end -= std::find_if(e.crbegin(), e.crend(), find_fun) - e.crbegin();
515 dynamic_shape<std::size_t> new_shape;
516 dynamic_shape<std::ptrdiff_t> new_strides;
520 std::back_inserter(new_shape),
526 decltype(
auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
528 old_strides.cbegin(),
530 std::back_inserter(new_strides),
537 return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
542 template <
class E,
class S>
543 inline auto squeeze_impl(E&& e, S&& axis, check_policy::none)
545 std::size_t new_dim = e.dimension() - axis.size();
546 dynamic_shape<std::size_t> new_shape(new_dim);
547 dynamic_shape<std::ptrdiff_t> new_strides(new_dim);
549 decltype(
auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
551 for (std::size_t i = 0, ix = 0; i < e.dimension(); ++i)
553 if (axis.cend() == std::find(axis.cbegin(), axis.cend(), i))
555 new_shape[ix] = e.shape()[i];
556 new_strides[ix++] = old_strides[i];
560 return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
563 template <
class E,
class S>
564 inline auto squeeze_impl(E&& e, S&& axis, check_policy::full)
568 if (
static_cast<std::size_t
>(ix) > e.dimension())
570 XTENSOR_THROW(std::runtime_error,
"Axis argument to squeeze > dimension of expression");
572 if (e.shape()[
static_cast<std::size_t
>(ix)] != 1)
574 XTENSOR_THROW(std::runtime_error,
"Trying to squeeze axis != 1");
577 return squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy::none());
591 template <class E, class S, class Tag, std::enable_if_t<!xtl::is_integral<S>::value,
int>>
592 inline auto squeeze(E&& e, S&& axis, Tag check_policy)
594 return detail::squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy);
598 template <
class E,
class I, std::
size_t N,
class Tag = check_policy::none>
599 inline auto squeeze(E&& e,
const I (&axis)[N], Tag check_policy = Tag())
601 using arr_t = std::array<I, N>;
602 return detail::squeeze_impl(
604 xtl::forward_sequence<arr_t,
decltype(axis)>(axis),
609 template <
class E,
class Tag = check_policy::none>
610 inline auto squeeze(E&& e, std::size_t axis, Tag check_policy = Tag())
612 return squeeze(std::forward<E>(e), std::array<std::size_t, 1>{axis}, check_policy);
657 template <std::
size_t N,
class E>
661 if (e.dimension() < N)
664 std::size_t end =
static_cast<std::size_t
>(std::round(
double(N - e.dimension()) /
double(N)));
732 inline auto split(E& e, std::size_t n, std::size_t axis)
734 if (axis >= e.dimension())
736 XTENSOR_THROW(std::runtime_error,
"Split along axis > dimension.");
739 std::size_t ax_sz = e.shape()[axis];
741 std::size_t step = ax_sz / n;
742 std::size_t rest = ax_sz % n;
746 XTENSOR_THROW(std::runtime_error,
"Split does not result in equal division.");
750 for (std::size_t i = 0; i < n; ++i)
752 sv[axis] =
range(i * step, (i + 1) * step);
770 return split(e, n, std::size_t(1));
785 return split(e, n, std::size_t(0));
802 using size_type =
typename std::decay_t<E>::size_type;
804 for (size_type d = 1; d < e.dimension(); ++d)
823 inline auto flip(E&& e, std::size_t axis)
825 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
828 resize_container(shape, e.shape().size());
829 std::copy(e.shape().cbegin(), e.shape().cend(), shape.begin());
831 get_strides_t<shape_type>
strides;
832 decltype(
auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
833 resize_container(
strides, old_strides.size());
834 std::copy(old_strides.cbegin(), old_strides.cend(),
strides.begin());
837 std::size_t offset =
static_cast<std::size_t
>(
838 static_cast<std::ptrdiff_t
>(e.data_offset())
839 + old_strides[axis] * (
static_cast<std::ptrdiff_t
>(e.shape()[axis]) - 1)
851 template <std::ptrdiff_t N>
858 inline auto operator()(E&& e,
const std::array<std::size_t, 2>& )
860 return std::forward<E>(e);
868 inline auto operator()(E&& e,
const std::array<std::size_t, 2>& axes)
872 dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
873 std::iota(axes_list.begin(), axes_list.end(), 0);
874 swap(axes_list[axes[0]], axes_list[axes[1]]);
876 return transpose(
flip(std::forward<E>(e), axes[1]), std::move(axes_list));
884 inline auto operator()(E&& e,
const std::array<std::size_t, 2>& axes)
886 return flip(
flip(std::forward<E>(e), axes[0]), axes[1]);
894 inline auto operator()(E&& e,
const std::array<std::size_t, 2>& axes)
898 dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
899 std::iota(axes_list.begin(), axes_list.end(), 0);
900 swap(axes_list[axes[0]], axes_list[axes[1]]);
902 return flip(
transpose(std::forward<E>(e), std::move(axes_list)), axes[1]);
918 template <std::ptrdiff_t N,
class E>
919 inline auto rot90(E&& e,
const std::array<std::ptrdiff_t, 2>& axes)
921 auto ndim =
static_cast<std::ptrdiff_t
>(e.shape().size());
923 if (axes[0] == axes[1] || std::abs(axes[0] - axes[1]) == ndim)
925 XTENSOR_THROW(std::runtime_error,
"Axes must be different");
928 auto norm_axes = forward_normalize<std::array<std::size_t, 2>>(e, axes);
929 constexpr std::ptrdiff_t n = (4 + (N % 4)) % 4;
931 return detail::rot90_impl<n>()(std::forward<E>(e), norm_axes);
952 inline auto roll(E&& e, std::ptrdiff_t shift)
955 auto flat_size = std::accumulate(
959 std::multiplies<std::size_t>()
967 std::copy(e.begin(), e.end() - shift, std::copy(e.end() - shift, e.end(), cpy.begin()));
978 template <
class To,
class From,
class S>
979 To roll(To to, From from, std::ptrdiff_t shift, std::size_t axis,
const S& shape, std::size_t M)
981 std::ptrdiff_t dim = std::ptrdiff_t(shape[M]);
982 std::ptrdiff_t offset = std::accumulate(
983 shape.begin() + M + 1,
986 std::multiplies<std::ptrdiff_t>()
988 if (shape.size() == M + 1)
992 const auto split = from + (dim - shift) * offset;
993 for (
auto iter =
split, end = from + dim * offset; iter != end; iter += offset, ++to)
997 for (
auto iter = from, end =
split; iter != end; iter += offset, ++to)
1004 for (
auto iter = from, end = from + dim * offset; iter != end; iter += offset, ++to)
1014 const auto split = from + (dim - shift) * offset;
1015 for (
auto iter =
split, end = from + dim * offset; iter != end; iter += offset)
1017 to = roll(to, iter, shift, axis, shape, M + 1);
1019 for (
auto iter = from, end =
split; iter != end; iter += offset)
1021 to = roll(to, iter, shift, axis, shape, M + 1);
1026 for (
auto iter = from, end = from + dim * offset; iter != end; iter += offset)
1028 to = roll(to, iter, shift, axis, shape, M + 1);
1049 inline auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis)
1052 const auto& shape = cpy.shape();
1053 std::size_t saxis =
static_cast<std::size_t
>(axis);
1056 axis += std::ptrdiff_t(cpy.dimension());
1059 if (saxis >= cpy.dimension() || axis < 0)
1061 XTENSOR_THROW(std::runtime_error,
"axis is no within shape dimension.");
1064 const auto axis_dim =
static_cast<std::ptrdiff_t
>(shape[saxis]);
1070 detail::roll(cpy.begin(), e.begin(), shift, saxis, shape, 0);
1080 template <
class E,
class R>
1081 inline auto make_xrepeat(E&& e, R&& r,
typename std::decay_t<E>::size_type axis)
1083 const auto casted_axis =
static_cast<typename std::decay_t<E>::size_type
>(axis);
1084 if (r.size() != e.shape(casted_axis))
1086 XTENSOR_THROW(std::invalid_argument,
"repeats must have the same size as the specified axis");
1088 return xrepeat<const_xclosure_t<E>, R>(std::forward<E>(e), std::forward<R>(r), axis);
1103 inline auto repeat(E&& e, std::size_t repeats, std::size_t axis)
1105 const auto casted_axis =
static_cast<typename std::decay_t<E>::size_type
>(axis);
1106 std::vector<std::size_t> broadcasted_repeats(e.shape(casted_axis));
1107 std::fill(broadcasted_repeats.begin(), broadcasted_repeats.end(), repeats);
1108 return repeat(std::forward<E>(e), std::move(broadcasted_repeats), axis);
1123 inline auto repeat(E&& e,
const std::vector<std::size_t>& repeats, std::size_t axis)
1125 return detail::make_xrepeat(std::forward<E>(e), repeats, axis);
1139 inline auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis)
1141 return detail::make_xrepeat(std::forward<E>(e), std::move(repeats), axis);
Dense multidimensional container adaptor with view semantics and fixed dimension.
auto nonzero(const T &arr)
return vector of indices where T is not zero
auto flatten(E &&e)
Return a flatten view of the given expression.
auto atleast_1d(E &&e)
Expand to at least 1D.
auto roll(E &&e, std::ptrdiff_t shift)
Roll an expression.
auto squeeze(E &&e)
Returns a squeeze view of the given expression.
auto moveaxis(E &&e, std::ptrdiff_t src, std::ptrdiff_t dest)
Return a new expression with an axis move to a new position.
auto ravel(E &&e)
Return a flatten view of the given expression.
auto transpose(E &&e) noexcept
Returns a transpose view by reversing the dimensions of xexpression e.
auto atleast_Nd(E &&e)
Expand dimensions of xexpression to at least N
auto repeat(E &&e, std::size_t repeats, std::size_t axis)
Repeat elements of an expression along a given axis.
auto trim_zeros(E &&e, const std::string &direction="fb")
Trim zeros at beginning, end or both of 1D sequence.
auto split(E &e, std::size_t n, std::size_t axis=0)
Split xexpression along axis into subexpressions.
auto rot90(E &&e, const std::array< std::ptrdiff_t, 2 > &axes={0, 1})
Rotate an array by 90 degrees in the plane specified by axes.
auto expand_dims(E &&e, std::size_t axis)
Expand the shape of an xexpression.
auto vsplit(E &e, std::size_t n)
Split an xexpression into subexpressions vertically (row-wise)
auto flip(E &&e)
Reverse the order of elements in an xexpression along every axis.
auto atleast_2d(E &&e)
Expand to at least 2D.
auto swapaxes(E &&e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
Return a new expression with two axes interchanged.
auto hsplit(E &e, std::size_t n)
Split an xexpression into subexpressions horizontally (column-wise)
auto atleast_3d(E &&e)
Expand to at least 3D.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto newaxis() noexcept
Returns a slice representing a new axis of length one, to be used as an argument of view function.
auto empty_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of the same shape,...
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto flatnonzero(const T &arr)
Return indices that are non-zero in the flattened version of arr.