diff --git a/src/modint.rs b/src/modint.rs index 022263c..38ffde6 100644 --- a/src/modint.rs +++ b/src/modint.rs @@ -793,20 +793,21 @@ trait InternalImplementations: ModIntBase { #[inline] fn add_impl(lhs: Self, rhs: Self) -> Self { let modulus = Self::modulus(); - let mut val = lhs.val() + rhs.val(); - if val >= modulus { - val -= modulus; - } + let v = u64::from(lhs.val()) + u64::from(rhs.val()); + let val = match v.overflowing_sub(u64::from(modulus)) { + (_, true) => v as u32, + (w, false) => w as u32, + }; Self::raw(val) } #[inline] fn sub_impl(lhs: Self, rhs: Self) -> Self { let modulus = Self::modulus(); - let mut val = lhs.val().wrapping_sub(rhs.val()); - if val >= modulus { - val = val.wrapping_add(modulus) - } + let val = match lhs.val().overflowing_sub(rhs.val()) { + (v, true) => v.wrapping_add(modulus), + (v, false) => v, + }; Self::raw(val) } @@ -1050,6 +1051,8 @@ impl_folding! { #[cfg(test)] mod tests { + #![allow(clippy::unreadable_literal)] + use crate::modint::ModInt; use crate::modint::ModInt1000000007; #[test] @@ -1157,4 +1160,29 @@ mod tests { c /= b; assert_eq!(expected, c); } + + // test `2^31 < modulus < 2^32` case + // https://github.com/rust-lang-ja/ac-library-rs/issues/111 + #[test] + fn dynamic_modint_m32() { + let m = 3221225471; + ModInt::set_modulus(m); + let f = ModInt::new::; + assert_eq!(f(1398188832) + f(3184083880), f(1361047241)); + assert_eq!(f(3013899062) + f(2238406135), f(2031079726)); + assert_eq!(f(2699997885) + f(2745140255), f(2223912669)); + assert_eq!(f(2824399978) + f(2531872141), f(2135046648)); + assert_eq!(f(36496612) - f(2039504668), f(1218217415)); + assert_eq!(f(266176802) - f(1609833977), f(1877568296)); + assert_eq!(f(713535382) - f(2153383999), f(1781376854)); + assert_eq!(f(1249965147) - f(3144251805), f(1326938813)); + assert_eq!(f(2692223381) * f(2935379475), f(2084179397)); + assert_eq!(f(2800462205) * f(2822998916), f(2089431198)); + assert_eq!(f(3061947734) * f(3210920667), f(1962208034)); + assert_eq!(f(3138997926) * f(2994465129), f(1772479317)); + assert_eq!(f(2947552629) / f(576466398), f(2041593039)); + assert_eq!(f(2914694891) / f(399734126), f(1983162347)); + assert_eq!(f(2202862138) / f(1154428799), f(2139936238)); + assert_eq!(f(3037207894) / f(2865447143), f(1894581230)); + } }