8#include "cplib/num/gcd.hpp"
9#include "cplib/port/bit.hpp"
10#include "cplib/utils/type.hpp"
16template <
typename UInt>
17constexpr UInt barrett_mulh(UInt a, UInt b) {
18 return make_double_width_t<UInt>(a) * b >> std::numeric_limits<UInt>::digits;
22inline constexpr unsigned __int128 barrett_mulh(
unsigned __int128 a,
unsigned __int128 b) {
23 uint64_t ah = a >> 64, al = a << 64 >> 64, bh = b >> 64, bl = b << 64 >> 64;
24 using u128 =
unsigned __int128;
25 u128 l = u128(al) * bl;
26 u128 h = u128(ah) * bh;
27 u128 m = u128(al) * bh + u128(bl) * ah;
29 return l + (m << 64) < l ? h + 1 : h;
32template <
typename UInt>
33class BarrettReduction {
35 using int_type = UInt;
36 using int_double_t = make_double_width_t<int_type>;
38 constexpr explicit BarrettReduction(int_type mod)
39 : mod_(mod), red_(int_double_t(-1) / mod + port::has_single_bit(mod)) {}
41 constexpr int_type mod()
const {
return mod_; }
43 constexpr int_type mul(int_type a, int_type b)
const {
44 int_double_t m = int_double_t(a) * b;
45 int_double_t r = m - barrett_mulh(m, red_) * mod_;
46 return r >= mod_ ? r - mod_ : r;
49 constexpr int_type add(int_type a, int_type b)
const {
50 int_type r = a - (mod_ - b);
51 return r > a ? r + mod_ : r;
54 constexpr int_type sub(int_type a, int_type b)
const {
56 return r > a ? r + mod_ : r;
64template <
typename UInt, UInt Mod, std::enable_if_t<std::is_
unsigned_v<UInt>>* =
nullptr>
65class StaticBarrettReductionContext {
67 using int_type = UInt;
68 using br_type = BarrettReduction<int_type>;
69 static_assert(Mod > 1);
71 static constexpr const br_type& barrett_reduction() {
return reduction_; }
74 static constexpr auto reduction_ = br_type(Mod);
77template <
typename UInt, std::enable_if_t<std::is_
unsigned_v<UInt>>* =
nullptr>
78class DynamicBarrettReductionContext {
80 using int_type = UInt;
81 using br_type = BarrettReduction<int_type>;
83 static constexpr const br_type& barrett_reduction() {
return reduction_env_.back(); }
85 static void push_mod(int_type mod) {
87 reduction_env_.emplace_back(mod);
90 static void pop_mod() { reduction_env_.pop_back(); }
93 static inline std::vector<br_type> reduction_env_;
110template <
typename Context>
116 using int_type =
typename Context::int_type;
117 using br_type =
typename Context::br_type;
118 using int_double_t =
typename br_type::int_double_t;
128 template <
typename T, std::enable_if_t<std::is_
integral_v<T> && std::is_
signed_v<T>>* =
nullptr>
130 auto r = x % impl::make_double_width_t<std::make_signed_t<int_type>>(br().mod());
138 template <
typename T, std::enable_if_t<std::is_
unsigned_v<T>>* =
nullptr>
140 val_ = x % br().mod();
155 Context::push_mod(
mod);
160 int_type
val()
const {
return val_; }
170 static constexpr int_type
mod() {
return br().mod(); }
174 if (val_ == br().
mod()) {
180 mint operator++(
int) {
186 mint operator+()
const {
return *
this; }
188 mint operator+(
const mint& rhs)
const {
return from_raw(br().add(val_, rhs.val_)); }
190 mint& operator+=(
const mint& rhs) {
return *
this = *
this + rhs; }
194 val_ = br().mod() - 1;
201 mint operator--(
int) {
207 mint operator-()
const {
return from_raw(br().sub(0, val_)); }
209 mint operator-(
const mint& rhs)
const {
return from_raw(br().sub(val_, rhs.val_)); }
211 mint& operator-=(
const mint& rhs) {
return *
this = *
this - rhs; }
213 mint operator*(
const mint& rhs)
const {
return from_raw(br().mul(val_, rhs.val_)); }
215 mint& operator*=(
const mint& rhs) {
return *
this = *
this * rhs; }
224 mint operator/(
const mint& rhs)
const {
return *
this * rhs.inv(); }
226 mint& operator/=(
const mint& rhs) {
return *
this *= rhs.inv(); }
228 bool operator==(
const mint& rhs)
const {
return residue() == rhs.residue(); }
230 bool operator!=(
const mint& rhs)
const {
return !(*
this == rhs); }
236 ~Guard() { Context::pop_mod(); }
239 static constexpr mint from_raw(int_type x) {
245 static constexpr const br_type& br() {
return Context::barrett_reduction(); }
253template <u
int32_t Mod>
261template <u
int64_t Mod>
Modular integer using Barrett reduction.
Definition: bmint.hpp:111
int_type val() const
Converts back to a plain integer in the range .
Definition: bmint.hpp:160
mint inv() const
Returns the modular inverse.
Definition: bmint.hpp:222
static Guard set_mod_guard(int_type mod)
Set the dynamic modint's modulus.
Definition: bmint.hpp:154
static constexpr int_type mod()
Returns the modulus.
Definition: bmint.hpp:170
int_type residue() const
Returns a number that is the same for the same residue class modulo the modulus.
Definition: bmint.hpp:167
BarrettModInt(T x)
Converts a plain integer to a Barrett modular integer.
Definition: bmint.hpp:129
constexpr T mod_inverse(T x, T m)
Modular inverse.
Definition: gcd.hpp:102