xtensor
 
Loading...
Searching...
No Matches
xfft.hpp
1#ifdef XTENSOR_USE_TBB
2#include <oneapi/tbb.h>
3#endif
4#include <stdexcept>
5
6#include <xtl/xcomplex.hpp>
7
8#include "../containers/xarray.hpp"
9#include "../core/xmath.hpp"
10#include "../core/xnoalias.hpp"
11#include "../generators/xbuilder.hpp"
12#include "../misc/xcomplex.hpp"
13#include "../views/xaxis_slice_iterator.hpp"
14#include "../views/xview.hpp"
15#include "./xtl_concepts.hpp"
16
17namespace xt
18{
19 namespace fft
20 {
21 namespace detail
22 {
23 template <xtl::complex_concept E>
24 inline auto radix2(E&& e)
25 {
26 using namespace xt::placeholders;
27 using namespace std::complex_literals;
28 using value_type = typename std::decay_t<E>::value_type;
29 using precision = typename value_type::value_type;
30 auto N = e.size();
31 const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
32 // check for power of 2
33 if (!powerOfTwo || N == 0)
34 {
35 // TODO: Replace implementation with dft
36 XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2");
37 }
38 auto pi = xt::numeric_constants<precision>::PI;
40 if (N <= 1)
41 {
42 return ev;
43 }
44 else
45 {
46#ifdef XTENSOR_USE_TBB
49 oneapi::tbb::parallel_invoke(
50 [&]
51 {
52 even = radix2(xt::view(ev, xt::range(0, _, 2)));
53 },
54 [&]
55 {
56 odd = radix2(xt::view(ev, xt::range(1, _, 2)));
57 }
58 );
59#else
60 auto even = radix2(xt::view(ev, xt::range(0, _, 2)));
61 auto odd = radix2(xt::view(ev, xt::range(1, _, 2)));
62#endif
63
64 auto range = xt::arange<double>(N / 2);
65 auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N);
66 auto t = exp * odd;
67 auto first_half = even + t;
68 auto second_half = even - t;
69 // TODO: should be a call to stack if performance was improved
70 auto spectrum = xt::xtensor<value_type, 1>::from_shape({N});
71 xt::view(spectrum, xt::range(0, N / 2)) = first_half;
72 xt::view(spectrum, xt::range(N / 2, N)) = second_half;
73 return spectrum;
74 }
75 }
76
77 template <typename E>
78 auto transform_bluestein(E&& data)
79 {
80 using value_type = typename std::decay_t<E>::value_type;
81 using precision = typename value_type::value_type;
82
83 // Find a power-of-2 convolution length m such that m >= n * 2 + 1
84 const std::size_t n = data.size();
85 size_t m = std::ceil(std::log2(n * 2 + 1));
86 m = std::pow(2, m);
87
88 // Trignometric table
89 auto exp_table = xt::xtensor<std::complex<precision>, 1>::from_shape({n});
90 xt::xtensor<std::size_t, 1> i = xt::pow(xt::linspace<std::size_t>(0, n - 1, n), 2);
91 i %= (n * 2);
92
93 auto angles = xt::eval(precision{3.141592653589793238463} * i / n);
94 auto j = std::complex<precision>(0, 1);
95 exp_table = xt::exp(-angles * j);
96
97 // Temporary vectors and preprocessing
99 xt::view(av, xt::range(0, n)) = data * exp_table;
100
101
102 auto bv = xt::empty<std::complex<precision>>({m});
103 xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table);
104 xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view(
105 ::xt::conj(xt::flip(exp_table)),
106 xt::range(xt::placeholders::_, -1)
107 );
108
109 // Convolution
110 auto xv = radix2(av);
111 auto yv = radix2(bv);
112 auto spectrum_k = xv * yv;
113 auto complex_args = xt::conj(spectrum_k);
114 auto fft_res = radix2(complex_args);
115 auto cv = xt::conj(fft_res) / m;
116
117 return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table);
118 }
119 } // namespace detail
120
127 template <class E>
128 inline auto fft(E&& e, std::ptrdiff_t axis = -1)
129 {
130 using value_type = typename std::decay<E>::type::value_type;
131 if constexpr (xtl::is_complex<typename std::decay<E>::type::value_type>::value)
132 {
133 using precision = typename value_type::value_type;
134 const auto saxis = xt::normalize_axis(e.dimension(), axis);
135 const size_t N = e.shape(saxis);
136 const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
138 auto begin = xt::axis_slice_begin(out, saxis);
139 auto end = xt::axis_slice_end(out, saxis);
140 for (auto iter = begin; iter != end; iter++)
141 {
142 if (powerOfTwo)
143 {
144 xt::noalias(*iter) = detail::radix2(*iter);
145 }
146 else
147 {
148 xt::noalias(*iter) = detail::transform_bluestein(*iter);
149 }
150 }
151 return out;
152 }
153 else
154 {
155 return fft(xt::cast<std::complex<value_type>>(e), axis);
156 }
157 }
158
159 template <class E>
160 inline auto ifft(E&& e, std::ptrdiff_t axis = -1)
161 {
162 if constexpr (xtl::is_complex<typename std::decay<E>::type::value_type>::value)
163 {
164 // check the length of the data on that axis
165 const std::size_t n = e.shape(axis);
166 if (n == 0)
167 {
168 XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention");
169 }
170 auto complex_args = xt::conj(e);
171 auto fft_res = xt::fft::fft(complex_args, axis);
172 fft_res = xt::conj(fft_res);
173 return fft_res;
174 }
175 else
176 {
177 using value_type = typename std::decay<E>::type::value_type;
178 return ifft(xt::cast<std::complex<value_type>>(e), axis);
179 }
180 }
181
182 /*
183 * @brief performs a circular fft convolution xvec and yvec must
184 * be the same shape.
185 * @param xvec first array of the convolution
186 * @param yvec second array of the convolution
187 * @param axis axis along which to perform the convolution
188 */
189 template <typename E1, typename E2>
190 auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1)
191 {
192 // we could broadcast but that could get complicated???
193 if (xvec.dimension() != yvec.dimension())
194 {
195 XTENSOR_THROW(std::runtime_error, "Mismatched dimentions");
196 }
197
198 auto saxis = xt::normalize_axis(xvec.dimension(), axis);
199 if (xvec.shape(saxis) != yvec.shape(saxis))
200 {
201 XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis");
202 }
203
204 const std::size_t n = xvec.shape(saxis);
205
206 auto xv = fft(xvec, axis);
207 auto yv = fft(yvec, axis);
208
209 auto begin_x = xt::axis_slice_begin(xv, saxis);
210 auto end_x = xt::axis_slice_end(xv, saxis);
211 auto iter_y = xt::axis_slice_begin(yv, saxis);
212
213 for (auto iter = begin_x; iter != end_x; iter++)
214 {
215 (*iter) = (*iter_y++) * (*iter);
216 }
217
218 auto outvec = ifft(xv, axis);
219
220 // Scaling (because this FFT implementation omits it)
221 outvec = outvec / n;
222
223 return outvec;
224 }
225
226 }
227} // namespace xt::fft
size_type size() const noexcept
Returns the number of element in the container.
auto cast(E &&e) noexcept -> detail::xfunction_type_t< typename detail::cast< R >::functor, E >
Element-wise static_cast.
auto exp(E &&e) noexcept -> detail::xfunction_type_t< math::exp_fun, E >
Natural exponential function.
Definition xmath.hpp:900
auto pow(E1 &&e1, E2 &&e2) noexcept -> detail::xfunction_type_t< math::pow_fun, E1, E2 >
Power function.
Definition xmath.hpp:1015
auto conj(E &&e) noexcept
Return an xt::xfunction evaluating to the complex conjugate of the given expression.
Definition xcomplex.hpp:207
auto eval(T &&t) -> std::enable_if_t< detail::is_container< std::decay_t< T > >::value, T && >
Force evaluation of xexpression.
Definition xeval.hpp:46
auto flip(E &&e)
Reverse the order of elements in an xexpression along every axis.
standard mathematical functions for xexpressions
auto range(A start_val, B stop_val)
Select a range from start_val to stop_val (excluded).
Definition xslice.hpp:818
auto arange(T start, T stop, S step=1) noexcept
Generates numbers evenly spaced within given half-open interval [start, stop).
Definition xbuilder.hpp:432
xarray_container< uvector< T, A >, L, xt::svector< typename uvector< T, A >::size_type, 4, SA, true > > xarray
Alias template on xarray_container with default parameters for data container type and shape / stride...
xtensor_container< uvector< T, A >, N, L > xtensor
Alias template on xtensor_container with default parameters for data container type.
auto axis_slice_begin(E &&e)
Returns an iterator to the first element of the expression for axis 0.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.
Definition xview.hpp:1821
auto axis_slice_end(E &&e)
Returns an iterator to the element following the last element of the expression for axis 0.
xarray< T, L > empty(const S &shape)
Create a xcontainer (xarray, xtensor or xtensor_fixed) with uninitialized values of with value_type T...
Definition xbuilder.hpp:89