cai_lw's competitive programming library
 
Loading...
Searching...
No Matches
prime.hpp
1#pragma once
2
3#include "cplib/num/gcd.hpp"
4#include "cplib/num/mmint.hpp"
5#include "cplib/num/pow.hpp"
6#include "cplib/port/bit.hpp"
7
8namespace cplib {
9
10namespace impl {
11
12template <typename ModInt>
13typename ModInt::int_type miller_rabin(typename ModInt::int_type a_, typename ModInt::int_type d, int r) {
14 const ModInt a(a_), one(1), minus_one(-1);
15 ModInt x = pow(a, d);
16 if (x == one || x == minus_one) {
17 return 1;
18 }
19 while (r--) {
20 ModInt y = x * x;
21 if (y == one) {
22 return gcd(x.val() - 1, ModInt::mod());
23 }
24 x = y;
25 if (x == minus_one) {
26 return 1;
27 }
28 }
29 return 0;
30}
31
32template <typename ModInt>
33uint32_t miller_rabin_32() {
34 const uint32_t n = ModInt::mod();
35 int r = port::countr_zero(n - 1);
36 uint32_t d = (n - 1) >> r;
37 for (uint32_t a : {2, 7, 61}) {
38 uint32_t ret = miller_rabin<ModInt>(a, d, r);
39 if (ret != 1) {
40 return ret;
41 }
42 }
43 return 1;
44}
45
46template <typename ModInt>
47uint64_t miller_rabin_64() {
48 const uint64_t n = ModInt::mod();
49 int r = port::countr_zero(n - 1);
50 uint64_t d = (n - 1) >> r;
51 for (uint64_t a : {2, 325, 9375, 28178, 450775, 9780504, 1795265022}) {
52 uint64_t ret = miller_rabin<ModInt>(a, d, r);
53 if (ret != 1) {
54 return ret;
55 }
56 }
57 return 1;
58}
59
60constexpr uint64_t small_primes_mask() {
61 uint64_t mask = 0;
62 for (int i : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61}) {
63 mask |= 1ull << i;
64 }
65 return mask;
66}
67
68constexpr bool is_prime_lt64(int n) { return (1ull << n) & small_primes_mask(); }
69
70static uint32_t prime_or_factor_32(uint32_t n) {
71 if (n < 64) {
72 return is_prime_lt64(n);
73 }
74 if (n % 2 == 0) {
75 return 2;
76 }
77 constexpr uint32_t small_prod = 3u * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 29;
78 uint32_t g = gcd(n, small_prod);
79 if (g != 1) {
80 return g != n ? g : 0;
81 }
82 return mmint_by_modulus([](auto mint) { return miller_rabin_32<decltype(mint)>(); }, n);
83}
84
85static uint64_t prime_or_factor_64(uint64_t n) {
86 if (n < 64) {
87 return is_prime_lt64(n);
88 }
89 if (n % 2 == 0) {
90 return 2;
91 }
92 constexpr uint64_t small_prod = 3ull * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 29 * 31 * 37 * 41 * 43 * 47 * 53;
93 uint64_t g = gcd(n, small_prod);
94 if (g != 1) {
95 return g != n ? g : 0;
96 }
97 return mmint_by_modulus([](auto mint) { return miller_rabin_64<decltype(mint)>(); }, n);
98}
99
100} // namespace impl
101
122template <typename T, std::enable_if_t<std::is_unsigned_v<T>>* = nullptr>
124 if (n < (1ull << 32)) {
125 return impl::prime_or_factor_32(n);
126 } else {
127 return impl::prime_or_factor_64(n);
128 }
129}
130
138template <typename T, std::enable_if_t<std::is_unsigned_v<T>>* = nullptr>
139bool is_prime(T n) {
140 return prime_or_factor(n) == 1;
141}
142
143} // namespace cplib
constexpr T pow(T base, uint64_t exp, Op &&op={})
A generic exponetiation by squaring function.
Definition: pow.hpp:14
constexpr T gcd(T x, T y)
Greatest common divisor.
Definition: gcd.hpp:21
bool is_prime(T n)
Primality test.
Definition: prime.hpp:139
T prime_or_factor(T n)
Primality test and possibly return a non-trivial factor.
Definition: prime.hpp:123