10#ifndef XTENSOR_STRIDES_HPP
11#define XTENSOR_STRIDES_HPP
18#include <xtl/xsequence.hpp>
20#include "../core/xshape.hpp"
21#include "../core/xtensor_config.hpp"
22#include "../core/xtensor_forward.hpp"
23#include "../utils/xexception.hpp"
28 template <
class shape_type>
29 std::size_t compute_size(
const shape_type& shape)
noexcept;
39 template <
class offset_type,
class S>
40 offset_type data_offset(
const S&
strides)
noexcept;
66 template <
class offset_type,
class S,
class Arg,
class... Args>
67 offset_type data_offset(
const S&
strides, Arg
arg, Args... args)
noexcept;
70 offset_type unchecked_data_offset(
const S&
strides, Args... args)
noexcept;
72 template <
class offset_type,
class S,
class It>
73 offset_type element_offset(
const S&
strides, It first, It last)
noexcept;
88 template <layout_type L = layout_type::dynamic,
class shape_type,
class str
ides_type>
91 template <layout_type L = layout_type::dynamic,
class shape_type,
class str
ides_type,
class backstr
ides_type>
95 template <
class shape_type,
class str
ides_type>
96 void adapt_strides(
const shape_type& shape, strides_type&
strides)
noexcept;
98 template <
class shape_type,
class str
ides_type,
class backstr
ides_type>
99 void adapt_strides(
const shape_type& shape, strides_type&
strides, backstrides_type& backstrides)
noexcept;
112 template <
class S,
class T>
113 std::vector<get_strides_t<S>>
120 template <
class S,
class size_type>
121 S uninitialized_shape(size_type size);
123 template <
class S1,
class S2>
124 bool broadcast_shape(
const S1& input, S2& output);
126 template <
class S1,
class S2>
127 bool broadcastable(
const S1& s1, S2& s2);
133 template <layout_type L>
148 template <
class S,
class... Args>
163 template <
class S,
class... Args>
170 template <
class C,
class It,
class size_type>
171 It strided_data_end(
const C& c, It begin,
layout_type l, size_type offset)
173 using difference_type =
typename std::iterator_traits<It>::difference_type;
174 if (c.dimension() == 0)
180 for (std::size_t i = 0; i != c.dimension(); ++i)
182 begin += c.strides()[i] * difference_type(c.shape()[i] - 1);
186 begin += c.strides().back();
192 begin += c.strides().front();
205 template <
class return_type,
class S,
class T,
class D>
206 inline return_type compute_stride_impl(
layout_type layout,
const S& shape, T axis, D default_stride)
210 return std::accumulate(
211 shape.cbegin() + axis + 1,
213 static_cast<return_type
>(1),
214 std::multiplies<return_type>()
219 return std::accumulate(
221 shape.cbegin() + axis,
222 static_cast<return_type
>(1),
223 std::multiplies<return_type>()
226 return default_stride;
252 using strides_type =
typename E::strides_type;
253 using return_type =
typename strides_type::value_type;
254 strides_type ret = e.strides();
255 auto shape = e.shape();
262 for (std::size_t i = 0; i < ret.size(); ++i)
266 ret[i] = detail::compute_stride_impl<return_type>(e.layout(), shape, i, ret[i]);
272 return_type f =
static_cast<return_type
>(
sizeof(
typename E::value_type));
298 using strides_type =
typename E::strides_type;
299 using return_type =
typename strides_type::value_type;
301 return_type ret = e.strides()[axis];
310 if (e.shape(axis) == 1)
312 ret = detail::compute_stride_impl<return_type>(e.layout(), e.shape(), axis, ret);
318 return_type f =
static_cast<return_type
>(
sizeof(
typename E::value_type));
331 template <
class shape_type>
332 inline std::size_t compute_size_impl(
const shape_type& shape, std::true_type )
334 using size_type = std::decay_t<typename shape_type::value_type>;
335 return static_cast<std::size_t
>(std::abs(
336 std::accumulate(shape.cbegin(), shape.cend(), size_type(1), std::multiplies<size_type>())
340 template <
class shape_type>
341 inline std::size_t compute_size_impl(
const shape_type& shape, std::false_type )
343 using size_type = std::decay_t<typename shape_type::value_type>;
344 return static_cast<std::size_t
>(
345 std::accumulate(shape.cbegin(), shape.cend(), size_type(1), std::multiplies<size_type>())
350 template <
class shape_type>
351 inline std::size_t compute_size(
const shape_type& shape)
noexcept
353 return detail::compute_size_impl(
355 xtl::is_signed<std::decay_t<
typename std::decay_t<shape_type>::value_type>>()
362 template <std::
size_t dim,
class S>
363 inline auto raw_data_offset(
const S&)
noexcept
365 using strides_value_type = std::decay_t<decltype(std::declval<S>()[0])>;
366 return strides_value_type(0);
369 template <std::
size_t dim,
class S>
370 inline auto raw_data_offset(
const S&, missing_type)
noexcept
372 using strides_value_type = std::decay_t<decltype(std::declval<S>()[0])>;
373 return strides_value_type(0);
376 template <std::size_t dim,
class S,
class Arg,
class... Args>
377 inline auto raw_data_offset(
const S&
strides, Arg
arg, Args... args)
noexcept
379 return static_cast<std::ptrdiff_t
>(
arg) *
strides[dim] + raw_data_offset<dim + 1>(
strides, args...);
382 template <layout_type L, std::ptrdiff_t static_dim>
383 struct layout_data_offset
385 template <std::size_t dim,
class S,
class Arg,
class... Args>
386 inline static auto run(
const S&
strides, Arg
arg, Args... args)
noexcept
388 return raw_data_offset<dim>(
strides,
arg, args...);
392 template <std::ptrdiff_t static_dim>
395 using self_type = layout_data_offset<layout_type::row_major, static_dim>;
397 template <std::
size_t dim,
class S,
class Arg>
398 inline static auto run(
const S&
strides, Arg
arg)
noexcept
400 if (std::ptrdiff_t(dim) + 1 == static_dim)
410 template <std::size_t dim,
class S,
class Arg,
class... Args>
411 inline static auto run(
const S&
strides, Arg
arg, Args... args)
noexcept
417 template <std::ptrdiff_t static_dim>
420 using self_type = layout_data_offset<layout_type::column_major, static_dim>;
422 template <std::
size_t dim,
class S,
class Arg>
423 inline static auto run(
const S&
strides, Arg
arg)
noexcept
435 template <std::size_t dim,
class S,
class Arg,
class... Args>
436 inline static auto run(
const S&
strides, Arg
arg, Args... args)
noexcept
440 return arg + self_type::template run<dim + 1>(
strides, args...);
450 template <
class offset_type,
class S>
451 inline offset_type data_offset(
const S&)
noexcept
453 return offset_type(0);
456 template <
class offset_type,
class S,
class Arg,
class... Args>
457 inline offset_type data_offset(
const S&
strides, Arg
arg, Args... args)
noexcept
459 constexpr std::size_t nargs =
sizeof...(Args) + 1;
463 return static_cast<offset_type
>(detail::raw_data_offset<0>(
strides,
arg, args...));
465 else if (nargs >
strides.size())
468 return data_offset<offset_type, S>(
strides, args...);
470 else if (detail::last_type_is_missing<Args...>)
473 return static_cast<offset_type
>(detail::raw_data_offset<0>(
strides,
arg, args...));
479 return static_cast<offset_type
>(detail::raw_data_offset<0>(
view,
arg, args...));
483 template <
class offset_type,
layout_type L,
class S,
class... Args>
484 inline offset_type unchecked_data_offset(
const S&
strides, Args... args)
noexcept
486 return static_cast<offset_type
>(
487 detail::layout_data_offset<L, static_dimension<S>::value>::template run<0>(
strides.cbegin(), args...)
491 template <
class offset_type,
class S,
class It>
492 inline offset_type element_offset(
const S&
strides, It first, It last)
noexcept
494 using difference_type =
typename std::iterator_traits<It>::difference_type;
495 auto size =
static_cast<difference_type
>(
496 (std::min)(
static_cast<typename S::size_type
>(std::distance(first, last)),
strides.size())
498 return std::inner_product(last - size, last,
strides.cend() - size, offset_type(0));
503 template <
class shape_type,
class str
ides_type,
class bs_ptr>
504 inline void adapt_strides(
505 const shape_type& shape,
508 typename strides_type::size_type i
515 (*backstrides)[i] =
strides[i] * std::ptrdiff_t(shape[i] - 1);
518 template <
class shape_type,
class str
ides_type>
519 inline void adapt_strides(
520 const shape_type& shape,
523 typename strides_type::size_type i
532 template <layout_type L,
class shape_type,
class str
ides_type,
class bs_ptr>
534 compute_strides(
const shape_type& shape,
layout_type l, strides_type&
strides, bs_ptr bs)
536 using strides_value_type =
typename std::decay_t<strides_type>::value_type;
537 strides_value_type data_size = 1;
539#if defined(_MSC_VER) && (1931 <= _MSC_VER)
541 if (0 == shape.size())
543 return static_cast<std::size_t
>(data_size);
549 for (std::size_t i = shape.size(); i != 0; --i)
552 data_size =
strides[i - 1] *
static_cast<strides_value_type
>(shape[i - 1]);
553 adapt_strides(shape,
strides, bs, i - 1);
558 for (std::size_t i = 0; i < shape.size(); ++i)
561 data_size =
strides[i] *
static_cast<strides_value_type
>(shape[i]);
562 adapt_strides(shape,
strides, bs, i);
565 return static_cast<std::size_t
>(data_size);
569 template <layout_type L,
class shape_type,
class str
ides_type>
572 return detail::compute_strides<L>(shape, l,
strides,
nullptr);
575 template <layout_type L,
class shape_type,
class str
ides_type,
class backstr
ides_type>
579 return detail::compute_strides<L>(shape, l,
strides, &backstrides);
582 template <
class T1,
class T2>
584 stride_match_condition(
const T1& stride,
const T2& shape,
const T1& data_size,
bool zero_strides)
586 return (shape == T2(1) && stride == T1(0) && zero_strides) || (stride == data_size);
590 template <
class shape_type,
class str
ides_type>
592 do_strides_match(
const shape_type& shape,
const strides_type&
strides,
layout_type l,
bool zero_strides)
594 using value_type =
typename strides_type::value_type;
595 value_type data_size = 1;
598 for (std::size_t i =
strides.size(); i != 0; --i)
600 if (!stride_match_condition(
strides[i - 1], shape[i - 1], data_size, zero_strides))
604 data_size *=
static_cast<value_type
>(shape[i - 1]);
610 for (std::size_t i = 0; i <
strides.size(); ++i)
612 if (!stride_match_condition(
strides[i], shape[i], data_size, zero_strides))
616 data_size *=
static_cast<value_type
>(shape[i]);
626 template <
class shape_type,
class str
ides_type>
627 inline void adapt_strides(
const shape_type& shape, strides_type&
strides)
noexcept
629 for (
typename shape_type::size_type i = 0; i < shape.size(); ++i)
631 detail::adapt_strides(shape,
strides,
nullptr, i);
635 template <
class shape_type,
class str
ides_type,
class backstr
ides_type>
637 adapt_strides(
const shape_type& shape, strides_type&
strides, backstrides_type& backstrides)
noexcept
639 for (
typename shape_type::size_type i = 0; i < shape.size(); ++i)
641 detail::adapt_strides(shape,
strides, &backstrides, i);
648 inline S unravel_noexcept(
typename S::value_type idx,
const S&
strides,
layout_type l)
noexcept
650 using value_type =
typename S::value_type;
651 using size_type =
typename S::size_type;
652 S result = xtl::make_sequence<S>(
strides.size(), 0);
655 for (size_type i = 0; i <
strides.size(); ++i)
658 value_type quot = str != 0 ? idx / str : 0;
659 idx = str != 0 ? idx % str : idx;
665 for (size_type i =
strides.size(); i != 0; --i)
667 value_type str =
strides[i - 1];
668 value_type quot = str != 0 ? idx / str : 0;
669 idx = str != 0 ? idx % str : idx;
670 result[i - 1] = quot;
678 inline S unravel_from_strides(
typename S::value_type index,
const S&
strides,
layout_type l)
682 XTENSOR_THROW(std::runtime_error,
"unravel_index: dynamic layout not supported");
684 return detail::unravel_noexcept(index,
strides, l);
687 template <
class S,
class T>
688 inline get_value_type_t<T> ravel_from_strides(
const T& index,
const S&
strides)
690 return element_offset<get_value_type_t<T>>(
strides, index.begin(), index.end());
694 inline get_strides_t<S> unravel_index(
typename S::value_type index,
const S& shape,
layout_type l)
696 using strides_type = get_strides_t<S>;
697 using strides_value_type =
typename strides_type::value_type;
698 strides_type
strides = xtl::make_sequence<strides_type>(shape.size(), 0);
700 return unravel_from_strides(
static_cast<strides_value_type
>(index),
strides, l);
703 template <
class S,
class T>
704 inline std::vector<get_strides_t<S>> unravel_indices(
const T& idx,
const S& shape,
layout_type l)
706 using strides_type = get_strides_t<S>;
707 using strides_value_type =
typename strides_type::value_type;
708 strides_type
strides = xtl::make_sequence<strides_type>(shape.size(), 0);
710 std::vector<get_strides_t<S>> out(idx.size());
711 auto out_iter = out.begin();
712 auto idx_iter = idx.begin();
713 for (; out_iter != out.end(); ++out_iter, ++idx_iter)
715 *out_iter = unravel_from_strides(
static_cast<strides_value_type
>(*idx_iter),
strides, l);
720 template <
class S,
class T>
721 inline get_value_type_t<T> ravel_index(
const T& index,
const S& shape,
layout_type l)
723 using strides_type = get_strides_t<S>;
724 strides_type
strides = xtl::make_sequence<strides_type>(shape.size(), 0);
726 return ravel_from_strides(index,
strides);
729 template <
class S,
class stype>
730 inline S uninitialized_shape(stype size)
732 using value_type =
typename S::value_type;
733 using size_type =
typename S::size_type;
734 return xtl::make_sequence<S>(
static_cast<size_type
>(size), std::numeric_limits<value_type>::max());
737 template <
class S1,
class S2>
738 inline bool broadcast_shape(
const S1& input, S2& output)
740 bool trivial_broadcast = (input.size() == output.size());
742 using value_type =
typename S2::value_type;
743 auto output_index = output.size();
744 auto input_index = input.size();
746 if (output_index < input_index)
748 throw_broadcast_error(output, input);
750 for (; input_index != 0; --input_index, --output_index)
755 if (output[output_index - 1] == std::numeric_limits<value_type>::max())
757 output[output_index - 1] =
static_cast<value_type
>(input[input_index - 1]);
761 else if (output[output_index - 1] == 1)
763 output[output_index - 1] =
static_cast<value_type
>(input[input_index - 1]);
764 trivial_broadcast = trivial_broadcast && (input[input_index - 1] == 1);
768 else if (input[input_index - 1] == 1)
770 trivial_broadcast =
false;
774 else if (
static_cast<value_type
>(input[input_index - 1]) != output[output_index - 1])
776 throw_broadcast_error(output, input);
779 return trivial_broadcast;
782 template <
class S1,
class S2>
783 inline bool broadcastable(
const S1& src_shape,
const S2& dst_shape)
785 auto src_iter = src_shape.crbegin();
786 auto dst_iter = dst_shape.crbegin();
787 bool res = dst_shape.size() >= src_shape.size();
788 for (; src_iter != src_shape.crend() && res; ++src_iter, ++dst_iter)
790 res = (
static_cast<std::size_t
>(*src_iter) ==
static_cast<std::size_t
>(*dst_iter))
799 template <
class S1,
class S2>
800 static std::size_t get(
const S1& s1,
const S2& s2)
802 using value_type =
typename S1::value_type;
804 auto s1_index = s1.size();
805 auto s2_index = s2.size();
807 for (; s2_index != 0; --s1_index, --s2_index)
809 if (
static_cast<value_type
>(s1[s1_index - 1]) !=
static_cast<value_type
>(s2[s2_index - 1]))
821 template <
class S1,
class S2>
822 static std::size_t get(
const S1& s1,
const S2& s2)
825 using size_type =
typename S1::size_type;
826 using value_type =
typename S1::value_type;
831 if (s1.size() != s2.size())
836 auto size = s2.size();
838 for (; index < size; ++index)
840 if (
static_cast<value_type
>(s1[index]) !=
static_cast<value_type
>(s2[index]))
851 template <
class S, std::
size_t dim>
852 inline bool check_in_bounds_impl(
const S&)
857 template <
class S, std::
size_t dim>
858 inline bool check_in_bounds_impl(
const S&, missing_type)
863 template <
class S, std::size_t dim,
class T,
class... Args>
864 inline bool check_in_bounds_impl(
const S& shape, T&
arg, Args&... args)
866 if (
sizeof...(Args) + 1 > shape.size())
868 return check_in_bounds_impl<S, dim>(shape, args...);
873 && check_in_bounds_impl<S, dim + 1>(shape, args...);
878 template <
class S,
class... Args>
879 inline bool check_in_bounds(
const S& shape, Args&... args)
881 return detail::check_in_bounds_impl<S, 0>(shape, args...);
886 template <
class S, std::
size_t dim>
887 inline void normalize_periodic_impl(
const S&)
891 template <
class S, std::
size_t dim>
892 inline void normalize_periodic_impl(
const S&, missing_type)
896 template <
class S, std::size_t dim,
class T,
class... Args>
897 inline void normalize_periodic_impl(
const S& shape, T&
arg, Args&... args)
899 if (
sizeof...(Args) + 1 > shape.size())
901 normalize_periodic_impl<S, dim>(shape, args...);
905 T n =
static_cast<T
>(shape[dim]);
906 arg = (n + (
arg % n)) % n;
907 normalize_periodic_impl<S, dim + 1>(shape, args...);
912 template <
class S,
class... Args>
915 check_dimension(shape, args...);
916 detail::normalize_periodic_impl<S, 0>(shape, args...);
auto arg(E &&e) noexcept
Calculates the phase angle (in radians) elementwise for the complex numbers in e.
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
stride_type
Choose stride type.
void normalize_periodic(const S &shape, Args &... args)
Normalise an index of a periodic array.
@ bytes
Normal stride in bytes.
@ internal
As used internally (with stride(axis) == 0 if shape(axis) == 1)
@ normal
Normal stride corresponding to storage.
standard mathematical functions for xexpressions
bool in_bounds(const S &shape, Args &... args)
Check if the index is within the bounds of the array.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.