cai_lw's competitive programming library
 
Loading...
Searching...
No Matches
mmint.hpp
1#pragma once
2
3#include <cstdint>
4#include <limits>
5#include <type_traits>
6#include <vector>
7
8#include "cplib/num/gcd.hpp"
9#include "cplib/utils/type.hpp"
10
11namespace cplib {
12
13namespace impl {
14
15template <typename UInt>
16class MontgomeryReductionBase {
17 public:
18 using int_type = UInt;
19 using int_double_t = make_double_width_t<int_type>;
20
21 constexpr explicit MontgomeryReductionBase(int_type mod)
22 : mod_(mod),
23 mod_neg_inv_(inv_base(-mod)),
24 mbase_((int_double_t(1) << base_width_) % mod),
25 mbase2_(int_double_t(mbase_) * mbase_ % mod),
26 mbase3_(int_double_t(mbase2_) * mbase_ % mod) {}
27
28 // N
29 constexpr int_type mod() const { return mod_; }
30
31 // R%N
32 constexpr int_type mbase() const { return mbase_; }
33
34 // R^2%N
35 constexpr int_type mbase2() const { return mbase2_; }
36
37 // R^3%N
38 constexpr int_type mbase3() const { return mbase3_; }
39
40 protected:
41 int_type mod_, mod_neg_inv_, mbase_, mbase2_, mbase3_;
42 static constexpr int base_width_ = std::numeric_limits<int_type>::digits;
43
44 private:
45 // Modular inverse modulo 2^2^k by Hensel lifting.
46 static constexpr int_type inv_base(int_type x) {
47 int_type y = 1;
48 for (int i = 1; i < base_width_; i *= 2) {
49 y *= int_type(2) - x * y;
50 }
51 return y;
52 }
53};
54
55// Value in [0,2N), only works if N<R/4
56template <typename UInt>
57class MontgomeryReductionLoose : public MontgomeryReductionBase<UInt> {
58 public:
59 using Base = MontgomeryReductionBase<UInt>;
60 using Base::Base;
61 using typename Base::int_double_t;
62 using typename Base::int_type;
63
64 // a*b*(R^-1)%N. Result <2N if input <2N.
65 constexpr int_type mul(int_type a, int_type b) const {
66 int_double_t t = int_double_t(a) * b;
67 int_type m = int_type(t) * this->mod_neg_inv_;
68 int_type r = (t + int_double_t(m) * this->mod_) >> this->base_width_;
69 return r;
70 }
71
72 // (a+b)%N. Result <2N if input <2N.
73 constexpr int_type add(int_type a, int_type b) const {
74 int_type r = a + b;
75 return r >= this->mod_ * 2 ? r - this->mod_ * 2 : r;
76 }
77
78 // (a-b)%N. Result <2N if input <2N.
79 constexpr int_type sub(int_type a, int_type b) const {
80 int_type r = a - b;
81 return r > a ? r + this->mod_ * 2 : r;
82 }
83
84 // Shrink value from [0,2N) into [0,N)
85 constexpr int_type shrink(int_type x) const { return x >= this->mod_ ? x - this->mod_ : x; }
86};
87
88// Value in [0,N), works for all N<R
89template <typename UInt>
90class MontgomeryReductionStrict : public MontgomeryReductionBase<UInt> {
91 public:
92 using Base = MontgomeryReductionBase<UInt>;
93 using Base::Base;
94 using typename Base::int_double_t;
95 using typename Base::int_type;
96
97 // We use the same technique for the following functions where the result R is in [0,2N) where 2N may overflow.
98 // If the last step is R=X+Y where X,Y are in [0,N), we make it R'=X-(N-Y) so that the result is in [-N,N)
99 // The "true" value of R' is negative iff the last subtraction underflows, iff R'>X, and that's exactly when we
100 // add N to R' to bring the value back to [0,N).
101
102 // a*b*(R^-1)%N
103 constexpr int_type mul(int_type a, int_type b) const {
104 int_double_t t = int_double_t(a) * b;
105 int_type m = int_type(t) * this->mod_neg_inv_;
106 int_double_t s = t - int_double_t(-m) * this->mod_;
107 int_type r = s >> this->base_width_;
108 return s > t ? r + this->mod_ : r;
109 }
110
111 // (a+b)%N
112 constexpr int_type add(int_type a, int_type b) const {
113 int_type r = a - (this->mod_ - b);
114 return r > a ? r + this->mod_ : r;
115 }
116
117 // (a-b)%N
118 constexpr int_type sub(int_type a, int_type b) const {
119 int_type r = a - b;
120 return r > a ? r + this->mod_ : r;
121 }
122
123 // No-op
124 constexpr int_type shrink(int_type x) const { return x; }
125};
126
127template <typename UInt, UInt Mod, std::enable_if_t<std::is_unsigned_v<UInt>>* = nullptr>
128class StaticMontgomeryReductionContext {
129 public:
130 using int_type = UInt;
131 using mr_type =
132 std::conditional_t<Mod <= std::numeric_limits<int_type>::max() / 4, impl::MontgomeryReductionLoose<int_type>,
133 impl::MontgomeryReductionStrict<int_type>>;
134 static_assert(Mod % 2 == 1);
135
136 static constexpr const mr_type& montgomery_reduction() { return reduction_; }
137
138 private:
139 static constexpr auto reduction_ = mr_type(Mod);
140};
141
142template <typename UInt, bool Loose, std::enable_if_t<std::is_unsigned_v<UInt>>* = nullptr>
143class DynamicMontgomeryReductionContext {
144 public:
145 using int_type = UInt;
146 using mr_type =
147 std::conditional_t<Loose, impl::MontgomeryReductionLoose<int_type>, impl::MontgomeryReductionStrict<int_type>>;
148
149 static constexpr const mr_type& montgomery_reduction() { return reduction_env_.back(); }
150
151 static void push_mod(int_type mod) {
152 assert(mod % 2 == 1);
153 if constexpr (Loose) {
154 assert(mod <= std::numeric_limits<int_type>::max() / 4);
155 }
156 reduction_env_.emplace_back(mod);
157 }
158
159 static void pop_mod() { reduction_env_.pop_back(); }
160
161 private:
162 static inline std::vector<mr_type> reduction_env_;
163};
164
165} // namespace impl
166
190template <typename Context>
192 struct Guard;
193
194 public:
195 using mint = MontgomeryModInt;
196 using int_type = typename Context::int_type;
197 using mr_type = typename Context::mr_type;
198 using int_double_t = typename mr_type::int_double_t;
199
200 MontgomeryModInt() : val_(0) {}
201
208 template <typename T, std::enable_if_t<std::is_integral_v<T> && std::is_signed_v<T>>* = nullptr>
209 explicit MontgomeryModInt(T x) {
210 auto r = x % impl::make_double_width_t<std::make_signed_t<int_type>>(mr().mod());
211 if (r < 0) {
212 r += mr().mod();
213 }
214 val_ = mr().mul(mr().mbase2(), r);
215 }
216
218 template <typename T, std::enable_if_t<std::is_unsigned_v<T>>* = nullptr>
219 explicit MontgomeryModInt(T x) {
220 val_ = mr().mul(mr().mbase2(), x % mr().mod());
221 }
222
234 [[nodiscard]] static Guard set_mod_guard(int_type mod) {
235 Context::push_mod(mod);
236 return Guard();
237 }
238
240 int_type val() const { return mr().shrink(mr().mul(1, val_)); }
241
247 int_type residue() const { return mr().shrink(val_); }
248
250 static constexpr int_type mod() { return mr().mod(); }
251
252 mint& operator++() {
253 val_ = mr().add(val_, mr().mbase());
254 return *this;
255 }
256
257 mint operator++(int) {
258 mint ret = *this;
259 ++(*this);
260 return ret;
261 }
262
263 mint operator+() const { return *this; }
264
265 mint operator+(const mint& rhs) const { return from_raw(mr().add(val_, rhs.val_)); }
266
267 mint& operator+=(const mint& rhs) { return *this = *this + rhs; }
268
269 mint& operator--() {
270 val_ = mr().sub(val_, mr().mbase());
271 return *this;
272 }
273
274 mint operator--(int) {
275 mint ret = *this;
276 --(*this);
277 return ret;
278 }
279
280 mint operator-() const { return from_raw(mr().sub(0, val_)); }
281
282 mint operator-(const mint& rhs) const { return from_raw(mr().sub(val_, rhs.val_)); }
283
284 mint& operator-=(const mint& rhs) { return *this = *this - rhs; }
285
286 mint operator*(const mint& rhs) const { return from_raw(mr().mul(val_, rhs.val_)); }
287
288 mint& operator*=(const mint& rhs) { return *this = *this * rhs; }
289
295 mint inv() const { return from_raw(mr().mul(mr().mbase3(), mod_inverse(val_, mr().mod()))); }
296
297 mint operator/(const mint& rhs) const { return *this * rhs.inv(); }
298
299 mint& operator/=(const mint& rhs) { return *this *= rhs.inv(); }
300
301 bool operator==(const mint& rhs) const { return residue() == rhs.residue(); }
302
303 bool operator!=(const mint& rhs) const { return !(*this == rhs); }
304
305 private:
306 int_type val_;
307
308 struct Guard {
309 ~Guard() { Context::pop_mod(); }
310 };
311
312 static constexpr mint from_raw(int_type x) {
313 mint ret;
314 ret.val_ = x;
315 return ret;
316 }
317
318 static constexpr const mr_type& mr() { return Context::montgomery_reduction(); }
319};
320
326template <uint32_t Mod>
328
334template <uint64_t Mod>
336
342
348
354
360
370template <typename Visitor, typename UInt, typename... Args>
371auto mmint_by_modulus(Visitor&& visitor, UInt mod, Args&&... args) {
372 if (mod <= std::numeric_limits<UInt>::max() / 4) {
374 auto _guard = mint::set_mod_guard(mod);
375 if constexpr (sizeof...(args) == 0) {
376 return visitor(mint());
377 } else {
378 return visitor(mint(args)...);
379 }
380 } else {
382 auto _guard = mint::set_mod_guard(mod);
383 if constexpr (sizeof...(args) == 0) {
384 return visitor(mint());
385 } else {
386 return visitor(mint(args)...);
387 }
388 }
389};
390
391} // namespace cplib
Modular integer stored in Montgomery form.
Definition: mmint.hpp:191
auto mmint_by_modulus(Visitor &&visitor, UInt mod, Args &&... args)
Given a modulus, calls a callable (visitor) with a dynamically selected fastest MontgomeryModInt type...
Definition: mmint.hpp:371
static constexpr int_type mod()
Returns the modulus.
Definition: mmint.hpp:250
int_type val() const
Converts back to a plain integer in the range .
Definition: mmint.hpp:240
MontgomeryModInt(T x)
Converts a plain integer to a Montgomery modular integer.
Definition: mmint.hpp:209
mint inv() const
Returns the modular inverse.
Definition: mmint.hpp:295
static Guard set_mod_guard(int_type mod)
Set the dynamic modint's modulus.
Definition: mmint.hpp:234
int_type residue() const
Returns a number that is the same for the same residue class modulo the modulus.
Definition: mmint.hpp:247
constexpr T mod_inverse(T x, T m)
Modular inverse.
Definition: gcd.hpp:102