10#ifndef XTENSOR_XSHAPE_HPP
11#define XTENSOR_XSHAPE_HPP
16#include <initializer_list>
19#include "../containers/xstorage.hpp"
20#include "../core/xlayout.hpp"
21#include "../core/xtensor_forward.hpp"
28 template <
class T, std::
size_t N>
29 using static_shape = std::array<T, N>;
31 template <std::size_t... X>
34 using xindex = dynamic_shape<std::size_t>;
36 template <
class S1,
class S2>
37 bool same_shape(
const S1& s1,
const S2& s2)
noexcept;
42 template <
class R,
class T>
43 constexpr R shape(T t);
45 template <
class R = std::
size_t,
class T, std::
size_t N>
46 xt::static_shape<R, N> shape(
const T (&aList)[N]);
51 template <layout_type L,
class S>
69 struct sequence_builder;
71 template <std::size_t... I>
72 struct sequence_builder<xt::fixed_shape<I...>>
74 using sequence_type = xt::fixed_shape<I...>;
75 using value_type =
typename sequence_type::value_type;
77 inline static sequence_type make(std::size_t )
79 return sequence_type{};
82 inline static sequence_type make(std::size_t , value_type )
84 return sequence_type{};
108 template <
class S1,
class S2>
111 return s1.size() == s2.size() && std::equal(s1.begin(), s1.end(), s2.begin());
126 template <
class E,
class S>
127 inline bool has_shape(
const E& e, std::initializer_list<S> shape)
noexcept
129 return e.shape().size() == shape.size()
130 && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
141 template <class E, class S, class = typename std::enable_if_t<has_iterator_interface<S>::value>>
144 return e.shape().size() == shape.size()
145 && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
155 struct initializer_depth_impl
157 static constexpr std::size_t value = 0;
161 struct initializer_depth_impl<std::initializer_list<T>>
163 static constexpr std::size_t value = 1 + initializer_depth_impl<T>::value;
170 static constexpr std::size_t value = detail::initializer_depth_impl<U>::value;
179 template <std::
size_t I>
180 struct initializer_shape_impl
183 static constexpr std::size_t value(T t)
185 return t.size() == 0 ? 0 : initializer_shape_impl<I - 1>::value(*t.begin());
190 struct initializer_shape_impl<0>
193 static constexpr std::size_t value(T t)
199 template <
class R,
class U, std::size_t... I>
200 constexpr R initializer_shape(U t, std::index_sequence<I...>)
202 using size_type =
typename R::value_type;
203 return {size_type(initializer_shape_impl<I>::value(t))...};
207 template <
class R,
class T>
208 constexpr R shape(T t)
210 return detail::initializer_shape<R, decltype(t)>(
217 template <
class R,
class T, std::
size_t N>
218 xt::static_shape<R, N> shape(
const T (&list)[N])
220 xt::static_shape<R, N> shape;
221 std::copy(std::begin(list), std::end(list), std::begin(shape));
231 template <
class T,
class E =
void>
232 struct static_dimension_impl
234 static constexpr std::ptrdiff_t value = -1;
238 struct static_dimension_impl<T, void_t<decltype(std::tuple_size<T>::value)>>
240 static constexpr std::ptrdiff_t value =
static_cast<std::ptrdiff_t
>(std::tuple_size<T>::value);
247 static constexpr std::ptrdiff_t value = detail::static_dimension_impl<S>::value;
259 template <layout_type L,
class S>
262 static constexpr std::ptrdiff_t static_dimension = xt::static_dimension<S>::value;
263 static constexpr bool is_any = static_dimension != -1 && static_dimension <= 1
274 template <
class T1,
class T2>
275 constexpr std::common_type_t<T1, T2> imax(
const T1& a,
const T2& b)
277 return a > b ? a : b;
281 template <
class... T>
282 struct max_array_size;
285 struct max_array_size<>
287 static constexpr std::size_t value = 0;
290 template <
class T,
class... Ts>
291 struct max_array_size<T, Ts...>
292 : std::integral_constant<std::size_t, imax(std::tuple_size<T>::value, max_array_size<Ts...>::value)>
297 template <std::size_t IDX, std::size_t... X>
300 static constexpr std::size_t arr[
sizeof...(X)] = {X...};
301 static constexpr std::size_t value = (IDX <
sizeof...(X)) ? arr[IDX] : 0;
304 template <
class S1,
class S2>
305 struct broadcast_fixed_shape;
307 template <
class IX,
class A,
class B>
308 struct broadcast_fixed_shape_impl;
310 template <std::
size_t IX,
class A,
class B>
311 struct broadcast_fixed_shape_cmp_impl;
313 template <std::size_t JX, std::size_t... I, std::size_t... J>
314 struct broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>
318 static constexpr std::size_t IX = JX - (
sizeof...(J) -
sizeof...(I));
321 static constexpr std::size_t I_v = at<IX, I...>::value;
322 static constexpr std::size_t J_v = at<JX, J...>::value;
325 static_assert(!I_v || I_v == 1 || J_v == 1 || J_v == I_v,
"broadcast shapes do not match.");
327 static constexpr std::size_t ordinate = (I_v > J_v) ? I_v : J_v;
328 static constexpr bool value = (I_v == J_v);
331 template <std::size_t... JX, std::size_t... I, std::size_t... J>
332 struct broadcast_fixed_shape_impl<std::index_sequence<JX...>, fixed_shape<I...>, fixed_shape<J...>>
334 static_assert(
sizeof...(J) >=
sizeof...(I),
"broadcast shapes do not match.");
336 using type = xt::fixed_shape<
337 broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>::ordinate...>;
338 static constexpr bool value = std::conjunction<
339 broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>...>::value;
348 template <std::size_t... I, std::size_t... J>
349 struct broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
350 : broadcast_fixed_shape_impl<std::make_index_sequence<sizeof...(J)>, fixed_shape<I...>, fixed_shape<J...>>
358 static constexpr bool value =
false;
361 template <
class T, std::
size_t N>
362 struct is_array<std::array<T, N>>
364 static constexpr bool value =
true;
368 struct is_fixed : std::false_type
372 template <std::size_t... N>
373 struct is_fixed<fixed_shape<N...>> : std::true_type
378 struct is_scalar_shape
380 static constexpr bool value =
false;
384 struct is_scalar_shape<std::array<T, 0>>
386 static constexpr bool value =
true;
389 template <
class... S>
390 using only_array = std::conjunction<std::disjunction<is_array<S>, is_fixed<S>>...>;
394 template <
class... S>
395 using only_fixed = std::integral_constant<
397 std::disjunction<is_fixed<S>...>::value
398 && std::conjunction<std::disjunction<is_fixed<S>, is_scalar_shape<S>>...>::value>;
400 template <
class... S>
401 using all_fixed = std::conjunction<is_fixed<S>...>;
407 template <
class... S>
411 array<
typename std::common_type<
typename S::value_type...>::type, max_array_size<S...>::value>;
415 struct promote_array<>
417 using type = std::array<std::size_t, 0>;
427 struct filter_scalar<std::array<T, 0>>
429 using type = fixed_shape<1>;
433 using filter_scalar_t =
typename filter_scalar<S>::type;
435 template <
class... S>
436 struct promote_fixed : promote_fixed<filter_scalar_t<S>...>
440 template <std::size_t... I>
441 struct promote_fixed<fixed_shape<I...>>
443 using type = fixed_shape<I...>;
444 static constexpr bool value =
true;
447 template <std::size_t... I, std::size_t... J,
class... S>
448 struct promote_fixed<fixed_shape<I...>, fixed_shape<J...>, S...>
452 using intermediate = std::conditional_t<
453 (
sizeof...(I) >
sizeof...(J)),
454 broadcast_fixed_shape<fixed_shape<J...>, fixed_shape<I...>>,
455 broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>>;
456 using result = promote_fixed<
typename intermediate::type, S...>;
460 using type =
typename result::type;
461 static constexpr bool value = std::conjunction<intermediate, result>::value;
464 template <
bool all_index,
bool all_array,
class... S>
465 struct select_promote_index;
467 template <
class... S>
468 struct select_promote_index<true, true, S...> : promote_fixed<S...>
473 struct select_promote_index<true, true>
476 using type = dynamic_shape<std::size_t>;
479 template <
class... S>
480 struct select_promote_index<false, true, S...> : promote_array<S...>
484 template <
class... S>
485 struct select_promote_index<false, false, S...>
487 using type = dynamic_shape<
typename std::common_type<
typename S::value_type...>::type>;
490 template <
class... S>
491 struct promote_index : select_promote_index<only_fixed<S...>::value, only_array<S...>::value, S...>
496 struct index_from_shape_impl
501 template <std::size_t... N>
502 struct index_from_shape_impl<fixed_shape<N...>>
504 using type = std::array<std::size_t,
sizeof...(N)>;
508 template <
typename T>
511 template <
class... S>
514 using type =
typename detail::promote_index<S...>::type;
520 template <
class... S>
523 template <
class... S>
526 using type =
typename detail::promote_index<S...>::type;
532 template <
class... S>
538 using type =
typename detail::index_from_shape_impl<S>::type;
545 using index_from_shape_t =
typename index_from_shape<S>::type;
554 struct filter_fixed_shape_impl
559 template <std::size_t... N>
560 struct filter_fixed_shape_impl<fixed_shape<N...>>
562 using type = std::array<std::size_t,
sizeof...(N)>;
575 using filter_fixed_shape_t =
typename filter_fixed_shape<S>::type;
Fixed shape implementation for compile time defined arrays.
bool same_shape(const S1 &s1, const S2 &s2) noexcept
Check if two objects have the same shape.
bool has_shape(const E &e, std::initializer_list< S > shape) noexcept
Check if an object has a certain shape.
standard mathematical functions for xexpressions
Compute a layout based on a layout and a shape type.