xtensor
 
Loading...
Searching...
No Matches
xshape.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_XSHAPE_HPP
11#define XTENSOR_XSHAPE_HPP
12
13#include <algorithm>
14#include <cstddef>
15#include <cstdlib>
16#include <initializer_list>
17#include <iterator>
18
19#include "../containers/xstorage.hpp"
20#include "../core/xlayout.hpp"
21#include "../core/xtensor_forward.hpp"
22
23namespace xt
24{
25 template <class T>
26 using dynamic_shape = svector<T, 4>;
27
28 template <class T, std::size_t N>
29 using static_shape = std::array<T, N>;
30
31 template <std::size_t... X>
32 class fixed_shape;
33
34 using xindex = dynamic_shape<std::size_t>;
35
36 template <class S1, class S2>
37 bool same_shape(const S1& s1, const S2& s2) noexcept;
38
39 template <class U>
41
42 template <class R, class T>
43 constexpr R shape(T t);
44
45 template <class R = std::size_t, class T, std::size_t N>
46 xt::static_shape<R, N> shape(const T (&aList)[N]);
47
48 template <class S>
49 struct static_dimension;
50
51 template <layout_type L, class S>
52 struct select_layout;
53
54 template <class... S>
55 struct promote_shape;
56
57 template <class... S>
58 struct promote_strides;
59
60 template <class S>
61 struct index_from_shape;
62}
63
64namespace xtl
65{
66 namespace detail
67 {
68 template <class S>
69 struct sequence_builder;
70
71 template <std::size_t... I>
72 struct sequence_builder<xt::fixed_shape<I...>>
73 {
74 using sequence_type = xt::fixed_shape<I...>;
75 using value_type = typename sequence_type::value_type;
76
77 inline static sequence_type make(std::size_t /*size*/)
78 {
79 return sequence_type{};
80 }
81
82 inline static sequence_type make(std::size_t /*size*/, value_type /*v*/)
83 {
84 return sequence_type{};
85 }
86 };
87 }
88}
89
90namespace xt
91{
95
96 /**************
97 * same_shape *
98 **************/
99
108 template <class S1, class S2>
109 inline bool same_shape(const S1& s1, const S2& s2) noexcept
110 {
111 return s1.size() == s2.size() && std::equal(s1.begin(), s1.end(), s2.begin());
112 }
113
114 /*************
115 * has_shape *
116 *************/
117
126 template <class E, class S>
127 inline bool has_shape(const E& e, std::initializer_list<S> shape) noexcept
128 {
129 return e.shape().size() == shape.size()
130 && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
131 }
132
141 template <class E, class S, class = typename std::enable_if_t<has_iterator_interface<S>::value>>
142 inline bool has_shape(const E& e, const S& shape)
143 {
144 return e.shape().size() == shape.size()
145 && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
146 }
147
148 /*************************
149 * initializer_dimension *
150 *************************/
151
152 namespace detail
153 {
154 template <class U>
155 struct initializer_depth_impl
156 {
157 static constexpr std::size_t value = 0;
158 };
159
160 template <class T>
161 struct initializer_depth_impl<std::initializer_list<T>>
162 {
163 static constexpr std::size_t value = 1 + initializer_depth_impl<T>::value;
164 };
165 }
166
167 template <class U>
169 {
170 static constexpr std::size_t value = detail::initializer_depth_impl<U>::value;
171 };
172
173 /*********************
174 * initializer_shape *
175 *********************/
176
177 namespace detail
178 {
179 template <std::size_t I>
180 struct initializer_shape_impl
181 {
182 template <class T>
183 static constexpr std::size_t value(T t)
184 {
185 return t.size() == 0 ? 0 : initializer_shape_impl<I - 1>::value(*t.begin());
186 }
187 };
188
189 template <>
190 struct initializer_shape_impl<0>
191 {
192 template <class T>
193 static constexpr std::size_t value(T t)
194 {
195 return t.size();
196 }
197 };
198
199 template <class R, class U, std::size_t... I>
200 constexpr R initializer_shape(U t, std::index_sequence<I...>)
201 {
202 using size_type = typename R::value_type;
203 return {size_type(initializer_shape_impl<I>::value(t))...};
204 }
205 }
206
207 template <class R, class T>
208 constexpr R shape(T t)
209 {
210 return detail::initializer_shape<R, decltype(t)>(
211 t,
212 std::make_index_sequence<initializer_dimension<decltype(t)>::value>()
213 );
214 }
215
217 template <class R, class T, std::size_t N>
218 xt::static_shape<R, N> shape(const T (&list)[N])
219 {
220 xt::static_shape<R, N> shape;
221 std::copy(std::begin(list), std::end(list), std::begin(shape));
222 return shape;
223 }
224
225 /********************
226 * static_dimension *
227 ********************/
228
229 namespace detail
230 {
231 template <class T, class E = void>
232 struct static_dimension_impl
233 {
234 static constexpr std::ptrdiff_t value = -1;
235 };
236
237 template <class T>
238 struct static_dimension_impl<T, void_t<decltype(std::tuple_size<T>::value)>>
239 {
240 static constexpr std::ptrdiff_t value = static_cast<std::ptrdiff_t>(std::tuple_size<T>::value);
241 };
242 }
243
244 template <class S>
246 {
247 static constexpr std::ptrdiff_t value = detail::static_dimension_impl<S>::value;
248 };
249
259 template <layout_type L, class S>
261 {
262 static constexpr std::ptrdiff_t static_dimension = xt::static_dimension<S>::value;
263 static constexpr bool is_any = static_dimension != -1 && static_dimension <= 1
264 && L != layout_type::dynamic;
265 static constexpr layout_type value = is_any ? layout_type::any : L;
266 };
267
268 /*************************************
269 * promote_shape and promote_strides *
270 *************************************/
271
272 namespace detail
273 {
274 template <class T1, class T2>
275 constexpr std::common_type_t<T1, T2> imax(const T1& a, const T2& b)
276 {
277 return a > b ? a : b;
278 }
279
280 // Variadic meta-function returning the maximal size of std::arrays.
281 template <class... T>
282 struct max_array_size;
283
284 template <>
285 struct max_array_size<>
286 {
287 static constexpr std::size_t value = 0;
288 };
289
290 template <class T, class... Ts>
291 struct max_array_size<T, Ts...>
292 : std::integral_constant<std::size_t, imax(std::tuple_size<T>::value, max_array_size<Ts...>::value)>
293 {
294 };
295
296 // Broadcasting for fixed shapes
297 template <std::size_t IDX, std::size_t... X>
298 struct at
299 {
300 static constexpr std::size_t arr[sizeof...(X)] = {X...};
301 static constexpr std::size_t value = (IDX < sizeof...(X)) ? arr[IDX] : 0;
302 };
303
304 template <class S1, class S2>
305 struct broadcast_fixed_shape;
306
307 template <class IX, class A, class B>
308 struct broadcast_fixed_shape_impl;
309
310 template <std::size_t IX, class A, class B>
311 struct broadcast_fixed_shape_cmp_impl;
312
313 template <std::size_t JX, std::size_t... I, std::size_t... J>
314 struct broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>
315 {
316 // We line the shapes up from the last index
317 // IX may underflow, thus being a very large number
318 static constexpr std::size_t IX = JX - (sizeof...(J) - sizeof...(I));
319
320 // Out of bounds access gives value 0
321 static constexpr std::size_t I_v = at<IX, I...>::value;
322 static constexpr std::size_t J_v = at<JX, J...>::value;
323
324 // we're statically checking if the broadcast shapes are either one on either of them or equal
325 static_assert(!I_v || I_v == 1 || J_v == 1 || J_v == I_v, "broadcast shapes do not match.");
326
327 static constexpr std::size_t ordinate = (I_v > J_v) ? I_v : J_v;
328 static constexpr bool value = (I_v == J_v);
329 };
330
331 template <std::size_t... JX, std::size_t... I, std::size_t... J>
332 struct broadcast_fixed_shape_impl<std::index_sequence<JX...>, fixed_shape<I...>, fixed_shape<J...>>
333 {
334 static_assert(sizeof...(J) >= sizeof...(I), "broadcast shapes do not match.");
335
336 using type = xt::fixed_shape<
337 broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>::ordinate...>;
338 static constexpr bool value = std::conjunction<
339 broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>...>::value;
340 };
341
342 /* broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
343 * Just like a call to broadcast_shape(cont S1& input, S2& output),
344 * except that the result shape is alised as type, and the returned
345 * bool is the member value. Asserts on an illegal broadcast, including
346 * the case where pack I is strictly longer than pack J. */
347
348 template <std::size_t... I, std::size_t... J>
349 struct broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
350 : broadcast_fixed_shape_impl<std::make_index_sequence<sizeof...(J)>, fixed_shape<I...>, fixed_shape<J...>>
351 {
352 };
353
354 // Simple is_array and only_array meta-functions
355 template <class S>
356 struct is_array
357 {
358 static constexpr bool value = false;
359 };
360
361 template <class T, std::size_t N>
362 struct is_array<std::array<T, N>>
363 {
364 static constexpr bool value = true;
365 };
366
367 template <class S>
368 struct is_fixed : std::false_type
369 {
370 };
371
372 template <std::size_t... N>
373 struct is_fixed<fixed_shape<N...>> : std::true_type
374 {
375 };
376
377 template <class S>
378 struct is_scalar_shape
379 {
380 static constexpr bool value = false;
381 };
382
383 template <class T>
384 struct is_scalar_shape<std::array<T, 0>>
385 {
386 static constexpr bool value = true;
387 };
388
389 template <class... S>
390 using only_array = std::conjunction<std::disjunction<is_array<S>, is_fixed<S>>...>;
391
392 // test that at least one argument is a fixed shape. If yes, then either argument has to be fixed or
393 // scalar
394 template <class... S>
395 using only_fixed = std::integral_constant<
396 bool,
397 std::disjunction<is_fixed<S>...>::value
398 && std::conjunction<std::disjunction<is_fixed<S>, is_scalar_shape<S>>...>::value>;
399
400 template <class... S>
401 using all_fixed = std::conjunction<is_fixed<S>...>;
402
403 // The promote_index meta-function returns std::vector<promoted_value_type> in the
404 // general case and an array of the promoted value type and maximal size if all
405 // arguments are of type std::array
406
407 template <class... S>
408 struct promote_array
409 {
410 using type = std::
411 array<typename std::common_type<typename S::value_type...>::type, max_array_size<S...>::value>;
412 };
413
414 template <>
415 struct promote_array<>
416 {
417 using type = std::array<std::size_t, 0>;
418 };
419
420 template <class S>
421 struct filter_scalar
422 {
423 using type = S;
424 };
425
426 template <class T>
427 struct filter_scalar<std::array<T, 0>>
428 {
429 using type = fixed_shape<1>;
430 };
431
432 template <class S>
433 using filter_scalar_t = typename filter_scalar<S>::type;
434
435 template <class... S>
436 struct promote_fixed : promote_fixed<filter_scalar_t<S>...>
437 {
438 };
439
440 template <std::size_t... I>
441 struct promote_fixed<fixed_shape<I...>>
442 {
443 using type = fixed_shape<I...>;
444 static constexpr bool value = true;
445 };
446
447 template <std::size_t... I, std::size_t... J, class... S>
448 struct promote_fixed<fixed_shape<I...>, fixed_shape<J...>, S...>
449 {
450 private:
451
452 using intermediate = std::conditional_t<
453 (sizeof...(I) > sizeof...(J)),
454 broadcast_fixed_shape<fixed_shape<J...>, fixed_shape<I...>>,
455 broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>>;
456 using result = promote_fixed<typename intermediate::type, S...>;
457
458 public:
459
460 using type = typename result::type;
461 static constexpr bool value = std::conjunction<intermediate, result>::value;
462 };
463
464 template <bool all_index, bool all_array, class... S>
465 struct select_promote_index;
466
467 template <class... S>
468 struct select_promote_index<true, true, S...> : promote_fixed<S...>
469 {
470 };
471
472 template <>
473 struct select_promote_index<true, true>
474 {
475 // todo correct? used in xvectorize
476 using type = dynamic_shape<std::size_t>;
477 };
478
479 template <class... S>
480 struct select_promote_index<false, true, S...> : promote_array<S...>
481 {
482 };
483
484 template <class... S>
485 struct select_promote_index<false, false, S...>
486 {
487 using type = dynamic_shape<typename std::common_type<typename S::value_type...>::type>;
488 };
489
490 template <class... S>
491 struct promote_index : select_promote_index<only_fixed<S...>::value, only_array<S...>::value, S...>
492 {
493 };
494
495 template <class T>
496 struct index_from_shape_impl
497 {
498 using type = T;
499 };
500
501 template <std::size_t... N>
502 struct index_from_shape_impl<fixed_shape<N...>>
503 {
504 using type = std::array<std::size_t, sizeof...(N)>;
505 };
506 }
507
508 template <typename T>
509 concept fixed_shape_container_concept = detail::is_fixed<typename std::decay_t<T>::shape_type>::value;
510
511 template <class... S>
513 {
514 using type = typename detail::promote_index<S...>::type;
515 };
516
520 template <class... S>
521 using promote_shape_t = typename promote_shape<S...>::type;
522
523 template <class... S>
525 {
526 using type = typename detail::promote_index<S...>::type;
527 };
528
532 template <class... S>
533 using promote_strides_t = typename promote_strides<S...>::type;
534
535 template <class S>
537 {
538 using type = typename detail::index_from_shape_impl<S>::type;
539 };
540
544 template <class S>
545 using index_from_shape_t = typename index_from_shape<S>::type;
546
547 /**********************
548 * filter_fixed_shape *
549 **********************/
550
551 namespace detail
552 {
553 template <class S>
554 struct filter_fixed_shape_impl
555 {
556 using type = S;
557 };
558
559 template <std::size_t... N>
560 struct filter_fixed_shape_impl<fixed_shape<N...>>
561 {
562 using type = std::array<std::size_t, sizeof...(N)>;
563 };
564 }
565
566 template <class S>
567 struct filter_fixed_shape : detail::filter_fixed_shape_impl<S>
568 {
569 };
570
574 template <class S>
575 using filter_fixed_shape_t = typename filter_fixed_shape<S>::type;
576}
577
578#endif
Fixed shape implementation for compile time defined arrays.
bool same_shape(const S1 &s1, const S2 &s2) noexcept
Check if two objects have the same shape.
Definition xshape.hpp:109
bool has_shape(const E &e, std::initializer_list< S > shape) noexcept
Check if an object has a certain shape.
Definition xshape.hpp:127
standard mathematical functions for xexpressions
layout_type
Definition xlayout.hpp:24
Compute a layout based on a layout and a shape type.
Definition xshape.hpp:261