xtensor
Loading...
Searching...
No Matches
xfft.hpp
1#ifdef XTENSOR_USE_TBB
2#include <oneapi/tbb.h>
3#endif
4#include <stdexcept>
5
6#include <xtl/xcomplex.hpp>
7
8#include <xtensor/xarray.hpp>
9#include <xtensor/xaxis_slice_iterator.hpp>
10#include <xtensor/xbuilder.hpp>
11#include <xtensor/xcomplex.hpp>
12#include <xtensor/xmath.hpp>
13#include <xtensor/xnoalias.hpp>
14#include <xtensor/xview.hpp>
15
16namespace xt
17{
18 namespace fft
19 {
20 namespace detail
21 {
22 template <
23 class E,
24 typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
25 inline auto radix2(E&& e)
26 {
27 using namespace xt::placeholders;
28 using namespace std::complex_literals;
29 using value_type = typename std::decay_t<E>::value_type;
30 using precision = typename value_type::value_type;
31 auto N = e.size();
32 const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
33 // check for power of 2
34 if (!powerOfTwo || N == 0)
35 {
36 // TODO: Replace implementation with dft
37 XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2");
38 }
41 if (N <= 1)
42 {
43 return ev;
44 }
45 else
46 {
47#ifdef XTENSOR_USE_TBB
50 oneapi::tbb::parallel_invoke(
51 [&]
52 {
53 even = radix2(xt::view(ev, xt::range(0, _, 2)));
54 },
55 [&]
56 {
57 odd = radix2(xt::view(ev, xt::range(1, _, 2)));
58 }
59 );
60#else
61 auto even = radix2(xt::view(ev, xt::range(0, _, 2)));
62 auto odd = radix2(xt::view(ev, xt::range(1, _, 2)));
63#endif
64
65 auto range = xt::arange<double>(N / 2);
66 auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N);
67 auto t = exp * odd;
68 auto first_half = even + t;
69 auto second_half = even - t;
70 // TODO: should be a call to stack if performance was improved
74 return spectrum;
75 }
76 }
77
78 template <typename E>
79 auto transform_bluestein(E&& data)
80 {
81 using value_type = typename std::decay_t<E>::value_type;
82 using precision = typename value_type::value_type;
83
84 // Find a power-of-2 convolution length m such that m >= n * 2 + 1
85 const std::size_t n = data.size();
86 size_t m = std::ceil(std::log2(n * 2 + 1));
87 m = std::pow(2, m);
88
89 // Trignometric table
90 auto exp_table = xt::xtensor<std::complex<precision>, 1>::from_shape({n});
92 i %= (n * 2);
93
94 auto angles = xt::eval(precision{3.141592653589793238463} * i / n);
95 auto j = std::complex<precision>(0, 1);
97
98 // Temporary vectors and preprocessing
100 xt::view(av, xt::range(0, n)) = data * exp_table;
101
102
105 xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view(
107 xt::range(xt::placeholders::_, -1)
108 );
109
110 // Convolution
111 auto xv = radix2(av);
112 auto yv = radix2(bv);
113 auto spectrum_k = xv * yv;
115 auto fft_res = radix2(complex_args);
116 auto cv = xt::conj(fft_res) / m;
117
118 return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table);
119 }
120 } // namespace detail
121
128 template <
129 class E,
130 typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
131 inline auto fft(E&& e, std::ptrdiff_t axis = -1)
132 {
133 using value_type = typename std::decay_t<E>::value_type;
134 using precision = typename value_type::value_type;
135 const auto saxis = xt::normalize_axis(e.dimension(), axis);
136 const size_t N = e.shape(saxis);
137 const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
139 auto begin = xt::axis_slice_begin(out, saxis);
140 auto end = xt::axis_slice_end(out, saxis);
141 for (auto iter = begin; iter != end; iter++)
142 {
143 if (powerOfTwo)
144 {
145 xt::noalias(*iter) = detail::radix2(*iter);
146 }
147 else
148 {
149 xt::noalias(*iter) = detail::transform_bluestein(*iter);
150 }
151 }
152 return out;
153 }
154
161 template <
162 class E,
163 typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
164 inline auto fft(E&& e, std::ptrdiff_t axis = -1)
165 {
166 using value_type = typename std::decay<E>::type::value_type;
167 return fft(xt::cast<std::complex<value_type>>(e), axis);
168 }
169
170 template <
171 class E,
172 typename std::enable_if<xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
173 auto ifft(E&& e, std::ptrdiff_t axis = -1)
174 {
175 // check the length of the data on that axis
176 const std::size_t n = e.shape(axis);
177 if (n == 0)
178 {
179 XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
180 }
181 auto complex_args = xt::conj(e);
182 auto fft_res = xt::fft::fft(complex_args, axis);
184 return fft_res;
185 }
186
187 template <
188 class E,
189 typename std::enable_if<!xtl::is_complex<typename std::decay<E>::type::value_type>::value, bool>::type = true>
190 inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
191 {
192 using value_type = typename std::decay<E>::type::value_type;
193 return ifft(xt::cast<std::complex<value_type>>(e), axis);
194 }
195
196 /*
197 * @brief performs a circular fft convolution xvec and yvec must
198 * be the same shape.
199 * @param xvec first array of the convolution
200 * @param yvec second array of the convolution
201 * @param axis axis along which to perform the convolution
202 */
203 template <typename E1, typename E2>
204 auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1)
205 {
206 // we could broadcast but that could get complicated???
207 if (xvec.dimension() != yvec.dimension())
208 {
209 XTENSOR_THROW(std::runtime_error, "Mismatched dimentions");
210 }
211
212 auto saxis = xt::normalize_axis(xvec.dimension(), axis);
213 if (xvec.shape(saxis) != yvec.shape(saxis))
214 {
215 XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis");
216 }
217
218 const std::size_t n = xvec.shape(saxis);
219
220 auto xv = fft(xvec, axis);
221 auto yv = fft(yvec, axis);
222
226
227 for (auto iter = begin_x; iter != end_x; iter++)
228 {
229 (*iter) = (*iter_y++) * (*iter);
230 }
231
232 auto outvec = ifft(xv, axis);
233
234 // Scaling (because this FFT implementation omits it)
235 outvec = outvec / n;
236
237 return outvec;
238 }
239
240 }
241} // namespace xt::fft
auto cast(E &&e) noexcept -> detail::xfunction_type_t< typename detail::cast< R >::functor, E >
Element-wise static_cast.
auto exp(E &&e) noexcept -> detail::xfunction_type_t< math::exp_fun, E >
Natural exponential function.
Definition xmath.hpp:900
auto pow(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::pow_fun, E1, E2 >
Power function.
Definition xmath.hpp:1015
auto conj(E &&e) noexcept
Return an xt::xfunction evaluating to the complex conjugate of the given expression.
Definition xcomplex.hpp:207
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
Definition xeval.hpp:46
auto flip(E &&e)
Reverse the order of elements in an xexpression along every axis.
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
Definition xslice.hpp:818
auto axis_slice_begin(E &&e)
Returns an iterator to the first element of the expression for axis 0.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.
Definition xview.hpp:1834
auto axis_slice_end(E &&e)
Returns an iterator to the element following the last element of the expression for axis 0.