cai_lw's competitive programming library
 
Loading...
Searching...
No Matches
bmint.hpp
1#pragma once
2
3#include <cstdint>
4#include <limits>
5#include <type_traits>
6#include <vector>
7
8#include "cplib/num/gcd.hpp"
9#include "cplib/port/bit.hpp"
10#include "cplib/utils/type.hpp"
11
12namespace cplib {
13
14namespace impl {
15
16template <typename UInt>
17constexpr UInt barrett_mulh(UInt a, UInt b) {
18 return make_double_width_t<UInt>(a) * b >> std::numeric_limits<UInt>::digits;
19}
20
21template <>
22inline constexpr unsigned __int128 barrett_mulh(unsigned __int128 a, unsigned __int128 b) {
23 uint64_t ah = a >> 64, al = a << 64 >> 64, bh = b >> 64, bl = b << 64 >> 64;
24 using u128 = unsigned __int128;
25 u128 l = u128(al) * bl;
26 u128 h = u128(ah) * bh;
27 u128 m = u128(al) * bh + u128(bl) * ah;
28 h += m >> 64;
29 return l + (m << 64) < l ? h + 1 : h;
30}
31
32template <typename UInt>
33class BarrettReduction {
34 public:
35 using int_type = UInt;
36 using int_double_t = make_double_width_t<int_type>;
37
38 constexpr explicit BarrettReduction(int_type mod)
39 : mod_(mod), red_(int_double_t(-1) / mod + port::has_single_bit(mod)) {}
40
41 constexpr int_type mod() const { return mod_; }
42
43 constexpr int_type mul(int_type a, int_type b) const {
44 int_double_t m = int_double_t(a) * b;
45 int_double_t r = m - barrett_mulh(m, red_) * mod_;
46 return r >= mod_ ? r - mod_ : r;
47 }
48
49 constexpr int_type add(int_type a, int_type b) const {
50 int_type r = a - (mod_ - b);
51 return r > a ? r + mod_ : r;
52 }
53
54 constexpr int_type sub(int_type a, int_type b) const {
55 int_type r = a - b;
56 return r > a ? r + mod_ : r;
57 }
58
59 private:
60 int_type mod_;
61 int_double_t red_;
62};
63
64template <typename UInt, UInt Mod, std::enable_if_t<std::is_unsigned_v<UInt>>* = nullptr>
65class StaticBarrettReductionContext {
66 public:
67 using int_type = UInt;
68 using br_type = BarrettReduction<int_type>;
69 static_assert(Mod > 1);
70
71 static constexpr const br_type& barrett_reduction() { return reduction_; }
72
73 private:
74 static constexpr auto reduction_ = br_type(Mod);
75};
76
77template <typename UInt, std::enable_if_t<std::is_unsigned_v<UInt>>* = nullptr>
78class DynamicBarrettReductionContext {
79 public:
80 using int_type = UInt;
81 using br_type = BarrettReduction<int_type>;
82
83 static constexpr const br_type& barrett_reduction() { return reduction_env_.back(); }
84
85 static void push_mod(int_type mod) {
86 assert(mod > 1);
87 reduction_env_.emplace_back(mod);
88 }
89
90 static void pop_mod() { reduction_env_.pop_back(); }
91
92 private:
93 static inline std::vector<br_type> reduction_env_;
94};
95
96} // namespace impl
97
110template <typename Context>
112 struct Guard;
113
114 public:
115 using mint = BarrettModInt;
116 using int_type = typename Context::int_type;
117 using br_type = typename Context::br_type;
118 using int_double_t = typename br_type::int_double_t;
119
120 BarrettModInt() : val_(0) {}
121
128 template <typename T, std::enable_if_t<std::is_integral_v<T> && std::is_signed_v<T>>* = nullptr>
129 explicit BarrettModInt(T x) {
130 auto r = x % impl::make_double_width_t<std::make_signed_t<int_type>>(br().mod());
131 if (r < 0) {
132 r += br().mod();
133 }
134 val_ = r;
135 }
136
138 template <typename T, std::enable_if_t<std::is_unsigned_v<T>>* = nullptr>
139 explicit BarrettModInt(T x) {
140 val_ = x % br().mod();
141 }
142
154 [[nodiscard]] static Guard set_mod_guard(int_type mod) {
155 Context::push_mod(mod);
156 return Guard();
157 }
158
160 int_type val() const { return val_; }
161
167 int_type residue() const { return val_; }
168
170 static constexpr int_type mod() { return br().mod(); }
171
172 mint& operator++() {
173 val_++;
174 if (val_ == br().mod()) {
175 val_ = 0;
176 }
177 return *this;
178 }
179
180 mint operator++(int) {
181 mint ret = *this;
182 ++(*this);
183 return ret;
184 }
185
186 mint operator+() const { return *this; }
187
188 mint operator+(const mint& rhs) const { return from_raw(br().add(val_, rhs.val_)); }
189
190 mint& operator+=(const mint& rhs) { return *this = *this + rhs; }
191
192 mint& operator--() {
193 if (val_ == 0) {
194 val_ = br().mod() - 1;
195 } else {
196 val_--;
197 }
198 return *this;
199 }
200
201 mint operator--(int) {
202 mint ret = *this;
203 --(*this);
204 return ret;
205 }
206
207 mint operator-() const { return from_raw(br().sub(0, val_)); }
208
209 mint operator-(const mint& rhs) const { return from_raw(br().sub(val_, rhs.val_)); }
210
211 mint& operator-=(const mint& rhs) { return *this = *this - rhs; }
212
213 mint operator*(const mint& rhs) const { return from_raw(br().mul(val_, rhs.val_)); }
214
215 mint& operator*=(const mint& rhs) { return *this = *this * rhs; }
216
222 mint inv() const { return from_raw(mod_inverse(val_, br().mod())); }
223
224 mint operator/(const mint& rhs) const { return *this * rhs.inv(); }
225
226 mint& operator/=(const mint& rhs) { return *this *= rhs.inv(); }
227
228 bool operator==(const mint& rhs) const { return residue() == rhs.residue(); }
229
230 bool operator!=(const mint& rhs) const { return !(*this == rhs); }
231
232 private:
233 int_type val_;
234
235 struct Guard {
236 ~Guard() { Context::pop_mod(); }
237 };
238
239 static constexpr mint from_raw(int_type x) {
240 mint ret;
241 ret.val_ = x;
242 return ret;
243 }
244
245 static constexpr const br_type& br() { return Context::barrett_reduction(); }
246};
247
253template <uint32_t Mod>
255
261template <uint64_t Mod>
263
269
275
276} // namespace cplib
Modular integer using Barrett reduction.
Definition: bmint.hpp:111
int_type val() const
Converts back to a plain integer in the range .
Definition: bmint.hpp:160
mint inv() const
Returns the modular inverse.
Definition: bmint.hpp:222
static Guard set_mod_guard(int_type mod)
Set the dynamic modint's modulus.
Definition: bmint.hpp:154
static constexpr int_type mod()
Returns the modulus.
Definition: bmint.hpp:170
int_type residue() const
Returns a number that is the same for the same residue class modulo the modulus.
Definition: bmint.hpp:167
BarrettModInt(T x)
Converts a plain integer to a Barrett modular integer.
Definition: bmint.hpp:129
constexpr T mod_inverse(T x, T m)
Modular inverse.
Definition: gcd.hpp:102