cai_lw's competitive programming library
 
Loading...
Searching...
No Matches
fft.hpp
1#pragma once
2
3#include <cmath>
4#include <complex>
5#include <iterator>
6#include <vector>
7
8#include "cplib/num/mmint.hpp"
9#include "cplib/num/pow.hpp"
10#include "cplib/port/bit.hpp"
11
12namespace cplib {
13
14namespace impl {
15
16template <typename T>
17std::vector<T> twiddling_factors(const std::vector<T>& roots) {
18 std::size_t n = 1 << (roots.size() - 2);
19 std::vector<T> w, w_stack;
20 w.reserve(n);
21 w_stack.reserve(roots.size());
22 w.push_back(roots.front());
23 w_stack.push_back(roots.front());
24 // For any give i, the k-th element in the stack is i with only the highest k 1-bits.
25 // Each element is the product of its lowest bit and other bits, and thus is the product of at most log2(N) roots.
26 // Thus for complex numbers the error in twiddling factors is at most log2(N) eps.
27 for (std::size_t i = 1; i < n; i++) {
28 w_stack.resize(port::popcount(i));
29 w_stack.push_back(w_stack.back() * roots[roots.size() - 1 - port::countr_zero(i)]);
30 w.push_back(w_stack.back());
31 }
32 return w;
33}
34
35} // namespace impl
36
54template <typename T>
56
63template <>
64struct radix2_fft_root<MMInt<998244353>> {
65 using mint = MMInt<998244353>;
66 static mint get(int n) { return pow(mint(3), 119 << (23 - n)); }
67};
68
77template <>
78struct radix2_fft_root<MMInt64<4512606826625236993>> {
79 using mint = MMInt64<4512606826625236993>;
80 static mint get(int n) { return pow(mint(7), 501ull << (53 - n)); }
81};
82
88template <typename Float>
89struct radix2_fft_root<std::complex<Float>> {
90 constexpr static std::complex<Float> get(size_t n) {
91 constexpr long double tau = atanl(1) * 8; // In C++20, std::pi_v<Float> * 2
92 return std::polar<Float>(1, tau / (1ull << n));
93 }
94};
95
107template <typename RandomIt>
108void fft_inplace(RandomIt first, RandomIt last) {
109 using usize = std::size_t;
110 const usize n = std::distance(first, last);
111 assert(port::has_single_bit(n));
112 int log2n = port::countr_zero(n);
113 using T = typename std::iterator_traits<RandomIt>::value_type;
114 std::vector<T> roots(log2n + 1);
115 for (int i = 0; i <= log2n; i++) {
116 roots[i] = radix2_fft_root<T>::get(i);
117 }
118 for (int stage = log2n - 1; stage >= 0; stage--) {
119 usize len = usize(1) << stage;
120 std::vector<T> w = impl::twiddling_factors(roots);
121 for (usize block = 0; block < n; block += len * 2) {
122 for (usize offset = 0; offset < len; offset++) {
123 usize i = block + offset, j = i + len;
124 T tmp = (first[i] - first[j]) * w[offset];
125 first[i] += first[j];
126 first[j] = tmp;
127 }
128 }
129 roots.pop_back();
130 }
131}
132
145template <typename RandomIt>
146void ifft_inplace(RandomIt first, RandomIt last) {
147 using usize = std::size_t;
148 const usize n = std::distance(first, last);
149 assert(port::has_single_bit(n));
150 int log2n = port::countr_zero(n);
151 using T = typename std::iterator_traits<RandomIt>::value_type;
152 T one = radix2_fft_root<T>::get(0);
153 std::vector<T> roots{one};
154 for (int stage = 0; stage < log2n; stage++) {
155 usize len = usize(1) << stage;
156 roots.push_back(one / radix2_fft_root<T>::get(stage + 1));
157 std::vector<T> w = impl::twiddling_factors(roots);
158 for (usize block = 0; block < n; block += len * 2) {
159 for (usize offset = 0; offset < len; offset++) {
160 usize i = block + offset, j = i + len;
161 first[j] *= w[offset];
162 T tmp = first[i] - first[j];
163 first[i] += first[j];
164 first[j] = tmp;
165 }
166 }
167 }
168 T half = one / (one + one);
169 T n_inv = pow(half, log2n);
170 for (usize i = 0; i < n; i++) {
171 first[i] *= n_inv;
172 }
173}
174
184template <typename T>
185void fft_inplace(std::vector<T>& a) {
186 fft_inplace(a.begin(), a.end());
187}
188
199template <typename T>
200void ifft_inplace(std::vector<T>& a) {
201 ifft_inplace(a.begin(), a.end());
202}
203
204} // namespace cplib
void ifft_inplace(RandomIt first, RandomIt last)
In-place inverse fast Fourier transform (IFFT).
Definition: fft.hpp:146
void fft_inplace(RandomIt first, RandomIt last)
In-place fast Fourier transform (FFT).
Definition: fft.hpp:108
constexpr T pow(T base, uint64_t exp, Op &&op={})
A generic exponetiation by squaring function.
Definition: pow.hpp:14
-th root of unity for radix-2 FFT.
Definition: fft.hpp:55