8#include "cplib/num/gcd.hpp"
9#include "cplib/utils/type.hpp"
15template <
typename UInt>
16class MontgomeryReductionBase {
18 using int_type = UInt;
19 using int_double_t = make_double_width_t<int_type>;
21 constexpr explicit MontgomeryReductionBase(int_type mod)
23 mod_neg_inv_(inv_base(-mod)),
24 mbase_((int_double_t(1) << base_width_) % mod),
25 mbase2_(int_double_t(mbase_) * mbase_ % mod),
26 mbase3_(int_double_t(mbase2_) * mbase_ % mod) {}
29 constexpr int_type mod()
const {
return mod_; }
32 constexpr int_type mbase()
const {
return mbase_; }
35 constexpr int_type mbase2()
const {
return mbase2_; }
38 constexpr int_type mbase3()
const {
return mbase3_; }
41 int_type mod_, mod_neg_inv_, mbase_, mbase2_, mbase3_;
42 static constexpr int base_width_ = std::numeric_limits<int_type>::digits;
46 static constexpr int_type inv_base(int_type x) {
48 for (
int i = 1; i < base_width_; i *= 2) {
49 y *= int_type(2) - x * y;
56template <
typename UInt>
57class MontgomeryReductionLoose :
public MontgomeryReductionBase<UInt> {
59 using Base = MontgomeryReductionBase<UInt>;
61 using typename Base::int_double_t;
62 using typename Base::int_type;
65 constexpr int_type mul(int_type a, int_type b)
const {
66 int_double_t t = int_double_t(a) * b;
67 int_type m = int_type(t) * this->mod_neg_inv_;
68 int_type r = (t + int_double_t(m) * this->mod_) >> this->base_width_;
73 constexpr int_type add(int_type a, int_type b)
const {
75 return r >= this->mod_ * 2 ? r - this->mod_ * 2 : r;
79 constexpr int_type sub(int_type a, int_type b)
const {
81 return r > a ? r + this->mod_ * 2 : r;
85 constexpr int_type shrink(int_type x)
const {
return x >= this->mod_ ? x - this->mod_ : x; }
89template <
typename UInt>
90class MontgomeryReductionStrict :
public MontgomeryReductionBase<UInt> {
92 using Base = MontgomeryReductionBase<UInt>;
94 using typename Base::int_double_t;
95 using typename Base::int_type;
103 constexpr int_type mul(int_type a, int_type b)
const {
104 int_double_t t = int_double_t(a) * b;
105 int_type m = int_type(t) * this->mod_neg_inv_;
106 int_double_t s = t - int_double_t(-m) * this->mod_;
107 int_type r = s >> this->base_width_;
108 return s > t ? r + this->mod_ : r;
112 constexpr int_type add(int_type a, int_type b)
const {
113 int_type r = a - (this->mod_ - b);
114 return r > a ? r + this->mod_ : r;
118 constexpr int_type sub(int_type a, int_type b)
const {
120 return r > a ? r + this->mod_ : r;
124 constexpr int_type shrink(int_type x)
const {
return x; }
127template <
typename UInt, UInt Mod, std::enable_if_t<std::is_
unsigned_v<UInt>>* =
nullptr>
128class StaticMontgomeryReductionContext {
130 using int_type = UInt;
132 std::conditional_t<Mod <= std::numeric_limits<int_type>::max() / 4, impl::MontgomeryReductionLoose<int_type>,
133 impl::MontgomeryReductionStrict<int_type>>;
134 static_assert(Mod % 2 == 1);
136 static constexpr const mr_type& montgomery_reduction() {
return reduction_; }
139 static constexpr auto reduction_ = mr_type(Mod);
142template <
typename UInt,
bool Loose, std::enable_if_t<std::is_
unsigned_v<UInt>>* =
nullptr>
143class DynamicMontgomeryReductionContext {
145 using int_type = UInt;
147 std::conditional_t<Loose, impl::MontgomeryReductionLoose<int_type>, impl::MontgomeryReductionStrict<int_type>>;
149 static constexpr const mr_type& montgomery_reduction() {
return reduction_env_.back(); }
151 static void push_mod(int_type mod) {
152 assert(mod % 2 == 1);
153 if constexpr (Loose) {
154 assert(mod <= std::numeric_limits<int_type>::max() / 4);
156 reduction_env_.emplace_back(mod);
159 static void pop_mod() { reduction_env_.pop_back(); }
162 static inline std::vector<mr_type> reduction_env_;
190template <
typename Context>
196 using int_type =
typename Context::int_type;
197 using mr_type =
typename Context::mr_type;
198 using int_double_t =
typename mr_type::int_double_t;
208 template <
typename T, std::enable_if_t<std::is_
integral_v<T> && std::is_
signed_v<T>>* =
nullptr>
210 auto r = x % impl::make_double_width_t<std::make_signed_t<int_type>>(mr().mod());
214 val_ = mr().mul(mr().mbase2(), r);
218 template <
typename T, std::enable_if_t<std::is_
unsigned_v<T>>* =
nullptr>
220 val_ = mr().mul(mr().mbase2(), x % mr().
mod());
235 Context::push_mod(
mod);
240 int_type
val()
const {
return mr().shrink(mr().mul(1, val_)); }
247 int_type
residue()
const {
return mr().shrink(val_); }
250 static constexpr int_type
mod() {
return mr().mod(); }
253 val_ = mr().add(val_, mr().mbase());
257 mint operator++(
int) {
263 mint operator+()
const {
return *
this; }
265 mint operator+(
const mint& rhs)
const {
return from_raw(mr().add(val_, rhs.val_)); }
267 mint& operator+=(
const mint& rhs) {
return *
this = *
this + rhs; }
270 val_ = mr().sub(val_, mr().mbase());
274 mint operator--(
int) {
280 mint operator-()
const {
return from_raw(mr().sub(0, val_)); }
282 mint operator-(
const mint& rhs)
const {
return from_raw(mr().sub(val_, rhs.val_)); }
284 mint& operator-=(
const mint& rhs) {
return *
this = *
this - rhs; }
286 mint operator*(
const mint& rhs)
const {
return from_raw(mr().mul(val_, rhs.val_)); }
288 mint& operator*=(
const mint& rhs) {
return *
this = *
this * rhs; }
297 mint operator/(
const mint& rhs)
const {
return *
this * rhs.inv(); }
299 mint& operator/=(
const mint& rhs) {
return *
this *= rhs.inv(); }
301 bool operator==(
const mint& rhs)
const {
return residue() == rhs.residue(); }
303 bool operator!=(
const mint& rhs)
const {
return !(*
this == rhs); }
309 ~Guard() { Context::pop_mod(); }
312 static constexpr mint from_raw(int_type x) {
318 static constexpr const mr_type& mr() {
return Context::montgomery_reduction(); }
326template <u
int32_t Mod>
334template <u
int64_t Mod>
370template <
typename Visitor,
typename UInt,
typename... Args>
372 if (mod <= std::numeric_limits<UInt>::max() / 4) {
374 auto _guard = mint::set_mod_guard(mod);
375 if constexpr (
sizeof...(args) == 0) {
376 return visitor(
mint());
378 return visitor(
mint(args)...);
382 auto _guard = mint::set_mod_guard(mod);
383 if constexpr (
sizeof...(args) == 0) {
384 return visitor(
mint());
386 return visitor(
mint(args)...);
Modular integer stored in Montgomery form.
Definition: mmint.hpp:191
auto mmint_by_modulus(Visitor &&visitor, UInt mod, Args &&... args)
Given a modulus, calls a callable (visitor) with a dynamically selected fastest MontgomeryModInt type...
Definition: mmint.hpp:371
static constexpr int_type mod()
Returns the modulus.
Definition: mmint.hpp:250
int_type val() const
Converts back to a plain integer in the range .
Definition: mmint.hpp:240
MontgomeryModInt(T x)
Converts a plain integer to a Montgomery modular integer.
Definition: mmint.hpp:209
mint inv() const
Returns the modular inverse.
Definition: mmint.hpp:295
static Guard set_mod_guard(int_type mod)
Set the dynamic modint's modulus.
Definition: mmint.hpp:234
int_type residue() const
Returns a number that is the same for the same residue class modulo the modulus.
Definition: mmint.hpp:247
constexpr T mod_inverse(T x, T m)
Modular inverse.
Definition: gcd.hpp:102