cai_lw's competitive programming library
 
Loading...
Searching...
No Matches
conv.hpp
1#pragma once
2
3#include <vector>
4
5#include "cplib/conv/fft.hpp"
6#include "cplib/port/bit.hpp"
7
8namespace cplib {
9
10namespace impl {
11
12inline bool conv_naive_is_efficient(std::size_t n, std::size_t m) { return std::min(n, m) <= 32; }
13
14template <typename T>
15void conv_naive_inplace(std::vector<T>& a, const std::vector<T>& b) {
16 if (a.empty() || b.empty()) {
17 a.clear();
18 return;
19 }
20 using usize = std::size_t;
21 usize a_deg = a.size() - 1, b_deg = b.size() - 1;
22 a.resize(a_deg + b_deg + 1, T(0));
23 for (usize i = a_deg + b_deg; i > 0; i--) {
24 if (i <= a_deg) {
25 a[i] *= b[0];
26 }
27 usize j_low = i <= a_deg ? 1 : i - a_deg;
28 usize j_high = i <= b_deg ? i : b_deg;
29 for (usize j = j_low; j <= j_high; j++) {
30 a[i] += a[i - j] * b[j];
31 }
32 }
33 a[0] *= b[0];
34}
35
36template <typename T>
37void conv_fft_inplace2(std::vector<T>& a, std::vector<T>& b) {
38 using usize = std::size_t;
39 usize out_size = a.size() + b.size() - 1;
40 usize padded_out_size = port::bit_ceil(out_size);
41 a.resize(padded_out_size, T(0));
42 b.resize(padded_out_size, T(0));
43 fft_inplace(a);
44 fft_inplace(b);
45 for (size_t i = 0; i < padded_out_size; i++) {
46 a[i] *= b[i];
47 }
48 ifft_inplace(a);
49 a.resize(out_size);
50}
51
52} // namespace impl
53
67template <typename T>
68void convolve_inplace2(std::vector<T>& a, std::vector<T>& b) {
69 if (impl::conv_naive_is_efficient(a.size(), b.size())) {
70 impl::conv_naive_inplace(a, b);
71 } else {
72 impl::conv_fft_inplace2(a, b);
73 }
74}
75
84template <typename T>
85void convolve_inplace(std::vector<T>& a, const std::vector<T>& b) {
86 if (impl::conv_naive_is_efficient(a.size(), b.size())) {
87 impl::conv_naive_inplace(a, b);
88 } else {
89 auto b_copy = b;
90 impl::conv_fft_inplace2(a, b_copy);
91 }
92}
93
99template <typename T>
100std::vector<T> convolve(const std::vector<T>& a, const std::vector<T>& b) {
101 auto a_copy = a;
102 convolve_inplace(a_copy, b);
103 return a_copy;
104}
105
106} // namespace cplib
void convolve_inplace2(std::vector< T > &a, std::vector< T > &b)
In-place convolution where both arrays are modified.
Definition: conv.hpp:68
void ifft_inplace(RandomIt first, RandomIt last)
In-place inverse fast Fourier transform (IFFT).
Definition: fft.hpp:146
void convolve_inplace(std::vector< T > &a, const std::vector< T > &b)
In-place convolution where one array is modified.
Definition: conv.hpp:85
std::vector< T > convolve(const std::vector< T > &a, const std::vector< T > &b)
Returns the convolution of two arrays.
Definition: conv.hpp:100
void fft_inplace(RandomIt first, RandomIt last)
In-place fast Fourier transform (FFT).
Definition: fft.hpp:108