diff --git a/src/internal_math.rs b/src/internal_math.rs index 515191c..cb76908 100644 --- a/src/internal_math.rs +++ b/src/internal_math.rs @@ -235,6 +235,41 @@ pub(crate) fn primitive_root(m: i32) -> i32 { // omitted // template constexpr int primitive_root = primitive_root_constexpr(m); +/// # Arguments +/// * `n` `n < 2^32` +/// * `m` `1 <= m < 2^32` +/// +/// # Returns +/// `sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)` +/* const */ +#[allow(clippy::many_single_char_names)] +pub(crate) fn floor_sum_unsigned(mut n: u64, mut m: u64, mut a: u64, mut b: u64) -> u64 { + let mut ans = 0; + loop { + if a >= m { + if n > 0 { + ans += n * (n - 1) / 2 * (a / m); + } + a %= m; + } + if b >= m { + ans += n * (b / m); + b %= m; + } + + let y_max = a * n + b; + if y_max < m { + break; + } + // y_max < m * (n + 1) + // floor(y_max / m) <= n + n = y_max / m; + b = y_max % m; + std::mem::swap(&mut m, &mut a); + } + ans +} + #[cfg(test)] mod tests { #![allow(clippy::unreadable_literal)] diff --git a/src/math.rs b/src/math.rs index 61f15d5..7f4bd1f 100644 --- a/src/math.rs +++ b/src/math.rs @@ -185,25 +185,22 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) { /// /// assert_eq!(math::floor_sum(6, 5, 4, 3), 13); /// ``` +#[allow(clippy::many_single_char_names)] pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 { + assert!((0..1i64 << 32).contains(&n)); + assert!((1..1i64 << 32).contains(&m)); let mut ans = 0; - if a >= m { - ans += (n - 1) * n * (a / m) / 2; - a %= m; - } - if b >= m { - ans += n * (b / m); - b %= m; + if a < 0 { + let a2 = internal_math::safe_mod(a, m); + ans -= n * (n - 1) / 2 * ((a2 - a) / m); + a = a2; } - - let y_max = (a * n + b) / m; - let x_max = y_max * m - b; - if y_max == 0 { - return ans; + if b < 0 { + let b2 = internal_math::safe_mod(b, m); + ans -= n * ((b2 - b) / m); + b = b2; } - ans += (n - (x_max + a - 1) / a) * y_max; - ans += floor_sum(y_max, a, m, (a - x_max % a) % a); - ans + ans + internal_math::floor_sum_unsigned(n as u64, m as u64, a as u64, b as u64) as i64 } #[cfg(test)] @@ -306,5 +303,24 @@ mod tests { 499_999_999_500_000_000 ); assert_eq!(floor_sum(332955, 5590132, 2231, 999423), 22014575); + for n in 0..20 { + for m in 1..20 { + for a in -20..20 { + for b in -20..20 { + assert_eq!(floor_sum(n, m, a, b), floor_sum_naive(n, m, a, b)); + } + } + } + } + } + + #[allow(clippy::many_single_char_names)] + fn floor_sum_naive(n: i64, m: i64, a: i64, b: i64) -> i64 { + let mut ans = 0; + for i in 0..n { + let z = a * i + b; + ans += (z - internal_math::safe_mod(z, m)) / m; + } + ans } }