xtensor
Loading...
Searching...
No Matches
xbuilder.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
14#ifndef XTENSOR_BUILDER_HPP
15#define XTENSOR_BUILDER_HPP
16
17#include <array>
18#include <chrono>
19#include <cmath>
20#include <cstddef>
21#include <functional>
22#include <utility>
23#include <vector>
24
25#include <xtl/xclosure.hpp>
26#include <xtl/xsequence.hpp>
27#include <xtl/xtype_traits.hpp>
28
29#include "xbroadcast.hpp"
30#include "xfunction.hpp"
31#include "xgenerator.hpp"
32#include "xoperation.hpp"
33
34namespace xt
35{
36
37 /********
38 * ones *
39 ********/
40
45 template <class T, class S>
46 inline auto ones(S shape) noexcept
47 {
48 return broadcast(T(1), std::forward<S>(shape));
49 }
50
51 template <class T, class I, std::size_t L>
52 inline auto ones(const I (&shape)[L]) noexcept
53 {
54 return broadcast(T(1), shape);
55 }
56
57 /*********
58 * zeros *
59 *********/
60
65 template <class T, class S>
66 inline auto zeros(S shape) noexcept
67 {
68 return broadcast(T(0), std::forward<S>(shape));
69 }
70
71 template <class T, class I, std::size_t L>
72 inline auto zeros(const I (&shape)[L]) noexcept
73 {
74 return broadcast(T(0), shape);
75 }
76
88 template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class S>
89 inline xarray<T, L> empty(const S& shape)
90 {
91 return xarray<T, L>::from_shape(shape);
92 }
93
94 template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class ST, std::size_t N>
95 inline xtensor<T, N, L> empty(const std::array<ST, N>& shape)
96 {
97 using shape_type = typename xtensor<T, N>::shape_type;
98 return xtensor<T, N, L>(xtl::forward_sequence<shape_type, decltype(shape)>(shape));
99 }
100
101 template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, class I, std::size_t N>
102 inline xtensor<T, N, L> empty(const I (&shape)[N])
103 {
104 using shape_type = typename xtensor<T, N>::shape_type;
105 return xtensor<T, N, L>(xtl::forward_sequence<shape_type, decltype(shape)>(shape));
106 }
107
108 template <class T, layout_type L = XTENSOR_DEFAULT_LAYOUT, std::size_t... N>
109 inline xtensor_fixed<T, fixed_shape<N...>, L> empty(const fixed_shape<N...>& /*shape*/)
110 {
111 return xtensor_fixed<T, fixed_shape<N...>, L>();
112 }
113
120 template <class E>
121 inline auto empty_like(const xexpression<E>& e)
122 {
124 auto res = xtype::from_shape(e.derived_cast().shape());
125 return res;
126 }
127
135 template <class E>
136 inline auto full_like(const xexpression<E>& e, typename E::value_type fill_value)
137 {
139 auto res = xtype::from_shape(e.derived_cast().shape());
140 res.fill(fill_value);
141 return res;
142 }
143
153 template <class E>
154 inline auto zeros_like(const xexpression<E>& e)
155 {
156 return full_like(e, typename E::value_type(0));
157 }
158
168 template <class E>
169 inline auto ones_like(const xexpression<E>& e)
170 {
171 return full_like(e, typename E::value_type(1));
172 }
173
174 namespace detail
175 {
176 template <class T, class S>
177 struct get_mult_type_impl
178 {
179 using type = T;
180 };
181
182 template <class T, class R, class P>
183 struct get_mult_type_impl<T, std::chrono::duration<R, P>>
184 {
185 using type = R;
186 };
187
188 template <class T, class S>
189 using get_mult_type = typename get_mult_type_impl<T, S>::type;
190
191 // These methods should be private methods of arange_generator, however thi leads
192 // to ICE on VS2015
193 template <class R, class E, class U, class X, XTL_REQUIRES(xtl::is_integral<X>)>
194 inline void arange_assign_to(xexpression<E>& e, U start, U, X step, bool) noexcept
195 {
196 auto& de = e.derived_cast();
197 U value = start;
198
199 for (auto&& el : de.storage())
200 {
201 el = static_cast<R>(value);
202 value += step;
203 }
204 }
205
206 template <class R, class E, class U, class X, XTL_REQUIRES(xtl::negation<xtl::is_integral<X>>)>
207 inline void arange_assign_to(xexpression<E>& e, U start, U stop, X step, bool endpoint) noexcept
208 {
209 auto& buf = e.derived_cast().storage();
210 using size_type = decltype(buf.size());
211 using mult_type = get_mult_type<U, X>;
212 size_type num = buf.size();
213 for (size_type i = 0; i < num; ++i)
214 {
215 buf[i] = static_cast<R>(start + step * mult_type(i));
216 }
217 if (endpoint && num > 1)
218 {
219 buf[num - 1] = static_cast<R>(stop);
220 }
221 }
222
223 template <class T, class R = T, class S = T>
224 class arange_generator
225 {
226 public:
227
228 using value_type = R;
229 using step_type = S;
230
231 arange_generator(T start, T stop, S step, size_t num_steps, bool endpoint = false)
232 : m_start(start)
233 , m_stop(stop)
234 , m_step(step)
235 , m_num_steps(num_steps)
236 , m_endpoint(endpoint)
237 {
238 }
239
240 template <class... Args>
241 inline R operator()(Args... args) const
242 {
243 return access_impl(args...);
244 }
245
246 template <class It>
247 inline R element(It first, It) const
248 {
249 return access_impl(*first);
250 }
251
252 template <class E>
253 inline void assign_to(xexpression<E>& e) const noexcept
254 {
255 arange_assign_to<R>(e, m_start, m_stop, m_step, m_endpoint);
256 }
257
258 private:
259
260 T m_start;
261 T m_stop;
262 step_type m_step;
263 size_t m_num_steps;
264 bool m_endpoint; // true for setting the last element to m_stop
265
266 template <class T1, class... Args>
267 inline R access_impl(T1 t, Args...) const
268 {
269 if (m_endpoint && m_num_steps > 1 && size_t(t) == m_num_steps - 1)
270 {
271 return static_cast<R>(m_stop);
272 }
273 // Avoids warning when T = char (because char + char => int!)
274 using mult_type = get_mult_type<T, S>;
275 return static_cast<R>(m_start + m_step * mult_type(t));
276 }
277
278 inline R access_impl() const
279 {
280 return static_cast<R>(m_start);
281 }
282 };
283
284 template <class T, class S>
285 using both_integer = xtl::conjunction<xtl::is_integral<T>, xtl::is_integral<S>>;
286
287 template <class T, class S>
288 using integer_with_signed_integer = xtl::conjunction<both_integer<T, S>, xtl::is_signed<S>>;
289
290 template <class T, class S>
291 using integer_with_unsigned_integer = xtl::conjunction<both_integer<T, S>, std::is_unsigned<S>>;
292
293 template <class T, class S = T, XTL_REQUIRES(xtl::negation<both_integer<T, S>>)>
294 inline auto arange_impl(T start, T stop, S step = 1) noexcept
295 {
296 std::size_t shape = static_cast<std::size_t>(std::ceil((stop - start) / step));
297 return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
298 }
299
300 template <class T, class S = T, XTL_REQUIRES(integer_with_signed_integer<T, S>)>
301 inline auto arange_impl(T start, T stop, S step = 1) noexcept
302 {
303 bool empty_cond = (stop - start) / step <= 0;
304 std::size_t shape = 0;
305 if (!empty_cond)
306 {
307 shape = stop > start ? static_cast<std::size_t>((stop - start + step - S(1)) / step)
308 : static_cast<std::size_t>((start - stop - step - S(1)) / -step);
309 }
310 return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
311 }
312
313 template <class T, class S = T, XTL_REQUIRES(integer_with_unsigned_integer<T, S>)>
314 inline auto arange_impl(T start, T stop, S step = 1) noexcept
315 {
316 bool empty_cond = stop <= start;
317 std::size_t shape = 0;
318 if (!empty_cond)
319 {
320 shape = static_cast<std::size_t>((stop - start + step - S(1)) / step);
321 }
322 return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
323 }
324
325 template <class F>
326 class fn_impl
327 {
328 public:
329
330 using value_type = typename F::value_type;
331 using size_type = std::size_t;
332
333 fn_impl(F&& f)
334 : m_ft(f)
335 {
336 }
337
338 inline value_type operator()() const
339 {
340 size_type idx[1] = {0ul};
341 return access_impl(std::begin(idx), std::end(idx));
342 }
343
344 template <class... Args>
345 inline value_type operator()(Args... args) const
346 {
347 size_type idx[sizeof...(Args)] = {static_cast<size_type>(args)...};
348 return access_impl(std::begin(idx), std::end(idx));
349 }
350
351 template <class It>
352 inline value_type element(It first, It last) const
353 {
354 return access_impl(first, last);
355 }
356
357 private:
358
359 F m_ft;
360
361 template <class It>
362 inline value_type access_impl(const It& begin, const It& end) const
363 {
364 return m_ft(begin, end);
365 }
366 };
367
368 template <class T>
369 class eye_fn
370 {
371 public:
372
373 using value_type = T;
374
375 eye_fn(int k)
376 : m_k(k)
377 {
378 }
379
380 template <class It>
381 inline T operator()(const It& /*begin*/, const It& end) const
382 {
383 using lvalue_type = typename std::iterator_traits<It>::value_type;
384 return *(end - 1) == *(end - 2) + static_cast<lvalue_type>(m_k) ? T(1) : T(0);
385 }
386
387 private:
388
389 std::ptrdiff_t m_k;
390 };
391 }
392
402 template <class T = bool>
403 inline auto eye(const std::vector<std::size_t>& shape, int k = 0)
404 {
405 return detail::make_xgenerator(detail::fn_impl<detail::eye_fn<T>>(detail::eye_fn<T>(k)), shape);
406 }
407
417 template <class T = bool>
418 inline auto eye(std::size_t n, int k = 0)
419 {
420 return eye<T>({n, n}, k);
421 }
422
431 template <class T, class S = T>
432 inline auto arange(T start, T stop, S step = 1) noexcept
433 {
434 return detail::arange_impl(start, stop, step);
435 }
436
444 template <class T>
445 inline auto arange(T stop) noexcept
446 {
447 return arange<T>(T(0), stop, T(1));
448 }
449
459 template <class T>
460 inline auto linspace(T start, T stop, std::size_t num_samples = 50, bool endpoint = true) noexcept
461 {
462 using fp_type = std::common_type_t<T, double>;
463 fp_type step = fp_type(stop - start) / std::fmax(fp_type(1), fp_type(num_samples - (endpoint ? 1 : 0)));
464 return detail::make_xgenerator(
465 detail::arange_generator<fp_type, T>(fp_type(start), fp_type(stop), step, num_samples, endpoint),
467 );
468 }
469
480 template <class T>
481 inline auto logspace(T start, T stop, std::size_t num_samples, T base = 10, bool endpoint = true) noexcept
482 {
483 return pow(std::move(base), linspace(start, stop, num_samples, endpoint));
484 }
485
486 namespace detail
487 {
488 template <class... CT>
489 class concatenate_access
490 {
491 public:
492
493 using tuple_type = std::tuple<CT...>;
494 using size_type = std::size_t;
495 using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
496
497 template <class It>
498 inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
499 {
500 // trim off extra indices if provided to match behavior of containers
501 auto dim_offset = std::distance(first, last) - std::get<0>(t).dimension();
502 size_t axis_dim = *(first + axis + dim_offset);
503 auto match = [&](auto& arr)
504 {
505 if (axis_dim >= arr.shape()[axis])
506 {
507 axis_dim -= arr.shape()[axis];
508 return false;
509 }
510 return true;
511 };
512
513 auto get = [&](auto& arr)
514 {
515 size_t offset = 0;
516 const size_t end = arr.dimension();
517 for (size_t i = 0; i < end; i++)
518 {
519 const auto& shape = arr.shape();
520 const size_t stride = std::accumulate(
521 shape.begin() + i + 1,
522 shape.end(),
523 1,
524 std::multiplies<size_t>()
525 );
526 if (i == axis)
527 {
528 offset += axis_dim * stride;
529 }
530 else
531 {
532 const auto len = (*(first + i + dim_offset));
533 offset += len * stride;
534 }
535 }
536 const auto element = arr.begin() + offset;
537 return *element;
538 };
539
540 size_type i = 0;
541 for (; i < sizeof...(CT); ++i)
542 {
543 if (apply<bool>(i, match, t))
544 {
545 break;
546 }
547 }
548 return apply<value_type>(i, get, t);
549 }
550 };
551
552 template <class... CT>
553 class stack_access
554 {
555 public:
556
557 using tuple_type = std::tuple<CT...>;
558 using size_type = std::size_t;
559 using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
560
561 template <class It>
562 inline value_type access(const tuple_type& t, size_type axis, It first, It) const
563 {
564 auto get_item = [&](auto& arr)
565 {
566 size_t offset = 0;
567 const size_t end = arr.dimension();
568 size_t after_axis = 0;
569 for (size_t i = 0; i < end; i++)
570 {
571 if (i == axis)
572 {
573 after_axis = 1;
574 }
575 const auto& shape = arr.shape();
576 const size_t stride = std::accumulate(
577 shape.begin() + i + 1,
578 shape.end(),
579 1,
580 std::multiplies<size_t>()
581 );
582 const auto len = (*(first + i + after_axis));
583 offset += len * stride;
584 }
585 const auto element = arr.begin() + offset;
586 return *element;
587 };
588 size_type i = *(first + axis);
589 return apply<value_type>(i, get_item, t);
590 }
591 };
592
593 template <class... CT>
594 class vstack_access
595 {
596 public:
597
598 using tuple_type = std::tuple<CT...>;
599 using size_type = std::size_t;
600 using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
601
602 template <class It>
603 inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
604 {
605 if (std::get<0>(t).dimension() == 1)
606 {
607 return stack.access(t, axis, first, last);
608 }
609 else
610 {
611 return concatonate.access(t, axis, first, last);
612 }
613 }
614
615 private:
616
617 concatenate_access<CT...> concatonate;
618 stack_access<CT...> stack;
619 };
620
621 template <template <class...> class F, class... CT>
622 class concatenate_invoker
623 {
624 public:
625
626 using tuple_type = std::tuple<CT...>;
627 using size_type = std::size_t;
628 using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
629
630 inline concatenate_invoker(tuple_type&& t, size_type axis)
631 : m_t(std::move(t))
632 , m_axis(axis)
633 {
634 }
635
636 template <class... Args>
637 inline value_type operator()(Args... args) const
638 {
639 // TODO: avoid memory allocation
640 xindex index({static_cast<size_type>(args)...});
641 return access_method.access(m_t, m_axis, index.begin(), index.end());
642 }
643
644 template <class It>
645 inline value_type element(It first, It last) const
646 {
647 return access_method.access(m_t, m_axis, first, last);
648 }
649
650 private:
651
652 F<CT...> access_method;
653 tuple_type m_t;
654 size_type m_axis;
655 };
656
657 template <class... CT>
658 using concatenate_impl = concatenate_invoker<concatenate_access, CT...>;
659
660 template <class... CT>
661 using stack_impl = concatenate_invoker<stack_access, CT...>;
662
663 template <class... CT>
664 using vstack_impl = concatenate_invoker<vstack_access, CT...>;
665
666 template <class CT>
667 class repeat_impl
668 {
669 public:
670
671 using xexpression_type = std::decay_t<CT>;
672 using size_type = typename xexpression_type::size_type;
673 using value_type = typename xexpression_type::value_type;
674
675 template <class CTA>
676 repeat_impl(CTA&& source, size_type axis)
677 : m_source(std::forward<CTA>(source))
678 , m_axis(axis)
679 {
680 }
681
682 template <class... Args>
683 value_type operator()(Args... args) const
684 {
685 std::array<size_type, sizeof...(Args)> args_arr = {static_cast<size_type>(args)...};
686 return m_source(args_arr[m_axis]);
687 }
688
689 template <class It>
690 inline value_type element(It first, It) const
691 {
692 return m_source(*(first + static_cast<std::ptrdiff_t>(m_axis)));
693 }
694
695 private:
696
697 CT m_source;
698 size_type m_axis;
699 };
700 }
701
706 template <class... Types>
707 inline auto xtuple(Types&&... args)
708 {
709 return std::tuple<xtl::const_closure_type_t<Types>...>(std::forward<Types>(args)...);
710 }
711
712 namespace detail
713 {
714 template <bool... values>
716
717 template <class X, class Y, std::size_t axis, class AxesSequence>
718 struct concat_fixed_shape_impl;
719
720 template <class X, class Y, std::size_t axis, std::size_t... Is>
721 struct concat_fixed_shape_impl<X, Y, axis, std::index_sequence<Is...>>
722 {
723 static_assert(X::size() == Y::size(), "Concatenation requires equisized shapes");
724 static_assert(axis < X::size(), "Concatenation requires a valid axis");
725 static_assert(
726 all_true<(axis == Is || X::template get<Is>() == Y::template get<Is>())...>::value,
727 "Concatenation requires compatible shapes and axis"
728 );
729
730 using type = fixed_shape<
731 (axis == Is ? X::template get<Is>() + Y::template get<Is>() : X::template get<Is>())...>;
732 };
733
734 template <std::size_t axis, class X, class Y, class... Rest>
735 struct concat_fixed_shape;
736
737 template <std::size_t axis, class X, class Y>
738 struct concat_fixed_shape<axis, X, Y>
739 {
740 using type = typename concat_fixed_shape_impl<X, Y, axis, std::make_index_sequence<X::size()>>::type;
741 };
742
743 template <std::size_t axis, class X, class Y, class... Rest>
744 struct concat_fixed_shape
745 {
746 using type = typename concat_fixed_shape<axis, X, typename concat_fixed_shape<axis, Y, Rest...>::type>::type;
747 };
748
749 template <std::size_t axis, class... Args>
750 using concat_fixed_shape_t = typename concat_fixed_shape<axis, Args...>::type;
751
752 template <class... CT>
753 using all_fixed_shapes = detail::all_fixed<typename std::decay_t<CT>::shape_type...>;
754
755 struct concat_shape_builder_t
756 {
757 template <class Shape, bool = detail::is_fixed<Shape>::value>
758 struct concat_shape;
759
760 template <class Shape>
761 struct concat_shape<Shape, true>
762 {
763 // Convert `fixed_shape` to `static_shape` to allow runtime dimension calculation.
764 using type = static_shape<typename Shape::value_type, Shape::size()>;
765 };
766
767 template <class Shape>
768 struct concat_shape<Shape, false>
769 {
770 using type = Shape;
771 };
772
773 template <class... Args>
774 static auto build(const std::tuple<Args...>& t, std::size_t axis)
775 {
776 using shape_type = promote_shape_t<
777 typename concat_shape<typename std::decay_t<Args>::shape_type>::type...>;
778 using source_shape_type = decltype(std::get<0>(t).shape());
779 shape_type new_shape = xtl::forward_sequence<shape_type, source_shape_type>(
780 std::get<0>(t).shape()
781 );
782
783 auto check_shape = [&axis, &new_shape](auto& arr)
784 {
785 std::size_t s = new_shape.size();
786 bool res = s == arr.dimension();
787 for (std::size_t i = 0; i < s; ++i)
788 {
789 res = res && (i == axis || new_shape[i] == arr.shape(i));
790 }
791 if (!res)
792 {
793 throw_concatenate_error(new_shape, arr.shape());
794 }
795 };
796 for_each(check_shape, t);
797
798 auto shape_at_axis = [&axis](std::size_t prev, auto& arr) -> std::size_t
799 {
800 return prev + arr.shape()[axis];
801 };
802 new_shape[axis] += accumulate(shape_at_axis, std::size_t(0), t) - new_shape[axis];
803
804 return new_shape;
805 }
806 };
807
808 } // namespace detail
809
810 /***************
811 * concatenate *
812 ***************/
813
829 template <class... CT>
830 inline auto concatenate(std::tuple<CT...>&& t, std::size_t axis = 0)
831 {
832 const auto shape = detail::concat_shape_builder_t::build(t, axis);
833 return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape);
834 }
835
836 template <std::size_t axis, class... CT, typename = std::enable_if_t<detail::all_fixed_shapes<CT...>::value>>
837 inline auto concatenate(std::tuple<CT...>&& t)
838 {
839 using shape_type = detail::concat_fixed_shape_t<axis, typename std::decay_t<CT>::shape_type...>;
840 return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape_type{});
841 }
842
843 namespace detail
844 {
845 template <class T, std::size_t N>
846 inline std::array<T, N + 1> add_axis(std::array<T, N> arr, std::size_t axis, std::size_t value)
847 {
848 std::array<T, N + 1> temp;
849 std::copy(arr.begin(), arr.begin() + axis, temp.begin());
850 temp[axis] = value;
851 std::copy(arr.begin() + axis, arr.end(), temp.begin() + axis + 1);
852 return temp;
853 }
854
855 template <class T>
856 inline T add_axis(T arr, std::size_t axis, std::size_t value)
857 {
858 T temp(arr);
859 temp.insert(temp.begin() + std::ptrdiff_t(axis), value);
860 return temp;
861 }
862 }
863
882 template <class... CT>
883 inline auto stack(std::tuple<CT...>&& t, std::size_t axis = 0)
884 {
886 using source_shape_type = decltype(std::get<0>(t).shape());
887 auto new_shape = detail::add_axis(
888 xtl::forward_sequence<shape_type, source_shape_type>(std::get<0>(t).shape()),
889 axis,
890 sizeof...(CT)
891 );
892 return detail::make_xgenerator(detail::stack_impl<CT...>(std::move(t), axis), new_shape);
893 }
894
903 template <class... CT>
904 inline auto hstack(std::tuple<CT...>&& t)
905 {
906 auto dim = std::get<0>(t).dimension();
907 std::size_t axis = dim > std::size_t(1) ? 1 : 0;
908 return concatenate(std::move(t), axis);
909 }
910
911 namespace detail
912 {
913 template <class S, class... CT>
914 inline auto vstack_shape(std::tuple<CT...>& t, const S& shape)
915 {
916 using size_type = typename S::value_type;
917 auto res = shape.size() == size_type(1)
918 ? S({sizeof...(CT), shape[0]})
919 : concat_shape_builder_t::build(std::move(t), size_type(0));
920 return res;
921 }
922
923 template <class T, class... CT>
924 inline auto vstack_shape(const std::tuple<CT...>&, std::array<T, 1> shape)
925 {
926 std::array<T, 2> res = {sizeof...(CT), shape[0]};
927 return res;
928 }
929 }
930
939 template <class... CT>
940 inline auto vstack(std::tuple<CT...>&& t)
941 {
943 using source_shape_type = decltype(std::get<0>(t).shape());
944 auto new_shape = detail::vstack_shape(
945 t,
946 xtl::forward_sequence<shape_type, source_shape_type>(std::get<0>(t).shape())
947 );
948 return detail::make_xgenerator(detail::vstack_impl<CT...>(std::move(t), size_t(0)), new_shape);
949 }
950
951 namespace detail
952 {
953
954 template <std::size_t... I, class... E>
955 inline auto meshgrid_impl(std::index_sequence<I...>, E&&... e) noexcept
956 {
957#if defined _MSC_VER
958 const std::array<std::size_t, sizeof...(E)> shape = {e.shape()[0]...};
959 return std::make_tuple(
960 detail::make_xgenerator(detail::repeat_impl<xclosure_t<E>>(std::forward<E>(e), I), shape)...
961 );
962#else
963 return std::make_tuple(detail::make_xgenerator(
964 detail::repeat_impl<xclosure_t<E>>(std::forward<E>(e), I),
965 {e.shape()[0]...}
966 )...);
967#endif
968 }
969 }
970
979 template <class... E>
980 inline auto meshgrid(E&&... e) noexcept
981 {
982 return detail::meshgrid_impl(std::make_index_sequence<sizeof...(E)>(), std::forward<E>(e)...);
983 }
984
985 namespace detail
986 {
987 template <class CT>
988 class diagonal_fn
989 {
990 public:
991
992 using xexpression_type = std::decay_t<CT>;
993 using value_type = typename xexpression_type::value_type;
994
995 template <class CTA>
996 diagonal_fn(CTA&& source, int offset, std::size_t axis_1, std::size_t axis_2)
997 : m_source(std::forward<CTA>(source))
998 , m_offset(offset)
999 , m_axis_1(axis_1)
1000 , m_axis_2(axis_2)
1001 {
1002 }
1003
1004 template <class It>
1005 inline value_type operator()(It begin, It) const
1006 {
1007 xindex idx(m_source.shape().size());
1008
1009 for (std::size_t i = 0; i < idx.size(); i++)
1010 {
1011 if (i != m_axis_1 && i != m_axis_2)
1012 {
1013 idx[i] = static_cast<std::size_t>(*begin++);
1014 }
1015 }
1016 using it_vtype = typename std::iterator_traits<It>::value_type;
1017 it_vtype uoffset = static_cast<it_vtype>(m_offset);
1018 if (m_offset >= 0)
1019 {
1020 idx[m_axis_1] = static_cast<std::size_t>(*(begin));
1021 idx[m_axis_2] = static_cast<std::size_t>(*(begin) + uoffset);
1022 }
1023 else
1024 {
1025 idx[m_axis_1] = static_cast<std::size_t>(*(begin) -uoffset);
1026 idx[m_axis_2] = static_cast<std::size_t>(*(begin));
1027 }
1028 return m_source[idx];
1029 }
1030
1031 private:
1032
1033 CT m_source;
1034 const int m_offset;
1035 const std::size_t m_axis_1;
1036 const std::size_t m_axis_2;
1037 };
1038
1039 template <class CT>
1040 class diag_fn
1041 {
1042 public:
1043
1044 using xexpression_type = std::decay_t<CT>;
1045 using value_type = typename xexpression_type::value_type;
1046
1047 template <class CTA>
1048 diag_fn(CTA&& source, int k)
1049 : m_source(std::forward<CTA>(source))
1050 , m_k(k)
1051 {
1052 }
1053
1054 template <class It>
1055 inline value_type operator()(It begin, It) const
1056 {
1057 using it_vtype = typename std::iterator_traits<It>::value_type;
1058 it_vtype umk = static_cast<it_vtype>(m_k);
1059 if (m_k > 0)
1060 {
1061 return *begin + umk == *(begin + 1) ? m_source(*begin) : value_type(0);
1062 }
1063 else
1064 {
1065 return *begin + umk == *(begin + 1) ? m_source(*begin + umk) : value_type(0);
1066 }
1067 }
1068
1069 private:
1070
1071 CT m_source;
1072 const int m_k;
1073 };
1074
1075 template <class CT, class Comp>
1076 class trilu_fn
1077 {
1078 public:
1079
1080 using xexpression_type = std::decay_t<CT>;
1081 using value_type = typename xexpression_type::value_type;
1082 using signed_idx_type = long int;
1083
1084 template <class CTA>
1085 trilu_fn(CTA&& source, int k, Comp comp)
1086 : m_source(std::forward<CTA>(source))
1087 , m_k(k)
1088 , m_comp(comp)
1089 {
1090 }
1091
1092 template <class It>
1093 inline value_type operator()(It begin, It end) const
1094 {
1095 // have to cast to signed int otherwise -1 can lead to overflow
1096 return m_comp(signed_idx_type(*begin) + m_k, signed_idx_type(*(begin + 1)))
1097 ? m_source.element(begin, end)
1098 : value_type(0);
1099 }
1100
1101 private:
1102
1103 CT m_source;
1104 const signed_idx_type m_k;
1105 const Comp m_comp;
1106 };
1107 }
1108
1109 namespace detail
1110 {
1111 // meta-function returning the shape type for a diagonal
1112 template <class ST, class... S>
1113 struct diagonal_shape_type
1114 {
1115 using type = ST;
1116 };
1117
1118 template <class I, std::size_t L>
1119 struct diagonal_shape_type<std::array<I, L>>
1120 {
1121 using type = std::array<I, L - 1>;
1122 };
1123 }
1124
1149 template <class E>
1150 inline auto diagonal(E&& arr, int offset = 0, std::size_t axis_1 = 0, std::size_t axis_2 = 1)
1151 {
1152 using CT = xclosure_t<E>;
1153 using shape_type = typename detail::diagonal_shape_type<typename std::decay_t<E>::shape_type>::type;
1154
1155 auto shape = arr.shape();
1156 auto dimension = arr.dimension();
1157
1158 // The following shape calculation code is an almost verbatim adaptation of NumPy:
1159 // https://github.com/numpy/numpy/blob/2aabeafb97bea4e1bfa29d946fbf31e1104e7ae0/numpy/core/src/multiarray/item_selection.c#L1799
1160 auto ret_shape = xtl::make_sequence<shape_type>(dimension - 1, 0);
1161 int dim_1 = static_cast<int>(shape[axis_1]);
1162 int dim_2 = static_cast<int>(shape[axis_2]);
1163
1164 offset >= 0 ? dim_2 -= offset : dim_1 += offset;
1165
1166 auto diag_size = std::size_t(dim_2 < dim_1 ? dim_2 : dim_1);
1167
1168 std::size_t i = 0;
1169 for (std::size_t idim = 0; idim < dimension; ++idim)
1170 {
1171 if (idim != axis_1 && idim != axis_2)
1172 {
1173 ret_shape[i++] = shape[idim];
1174 }
1175 }
1176
1177 ret_shape.back() = diag_size;
1178
1179 return detail::make_xgenerator(
1180 detail::fn_impl<detail::diagonal_fn<CT>>(
1181 detail::diagonal_fn<CT>(std::forward<E>(arr), offset, axis_1, axis_2)
1182 ),
1183 ret_shape
1184 );
1185 }
1186
1201 template <class E>
1202 inline auto diag(E&& arr, int k = 0)
1203 {
1204 using CT = xclosure_t<E>;
1205 std::size_t sk = std::size_t(std::abs(k));
1206 std::size_t s = arr.shape()[0] + sk;
1207 return detail::make_xgenerator(
1208 detail::fn_impl<detail::diag_fn<CT>>(detail::diag_fn<CT>(std::forward<E>(arr), k)),
1209 {s, s}
1210 );
1211 }
1212
1222 template <class E>
1223 inline auto tril(E&& arr, int k = 0)
1224 {
1225 using CT = xclosure_t<E>;
1226 auto shape = arr.shape();
1227 return detail::make_xgenerator(
1228 detail::fn_impl<detail::trilu_fn<CT, std::greater_equal<long int>>>(
1229 detail::trilu_fn<CT, std::greater_equal<long int>>(
1230 std::forward<E>(arr),
1231 k,
1232 std::greater_equal<long int>()
1233 )
1234 ),
1235 shape
1236 );
1237 }
1238
1248 template <class E>
1249 inline auto triu(E&& arr, int k = 0)
1250 {
1251 using CT = xclosure_t<E>;
1252 auto shape = arr.shape();
1253 return detail::make_xgenerator(
1254 detail::fn_impl<detail::trilu_fn<CT, std::less_equal<long int>>>(
1255 detail::trilu_fn<CT, std::less_equal<long int>>(std::forward<E>(arr), k, std::less_equal<long int>())
1256 ),
1257 shape
1258 );
1259 }
1260}
1261#endif
auto pow(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::pow_fun, E1, E2 >
Power function.
Definition xmath.hpp:1015
standard mathematical functions for xexpressions
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.
auto stack(std::tuple< CT... > &&t, std::size_t axis=0)
Stack xexpressions along axis.
Definition xbuilder.hpp:883
auto arange(T start, T stop, S step=1) noexcept
Generates numbers evenly spaced within given half-open interval [start, stop).
Definition xbuilder.hpp:432
auto concatenate(std::tuple< CT... > &&t, std::size_t axis=0)
Concatenates xexpressions along axis.
Definition xbuilder.hpp:830
auto eye(const std::vector< std::size_t > &shape, int k=0)
Generates an array with ones on the diagonal.
Definition xbuilder.hpp:403
auto ones_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with ones and of the same shape,...
Definition xbuilder.hpp:169
auto ones(S shape) noexcept
Returns an xexpression containing ones of the specified shape.
Definition xbuilder.hpp:46
auto meshgrid(E &&... e) noexcept
Return coordinate tensors from coordinate vectors.
Definition xbuilder.hpp:980
layout_type
Definition xlayout.hpp:24
auto triu(E &&arr, int k=0)
Extract upper triangular matrix from xexpression.
auto full_like(const xexpression< E > &e, typename E::value_type fill_value)
Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with fill_value and of the same shape,...
Definition xbuilder.hpp:136
auto zeros_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with zeros and of the same shape,...
Definition xbuilder.hpp:154
auto zeros(S shape) noexcept
Returns an xexpression containing zeros of the specified shape.
Definition xbuilder.hpp:66
auto accumulate(F &&f, E &&e, EVS evaluation_strategy=EVS())
Accumulate and flatten array NOTE This function is not lazy!
auto linspace(T start, T stop, std::size_t num_samples=50, bool endpoint=true) noexcept
Generates num_samples evenly spaced numbers over given interval.
Definition xbuilder.hpp:460
auto vstack(std::tuple< CT... > &&t)
Stack xexpressions in sequence vertically (row wise).
Definition xbuilder.hpp:940
auto diagonal(E &&arr, int offset=0, std::size_t axis_1=0, std::size_t axis_2=1)
Returns the elements on the diagonal of arr If arr has more than two dimensions, then the axes specif...
auto hstack(std::tuple< CT... > &&t)
Stack xexpressions in sequence horizontally (column wise).
Definition xbuilder.hpp:904
auto tril(E &&arr, int k=0)
Extract lower triangular matrix from xexpression.
auto empty_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of the same shape,...
Definition xbuilder.hpp:121
auto diag(E &&arr, int k=0)
xexpression with values of arr on the diagonal, zeroes otherwise
auto xtuple(Types &&... args)
Creates tuples from arguments for concatenate and stack.
Definition xbuilder.hpp:707
auto logspace(T start, T stop, std::size_t num_samples, T base=10, bool endpoint=true) noexcept
Generates num_samples numbers evenly spaced on a log scale over given interval.
Definition xbuilder.hpp:481
xfixed_container< T, FSH, L, Sharable > xtensor_fixed
Alias template on xfixed_container with default parameters for layout type.
xarray< T, L > empty(const S &shape)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of with value_type T...
Definition xbuilder.hpp:89