xtensor
Loading...
Searching...
No Matches
xshape.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_XSHAPE_HPP
11#define XTENSOR_XSHAPE_HPP
12
13#include <algorithm>
14#include <cassert>
15#include <cstddef>
16#include <cstdlib>
17#include <cstring>
18#include <initializer_list>
19#include <iterator>
20#include <memory>
21
22#include "xlayout.hpp"
23#include "xstorage.hpp"
24#include "xtensor_forward.hpp"
25
26namespace xt
27{
28 template <class T>
29 using dynamic_shape = svector<T, 4>;
30
31 template <class T, std::size_t N>
32 using static_shape = std::array<T, N>;
33
34 template <std::size_t... X>
35 class fixed_shape;
36
37 using xindex = dynamic_shape<std::size_t>;
38
39 template <class S1, class S2>
40 bool same_shape(const S1& s1, const S2& s2) noexcept;
41
42 template <class U>
43 struct initializer_dimension;
44
45 template <class R, class T>
46 constexpr R shape(T t);
47
48 template <class R = std::size_t, class T, std::size_t N>
49 xt::static_shape<R, N> shape(const T (&aList)[N]);
50
51 template <class S>
52 struct static_dimension;
53
54 template <layout_type L, class S>
55 struct select_layout;
56
57 template <class... S>
58 struct promote_shape;
59
60 template <class... S>
61 struct promote_strides;
62
63 template <class S>
64 struct index_from_shape;
65}
66
67namespace xtl
68{
69 namespace detail
70 {
71 template <class S>
72 struct sequence_builder;
73
74 template <std::size_t... I>
75 struct sequence_builder<xt::fixed_shape<I...>>
76 {
77 using sequence_type = xt::fixed_shape<I...>;
78 using value_type = typename sequence_type::value_type;
79
80 inline static sequence_type make(std::size_t /*size*/)
81 {
82 return sequence_type{};
83 }
84
85 inline static sequence_type make(std::size_t /*size*/, value_type /*v*/)
86 {
87 return sequence_type{};
88 }
89 };
90 }
91}
92
93namespace xt
94{
99 /**************
100 * same_shape *
101 **************/
102
111 template <class S1, class S2>
112 inline bool same_shape(const S1& s1, const S2& s2) noexcept
113 {
114 return s1.size() == s2.size() && std::equal(s1.begin(), s1.end(), s2.begin());
115 }
116
117 /*************
118 * has_shape *
119 *************/
120
129 template <class E, class S>
130 inline bool has_shape(const E& e, std::initializer_list<S> shape) noexcept
131 {
132 return e.shape().size() == shape.size()
133 && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
134 }
135
144 template <class E, class S, class = typename std::enable_if_t<has_iterator_interface<S>::value>>
145 inline bool has_shape(const E& e, const S& shape)
146 {
147 return e.shape().size() == shape.size()
148 && std::equal(e.shape().cbegin(), e.shape().cend(), shape.begin());
149 }
150
151 /*************************
152 * initializer_dimension *
153 *************************/
154
155 namespace detail
156 {
157 template <class U>
158 struct initializer_depth_impl
159 {
160 static constexpr std::size_t value = 0;
161 };
162
163 template <class T>
164 struct initializer_depth_impl<std::initializer_list<T>>
165 {
166 static constexpr std::size_t value = 1 + initializer_depth_impl<T>::value;
167 };
168 }
169
170 template <class U>
172 {
173 static constexpr std::size_t value = detail::initializer_depth_impl<U>::value;
174 };
175
176 /*********************
177 * initializer_shape *
178 *********************/
179
180 namespace detail
181 {
182 template <std::size_t I>
183 struct initializer_shape_impl
184 {
185 template <class T>
186 static constexpr std::size_t value(T t)
187 {
188 return t.size() == 0 ? 0 : initializer_shape_impl<I - 1>::value(*t.begin());
189 }
190 };
191
192 template <>
193 struct initializer_shape_impl<0>
194 {
195 template <class T>
196 static constexpr std::size_t value(T t)
197 {
198 return t.size();
199 }
200 };
201
202 template <class R, class U, std::size_t... I>
203 constexpr R initializer_shape(U t, std::index_sequence<I...>)
204 {
205 using size_type = typename R::value_type;
206 return {size_type(initializer_shape_impl<I>::value(t))...};
207 }
208 }
209
210 template <class R, class T>
211 constexpr R shape(T t)
212 {
213 return detail::initializer_shape<R, decltype(t)>(
214 t,
215 std::make_index_sequence<initializer_dimension<decltype(t)>::value>()
216 );
217 }
218
220 template <class R, class T, std::size_t N>
221 xt::static_shape<R, N> shape(const T (&list)[N])
222 {
224 std::copy(std::begin(list), std::end(list), std::begin(shape));
225 return shape;
226 }
227
228 /********************
229 * static_dimension *
230 ********************/
231
232 namespace detail
233 {
234 template <class T, class E = void>
235 struct static_dimension_impl
236 {
237 static constexpr std::ptrdiff_t value = -1;
238 };
239
240 template <class T>
241 struct static_dimension_impl<T, void_t<decltype(std::tuple_size<T>::value)>>
242 {
243 static constexpr std::ptrdiff_t value = static_cast<std::ptrdiff_t>(std::tuple_size<T>::value);
244 };
245 }
246
247 template <class S>
249 {
250 static constexpr std::ptrdiff_t value = detail::static_dimension_impl<S>::value;
251 };
252
262 template <layout_type L, class S>
264 {
265 static constexpr std::ptrdiff_t static_dimension = xt::static_dimension<S>::value;
266 static constexpr bool is_any = static_dimension != -1 && static_dimension <= 1
268 static constexpr layout_type value = is_any ? layout_type::any : L;
269 };
270
271 /*************************************
272 * promote_shape and promote_strides *
273 *************************************/
274
275 namespace detail
276 {
277 template <class T1, class T2>
278 constexpr std::common_type_t<T1, T2> imax(const T1& a, const T2& b)
279 {
280 return a > b ? a : b;
281 }
282
283 // Variadic meta-function returning the maximal size of std::arrays.
284 template <class... T>
285 struct max_array_size;
286
287 template <>
288 struct max_array_size<>
289 {
290 static constexpr std::size_t value = 0;
291 };
292
293 template <class T, class... Ts>
294 struct max_array_size<T, Ts...>
295 : std::integral_constant<std::size_t, imax(std::tuple_size<T>::value, max_array_size<Ts...>::value)>
296 {
297 };
298
299 // Broadcasting for fixed shapes
300 template <std::size_t IDX, std::size_t... X>
301 struct at
302 {
303 static constexpr std::size_t arr[sizeof...(X)] = {X...};
304 static constexpr std::size_t value = (IDX < sizeof...(X)) ? arr[IDX] : 0;
305 };
306
307 template <class S1, class S2>
308 struct broadcast_fixed_shape;
309
310 template <class IX, class A, class B>
311 struct broadcast_fixed_shape_impl;
312
313 template <std::size_t IX, class A, class B>
314 struct broadcast_fixed_shape_cmp_impl;
315
316 template <std::size_t JX, std::size_t... I, std::size_t... J>
317 struct broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>
318 {
319 // We line the shapes up from the last index
320 // IX may underflow, thus being a very large number
321 static constexpr std::size_t IX = JX - (sizeof...(J) - sizeof...(I));
322
323 // Out of bounds access gives value 0
324 static constexpr std::size_t I_v = at<IX, I...>::value;
325 static constexpr std::size_t J_v = at<JX, J...>::value;
326
327 // we're statically checking if the broadcast shapes are either one on either of them or equal
328 static_assert(!I_v || I_v == 1 || J_v == 1 || J_v == I_v, "broadcast shapes do not match.");
329
330 static constexpr std::size_t ordinate = (I_v > J_v) ? I_v : J_v;
331 static constexpr bool value = (I_v == J_v);
332 };
333
334 template <std::size_t... JX, std::size_t... I, std::size_t... J>
335 struct broadcast_fixed_shape_impl<std::index_sequence<JX...>, fixed_shape<I...>, fixed_shape<J...>>
336 {
337 static_assert(sizeof...(J) >= sizeof...(I), "broadcast shapes do not match.");
338
339 using type = xt::fixed_shape<
340 broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>::ordinate...>;
341 static constexpr bool value = xtl::conjunction<
342 broadcast_fixed_shape_cmp_impl<JX, fixed_shape<I...>, fixed_shape<J...>>...>::value;
343 };
344
345 /* broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
346 * Just like a call to broadcast_shape(cont S1& input, S2& output),
347 * except that the result shape is alised as type, and the returned
348 * bool is the member value. Asserts on an illegal broadcast, including
349 * the case where pack I is strictly longer than pack J. */
350
351 template <std::size_t... I, std::size_t... J>
352 struct broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>
353 : broadcast_fixed_shape_impl<std::make_index_sequence<sizeof...(J)>, fixed_shape<I...>, fixed_shape<J...>>
354 {
355 };
356
357 // Simple is_array and only_array meta-functions
358 template <class S>
359 struct is_array
360 {
361 static constexpr bool value = false;
362 };
363
364 template <class T, std::size_t N>
365 struct is_array<std::array<T, N>>
366 {
367 static constexpr bool value = true;
368 };
369
370 template <class S>
371 struct is_fixed : std::false_type
372 {
373 };
374
375 template <std::size_t... N>
376 struct is_fixed<fixed_shape<N...>> : std::true_type
377 {
378 };
379
380 template <class S>
381 struct is_scalar_shape
382 {
383 static constexpr bool value = false;
384 };
385
386 template <class T>
387 struct is_scalar_shape<std::array<T, 0>>
388 {
389 static constexpr bool value = true;
390 };
391
392 template <class... S>
393 using only_array = xtl::conjunction<xtl::disjunction<is_array<S>, is_fixed<S>>...>;
394
395 // test that at least one argument is a fixed shape. If yes, then either argument has to be fixed or
396 // scalar
397 template <class... S>
398 using only_fixed = std::integral_constant<
399 bool,
400 xtl::disjunction<is_fixed<S>...>::value
401 && xtl::conjunction<xtl::disjunction<is_fixed<S>, is_scalar_shape<S>>...>::value>;
402
403 template <class... S>
404 using all_fixed = xtl::conjunction<is_fixed<S>...>;
405
406 // The promote_index meta-function returns std::vector<promoted_value_type> in the
407 // general case and an array of the promoted value type and maximal size if all
408 // arguments are of type std::array
409
410 template <class... S>
411 struct promote_array
412 {
413 using type = std::
414 array<typename std::common_type<typename S::value_type...>::type, max_array_size<S...>::value>;
415 };
416
417 template <>
418 struct promote_array<>
419 {
420 using type = std::array<std::size_t, 0>;
421 };
422
423 template <class S>
424 struct filter_scalar
425 {
426 using type = S;
427 };
428
429 template <class T>
430 struct filter_scalar<std::array<T, 0>>
431 {
432 using type = fixed_shape<1>;
433 };
434
435 template <class S>
436 using filter_scalar_t = typename filter_scalar<S>::type;
437
438 template <class... S>
439 struct promote_fixed : promote_fixed<filter_scalar_t<S>...>
440 {
441 };
442
443 template <std::size_t... I>
444 struct promote_fixed<fixed_shape<I...>>
445 {
446 using type = fixed_shape<I...>;
447 static constexpr bool value = true;
448 };
449
450 template <std::size_t... I, std::size_t... J, class... S>
451 struct promote_fixed<fixed_shape<I...>, fixed_shape<J...>, S...>
452 {
453 private:
454
455 using intermediate = std::conditional_t<
456 (sizeof...(I) > sizeof...(J)),
457 broadcast_fixed_shape<fixed_shape<J...>, fixed_shape<I...>>,
458 broadcast_fixed_shape<fixed_shape<I...>, fixed_shape<J...>>>;
459 using result = promote_fixed<typename intermediate::type, S...>;
460
461 public:
462
463 using type = typename result::type;
464 static constexpr bool value = xtl::conjunction<intermediate, result>::value;
465 };
466
467 template <bool all_index, bool all_array, class... S>
468 struct select_promote_index;
469
470 template <class... S>
471 struct select_promote_index<true, true, S...> : promote_fixed<S...>
472 {
473 };
474
475 template <>
476 struct select_promote_index<true, true>
477 {
478 // todo correct? used in xvectorize
479 using type = dynamic_shape<std::size_t>;
480 };
481
482 template <class... S>
483 struct select_promote_index<false, true, S...> : promote_array<S...>
484 {
485 };
486
487 template <class... S>
488 struct select_promote_index<false, false, S...>
489 {
490 using type = dynamic_shape<typename std::common_type<typename S::value_type...>::type>;
491 };
492
493 template <class... S>
494 struct promote_index : select_promote_index<only_fixed<S...>::value, only_array<S...>::value, S...>
495 {
496 };
497
498 template <class T>
499 struct index_from_shape_impl
500 {
501 using type = T;
502 };
503
504 template <std::size_t... N>
505 struct index_from_shape_impl<fixed_shape<N...>>
506 {
507 using type = std::array<std::size_t, sizeof...(N)>;
508 };
509 }
510
511 template <class... S>
513 {
514 using type = typename detail::promote_index<S...>::type;
515 };
516
520 template <class... S>
521 using promote_shape_t = typename promote_shape<S...>::type;
522
523 template <class... S>
525 {
526 using type = typename detail::promote_index<S...>::type;
527 };
528
532 template <class... S>
533 using promote_strides_t = typename promote_strides<S...>::type;
534
535 template <class S>
537 {
538 using type = typename detail::index_from_shape_impl<S>::type;
539 };
540
544 template <class S>
545 using index_from_shape_t = typename index_from_shape<S>::type;
546
547 /**********************
548 * filter_fixed_shape *
549 **********************/
550
551 namespace detail
552 {
553 template <class S>
554 struct filter_fixed_shape_impl
555 {
556 using type = S;
557 };
558
559 template <std::size_t... N>
560 struct filter_fixed_shape_impl<fixed_shape<N...>>
561 {
562 using type = std::array<std::size_t, sizeof...(N)>;
563 };
564 }
565
566 template <class S>
567 struct filter_fixed_shape : detail::filter_fixed_shape_impl<S>
568 {
569 };
570
574 template <class S>
575 using filter_fixed_shape_t = typename filter_fixed_shape<S>::type;
576}
577
578#endif
Fixed shape implementation for compile time defined arrays.
bool same_shape(const S1 &s1, const S2 &s2) noexcept
Check if two objects have the same shape.
Definition xshape.hpp:112
bool has_shape(const E &e, std::initializer_list< S > shape) noexcept
Check if an object has a certain shape.
Definition xshape.hpp:130
standard mathematical functions for xexpressions
layout_type
Definition xlayout.hpp:24
Compute a layout based on a layout and a shape type.
Definition xshape.hpp:264