14#ifndef XTENSOR_BUILDER_HPP
15#define XTENSOR_BUILDER_HPP
25#include <xtl/xclosure.hpp>
26#include <xtl/xsequence.hpp>
27#include <xtl/xtype_traits.hpp>
29#include "xbroadcast.hpp"
30#include "xfunction.hpp"
31#include "xgenerator.hpp"
32#include "xoperation.hpp"
45 template <
class T,
class S>
46 inline auto ones(
S shape)
noexcept
48 return broadcast(T(1), std::forward<S>(shape));
51 template <
class T,
class I, std::
size_t L>
52 inline auto ones(
const I (&shape)[L])
noexcept
65 template <
class T,
class S>
68 return broadcast(T(0), std::forward<S>(shape));
71 template <
class T,
class I, std::
size_t L>
72 inline auto zeros(
const I (&shape)[L])
noexcept
88 template <
class T, layout_type L = XTENSOR_DEFAULT_LAYOUT,
class S>
94 template <
class T, layout_type L = XTENSOR_DEFAULT_LAYOUT,
class ST, std::
size_t N>
95 inline xtensor<T, N, L>
empty(
const std::array<ST, N>& shape)
97 using shape_type =
typename xtensor<T, N>::shape_type;
98 return xtensor<T, N, L>(xtl::forward_sequence<shape_type,
decltype(shape)>(shape));
101 template <
class T, layout_type L = XTENSOR_DEFAULT_LAYOUT,
class I, std::
size_t N>
102 inline xtensor<T, N, L>
empty(
const I (&shape)[N])
104 using shape_type =
typename xtensor<T, N>::shape_type;
105 return xtensor<T, N, L>(xtl::forward_sequence<shape_type,
decltype(shape)>(shape));
108 template <
class T,
layout_type L = XTENSOR_DEFAULT_LAYOUT, std::size_t... N>
124 auto res = xtype::from_shape(
e.derived_cast().shape());
139 auto res = xtype::from_shape(
e.derived_cast().shape());
156 return full_like(
e,
typename E::value_type(0));
171 return full_like(
e,
typename E::value_type(1));
176 template <
class T,
class S>
177 struct get_mult_type_impl
182 template <
class T,
class R,
class P>
183 struct get_mult_type_impl<T, std::chrono::duration<R, P>>
188 template <
class T,
class S>
189 using get_mult_type =
typename get_mult_type_impl<T, S>::type;
193 template <
class R,
class E,
class U,
class X, XTL_REQUIRES(xtl::is_
integral<X>)>
194 inline void arange_assign_to(xexpression<E>& e, U start, U, X step,
bool)
noexcept
196 auto& de = e.derived_cast();
199 for (
auto&& el : de.storage())
201 el =
static_cast<R
>(value);
206 template <
class R,
class E,
class U,
class X, XTL_REQUIRES(xtl::negation<xtl::is_
integral<X>>)>
207 inline void arange_assign_to(xexpression<E>& e, U start, U stop, X step,
bool endpoint)
noexcept
209 auto& buf = e.derived_cast().storage();
210 using size_type =
decltype(buf.size());
211 using mult_type = get_mult_type<U, X>;
212 size_type num = buf.size();
213 for (size_type i = 0; i < num; ++i)
215 buf[i] =
static_cast<R
>(start + step * mult_type(i));
217 if (endpoint && num > 1)
219 buf[num - 1] =
static_cast<R
>(stop);
223 template <
class T,
class R = T,
class S = T>
224 class arange_generator
228 using value_type = R;
231 arange_generator(T start, T stop, S step,
size_t num_steps,
bool endpoint =
false)
235 , m_num_steps(num_steps)
236 , m_endpoint(endpoint)
240 template <
class... Args>
241 inline R operator()(Args... args)
const
243 return access_impl(args...);
247 inline R element(It first, It)
const
249 return access_impl(*first);
253 inline void assign_to(xexpression<E>& e)
const noexcept
255 arange_assign_to<R>(e, m_start, m_stop, m_step, m_endpoint);
266 template <
class T1,
class... Args>
267 inline R access_impl(T1 t, Args...)
const
269 if (m_endpoint && m_num_steps > 1 &&
size_t(t) == m_num_steps - 1)
271 return static_cast<R
>(m_stop);
274 using mult_type = get_mult_type<T, S>;
275 return static_cast<R
>(m_start + m_step * mult_type(t));
278 inline R access_impl()
const
280 return static_cast<R
>(m_start);
284 template <
class T,
class S>
287 template <
class T,
class S>
290 template <
class T,
class S>
293 template <
class T,
class S = T, XTL_REQUIRES(xtl::negation<both_
integer<T, S>>)>
294 inline auto arange_impl(T start, T stop, S step = 1) noexcept
296 std::size_t shape =
static_cast<std::size_t
>(std::ceil((stop - start) / step));
297 return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
300 template <
class T,
class S = T, XTL_REQUIRES(
integer_with_
signed_
integer<T, S>)>
301 inline auto arange_impl(T start, T stop, S step = 1) noexcept
303 bool empty_cond = (stop - start) / step <= 0;
304 std::size_t shape = 0;
307 shape = stop > start ?
static_cast<std::size_t
>((stop - start + step - S(1)) / step)
308 : static_cast<std::size_t>((start - stop - step - S(1)) / -step);
310 return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
313 template <
class T,
class S = T, XTL_REQUIRES(
integer_with_
unsigned_
integer<T, S>)>
314 inline auto arange_impl(T start, T stop, S step = 1) noexcept
316 bool empty_cond = stop <= start;
317 std::size_t shape = 0;
320 shape =
static_cast<std::size_t
>((stop - start + step - S(1)) / step);
322 return detail::make_xgenerator(detail::arange_generator<T, T, S>(start, stop, step, shape), {shape});
330 using value_type =
typename F::value_type;
331 using size_type = std::size_t;
338 inline value_type operator()()
const
340 size_type idx[1] = {0ul};
341 return access_impl(std::begin(idx), std::end(idx));
344 template <
class... Args>
345 inline value_type operator()(Args... args)
const
347 size_type idx[
sizeof...(Args)] = {
static_cast<size_type
>(args)...};
348 return access_impl(std::begin(idx), std::end(idx));
352 inline value_type element(It first, It last)
const
354 return access_impl(first, last);
362 inline value_type access_impl(
const It& begin,
const It& end)
const
364 return m_ft(begin, end);
373 using value_type = T;
381 inline T operator()(
const It& ,
const It& end)
const
383 using lvalue_type =
typename std::iterator_traits<It>::value_type;
384 return *(end - 1) == *(end - 2) +
static_cast<lvalue_type
>(m_k) ? T(1) : T(0);
402 template <
class T =
bool>
403 inline auto eye(
const std::vector<std::size_t>& shape,
int k = 0)
405 return detail::make_xgenerator(detail::fn_impl<detail::eye_fn<T>>(detail::eye_fn<T>(
k)), shape);
417 template <
class T =
bool>
418 inline auto eye(std::size_t
n,
int k = 0)
431 template <
class T,
class S = T>
434 return detail::arange_impl(start, stop, step);
462 using fp_type = std::common_type_t<T, double>;
464 return detail::make_xgenerator(
488 template <
class... CT>
489 class concatenate_access
493 using tuple_type = std::tuple<CT...>;
494 using size_type = std::size_t;
495 using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
498 inline value_type access(
const tuple_type& t, size_type axis, It first, It last)
const
501 auto dim_offset = std::distance(first, last) - std::get<0>(t).dimension();
502 size_t axis_dim = *(first + axis + dim_offset);
503 auto match = [&](
auto& arr)
505 if (axis_dim >= arr.shape()[axis])
507 axis_dim -= arr.shape()[axis];
513 auto get = [&](
auto& arr)
516 const size_t end = arr.dimension();
517 for (
size_t i = 0; i < end; i++)
519 const auto& shape = arr.shape();
520 const size_t stride = std::accumulate(
521 shape.begin() + i + 1,
524 std::multiplies<size_t>()
528 offset += axis_dim * stride;
532 const auto len = (*(first + i + dim_offset));
533 offset += len * stride;
536 const auto element = arr.begin() + offset;
541 for (; i <
sizeof...(CT); ++i)
543 if (apply<bool>(i, match, t))
548 return apply<value_type>(i, get, t);
552 template <
class... CT>
557 using tuple_type = std::tuple<CT...>;
558 using size_type = std::size_t;
559 using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
562 inline value_type access(
const tuple_type& t, size_type axis, It first, It)
const
564 auto get_item = [&](
auto& arr)
567 const size_t end = arr.dimension();
568 size_t after_axis = 0;
569 for (
size_t i = 0; i < end; i++)
575 const auto& shape = arr.shape();
576 const size_t stride = std::accumulate(
577 shape.begin() + i + 1,
580 std::multiplies<size_t>()
582 const auto len = (*(first + i + after_axis));
583 offset += len * stride;
585 const auto element = arr.begin() + offset;
588 size_type i = *(first + axis);
589 return apply<value_type>(i, get_item, t);
593 template <
class... CT>
598 using tuple_type = std::tuple<CT...>;
599 using size_type = std::size_t;
600 using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
603 inline value_type access(
const tuple_type& t, size_type axis, It first, It last)
const
605 if (std::get<0>(t).dimension() == 1)
607 return stack.access(t, axis, first, last);
611 return concatonate.access(t, axis, first, last);
617 concatenate_access<CT...> concatonate;
618 stack_access<CT...>
stack;
621 template <
template <
class...>
class F,
class... CT>
622 class concatenate_invoker
626 using tuple_type = std::tuple<CT...>;
627 using size_type = std::size_t;
628 using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
630 inline concatenate_invoker(tuple_type&& t, size_type axis)
636 template <
class... Args>
637 inline value_type operator()(Args... args)
const
640 xindex index({
static_cast<size_type
>(args)...});
641 return access_method.access(m_t, m_axis, index.begin(), index.end());
645 inline value_type element(It first, It last)
const
647 return access_method.access(m_t, m_axis, first, last);
652 F<CT...> access_method;
657 template <
class... CT>
658 using concatenate_impl = concatenate_invoker<concatenate_access, CT...>;
660 template <
class... CT>
661 using stack_impl = concatenate_invoker<stack_access, CT...>;
663 template <
class... CT>
664 using vstack_impl = concatenate_invoker<vstack_access, CT...>;
671 using xexpression_type = std::decay_t<CT>;
672 using size_type =
typename xexpression_type::size_type;
673 using value_type =
typename xexpression_type::value_type;
676 repeat_impl(CTA&& source, size_type axis)
677 : m_source(std::forward<CTA>(source))
682 template <
class... Args>
683 value_type operator()(Args... args)
const
685 std::array<size_type,
sizeof...(Args)> args_arr = {
static_cast<size_type
>(args)...};
686 return m_source(args_arr[m_axis]);
690 inline value_type element(It first, It)
const
692 return m_source(*(first +
static_cast<std::ptrdiff_t
>(m_axis)));
706 template <
class... Types>
709 return std::tuple<xtl::const_closure_type_t<Types>...>(std::forward<Types>(
args)...);
714 template <
bool... values>
717 template <
class X,
class Y, std::
size_t axis,
class AxesSequence>
718 struct concat_fixed_shape_impl;
720 template <
class X,
class Y, std::size_t axis, std::size_t... Is>
721 struct concat_fixed_shape_impl<X, Y, axis, std::index_sequence<Is...>>
723 static_assert(X::size() == Y::size(),
"Concatenation requires equisized shapes");
724 static_assert(axis < X::size(),
"Concatenation requires a valid axis");
726 all_true<(axis == Is || X::template get<Is>() == Y::template get<Is>())...>::value,
727 "Concatenation requires compatible shapes and axis"
730 using type = fixed_shape<
731 (axis == Is ? X::template get<Is>() + Y::template get<Is>() : X::template get<Is>())...>;
734 template <std::size_t axis,
class X,
class Y,
class... Rest>
735 struct concat_fixed_shape;
737 template <std::
size_t axis,
class X,
class Y>
738 struct concat_fixed_shape<axis, X, Y>
740 using type =
typename concat_fixed_shape_impl<X, Y, axis, std::make_index_sequence<X::size()>>::type;
743 template <std::size_t axis,
class X,
class Y,
class... Rest>
744 struct concat_fixed_shape
746 using type =
typename concat_fixed_shape<axis, X,
typename concat_fixed_shape<axis, Y, Rest...>::type>::type;
749 template <std::size_t axis,
class... Args>
750 using concat_fixed_shape_t =
typename concat_fixed_shape<axis, Args...>::type;
752 template <
class... CT>
753 using all_fixed_shapes = detail::all_fixed<typename std::decay_t<CT>::shape_type...>;
755 struct concat_shape_builder_t
757 template <class Shape, bool = detail::is_fixed<Shape>::value>
760 template <
class Shape>
761 struct concat_shape<Shape, true>
764 using type = static_shape<
typename Shape::value_type, Shape::size()>;
767 template <
class Shape>
768 struct concat_shape<Shape, false>
773 template <
class... Args>
774 static auto build(
const std::tuple<Args...>& t, std::size_t axis)
776 using shape_type = promote_shape_t<
777 typename concat_shape<typename std::decay_t<Args>::shape_type>::type...>;
778 using source_shape_type =
decltype(std::get<0>(t).shape());
779 shape_type new_shape = xtl::forward_sequence<shape_type, source_shape_type>(
780 std::get<0>(t).shape()
783 auto check_shape = [&axis, &new_shape](
auto& arr)
785 std::size_t s = new_shape.size();
786 bool res = s == arr.dimension();
787 for (std::size_t i = 0; i < s; ++i)
789 res = res && (i == axis || new_shape[i] == arr.shape(i));
793 throw_concatenate_error(new_shape, arr.shape());
796 for_each(check_shape, t);
798 auto shape_at_axis = [&axis](std::size_t prev,
auto& arr) -> std::size_t
800 return prev + arr.shape()[axis];
802 new_shape[axis] +=
accumulate(shape_at_axis, std::size_t(0), t) - new_shape[axis];
829 template <
class... CT>
832 const auto shape = detail::concat_shape_builder_t::build(
t, axis);
836 template <std::size_t axis,
class... CT,
typename = std::enable_if_t<detail::all_fixed_shapes<CT...>::value>>
839 using shape_type = detail::concat_fixed_shape_t<axis, typename std::decay_t<CT>::shape_type...>;
840 return detail::make_xgenerator(detail::concatenate_impl<CT...>(std::move(t), axis), shape_type{});
845 template <
class T, std::
size_t N>
846 inline std::array<T, N + 1> add_axis(std::array<T, N> arr, std::size_t axis, std::size_t value)
848 std::array<T, N + 1> temp;
849 std::copy(arr.begin(), arr.begin() + axis, temp.begin());
851 std::copy(arr.begin() + axis, arr.end(), temp.begin() + axis + 1);
856 inline T add_axis(T arr, std::size_t axis, std::size_t value)
859 temp.insert(temp.begin() + std::ptrdiff_t(axis), value);
882 template <
class... CT>
883 inline auto stack(std::tuple<CT...>&&
t, std::size_t axis = 0)
887 auto new_shape = detail::add_axis(
888 xtl::forward_sequence<shape_type, source_shape_type>(std::get<0>(
t).shape()),
903 template <
class... CT>
906 auto dim = std::get<0>(
t).dimension();
907 std::size_t axis =
dim > std::size_t(1) ? 1 : 0;
913 template <
class S,
class... CT>
914 inline auto vstack_shape(std::tuple<CT...>& t,
const S& shape)
916 using size_type =
typename S::value_type;
917 auto res = shape.size() == size_type(1)
918 ? S({
sizeof...(CT), shape[0]})
919 : concat_shape_builder_t::build(std::move(t), size_type(0));
923 template <
class T,
class... CT>
924 inline auto vstack_shape(
const std::tuple<CT...>&, std::array<T, 1> shape)
926 std::array<T, 2> res = {
sizeof...(CT), shape[0]};
939 template <
class... CT>
944 auto new_shape = detail::vstack_shape(
946 xtl::forward_sequence<shape_type, source_shape_type>(std::get<0>(
t).shape())
954 template <std::size_t... I,
class... E>
955 inline auto meshgrid_impl(std::index_sequence<I...>, E&&... e)
noexcept
958 const std::array<std::size_t,
sizeof...(E)> shape = {e.shape()[0]...};
959 return std::make_tuple(
960 detail::make_xgenerator(detail::repeat_impl<xclosure_t<E>>(std::forward<E>(e), I), shape)...
963 return std::make_tuple(detail::make_xgenerator(
964 detail::repeat_impl<xclosure_t<E>>(std::forward<E>(e), I),
979 template <
class... E>
982 return detail::meshgrid_impl(std::make_index_sequence<
sizeof...(E)>(), std::forward<E>(
e)...);
992 using xexpression_type = std::decay_t<CT>;
993 using value_type =
typename xexpression_type::value_type;
996 diagonal_fn(CTA&& source,
int offset, std::size_t axis_1, std::size_t axis_2)
997 : m_source(std::forward<CTA>(source))
1005 inline value_type operator()(It begin, It)
const
1007 xindex idx(m_source.shape().size());
1009 for (std::size_t i = 0; i < idx.size(); i++)
1011 if (i != m_axis_1 && i != m_axis_2)
1013 idx[i] =
static_cast<std::size_t
>(*begin++);
1016 using it_vtype =
typename std::iterator_traits<It>::value_type;
1017 it_vtype uoffset =
static_cast<it_vtype
>(m_offset);
1020 idx[m_axis_1] =
static_cast<std::size_t
>(*(begin));
1021 idx[m_axis_2] =
static_cast<std::size_t
>(*(begin) + uoffset);
1025 idx[m_axis_1] =
static_cast<std::size_t
>(*(begin) -uoffset);
1026 idx[m_axis_2] =
static_cast<std::size_t
>(*(begin));
1028 return m_source[idx];
1035 const std::size_t m_axis_1;
1036 const std::size_t m_axis_2;
1044 using xexpression_type = std::decay_t<CT>;
1045 using value_type =
typename xexpression_type::value_type;
1047 template <
class CTA>
1048 diag_fn(CTA&& source,
int k)
1049 : m_source(std::forward<CTA>(source))
1055 inline value_type operator()(It begin, It)
const
1057 using it_vtype =
typename std::iterator_traits<It>::value_type;
1058 it_vtype umk =
static_cast<it_vtype
>(m_k);
1061 return *begin + umk == *(begin + 1) ? m_source(*begin) : value_type(0);
1065 return *begin + umk == *(begin + 1) ? m_source(*begin + umk) : value_type(0);
1075 template <
class CT,
class Comp>
1080 using xexpression_type = std::decay_t<CT>;
1081 using value_type =
typename xexpression_type::value_type;
1082 using signed_idx_type =
long int;
1084 template <
class CTA>
1085 trilu_fn(CTA&& source,
int k, Comp comp)
1086 : m_source(std::forward<CTA>(source))
1093 inline value_type operator()(It begin, It end)
const
1096 return m_comp(signed_idx_type(*begin) + m_k, signed_idx_type(*(begin + 1)))
1097 ? m_source.element(begin, end)
1104 const signed_idx_type m_k;
1112 template <
class ST,
class... S>
1113 struct diagonal_shape_type
1118 template <
class I, std::
size_t L>
1119 struct diagonal_shape_type<std::array<I, L>>
1121 using type = std::array<I, L - 1>;
1153 using shape_type =
typename detail::diagonal_shape_type<typename std::decay_t<E>::shape_type>::type;
1155 auto shape = arr.shape();
1156 auto dimension = arr.dimension();
1160 auto ret_shape = xtl::make_sequence<shape_type>(dimension - 1, 0);
1179 return detail::make_xgenerator(
1180 detail::fn_impl<detail::diagonal_fn<CT>>(
1205 std::size_t
sk = std::size_t(std::abs(
k));
1206 std::size_t
s = arr.shape()[0] +
sk;
1207 return detail::make_xgenerator(
1208 detail::fn_impl<detail::diag_fn<CT>>(detail::diag_fn<CT>(std::forward<E>(arr),
k)),
1226 auto shape = arr.shape();
1227 return detail::make_xgenerator(
1228 detail::fn_impl<detail::trilu_fn<
CT, std::greater_equal<long int>>>(
1229 detail::trilu_fn<
CT, std::greater_equal<long int>>(
1230 std::forward<E>(arr),
1232 std::greater_equal<long int>()
1252 auto shape = arr.shape();
1253 return detail::make_xgenerator(
1254 detail::fn_impl<detail::trilu_fn<
CT, std::less_equal<long int>>>(
1255 detail::trilu_fn<
CT, std::less_equal<long int>>(std::forward<E>(arr),
k, std::less_equal<long int>())
auto pow(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::pow_fun, E1, E2 >
Power function.
standard mathematical functions for xexpressions
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.
auto stack(std::tuple< CT... > &&t, std::size_t axis=0)
Stack xexpressions along axis.
auto arange(T start, T stop, S step=1) noexcept
Generates numbers evenly spaced within given half-open interval [start, stop).
auto concatenate(std::tuple< CT... > &&t, std::size_t axis=0)
Concatenates xexpressions along axis.
auto eye(const std::vector< std::size_t > &shape, int k=0)
Generates an array with ones on the diagonal.
auto ones_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with ones and of the same shape,...
auto ones(S shape) noexcept
Returns an xexpression containing ones of the specified shape.
auto meshgrid(E &&... e) noexcept
Return coordinate tensors from coordinate vectors.
auto triu(E &&arr, int k=0)
Extract upper triangular matrix from xexpression.
auto full_like(const xexpression< E > &e, typename E::value_type fill_value)
Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with fill_value and of the same shape,...
auto zeros_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed), filled with zeros and of the same shape,...
auto zeros(S shape) noexcept
Returns an xexpression containing zeros of the specified shape.
auto accumulate(F &&f, E &&e, EVS evaluation_strategy=EVS())
Accumulate and flatten array NOTE This function is not lazy!
auto linspace(T start, T stop, std::size_t num_samples=50, bool endpoint=true) noexcept
Generates num_samples evenly spaced numbers over given interval.
auto vstack(std::tuple< CT... > &&t)
Stack xexpressions in sequence vertically (row wise).
auto diagonal(E &&arr, int offset=0, std::size_t axis_1=0, std::size_t axis_2=1)
Returns the elements on the diagonal of arr If arr has more than two dimensions, then the axes specif...
auto hstack(std::tuple< CT... > &&t)
Stack xexpressions in sequence horizontally (column wise).
auto tril(E &&arr, int k=0)
Extract lower triangular matrix from xexpression.
auto empty_like(const xexpression< E > &e)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of the same shape,...
auto diag(E &&arr, int k=0)
xexpression with values of arr on the diagonal, zeroes otherwise
auto xtuple(Types &&... args)
Creates tuples from arguments for concatenate and stack.
auto logspace(T start, T stop, std::size_t num_samples, T base=10, bool endpoint=true) noexcept
Generates num_samples numbers evenly spaced on a log scale over given interval.
xfixed_container< T, FSH, L, Sharable > xtensor_fixed
Alias template on xfixed_container with default parameters for layout type.
xarray< T, L > empty(const S &shape)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of with value_type T...