10#ifndef XTENSOR_PAD_HPP
11#define XTENSOR_PAD_HPP
14#include "xstrided_view.hpp"
18using namespace xt::placeholders;
50 template <
class S,
class T>
51 inline bool check_pad_width(
const std::vector<std::vector<S>>& pad_width,
const T& shape)
53 if (pad_width.size() != shape.size())
73 template <class E, class S = typename std::decay_t<E>::size_type,
class V =
typename std::decay_t<E>::value_type>
76 const std::vector<std::vector<S>>& pad_width,
80 XTENSOR_ASSERT(detail::check_pad_width(pad_width, e.shape()));
82 using size_type =
typename std::decay_t<E>::size_type;
83 using return_type = temporary_type_t<E>;
87 auto new_shape = e.shape();
89 sv.reserve(e.shape().size());
90 for (size_type axis = 0; axis < e.shape().size(); ++axis)
92 size_type nb =
static_cast<size_type
>(pad_width[axis][0]);
93 size_type ne =
static_cast<size_type
>(pad_width[axis][1]);
94 size_type ns = nb + e.shape(axis) + ne;
96 sv.push_back(
xt::range(nb, nb + e.shape(axis)));
99 if (mode == pad_mode::constant)
101 return_type out(new_shape, constant_value);
106 return_type out(new_shape);
114 for (size_type axis = 0; axis < e.shape().size(); ++axis)
116 size_type nb =
static_cast<size_type
>(pad_width[axis][0]);
117 size_type ne =
static_cast<size_type
>(pad_width[axis][1]);
119 if (nb >
static_cast<size_type
>(0))
123 if (mode == pad_mode::wrap || mode == pad_mode::periodic)
125 XTENSOR_ASSERT(nb <= e.shape(axis));
126 svs[axis] =
xt::range(e.shape(axis), nb + e.shape(axis));
129 else if (mode == pad_mode::symmetric)
131 XTENSOR_ASSERT(nb <= e.shape(axis));
132 svs[axis] =
xt::range(2 * nb - 1, nb - 1, -1);
135 else if (mode == pad_mode::reflect)
137 XTENSOR_ASSERT(nb <= e.shape(axis) - 1);
141 else if (mode == pad_mode::edge)
151 if (ne >
static_cast<size_type
>(0))
153 svt[axis] =
xt::range(out.shape(axis) - ne, out.shape(axis));
155 if (mode == pad_mode::wrap || mode == pad_mode::periodic)
157 XTENSOR_ASSERT(ne <= e.shape(axis));
161 else if (mode == pad_mode::symmetric)
163 XTENSOR_ASSERT(ne <= e.shape(axis));
164 if (ne == nb + e.shape(axis))
166 svs[axis] =
xt::range(nb + e.shape(axis) - 1, _, -1);
170 svs[axis] =
xt::range(nb + e.shape(axis) - 1, nb + e.shape(axis) - ne - 1, -1);
174 else if (mode == pad_mode::reflect)
176 XTENSOR_ASSERT(ne <= e.shape(axis) - 1);
177 if (ne == nb + e.shape(axis) - 1)
179 svs[axis] =
xt::range(nb + e.shape(axis) - 2, _, -1);
183 svs[axis] =
xt::range(nb + e.shape(axis) - 2, nb + e.shape(axis) - ne - 2, -1);
187 else if (mode == pad_mode::edge)
189 svs[axis] =
xt::range(out.shape(axis) - ne - 1, out.shape(axis) - ne);
215 template <class E, class S = typename std::decay_t<E>::size_type,
class V =
typename std::decay_t<E>::value_type>
217 pad(E&& e,
const std::vector<S>& pad_width,
pad_mode mode = pad_mode::constant, V constant_value = 0)
219 std::vector<std::vector<S>> pw(e.shape().size(), pad_width);
221 return pad(e, pw, mode, constant_value);
234 template <class E, class S = typename std::decay_t<E>::size_type,
class V =
typename std::decay_t<E>::value_type>
235 inline auto pad(E&& e, S pad_width,
pad_mode mode = pad_mode::constant, V constant_value = 0)
237 std::vector<std::vector<S>> pw(e.shape().size(), {pad_width, pad_width});
239 return pad(e, pw, mode, constant_value);
245 template <
class E,
class S>
246 inline auto tile(E&& e,
const S& reps)
248 using size_type =
typename std::decay_t<E>::size_type;
250 using return_type = temporary_type_t<E>;
252 XTENSOR_ASSERT(e.shape().size() == reps.size());
254 using new_shape_type =
typename return_type::shape_type;
255 auto new_shape = xtl::make_sequence<new_shape_type>(e.shape().size());
259 for (size_type axis = 0; axis < reps.size(); ++axis)
261 new_shape[axis] = e.shape(axis) * reps[axis];
264 return_type out(new_shape);
271 for (size_type axis = 0; axis < e.shape().size(); ++axis)
273 for (size_type i = 1; i < static_cast<size_type>(reps[axis]); ++i)
276 svt[axis] =
xt::range(i * e.shape(axis), (i + 1) * e.shape(axis));
295 template <class E, class S = typename std::decay_t<E>::size_type>
296 inline auto tile(E&& e, std::initializer_list<S> reps)
298 return detail::tile(std::forward<E>(e), std::vector<S>{reps});
301 template <
class E,
class C, XTL_REQUIRES(xtl::negation<xtl::is_
integral<C>>)>
302 inline auto tile(E&& e,
const C& reps)
304 return detail::tile(std::forward<E>(e), reps);
314 template <class E, class S = typename std::decay_t<E>::size_type, XTL_REQUIRES(xtl::is_integral<S>)>
315 inline auto tile(E&& e, S reps)
317 std::vector<S> tw(e.shape().size(),
static_cast<S
>(1));
319 return detail::tile(std::forward<E>(e), tw);
standard mathematical functions for xexpressions
auto pad(E &&e, const std::vector< std::vector< S > > &pad_width, pad_mode mode=pad_mode::constant, V constant_value=0)
Pad an array.
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
auto all() noexcept
Returns a slice representing a full dimension, to be used as an argument of view function.
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto tile(E &&e, std::initializer_list< S > reps)
Tile an array.
pad_mode
Defines different algorithms to be used in xt::pad:
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.