xtensor
 
Loading...
Searching...
No Matches
xmanipulation.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_MANIPULATION_HPP
11#define XTENSOR_MANIPULATION_HPP
12
13#include <algorithm>
14#include <utility>
15
16#include <xtl/xcompare.hpp>
17#include <xtl/xsequence.hpp>
18
19#include "../core/xtensor_config.hpp"
20#include "../generators/xbuilder.hpp"
21#include "../utils/xexception.hpp"
22#include "../utils/xutils.hpp"
23#include "../views/xrepeat.hpp"
24#include "../views/xstrided_view.hpp"
25#include "xtl_concepts.hpp"
26
27namespace xt
28{
32
33 namespace check_policy
34 {
35 struct none
36 {
37 };
38
39 struct full
40 {
41 };
42 }
43
44 template <class E>
45 auto transpose(E&& e) noexcept;
46
47 template <class E, class S, class Tag = check_policy::none>
48 auto transpose(E&& e, S&& permutation, Tag check_policy = Tag());
49
50 template <class E>
51 auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2);
52
53 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
54 auto ravel(E&& e);
55
56 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E>
57 auto flatten(E&& e);
58
59 template <layout_type L, class T>
60 auto flatnonzero(const T& arr);
61
62 template <class E>
63 auto trim_zeros(E&& e, const std::string& direction = "fb");
64
65 template <class E>
66 auto squeeze(E&& e);
67
68 template <class E, xtl::non_integral_concept S, class Tag = check_policy::none>
69 auto squeeze(E&& e, S&& axis, Tag check_policy = Tag());
70
71 template <class E>
72 auto expand_dims(E&& e, std::size_t axis);
73
74 template <std::size_t N, class E>
75 auto atleast_Nd(E&& e);
76
77 template <class E>
78 auto atleast_1d(E&& e);
79
80 template <class E>
81 auto atleast_2d(E&& e);
82
83 template <class E>
84 auto atleast_3d(E&& e);
85
86 template <class E>
87 auto split(E& e, std::size_t n, std::size_t axis = 0);
88
89 template <class E>
90 auto hsplit(E& e, std::size_t n);
91
92 template <class E>
93 auto vsplit(E& e, std::size_t n);
94
95 template <class E>
96 auto flip(E&& e);
97
98 template <class E>
99 auto flip(E&& e, std::size_t axis);
100
101 template <std::ptrdiff_t N = 1, class E>
102 auto rot90(E&& e, const std::array<std::ptrdiff_t, 2>& axes = {0, 1});
103
104 template <class E>
105 auto roll(E&& e, std::ptrdiff_t shift);
106
107 template <class E>
108 auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis);
109
110 template <class E>
111 auto repeat(E&& e, std::size_t repeats, std::size_t axis);
112
113 template <class E>
114 auto repeat(E&& e, const std::vector<std::size_t>& repeats, std::size_t axis);
115
116 template <class E>
117 auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis);
118
119 /****************************
120 * transpose implementation *
121 ****************************/
122
123 namespace detail
124 {
125 inline layout_type transpose_layout_noexcept(layout_type l) noexcept
126 {
127 layout_type result = l;
128 if (l == layout_type::row_major)
129 {
131 }
132 else if (l == layout_type::column_major)
133 {
134 result = layout_type::row_major;
135 }
136 return result;
137 }
138
139 inline layout_type transpose_layout(layout_type l)
140 {
142 {
143 XTENSOR_THROW(transpose_error, "cannot compute transposed layout of dynamic layout");
144 }
145 return transpose_layout_noexcept(l);
146 }
147
148 template <class E, class S>
149 inline auto transpose_impl(E&& e, S&& permutation, check_policy::none)
150 {
151 if (std::size(permutation) != e.dimension())
152 {
153 XTENSOR_THROW(transpose_error, "Permutation does not have the same size as shape");
154 }
155
156 // permute stride and shape
157 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
158 shape_type temp_shape;
159 resize_container(temp_shape, e.shape().size());
160
161 using strides_type = get_strides_t<shape_type>;
162 strides_type temp_strides;
163 resize_container(temp_strides, e.strides().size());
164
165 using size_type = typename std::decay_t<E>::size_type;
166 for (std::size_t i = 0; i < e.shape().size(); ++i)
167 {
168 if (std::size_t(permutation[i]) >= e.dimension())
169 {
170 XTENSOR_THROW(transpose_error, "Permutation contains wrong axis");
171 }
172 size_type perm = static_cast<size_type>(permutation[i]);
173 temp_shape[i] = e.shape()[perm];
174 temp_strides[i] = e.strides()[perm];
175 }
176
178 if (std::is_sorted(std::begin(permutation), std::end(permutation)))
179 {
180 // keep old layout
181 new_layout = e.layout();
182 }
183 else if (std::is_sorted(std::begin(permutation), std::end(permutation), std::greater<>()))
184 {
185 new_layout = transpose_layout_noexcept(e.layout());
186 }
187
188 return strided_view(
189 std::forward<E>(e),
190 std::move(temp_shape),
191 std::move(temp_strides),
192 get_offset<XTENSOR_DEFAULT_LAYOUT>(e),
193 new_layout
194 );
195 }
196
197 template <class E, class S>
198 inline auto transpose_impl(E&& e, S&& permutation, check_policy::full)
199 {
200 // check if axis appears twice in permutation
201 for (std::size_t i = 0; i < std::size(permutation); ++i)
202 {
203 for (std::size_t j = i + 1; j < std::size(permutation); ++j)
204 {
205 if (permutation[i] == permutation[j])
206 {
207 XTENSOR_THROW(transpose_error, "Permutation contains axis more than once");
208 }
209 }
210 }
211 return transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy::none());
212 }
213
214 template <class E, class S, class X>
215 inline void compute_transposed_strides(E&& e, const S& shape, X& strides)
216 {
217 if constexpr (has_data_interface<std::decay_t<E>>::value)
218 {
219 std::copy(e.strides().crbegin(), e.strides().crend(), strides.begin());
220 }
221 else
222 {
223 // In the case where E does not have a data interface, the transposition
224 // makes use of a flat storage adaptor that has layout XTENSOR_DEFAULT_TRAVERSAL
225 // which should be the one inverted.
226 layout_type l = transpose_layout(XTENSOR_DEFAULT_TRAVERSAL);
227 compute_strides(shape, l, strides);
228 }
229 }
230 }
231
238 template <class E>
239 inline auto transpose(E&& e) noexcept
240 {
241 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
242 shape_type shape;
243 resize_container(shape, e.shape().size());
244 std::copy(e.shape().crbegin(), e.shape().crend(), shape.begin());
245
246 get_strides_t<shape_type> strides;
247 resize_container(strides, e.shape().size());
248 detail::compute_transposed_strides(e, shape, strides);
249
250 layout_type new_layout = detail::transpose_layout_noexcept(e.layout());
251
252 return strided_view(
253 std::forward<E>(e),
254 std::move(shape),
255 std::move(strides),
256 detail::get_offset<XTENSOR_DEFAULT_TRAVERSAL>(e),
257 new_layout
258 );
259 }
260
270 template <class E, class S, class Tag>
271 inline auto transpose(E&& e, S&& permutation, Tag check_policy)
272 {
273 return detail::transpose_impl(std::forward<E>(e), std::forward<S>(permutation), check_policy);
274 }
275
277 template <class E, class I, std::size_t N, class Tag = check_policy::none>
278 inline auto transpose(E&& e, const I (&permutation)[N], Tag check_policy = Tag())
279 {
280 return detail::transpose_impl(std::forward<E>(e), permutation, check_policy);
281 }
282
284
285 /*****************************
286 * swapaxes implementation *
287 *****************************/
288
289 namespace detail
290 {
291 template <class S>
292 inline S swapaxes_perm(std::size_t dim, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
293 {
294 const std::size_t ax1 = normalize_axis(dim, axis1);
295 const std::size_t ax2 = normalize_axis(dim, axis2);
296 auto perm = xtl::make_sequence<S>(dim, 0);
297 using id_t = typename S::value_type;
298 std::iota(perm.begin(), perm.end(), id_t(0));
299 perm[ax1] = ax2;
300 perm[ax2] = ax1;
301 return perm;
302 }
303 }
304
315 template <class E>
316 inline auto swapaxes(E&& e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
317 {
318 const auto dim = e.dimension();
319 check_axis_in_dim(axis1, dim, "Parameter axis1");
320 check_axis_in_dim(axis2, dim, "Parameter axis2");
321
322 using strides_t = get_strides_t<typename std::decay_t<E>::shape_type>;
323 return transpose(std::forward<E>(e), detail::swapaxes_perm<strides_t>(dim, axis1, axis2));
324 }
325
326 /*****************************
327 * moveaxis implementation *
328 *****************************/
329
330 namespace detail
331 {
332 template <class S>
333 inline S moveaxis_perm(std::size_t dim, std::ptrdiff_t src, std::ptrdiff_t dest)
334 {
335 using id_t = typename S::value_type;
336
337 const std::size_t src_norm = normalize_axis(dim, src);
338 const std::size_t dest_norm = normalize_axis(dim, dest);
339
340 // Initializing to src_norm handles case where `dest == -1` and the loop
341 // does not go check `perm_idx == dest_norm` a `dim+1`th time.
342 auto perm = xtl::make_sequence<S>(dim, src_norm);
343 id_t perm_idx = 0;
344 for (id_t i = 0; xtl::cmp_less(i, dim); ++i)
345 {
346 if (xtl::cmp_equal(perm_idx, dest_norm))
347 {
348 perm[perm_idx] = src_norm;
349 ++perm_idx;
350 }
351 if (xtl::cmp_not_equal(i, src_norm))
352 {
353 perm[perm_idx] = i;
354 ++perm_idx;
355 }
356 }
357 return perm;
358 }
359 }
360
369 template <class E>
370 inline auto moveaxis(E&& e, std::ptrdiff_t src, std::ptrdiff_t dest)
371 {
372 const auto dim = e.dimension();
373 check_axis_in_dim(src, dim, "Parameter src");
374 check_axis_in_dim(dest, dim, "Parameter dest");
375
376 using strides_t = get_strides_t<typename std::decay_t<E>::shape_type>;
377 return xt::transpose(std::forward<E>(e), detail::moveaxis_perm<strides_t>(e.dimension(), src, dest));
378 }
379
380 /************************************
381 * ravel and flatten implementation *
382 ************************************/
383
384 namespace detail
385 {
386 template <class E, layout_type L>
387 struct expression_iterator_getter
388 {
389 using iterator = decltype(std::declval<E>().template begin<L>());
390 using const_iterator = decltype(std::declval<E>().template cbegin<L>());
391
392 inline static iterator begin(E& e)
393 {
394 return e.template begin<L>();
395 }
396
397 inline static const_iterator cbegin(E& e)
398 {
399 return e.template cbegin<L>();
400 }
401
402 inline static auto size(E& e)
403 {
404 return e.size();
405 }
406 };
407 }
408
418 template <layout_type L, class E>
419 inline auto ravel(E&& e)
420 {
421 using iterator = decltype(e.template begin<L>());
422 using iterator_getter = detail::expression_iterator_getter<std::remove_reference_t<E>, L>;
423 auto size = e.size();
424 auto adaptor = make_xiterator_adaptor(std::forward<E>(e), iterator_getter());
425 constexpr layout_type layout = std::is_pointer<iterator>::value ? L : layout_type::dynamic;
426 using type = xtensor_view<decltype(adaptor), 1, layout, extension::get_expression_tag_t<E>>;
427 return type(std::move(adaptor), {size});
428 }
429
443 template <layout_type L, class E>
444 inline auto flatten(E&& e)
445 {
446 return ravel<L>(std::forward<E>(e));
447 }
448
457 template <layout_type L, class T>
458 inline auto flatnonzero(const T& arr)
459 {
460 return nonzero(ravel<L>(arr))[0];
461 }
462
463 /*****************************
464 * trim_zeros implementation *
465 *****************************/
466
476 template <class E>
477 inline auto trim_zeros(E&& e, const std::string& direction)
478 {
479 XTENSOR_ASSERT_MSG(e.dimension() == 1, "Dimension for trim_zeros has to be 1.");
480
481 std::ptrdiff_t begin = 0, end = static_cast<std::ptrdiff_t>(e.size());
482
483 auto find_fun = [](const auto& i)
484 {
485 return i != 0;
486 };
487
488 if (direction.find("f") != std::string::npos)
489 {
490 begin = std::find_if(e.cbegin(), e.cend(), find_fun) - e.cbegin();
491 }
492
493 if (direction.find("b") != std::string::npos && begin != end)
494 {
495 end -= std::find_if(e.crbegin(), e.crend(), find_fun) - e.crbegin();
496 }
497
498 return strided_view(std::forward<E>(e), {range(begin, end)});
499 }
500
501 /**************************
502 * squeeze implementation *
503 **************************/
504
514 template <class E>
515 inline auto squeeze(E&& e)
516 {
517 dynamic_shape<std::size_t> new_shape;
518 dynamic_shape<std::ptrdiff_t> new_strides;
519 std::copy_if(
520 e.shape().cbegin(),
521 e.shape().cend(),
522 std::back_inserter(new_shape),
523 [](std::size_t i)
524 {
525 return i != 1;
526 }
527 );
528 decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
529 std::copy_if(
530 old_strides.cbegin(),
531 old_strides.cend(),
532 std::back_inserter(new_strides),
533 [](std::ptrdiff_t i)
534 {
535 return i != 0;
536 }
537 );
538
539 return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
540 }
541
542 namespace detail
543 {
544 template <class E, class S>
545 inline auto squeeze_impl(E&& e, S&& axis, check_policy::none)
546 {
547 std::size_t new_dim = e.dimension() - axis.size();
548 dynamic_shape<std::size_t> new_shape(new_dim);
549 dynamic_shape<std::ptrdiff_t> new_strides(new_dim);
550
551 decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
552
553 for (std::size_t i = 0, ix = 0; i < e.dimension(); ++i)
554 {
555 if (axis.cend() == std::find(axis.cbegin(), axis.cend(), i))
556 {
557 new_shape[ix] = e.shape()[i];
558 new_strides[ix++] = old_strides[i];
559 }
560 }
561
562 return strided_view(std::forward<E>(e), std::move(new_shape), std::move(new_strides), 0, e.layout());
563 }
564
565 template <class E, class S>
566 inline auto squeeze_impl(E&& e, S&& axis, check_policy::full)
567 {
568 for (auto ix : axis)
569 {
570 if (static_cast<std::size_t>(ix) > e.dimension())
571 {
572 XTENSOR_THROW(std::runtime_error, "Axis argument to squeeze > dimension of expression");
573 }
574 if (e.shape()[static_cast<std::size_t>(ix)] != 1)
575 {
576 XTENSOR_THROW(std::runtime_error, "Trying to squeeze axis != 1");
577 }
578 }
579 return squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy::none());
580 }
581 }
582
593 template <class E, xtl::non_integral_concept S, class Tag>
594 inline auto squeeze(E&& e, S&& axis, Tag check_policy)
595 {
596 return detail::squeeze_impl(std::forward<E>(e), std::forward<S>(axis), check_policy);
597 }
598
600 template <class E, class I, std::size_t N, class Tag = check_policy::none>
601 inline auto squeeze(E&& e, const I (&axis)[N], Tag check_policy = Tag())
602 {
603 using arr_t = std::array<I, N>;
604 return detail::squeeze_impl(
605 std::forward<E>(e),
606 xtl::forward_sequence<arr_t, decltype(axis)>(axis),
607 check_policy
608 );
609 }
610
611 template <class E, class Tag = check_policy::none>
612 inline auto squeeze(E&& e, std::size_t axis, Tag check_policy = Tag())
613 {
614 return squeeze(std::forward<E>(e), std::array<std::size_t, 1>{axis}, check_policy);
615 }
616
618
619 /******************************
620 * expand_dims implementation *
621 ******************************/
622
634 template <class E>
635 inline auto expand_dims(E&& e, std::size_t axis)
636 {
637 xstrided_slice_vector sv(e.dimension() + 1, all());
638 sv[axis] = newaxis();
639 return strided_view(std::forward<E>(e), std::move(sv));
640 }
641
642 /*****************************
643 * atleast_Nd implementation *
644 *****************************/
645
659 template <std::size_t N, class E>
660 inline auto atleast_Nd(E&& e)
661 {
662 xstrided_slice_vector sv((std::max)(e.dimension(), N), all());
663 if (e.dimension() < N)
664 {
665 std::size_t i = 0;
666 std::size_t end = static_cast<std::size_t>(std::round(double(N - e.dimension()) / double(N)));
667 for (; i < end; ++i)
668 {
669 sv[i] = newaxis();
670 }
671 i += e.dimension();
672 for (; i < N; ++i)
673 {
674 sv[i] = newaxis();
675 }
676 }
677 return strided_view(std::forward<E>(e), std::move(sv));
678 }
679
686 template <class E>
687 inline auto atleast_1d(E&& e)
688 {
689 return atleast_Nd<1>(std::forward<E>(e));
690 }
691
698 template <class E>
699 inline auto atleast_2d(E&& e)
700 {
701 return atleast_Nd<2>(std::forward<E>(e));
702 }
703
710 template <class E>
711 inline auto atleast_3d(E&& e)
712 {
713 return atleast_Nd<3>(std::forward<E>(e));
714 }
715
716 /************************
717 * split implementation *
718 ************************/
719
733 template <class E>
734 inline auto split(E& e, std::size_t n, std::size_t axis)
735 {
736 if (axis >= e.dimension())
737 {
738 XTENSOR_THROW(std::runtime_error, "Split along axis > dimension.");
739 }
740
741 std::size_t ax_sz = e.shape()[axis];
742 xstrided_slice_vector sv(e.dimension(), all());
743 std::size_t step = ax_sz / n;
744 std::size_t rest = ax_sz % n;
745
746 if (rest)
747 {
748 XTENSOR_THROW(std::runtime_error, "Split does not result in equal division.");
749 }
750
751 std::vector<decltype(strided_view(e, sv))> result;
752 for (std::size_t i = 0; i < n; ++i)
753 {
754 sv[axis] = range(i * step, (i + 1) * step);
755 result.emplace_back(strided_view(e, sv));
756 }
757 return result;
758 }
759
769 template <class E>
770 inline auto hsplit(E& e, std::size_t n)
771 {
772 return split(e, n, std::size_t(1));
773 }
774
784 template <class E>
785 inline auto vsplit(E& e, std::size_t n)
786 {
787 return split(e, n, std::size_t(0));
788 }
789
790 /***********************
791 * flip implementation *
792 ***********************/
793
801 template <class E>
802 inline auto flip(E&& e)
803 {
804 using size_type = typename std::decay_t<E>::size_type;
805 auto r = flip(e, 0);
806 for (size_type d = 1; d < e.dimension(); ++d)
807 {
808 r = flip(r, d);
809 }
810 return r;
811 }
812
824 template <class E>
825 inline auto flip(E&& e, std::size_t axis)
826 {
827 using shape_type = xindex_type_t<typename std::decay_t<E>::shape_type>;
828
829 shape_type shape;
830 resize_container(shape, e.shape().size());
831 std::copy(e.shape().cbegin(), e.shape().cend(), shape.begin());
832
833 get_strides_t<shape_type> strides;
834 decltype(auto) old_strides = detail::get_strides<XTENSOR_DEFAULT_LAYOUT>(e);
835 resize_container(strides, old_strides.size());
836 std::copy(old_strides.cbegin(), old_strides.cend(), strides.begin());
837
838 strides[axis] *= -1;
839 std::size_t offset = static_cast<std::size_t>(
840 static_cast<std::ptrdiff_t>(e.data_offset())
841 + old_strides[axis] * (static_cast<std::ptrdiff_t>(e.shape()[axis]) - 1)
842 );
843
844 return strided_view(std::forward<E>(e), std::move(shape), std::move(strides), offset);
845 }
846
847 /************************
848 * rot90 implementation *
849 ************************/
850
851 namespace detail
852 {
853 template <std::ptrdiff_t N>
854 struct rot90_impl;
855
856 template <>
857 struct rot90_impl<0>
858 {
859 template <class E>
860 inline auto operator()(E&& e, const std::array<std::size_t, 2>& /*axes*/)
861 {
862 return std::forward<E>(e);
863 }
864 };
865
866 template <>
867 struct rot90_impl<1>
868 {
869 template <class E>
870 inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
871 {
872 using std::swap;
873
874 dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
875 std::iota(axes_list.begin(), axes_list.end(), 0);
876 swap(axes_list[axes[0]], axes_list[axes[1]]);
877
878 return transpose(flip(std::forward<E>(e), axes[1]), std::move(axes_list));
879 }
880 };
881
882 template <>
883 struct rot90_impl<2>
884 {
885 template <class E>
886 inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
887 {
888 return flip(flip(std::forward<E>(e), axes[0]), axes[1]);
889 }
890 };
891
892 template <>
893 struct rot90_impl<3>
894 {
895 template <class E>
896 inline auto operator()(E&& e, const std::array<std::size_t, 2>& axes)
897 {
898 using std::swap;
899
900 dynamic_shape<std::ptrdiff_t> axes_list(e.shape().size());
901 std::iota(axes_list.begin(), axes_list.end(), 0);
902 swap(axes_list[axes[0]], axes_list[axes[1]]);
903
904 return flip(transpose(std::forward<E>(e), std::move(axes_list)), axes[1]);
905 }
906 };
907 }
908
920 template <std::ptrdiff_t N, class E>
921 inline auto rot90(E&& e, const std::array<std::ptrdiff_t, 2>& axes)
922 {
923 auto ndim = static_cast<std::ptrdiff_t>(e.shape().size());
924
925 if (axes[0] == axes[1] || std::abs(axes[0] - axes[1]) == ndim)
926 {
927 XTENSOR_THROW(std::runtime_error, "Axes must be different");
928 }
929
930 auto norm_axes = forward_normalize<std::array<std::size_t, 2>>(e, axes);
931 constexpr std::ptrdiff_t n = (4 + (N % 4)) % 4;
932
933 return detail::rot90_impl<n>()(std::forward<E>(e), norm_axes);
934 }
935
936 /***********************
937 * roll implementation *
938 ***********************/
939
953 template <class E>
954 inline auto roll(E&& e, std::ptrdiff_t shift)
955 {
956 auto cpy = empty_like(e);
957 auto flat_size = std::accumulate(
958 cpy.shape().begin(),
959 cpy.shape().end(),
960 1L,
961 std::multiplies<std::size_t>()
962 );
963 while (shift < 0)
964 {
965 shift += flat_size;
966 }
967
968 shift %= flat_size;
969 std::copy(e.begin(), e.end() - shift, std::copy(e.end() - shift, e.end(), cpy.begin()));
970
971 return cpy;
972 }
973
974 namespace detail
975 {
979
980 template <class To, class From, class S>
981 To roll(To to, From from, std::ptrdiff_t shift, std::size_t axis, const S& shape, std::size_t M)
982 {
983 std::ptrdiff_t dim = std::ptrdiff_t(shape[M]);
984 std::ptrdiff_t offset = std::accumulate(
985 shape.begin() + M + 1,
986 shape.end(),
987 std::ptrdiff_t(1),
988 std::multiplies<std::ptrdiff_t>()
989 );
990 if (shape.size() == M + 1)
991 {
992 if (axis == M)
993 {
994 const auto split = from + (dim - shift) * offset;
995 for (auto iter = split, end = from + dim * offset; iter != end; iter += offset, ++to)
996 {
997 *to = *iter;
998 }
999 for (auto iter = from, end = split; iter != end; iter += offset, ++to)
1000 {
1001 *to = *iter;
1002 }
1003 }
1004 else
1005 {
1006 for (auto iter = from, end = from + dim * offset; iter != end; iter += offset, ++to)
1007 {
1008 *to = *iter;
1009 }
1010 }
1011 }
1012 else
1013 {
1014 if (axis == M)
1015 {
1016 const auto split = from + (dim - shift) * offset;
1017 for (auto iter = split, end = from + dim * offset; iter != end; iter += offset)
1018 {
1019 to = roll(to, iter, shift, axis, shape, M + 1);
1020 }
1021 for (auto iter = from, end = split; iter != end; iter += offset)
1022 {
1023 to = roll(to, iter, shift, axis, shape, M + 1);
1024 }
1025 }
1026 else
1027 {
1028 for (auto iter = from, end = from + dim * offset; iter != end; iter += offset)
1029 {
1030 to = roll(to, iter, shift, axis, shape, M + 1);
1031 }
1032 }
1033 }
1034 return to;
1035 }
1036 }
1037
1050 template <class E>
1051 inline auto roll(E&& e, std::ptrdiff_t shift, std::ptrdiff_t axis)
1052 {
1053 auto cpy = empty_like(e);
1054 const auto& shape = cpy.shape();
1055 std::size_t saxis = static_cast<std::size_t>(axis);
1056 if (axis < 0)
1057 {
1058 axis += std::ptrdiff_t(cpy.dimension());
1059 }
1060
1061 if (saxis >= cpy.dimension() || axis < 0)
1062 {
1063 XTENSOR_THROW(std::runtime_error, "axis is no within shape dimension.");
1064 }
1065
1066 const auto axis_dim = static_cast<std::ptrdiff_t>(shape[saxis]);
1067 while (shift < 0)
1068 {
1069 shift += axis_dim;
1070 }
1071
1072 detail::roll(cpy.begin(), e.begin(), shift, saxis, shape, 0);
1073 return cpy;
1074 }
1075
1076 /****************************
1077 * repeat implementation *
1078 ****************************/
1079
1080 namespace detail
1081 {
1082 template <class E, class R>
1083 inline auto make_xrepeat(E&& e, R&& r, typename std::decay_t<E>::size_type axis)
1084 {
1085 const auto casted_axis = static_cast<typename std::decay_t<E>::size_type>(axis);
1086 if (r.size() != e.shape(casted_axis))
1087 {
1088 XTENSOR_THROW(std::invalid_argument, "repeats must have the same size as the specified axis");
1089 }
1090 return xrepeat<const_xclosure_t<E>, R>(std::forward<E>(e), std::forward<R>(r), axis);
1091 }
1092 }
1093
1104 template <class E>
1105 inline auto repeat(E&& e, std::size_t repeats, std::size_t axis)
1106 {
1107 const auto casted_axis = static_cast<typename std::decay_t<E>::size_type>(axis);
1108 std::vector<std::size_t> broadcasted_repeats(e.shape(casted_axis));
1109 std::fill(broadcasted_repeats.begin(), broadcasted_repeats.end(), repeats);
1110 return repeat(std::forward<E>(e), std::move(broadcasted_repeats), axis);
1111 }
1112
1124 template <class E>
1125 inline auto repeat(E&& e, const std::vector<std::size_t>& repeats, std::size_t axis)
1126 {
1127 return detail::make_xrepeat(std::forward<E>(e), repeats, axis);
1128 }
1129
1140 template <class E>
1141 inline auto repeat(E&& e, std::vector<std::size_t>&& repeats, std::size_t axis)
1142 {
1143 return detail::make_xrepeat(std::forward<E>(e), std::move(repeats), axis);
1144 }
1145}
1146
1147#endif
Dense multidimensional container adaptor with view semantics and fixed dimension.
Definition xtensor.hpp:329
bool all(E &&e)
Any.
auto nonzero(const T &arr)
return vector of indices where T is not zero
auto flatten(E &&e)
Return a flatten view of the given expression.
auto atleast_1d(E &&e)
Expand to at least 1D.
auto roll(E &&e, std::ptrdiff_t shift)
Roll an expression.
auto squeeze(E &&e)
Returns a squeeze view of the given expression.
auto moveaxis(E &&e, std::ptrdiff_t src, std::ptrdiff_t dest)
Return a new expression with an axis move to a new position.
auto ravel(E &&e)
Return a flatten view of the given expression.
auto transpose(E &&e) noexcept
Returns a transpose view by reversing the dimensions of xexpression e.
auto atleast_Nd(E &&e)
Expand dimensions of xexpression to at least N
auto repeat(E &&e, std::size_t repeats, std::size_t axis)
Repeat elements of an expression along a given axis.
auto trim_zeros(E &&e, const std::string &direction="fb")
Trim zeros at beginning, end or both of 1D sequence.
auto split(E &e, std::size_t n, std::size_t axis=0)
Split xexpression along axis into subexpressions.
auto rot90(E &&e, const std::array< std::ptrdiff_t, 2 > &axes={0, 1})
Rotate an array by 90 degrees in the plane specified by axes.
auto expand_dims(E &&e, std::size_t axis)
Expand the shape of an xexpression.
auto vsplit(E &e, std::size_t n)
Split an xexpression into subexpressions vertically (row-wise)
auto flip(E &&e)
Reverse the order of elements in an xexpression along every axis.
auto atleast_2d(E &&e)
Expand to at least 2D.
auto swapaxes(E &&e, std::ptrdiff_t axis1, std::ptrdiff_t axis2)
Return a new expression with two axes interchanged.
auto hsplit(E &e, std::size_t n)
Split an xexpression into subexpressions horizontally (column-wise)
auto atleast_3d(E &&e)
Expand to at least 3D.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
Definition xstrides.hpp:248
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
Definition xslice.hpp:818
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto newaxis() noexcept
Returns a slice representing a new axis of length one, to be used as an argument of view function.
Definition xslice.hpp:300
layout_type
Definition xlayout.hpp:24
auto empty_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of the same shape,...
Definition xbuilder.hpp:121
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.
auto flatnonzero(const T &arr)
Return indices that are non-zero in the flattened version of arr.