xtensor
Loading...
Searching...
No Matches
xmanipulation.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_MANIPULATION_HPP
11#define XTENSOR_MANIPULATION_HPP
12
13#include <algorithm>
14#include <utility>
15
16#include <xtl/xcompare.hpp>
17#include <xtl/xsequence.hpp>
18
19#include "xbuilder.hpp"
20#include "xexception.hpp"
21#include "xrepeat.hpp"
22#include "xstrided_view.hpp"
23#include "xtensor_config.hpp"
24#include "xutils.hpp"
25
26namespace xt
27{
32 namespace check_policy
33 {
34 struct none
35 {
36 };
37
38 struct full
39 {
40 };
41 }
42
43 template <class E>
44 auto transpose(E&& e) noexcept;
45
46 template <class E, class S, class Tag = check_policy::none>
47 auto transpose(E&& e, S&& permutation, Tag check_policy = Tag());
48
49 template <class E>
50 auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2);
51
52 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
53 auto ravel(E&& e);
54
55 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
56 auto flatten(E&& e);
57
58 template <layout_type L, class T>
59 auto flatnonzero(const T& arr);
60
61 template <class E>
62 auto trim_zeros(E&& e, const std::string& direction = "fb");
63
64 template <class E>
65 auto squeeze(E&& e);
66
68 auto squeeze(E&& e, S&& axis, Tag check_policy = Tag());
69
70 template <class E>
71 auto expand_dims(E&& e, std::size_t axis);
72
73 template <std::size_t N, class E>
74 auto atleast_Nd(E&& e);
75
76 template <class E>
77 auto atleast_1d(E&& e);
78
79 template <class E>
80 auto atleast_2d(E&& e);
81
82 template <class E>
83 auto atleast_3d(E&& e);
84
85 template <class E>
86 auto split(E& e, std::size_t n, std::size_t axis = 0);
87
88 template <class E>
89 auto hsplit(E& e, std::size_t n);
90
91 template <class E>
92 auto vsplit(E& e, std::size_t n);
93
94 template <class E>
95 auto flip(E&& e);
96
97 template <class E>
98 auto flip(E&& e, std::size_t axis);
99
100 template <std::ptrdiff_t N = 1, class E>
101 auto rot90(E&& e, const std::array<std::ptrdiff_t, 2>& axes = {0, 1});
102
103 template <class E>
104 auto roll(E&& e, std::ptrdiff_t shift);
105
106 template <class E>
107 auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis);
108
109 template <class E>
110 auto repeat(E&& e, std::size_t repeats, std::size_t axis);
111
112 template <class E>
113 auto repeat(E&& e, const std::vector<std::size_t>& repeats, std::size_t axis);
114
115 template <class E>
116 auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis);
117
118 /****************************
119 * transpose implementation *
120 ****************************/
121
122 namespace detail
123 {
124 inline layout_type transpose_layout_noexcept(layout_type l) noexcept
125 {
126 layout_type result = l;
127 if (l == layout_type::row_major)
128 {
130 }
131 else if (l == layout_type::column_major)
132 {
133 result = layout_type::row_major;
134 }
135 return result;
136 }
137
138 inline layout_type transpose_layout(layout_type l)
139 {
141 {
142 XTENSOR_THROW(transpose_error, "cannot compute transposed layout of dynamic layout");
143 }
144 return transpose_layout_noexcept(l);
145 }
146
147 template <class E, class S>
148 inline auto transpose_impl(E&& e, S&& permutation, check_policy::none)
149 {
150 if (sequence_size(permutation) != e.dimension())
151 {
152 XTENSOR_THROW(transpose_error, "Permutation does not have the same size as shape");
153 }
154
155 // permute stride and shape
156 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
157 shape_type temp_shape;
158 resize_container(temp_shape, e.shape().size());
159
160 using strides_type = get_strides_t<shape_type>;
161 strides_type temp_strides;
162 resize_container(temp_strides, e.strides().size());
163
164 using size_type = typename std::decay_t<E>::size_type;
165 for (std::size_t i = 0; i < e.shape().size(); ++i)
166 {
167 if (std::size_t(permutation[i]) >= e.dimension())
168 {
169 XTENSOR_THROW(transpose_error, "Permutation contains wrong axis");
170 }
171 size_type perm = static_cast<size_type>(permutation[i]);
172 temp_shape[i] = e.shape()[perm];
173 temp_strides[i] = e.strides()[perm];
174 }
175
177 if (std::is_sorted(std::begin(permutation), std::end(permutation)))
178 {
179 // keep old layout
180 new_layout = e.layout();
181 }
182 else if (std::is_sorted(std::begin(permutation), std::end(permutation), std::greater<>()))
183 {
184 new_layout = transpose_layout_noexcept(e.layout());
185 }
186
187 return strided_view(
188 std::forward<E>(e),
189 std::move(temp_shape),
190 std::move(temp_strides),
191 get_offset<XTENSOR_DEFAULT_LAYOUT>(e),
192 new_layout
193 );
194 }
195
196 template <class E, class S>
197 inline auto transpose_impl(E&& e, S&& permutation, check_policy::full)
198 {
199 // check if axis appears twice in permutation
200 for (std::size_t i = 0; i < sequence_size(permutation); ++i)
201 {
202 for (std::size_t j = i + 1; j < sequence_size(permutation); ++j)
203 {
204 if (permutation[i] == permutation[j])
205 {
206 XTENSOR_THROW(transpose_error, "Permutation contains axis more than once");
207 }
208 }
209 }
210 return transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy::none());
211 }
212
213 template <class E, class S, class X, std::enable_if_t<has_data_interface<std::decay_t<E>>::value>* = nullptr>
214 inline void compute_transposed_strides(E&& e, const S&, X& strides)
215 {
216 std::copy(e.strides().crbegin(), e.strides().crend(), strides.begin());
217 }
218
219 template <class E, class S, class X, std::enable_if_t<!has_data_interface<std::decay_t<E>>::value>* = nullptr>
220 inline void compute_transposed_strides(E&&, const S& shape, X& strides)
221 {
222 // In the case where E does not have a data interface, the transposition
223 // makes use of a flat storage adaptor that has layout XTENSOR_DEFAULT_TRAVERSAL
224 // which should be the one inverted.
225 layout_type l = transpose_layout(XTENSOR_DEFAULT_TRAVERSAL);
226 compute_strides(shape, l, strides);
227 }
228 }
229
236 template <class E>
237 inline auto transpose(E&& e) noexcept
238 {
240 shape_type shape;
241 resize_container(shape, e.shape().size());
242 std::copy(e.shape().crbegin(), e.shape().crend(), shape.begin());
243
245 resize_container(strides, e.shape().size());
246 detail::compute_transposed_strides(e, shape, strides);
247
248 layout_type new_layout = detail::transpose_layout_noexcept(e.layout());
249
250 return strided_view(
251 std::forward<E>(e),
252 std::move(shape),
253 std::move(strides),
254 detail::get_offset<XTENSOR_DEFAULT_TRAVERSAL>(e),
255 new_layout
256 );
257 }
258
268 template <class E, class S, class Tag>
269 inline auto transpose(E&& e, S&& permutation, Tag check_policy)
270 {
271 return detail::transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy);
272 }
273
275 template <class E, class I, std::size_t N, class Tag = check_policy::none>
276 inline auto transpose(E&& e, const I (&permutation)[N], Tag check_policy = Tag())
277 {
278 return detail::transpose_impl(std::forward<E>(e), permutation, check_policy);
279 }
280
282
283 /*****************************
284 * swapaxes implementation *
285 *****************************/
286
287 namespace detail
288 {
289 template <class S>
290 inline S swapaxes_perm(std::size_t dim, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
291 {
292 const std::size_t ax1 = normalize_axis(dim, axis1);
293 const std::size_t ax2 = normalize_axis(dim, axis2);
294 auto perm = xtl::make_sequence<S>(dim, 0);
295 using id_t = typename S::value_type;
296 std::iota(perm.begin(), perm.end(), id_t(0));
297 perm[ax1] = ax2;
298 perm[ax2] = ax1;
299 return perm;
300 }
301 }
302
313 template <class E>
314 inline auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
315 {
316 const auto dim = e.dimension();
317 check_axis_in_dim(axis1, dim, "Parameter axis1");
318 check_axis_in_dim(axis2, dim, "Parameter axis2");
319
321 return transpose(std::forward<E>(e), detail::swapaxes_perm<strides_t>(dim, axis1, axis2));
322 }
323
324 /*****************************
325 * moveaxis implementation *
326 *****************************/
327
328 namespace detail
329 {
330 template <class S>
331 inline S moveaxis_perm(std::size_t dim, std::ptrdiff_t src, std::ptrdiff_t dest)
332 {
333 using id_t = typename S::value_type;
334
335 const std::size_t src_norm = normalize_axis(dim, src);
336 const std::size_t dest_norm = normalize_axis(dim, dest);
337
338 // Initializing to src_norm handles case where `dest == -1` and the loop
339 // does not go check `perm_idx == dest_norm` a `dim+1`th time.
340 auto perm = xtl::make_sequence<S>(dim, src_norm);
341 id_t perm_idx = 0;
342 for (id_t i = 0; xtl::cmp_less(i, dim); ++i)
343 {
344 if (xtl::cmp_equal(perm_idx, dest_norm))
345 {
346 perm[perm_idx] = src_norm;
347 ++perm_idx;
348 }
349 if (xtl::cmp_not_equal(i, src_norm))
350 {
351 perm[perm_idx] = i;
352 ++perm_idx;
353 }
354 }
355 return perm;
356 }
357 }
358
367 template <class E>
368 inline auto moveaxis(E&& e, std::ptrdiff_t src, std::ptrdiff_t dest)
369 {
370 const auto dim = e.dimension();
371 check_axis_in_dim(src, dim, "Parameter src");
372 check_axis_in_dim(dest, dim, "Parameter dest");
373
375 return xt::transpose(std::forward<E>(e), detail::moveaxis_perm<strides_t>(e.dimension(), src, dest));
376 }
377
378 /************************************
379 * ravel and flatten implementation *
380 ************************************/
381
382 namespace detail
383 {
384 template <class E, layout_type L>
385 struct expression_iterator_getter
386 {
387 using iterator = decltype(std::declval<E>().template begin<L>());
388 using const_iterator = decltype(std::declval<E>().template cbegin<L>());
389
390 inline static iterator begin(E& e)
391 {
392 return e.template begin<L>();
393 }
394
395 inline static const_iterator cbegin(E& e)
396 {
397 return e.template cbegin<L>();
398 }
399
400 inline static auto size(E& e)
401 {
402 return e.size();
403 }
404 };
405 }
406
416 template <layout_type L, class E>
417 inline auto ravel(E&& e)
418 {
419 using iterator = decltype(e.template begin<L>());
420 using iterator_getter = detail::expression_iterator_getter<std::remove_reference_t<E>, L>;
421 auto size = e.size();
422 auto adaptor = make_xiterator_adaptor(std::forward<E>(e), iterator_getter());
423 constexpr layout_type layout = std::is_pointer<iterator>::value ? L : layout_type::dynamic;
424 using type = xtensor_view<decltype(adaptor), 1, layout, extension::get_expression_tag_t<E>>;
425 return type(std::move(adaptor), {size});
426 }
427
441 template <layout_type L, class E>
442 inline auto flatten(E&& e)
443 {
444 return ravel<L>(std::forward<E>(e));
445 }
446
455 template <layout_type L, class T>
456 inline auto flatnonzero(const T& arr)
457 {
458 return nonzero(ravel<L>(arr))[0];
459 }
460
461 /*****************************
462 * trim_zeros implementation *
463 *****************************/
464
474 template <class E>
475 inline auto trim_zeros(E&& e, const std::string& direction)
476 {
477 XTENSOR_ASSERT_MSG(e.dimension() == 1, "Dimension for trim_zeros has to be 1.");
478
479 std::ptrdiff_t begin = 0, end = static_cast<std::ptrdiff_t>(e.size());
480
481 auto find_fun = [](const auto& i)
482 {
483 return i != 0;
484 };
485
486 if (direction.find("f") != std::string::npos)
487 {
488 begin = std::find_if(e.cbegin(), e.cend(), find_fun) - e.cbegin();
489 }
490
491 if (direction.find("b") != std::string::npos && begin != end)
492 {
493 end -= std::find_if(e.crbegin(), e.crend(), find_fun) - e.crbegin();
494 }
495
496 return strided_view(std::forward<E>(e), {range(begin, end)});
497 }
498
499 /**************************
500 * squeeze implementation *
501 **************************/
502
512 template <class E>
513 inline auto squeeze(E&& e)
514 {
517 std::copy_if(
518 e.shape().cbegin(),
519 e.shape().cend(),
520 std::back_inserter(new_shape),
521 [](std::size_t i)
522 {
523 return i != 1;
524 }
525 );
526 decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
527 std::copy_if(
528 old_strides.cbegin(),
529 old_strides.cend(),
530 std::back_inserter(new_strides),
531 [](std::ptrdiff_t i)
532 {
533 return i != 0;
534 }
535 );
536
537 return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
538 }
539
540 namespace detail
541 {
542 template <class E, class S>
543 inline auto squeeze_impl(E&& e, S&& axis, check_policy::none)
544 {
545 std::size_t new_dim = e.dimension() - axis.size();
546 dynamic_shape<std::size_t> new_shape(new_dim);
547 dynamic_shape<std::ptrdiff_t> new_strides(new_dim);
548
549 decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
550
551 for (std::size_t i = 0, ix = 0; i < e.dimension(); ++i)
552 {
553 if (axis.cend() == std::find(axis.cbegin(), axis.cend(), i))
554 {
555 new_shape[ix] = e.shape()[i];
556 new_strides[ix++] = old_strides[i];
557 }
558 }
559
560 return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
561 }
562
563 template <class E, class S>
564 inline auto squeeze_impl(E&& e, S&& axis, check_policy::full)
565 {
566 for (auto ix : axis)
567 {
568 if (static_cast<std::size_t>(ix) > e.dimension())
569 {
570 XTENSOR_THROW(std::runtime_error, "Axis argument to squeeze > dimension of expression");
571 }
572 if (e.shape()[static_cast<std::size_t>(ix)] != 1)
573 {
574 XTENSOR_THROW(std::runtime_error, "Trying to squeeze axis != 1");
575 }
576 }
577 return squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy::none());
578 }
579 }
580
591 template <class E, class S, class Tag, std::enable_if_t<!xtl::is_integral<S>::value, int>>
592 inline auto squeeze(E&& e, S&& axis, Tag check_policy)
593 {
594 return detail::squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy);
595 }
596
598 template <class E, class I, std::size_t N, class Tag = check_policy::none>
599 inline auto squeeze(E&& e, const I (&axis)[N], Tag check_policy = Tag())
600 {
601 using arr_t = std::array<I, N>;
602 return detail::squeeze_impl(
603 std::forward<E>(e),
604 xtl::forward_sequence<arr_t, decltype(axis)>(axis),
605 check_policy
606 );
607 }
608
609 template <class E, class Tag = check_policy::none>
610 inline auto squeeze(E&& e, std::size_t axis, Tag check_policy = Tag())
611 {
612 return squeeze(std::forward<E>(e), std::array<std::size_t, 1>{axis}, check_policy);
613 }
614
616
617 /******************************
618 * expand_dims implementation *
619 ******************************/
620
632 template <class E>
633 inline auto expand_dims(E&& e, std::size_t axis)
634 {
635 xstrided_slice_vector sv(e.dimension() + 1, all());
636 sv[axis] = newaxis();
637 return strided_view(std::forward<E>(e), std::move(sv));
638 }
639
640 /*****************************
641 * atleast_Nd implementation *
642 *****************************/
643
657 template <std::size_t N, class E>
658 inline auto atleast_Nd(E&& e)
659 {
660 xstrided_slice_vector sv((std::max)(e.dimension(), N), all());
661 if (e.dimension() < N)
662 {
663 std::size_t i = 0;
664 std::size_t end = static_cast<std::size_t>(std::round(double(N - e.dimension()) / double(N)));
665 for (; i < end; ++i)
666 {
667 sv[i] = newaxis();
668 }
669 i += e.dimension();
670 for (; i < N; ++i)
671 {
672 sv[i] = newaxis();
673 }
674 }
675 return strided_view(std::forward<E>(e), std::move(sv));
676 }
677
684 template <class E>
685 inline auto atleast_1d(E&& e)
686 {
687 return atleast_Nd<1>(std::forward<E>(e));
688 }
689
696 template <class E>
697 inline auto atleast_2d(E&& e)
698 {
699 return atleast_Nd<2>(std::forward<E>(e));
700 }
701
708 template <class E>
709 inline auto atleast_3d(E&& e)
710 {
711 return atleast_Nd<3>(std::forward<E>(e));
712 }
713
714 /************************
715 * split implementation *
716 ************************/
717
731 template <class E>
732 inline auto split(E& e, std::size_t n, std::size_t axis)
733 {
734 if (axis >= e.dimension())
735 {
736 XTENSOR_THROW(std::runtime_error, "Split along axis > dimension.");
737 }
738
739 std::size_t ax_sz = e.shape()[axis];
740 xstrided_slice_vector sv(e.dimension(), all());
741 std::size_t step = ax_sz / n;
742 std::size_t rest = ax_sz % n;
743
744 if (rest)
745 {
746 XTENSOR_THROW(std::runtime_error, "Split does not result in equal division.");
747 }
748
749 std::vector<decltype(strided_view(e, sv))> result;
750 for (std::size_t i = 0; i < n; ++i)
751 {
752 sv[axis] = range(i * step, (i + 1) * step);
753 result.emplace_back(strided_view(e, sv));
754 }
755 return result;
756 }
757
767 template <class E>
768 inline auto hsplit(E& e, std::size_t n)
769 {
770 return split(e, n, std::size_t(1));
771 }
772
782 template <class E>
783 inline auto vsplit(E& e, std::size_t n)
784 {
785 return split(e, n, std::size_t(0));
786 }
787
788 /***********************
789 * flip implementation *
790 ***********************/
791
799 template <class E>
800 inline auto flip(E&& e)
801 {
802 using size_type = typename std::decay_t<E>::size_type;
803 auto r = flip(e, 0);
804 for (size_type d = 1; d < e.dimension(); ++d)
805 {
806 r = flip(r, d);
807 }
808 return r;
809 }
810
822 template <class E>
823 inline auto flip(E&& e, std::size_t axis)
824 {
826
827 shape_type shape;
828 resize_container(shape, e.shape().size());
829 std::copy(e.shape().cbegin(), e.shape().cend(), shape.begin());
830
832 decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
833 resize_container(strides, old_strides.size());
834 std::copy(old_strides.cbegin(), old_strides.cend(), strides.begin());
835
836 strides[axis] *= -1;
837 std::size_t offset = static_cast<std::size_t>(
838 static_cast<std::ptrdiff_t>(e.data_offset())
839 + old_strides[axis] * (static_cast<std::ptrdiff_t>(e.shape()[axis]) - 1)
840 );
841
842 return strided_view(std::forward<E>(e), std::move(shape), std::move(strides), offset);
843 }
844
845 /************************
846 * rot90 implementation *
847 ************************/
848
849 namespace detail
850 {
851 template <std::ptrdiff_t N>
852 struct rot90_impl;
853
854 template <>
855 struct rot90_impl<0>
856 {
857 template <class E>
858 inline auto operator()(E&& e, const std::array<std::size_t, 2>& /*axes*/)
859 {
860 return std::forward<E>(e);
861 }
862 };
863
864 template <>
865 struct rot90_impl<1>
866 {
867 template <class E>
868 inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
869 {
870 using std::swap;
871
872 dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
873 std::iota(axes_list.begin(), axes_list.end(), 0);
874 swap(axes_list[axes[0]], axes_list[axes[1]]);
875
876 return transpose(flip(std::forward<E>(e), axes[1]), std::move(axes_list));
877 }
878 };
879
880 template <>
881 struct rot90_impl<2>
882 {
883 template <class E>
884 inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
885 {
886 return flip(flip(std::forward<E>(e), axes[0]), axes[1]);
887 }
888 };
889
890 template <>
891 struct rot90_impl<3>
892 {
893 template <class E>
894 inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
895 {
896 using std::swap;
897
898 dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
899 std::iota(axes_list.begin(), axes_list.end(), 0);
900 swap(axes_list[axes[0]], axes_list[axes[1]]);
901
902 return flip(transpose(std::forward<E>(e), std::move(axes_list)), axes[1]);
903 }
904 };
905 }
906
918 template <std::ptrdiff_t N, class E>
919 inline auto rot90(E&& e, const std::array<std::ptrdiff_t, 2>& axes)
920 {
921 auto ndim = static_cast<std::ptrdiff_t>(e.shape().size());
922
923 if (axes[0] == axes[1] || std::abs(axes[0] - axes[1]) == ndim)
924 {
925 XTENSOR_THROW(std::runtime_error, "Axes must be different");
926 }
927
929 constexpr std::ptrdiff_t n = (4 + (N % 4)) % 4;
930
931 return detail::rot90_impl<n>()(std::forward<E>(e), norm_axes);
932 }
933
934 /***********************
935 * roll implementation *
936 ***********************/
937
951 template <class E>
952 inline auto roll(E&& e, std::ptrdiff_t shift)
953 {
954 auto cpy = empty_like(e);
955 auto flat_size = std::accumulate(
956 cpy.shape().begin(),
957 cpy.shape().end(),
958 1L,
959 std::multiplies<std::size_t>()
960 );
961 while (shift < 0)
962 {
963 shift += flat_size;
964 }
965
966 shift %= flat_size;
967 std::copy(e.begin(), e.end() - shift, std::copy(e.end() - shift, e.end(), cpy.begin()));
968
969 return cpy;
970 }
971
972 namespace detail
973 {
978 template <class To, class From, class S>
979 To roll(To to, From from, std::ptrdiff_t shift, std::size_t axis, const S& shape, std::size_t M)
980 {
981 std::ptrdiff_t dim = std::ptrdiff_t(shape[M]);
982 std::ptrdiff_t offset = std::accumulate(
983 shape.begin() + M + 1,
984 shape.end(),
985 std::ptrdiff_t(1),
986 std::multiplies<std::ptrdiff_t>()
987 );
988 if (shape.size() == M + 1)
989 {
990 if (axis == M)
991 {
992 const auto split = from + (dim - shift) * offset;
993 for (auto iter = split, end = from + dim * offset; iter != end; iter += offset, ++to)
994 {
995 *to = *iter;
996 }
997 for (auto iter = from, end = split; iter != end; iter += offset, ++to)
998 {
999 *to = *iter;
1000 }
1001 }
1002 else
1003 {
1004 for (auto iter = from, end = from + dim * offset; iter != end; iter += offset, ++to)
1005 {
1006 *to = *iter;
1007 }
1008 }
1009 }
1010 else
1011 {
1012 if (axis == M)
1013 {
1014 const auto split = from + (dim - shift) * offset;
1015 for (auto iter = split, end = from + dim * offset; iter != end; iter += offset)
1016 {
1017 to = roll(to, iter, shift, axis, shape, M + 1);
1018 }
1019 for (auto iter = from, end = split; iter != end; iter += offset)
1020 {
1021 to = roll(to, iter, shift, axis, shape, M + 1);
1022 }
1023 }
1024 else
1025 {
1026 for (auto iter = from, end = from + dim * offset; iter != end; iter += offset)
1027 {
1028 to = roll(to, iter, shift, axis, shape, M + 1);
1029 }
1030 }
1031 }
1032 return to;
1033 }
1034 }
1035
1048 template <class E>
1049 inline auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis)
1050 {
1051 auto cpy = empty_like(e);
1052 const auto& shape = cpy.shape();
1053 std::size_t saxis = static_cast<std::size_t>(axis);
1054 if (axis < 0)
1055 {
1056 axis += std::ptrdiff_t(cpy.dimension());
1057 }
1058
1059 if (saxis >= cpy.dimension() || axis < 0)
1060 {
1061 XTENSOR_THROW(std::runtime_error, "axis is no within shape dimension.");
1062 }
1063
1064 const auto axis_dim = static_cast<std::ptrdiff_t>(shape[saxis]);
1065 while (shift < 0)
1066 {
1067 shift += axis_dim;
1068 }
1069
1070 detail::roll(cpy.begin(), e.begin(), shift, saxis, shape, 0);
1071 return cpy;
1072 }
1073
1074 /****************************
1075 * repeat implementation *
1076 ****************************/
1077
1078 namespace detail
1079 {
1080 template <class E, class R>
1081 inline auto make_xrepeat(E&& e, R&& r, typename std::decay_t<E>::size_type axis)
1082 {
1083 const auto casted_axis = static_cast<typename std::decay_t<E>::size_type>(axis);
1084 if (r.size() != e.shape(casted_axis))
1085 {
1086 XTENSOR_THROW(std::invalid_argument, "repeats must have the same size as the specified axis");
1087 }
1088 return xrepeat<const_xclosure_t<E>, R>(std::forward<E>(e), std::forward<R>(r), axis);
1089 }
1090 }
1091
1102 template <class E>
1103 inline auto repeat(E&& e, std::size_t repeats, std::size_t axis)
1104 {
1105 const auto casted_axis = static_cast<typename std::decay_t<E>::size_type>(axis);
1106 std::vector<std::size_t> broadcasted_repeats(e.shape(casted_axis));
1107 std::fill(broadcasted_repeats.begin(), broadcasted_repeats.end(), repeats);
1108 return repeat(std::forward<E>(e), std::move(broadcasted_repeats), axis);
1109 }
1110
1122 template <class E>
1123 inline auto repeat(E&& e, const std::vector<std::size_t>& repeats, std::size_t axis)
1124 {
1125 return detail::make_xrepeat(std::forward<E>(e), repeats, axis);
1126 }
1127
1138 template <class E>
1139 inline auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis)
1140 {
1141 return detail::make_xrepeat(std::forward<E>(e), std::move(repeats), axis);
1142 }
1143}
1144
1145#endif
Dense multidimensional container adaptor with view semantics and fixed dimension.
Definition xtensor.hpp:329
auto nonzero(const T &arr)
return vector of indices where T is not zero
auto flatten(E &&e)
Return a flatten view of the given expression.
auto atleast_1d(E &&e)
Expand to at least 1D.
auto roll(E &&e, std::ptrdiff_t shift)
Roll an expression.
auto squeeze(E &&e)
Returns a squeeze view of the given expression.
auto moveaxis(E &&e, std::ptrdiff_t src, std::ptrdiff_t dest)
Return a new expression with an axis move to a new position.
auto ravel(E &&e)
Return a flatten view of the given expression.
auto transpose(E &&e) noexcept
Returns a transpose view by reversing the dimensions of xexpression e.
auto atleast_Nd(E &&e)
Expand dimensions of xexpression to at least N
auto repeat(E &&e, std::size_t repeats, std::size_t axis)
Repeat elements of an expression along a given axis.
auto trim_zeros(E &&e, const std::string &direction="fb")
Trim zeros at beginning, end or both of 1D sequence.
auto split(E &e, std::size_t n, std::size_t axis=0)
Split xexpression along axis into subexpressions.
auto rot90(E &&e, const std::array< std::ptrdiff_t, 2 > &axes={0, 1})
Rotate an array by 90 degrees in the plane specified by axes.
auto expand_dims(E &&e, std::size_t axis)
Expand the shape of an xexpression.
auto vsplit(E &e, std::size_t n)
Split an xexpression into subexpressions vertically (row-wise)
auto flip(E &&e)
Reverse the order of elements in an xexpression along every axis.
auto atleast_2d(E &&e)
Expand to at least 2D.
auto swapaxes(E &&e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
Return a new expression with two axes interchanged.
auto hsplit(E &e, std::size_t n)
Split an xexpression into subexpressions horizontally (column-wise)
auto atleast_3d(E &&e)
Expand to at least 3D.
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
Definition xstrides.hpp:566
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
Definition xslice.hpp:818
auto all() noexcept
Returns a slice representing a full dimension, to be used as an argument of view function.
Definition xslice.hpp:234
auto newaxis() noexcept
Returns a slice representing a new axis of length one, to be used as an argument of view function.
Definition xslice.hpp:300
layout_type
Definition xlayout.hpp:24
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto empty_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of the same shape,...
Definition xbuilder.hpp:121
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto flatnonzero(const T &arr)
Return indices that are non-zero in the flattened version of arr.