cai_lw's competitive programming library
 
Loading...
Searching...
No Matches
anymod.hpp
1#pragma once
2
3#include <vector>
4
5#include "cplib/conv/conv.hpp"
6
7namespace cplib {
8
9namespace impl {
10
11template <typename In, typename Out>
12std::vector<Out> convolve_modint(const std::vector<In>& a, const std::vector<In>& b) {
13 std::vector<Out> a_modint, b_modint;
14 a_modint.reserve(a.size());
15 for (const In& x : a) {
16 a_modint.emplace_back(x.val());
17 }
18 b_modint.reserve(b.size());
19 for (const In& x : b) {
20 b_modint.emplace_back(x.val());
21 }
22 convolve_inplace2(a_modint, b_modint);
23 return a_modint;
24}
25
26template <typename In, typename Out, typename ModInt1, typename ModInt2>
27std::vector<Out> convolve_with_two_modints(const std::vector<In>& a, const std::vector<In>& b) {
28 std::vector<ModInt1> m1 = convolve_modint<In, ModInt1>(a, b);
29 std::vector<ModInt2> m2 = convolve_modint<In, ModInt2>(a, b);
30 std::vector<Out> ret;
31 ret.reserve(m1.size());
32 ModInt2 p1_inv = ModInt2(ModInt1::mod()).inv();
33 Out p1_out(ModInt1::mod());
34 for (std::size_t i = 0; i < m1.size(); i++) {
35 // r1+k*p1=r2 (mod p2) => k=(r2-r1)*p1^{-1} (mod p2)
36 auto r1 = m1[i].val();
37 auto k = ((m2[i] - ModInt2(r1)) * p1_inv).val();
38 ret.push_back(Out(r1) + Out(k) * p1_out);
39 }
40 return ret;
41}
42
43} // namespace impl
44
45template <>
46struct radix2_fft_root<MMInt64<4242390848983007233>> {
47 using mint = MMInt64<4242390848983007233>;
48 static mint get(int n) { return pow(mint(11), 471ull << (53 - n)); }
49};
50
68template <typename ModInt>
69void convolve_any_modint_inplace(std::vector<ModInt>& a, const std::vector<ModInt>& b) {
70 if (impl::conv_naive_is_efficient(a.size(), b.size())) {
71 impl::conv_naive_inplace(a, b);
72 return;
73 }
74 using mint1 = MMInt64<4512606826625236993>;
75 using mint2 = MMInt64<4242390848983007233>;
76 using u128 = unsigned __int128;
77 u128 max_prod = u128(ModInt::mod() - 1) * (ModInt::mod() - 1);
78 u128 limit = u128(mint1::mod()) * mint2::mod() - 1;
79 assert(max_prod <= limit / std::min(a.size(), b.size()));
80 a = impl::convolve_with_two_modints<ModInt, ModInt, mint1, mint2>(a, b);
81}
82
88template <typename ModInt>
89std::vector<ModInt> convolve_any_modint(const std::vector<ModInt>& a, const std::vector<ModInt>& b) {
90 std::vector<ModInt> a_copy = a;
92 return a_copy;
93}
94
95} // namespace cplib
void convolve_inplace2(std::vector< T > &a, std::vector< T > &b)
In-place convolution where both arrays are modified.
Definition: conv.hpp:68
void convolve_any_modint_inplace(std::vector< ModInt > &a, const std::vector< ModInt > &b)
In-place convolution with arbitrary modulus.
Definition: anymod.hpp:69
std::vector< ModInt > convolve_any_modint(const std::vector< ModInt > &a, const std::vector< ModInt > &b)
Returns the convolution of two arrays modulo an arbitrary integer.
Definition: anymod.hpp:89
constexpr T pow(T base, uint64_t exp, Op &&op={})
A generic exponetiation by squaring function.
Definition: pow.hpp:14
-th root of unity for radix-2 FFT.
Definition: fft.hpp:55