xtensor
Loading...
Searching...
No Matches
xpad.hpp
1/***************************************************************************
2 * Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3 * Copyright (c) QuantStack *
4 * *
5 * Distributed under the terms of the BSD 3-Clause License. *
6 * *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9
10#ifndef XTENSOR_PAD_HPP
11#define XTENSOR_PAD_HPP
12
13#include "xarray.hpp"
14#include "xstrided_view.hpp"
15#include "xtensor.hpp"
16#include "xview.hpp"
17
18using namespace xt::placeholders; // to enable _ syntax
19
20namespace xt
21{
38 enum class pad_mode
39 {
40 constant,
41 symmetric,
42 reflect,
43 wrap,
44 periodic,
45 edge
46 };
47
48 namespace detail
49 {
50 template <class S, class T>
51 inline bool check_pad_width(const std::vector<std::vector<S>>& pad_width, const T& shape)
52 {
53 if (pad_width.size() != shape.size())
54 {
55 return false;
56 }
57
58 return true;
59 }
60 }
61
73 template <class E, class S = typename std::decay_t<E>::size_type, class V = typename std::decay_t<E>::value_type>
74 inline auto
75 pad(E&& e,
76 const std::vector<std::vector<S>>& pad_width,
77 pad_mode mode = pad_mode::constant,
78 V constant_value = 0)
79 {
80 XTENSOR_ASSERT(detail::check_pad_width(pad_width, e.shape()));
81
82 using size_type = typename std::decay_t<E>::size_type;
83 using return_type = temporary_type_t<E>;
84
85 // place the original array in the center
86
87 auto new_shape = e.shape();
89 sv.reserve(e.shape().size());
90 for (size_type axis = 0; axis < e.shape().size(); ++axis)
91 {
92 size_type nb = static_cast<size_type>(pad_width[axis][0]);
93 size_type ne = static_cast<size_type>(pad_width[axis][1]);
94 size_type ns = nb + e.shape(axis) + ne;
95 new_shape[axis] = ns;
96 sv.push_back(xt::range(nb, nb + e.shape(axis)));
97 }
98
99 if (mode == pad_mode::constant)
100 {
101 return_type out(new_shape, constant_value);
103 return out;
104 }
105
106 return_type out(new_shape);
108
109 // construct padded regions based on original image
110
111 xt::xstrided_slice_vector svs(e.shape().size(), xt::all());
112 xt::xstrided_slice_vector svt(e.shape().size(), xt::all());
113
114 for (size_type axis = 0; axis < e.shape().size(); ++axis)
115 {
116 size_type nb = static_cast<size_type>(pad_width[axis][0]);
117 size_type ne = static_cast<size_type>(pad_width[axis][1]);
118
119 if (nb > static_cast<size_type>(0))
120 {
121 svt[axis] = xt::range(0, nb);
122
123 if (mode == pad_mode::wrap || mode == pad_mode::periodic)
124 {
125 XTENSOR_ASSERT(nb <= e.shape(axis));
126 svs[axis] = xt::range(e.shape(axis), nb + e.shape(axis));
128 }
129 else if (mode == pad_mode::symmetric)
130 {
131 XTENSOR_ASSERT(nb <= e.shape(axis));
132 svs[axis] = xt::range(2 * nb - 1, nb - 1, -1);
134 }
135 else if (mode == pad_mode::reflect)
136 {
137 XTENSOR_ASSERT(nb <= e.shape(axis) - 1);
138 svs[axis] = xt::range(2 * nb, nb, -1);
140 }
141 else if (mode == pad_mode::edge)
142 {
143 svs[axis] = xt::range(nb, nb + 1);
146 xt::strided_view(out, svt).shape()
147 );
148 }
149 }
150
151 if (ne > static_cast<size_type>(0))
152 {
153 svt[axis] = xt::range(out.shape(axis) - ne, out.shape(axis));
154
155 if (mode == pad_mode::wrap || mode == pad_mode::periodic)
156 {
157 XTENSOR_ASSERT(ne <= e.shape(axis));
158 svs[axis] = xt::range(nb, nb + ne);
160 }
161 else if (mode == pad_mode::symmetric)
162 {
163 XTENSOR_ASSERT(ne <= e.shape(axis));
164 if (ne == nb + e.shape(axis))
165 {
166 svs[axis] = xt::range(nb + e.shape(axis) - 1, _, -1);
167 }
168 else
169 {
170 svs[axis] = xt::range(nb + e.shape(axis) - 1, nb + e.shape(axis) - ne - 1, -1);
171 }
173 }
174 else if (mode == pad_mode::reflect)
175 {
176 XTENSOR_ASSERT(ne <= e.shape(axis) - 1);
177 if (ne == nb + e.shape(axis) - 1)
178 {
179 svs[axis] = xt::range(nb + e.shape(axis) - 2, _, -1);
180 }
181 else
182 {
183 svs[axis] = xt::range(nb + e.shape(axis) - 2, nb + e.shape(axis) - ne - 2, -1);
184 }
186 }
187 else if (mode == pad_mode::edge)
188 {
189 svs[axis] = xt::range(out.shape(axis) - ne - 1, out.shape(axis) - ne);
192 xt::strided_view(out, svt).shape()
193 );
194 }
195 }
196
197 svs[axis] = xt::all();
198 svt[axis] = xt::all();
199 }
200
201 return out;
202 }
203
215 template <class E, class S = typename std::decay_t<E>::size_type, class V = typename std::decay_t<E>::value_type>
216 inline auto
217 pad(E&& e, const std::vector<S>& pad_width, pad_mode mode = pad_mode::constant, V constant_value = 0)
218 {
219 std::vector<std::vector<S>> pw(e.shape().size(), pad_width);
220
221 return pad(e, pw, mode, constant_value);
222 }
223
234 template <class E, class S = typename std::decay_t<E>::size_type, class V = typename std::decay_t<E>::value_type>
235 inline auto pad(E&& e, S pad_width, pad_mode mode = pad_mode::constant, V constant_value = 0)
236 {
237 std::vector<std::vector<S>> pw(e.shape().size(), {pad_width, pad_width});
238
239 return pad(e, pw, mode, constant_value);
240 }
241
242 namespace detail
243 {
244
245 template <class E, class S>
246 inline auto tile(E&& e, const S& reps)
247 {
248 using size_type = typename std::decay_t<E>::size_type;
249
250 using return_type = temporary_type_t<E>;
251
252 XTENSOR_ASSERT(e.shape().size() == reps.size());
253
254 using new_shape_type = typename return_type::shape_type;
255 auto new_shape = xtl::make_sequence<new_shape_type>(e.shape().size());
256
258
259 for (size_type axis = 0; axis < reps.size(); ++axis)
260 {
261 new_shape[axis] = e.shape(axis) * reps[axis];
262 sv[axis] = xt::range(0, e.shape(axis));
263 }
264 return_type out(new_shape);
265
267
268 xt::xstrided_slice_vector svs(e.shape().size(), xt::all());
269 xt::xstrided_slice_vector svt(e.shape().size(), xt::all());
270
271 for (size_type axis = 0; axis < e.shape().size(); ++axis)
272 {
273 for (size_type i = 1; i < static_cast<size_type>(reps[axis]); ++i)
274 {
275 svs[axis] = xt::range(0, e.shape(axis));
276 svt[axis] = xt::range(i * e.shape(axis), (i + 1) * e.shape(axis));
278 svs[axis] = xt::all();
279 svt[axis] = xt::all();
280 }
281 }
282
283 return out;
284 }
285 }
286
296 inline auto tile(E&& e, std::initializer_list<S> reps)
297 {
298 return detail::tile(std::forward<E>(e), std::vector<S>{reps});
299 }
300
301 template <class E, class C, XTL_REQUIRES(xtl::negation<xtl::is_integral<C>>)>
302 inline auto tile(E&& e, const C& reps)
303 {
304 return detail::tile(std::forward<E>(e), reps);
305 }
306
315 inline auto tile(E&& e, S reps)
316 {
317 std::vector<S> tw(e.shape().size(), static_cast<S>(1));
318 tw[0] = reps;
319 return detail::tile(std::forward<E>(e), tw);
320 }
321}
322
323#endif
standard mathematical functions for xexpressions
auto pad(E &&e, const std::vector< std::vector< S > > &pad_width, pad_mode mode=pad_mode::constant, V constant_value=0)
Pad an array.
Definition xpad.hpp:75
auto broadcast(E &&e, const S &s)
Returns an xexpression broadcasting the given expression to a specified shape.
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
Definition xslice.hpp:818
auto all() noexcept
Returns a slice representing a full dimension, to be used as an argument of view function.
Definition xslice.hpp:234
std::vector< xstrided_slice< std::ptrdiff_t > > xstrided_slice_vector
vector of slices used to build a xstrided_view
auto tile(E &&e, std::initializer_list< S > reps)
Tile an array.
Definition xpad.hpp:296
pad_mode
Defines different algorithms to be used in xt::pad:
Definition xpad.hpp:39
auto strided_view(E &&e, S &&shape, X &&stride, std::size_t offset=0, layout_type layout=L) noexcept
Construct a strided view from an xexpression, shape, strides and offset.