Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul_mod for 2^31 < m < 2^32 #111

Open
mizar opened this issue Jan 14, 2023 · 2 comments · May be fixed by #112
Open

mul_mod for 2^31 < m < 2^32 #111

mizar opened this issue Jan 14, 2023 · 2 comments · May be fixed by #112

Comments

@mizar
Copy link
Collaborator

mizar commented Jan 14, 2023

mul_mod seems easily be adapted to the case 2^31 < m < 2^32 by simply improving the last subtraction borrow check.

(current code: ac-library)

https://github.com/atcoder/ac-library/blob/6c88a70c8f95fef575af354900d107fbd0db1a12/atcoder/internal_math.hpp#L22-L62

(current code: ac-library-rs)

/// Fast modular by barrett reduction
/// Reference: https://en.wikipedia.org/wiki/Barrett_reduction
/// NOTE: reconsider after Ice Lake
pub(crate) struct Barrett {
pub(crate) _m: u32,
pub(crate) im: u64,
}
impl Barrett {
/// # Arguments
/// * `m` `1 <= m`
/// (Note: `m <= 2^31` should also hold, which is undocumented in the original library.
/// See the [pull reqeust commment](https://github.com/rust-lang-ja/ac-library-rs/pull/3#discussion_r484661007)
/// for more details.)
pub(crate) fn new(m: u32) -> Barrett {
Barrett {
_m: m,
im: (-1i64 as u64 / m as u64).wrapping_add(1),
}
}
/// # Returns
/// `m`
pub(crate) fn umod(&self) -> u32 {
self._m
}
/// # Parameters
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
///
/// # Returns
/// a * b % m
#[allow(clippy::many_single_char_names)]
pub(crate) fn mul(&self, a: u32, b: u32) -> u32 {
mul_mod(a, b, self._m, self.im)
}
}
/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m <= 2^31`
/// * `im` = ceil(2^64 / `m`)
#[allow(clippy::many_single_char_names)]
pub(crate) fn mul_mod(a: u32, b: u32, m: u32, im: u64) -> u32 {
// [1] m = 1
// a = b = im = 0, so okay
// [2] m >= 2
// im = ceil(2^64 / m)
// -> im * m = 2^64 + r (0 <= r < m)
// let z = a*b = c*m + d (0 <= c, d < m)
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
// ((ab * im) >> 64) == c or c + 1
let mut z = a as u64;
z *= b as u64;
let x = (((z as u128) * (im as u128)) >> 64) as u64;
let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
if m <= v {
v = v.wrapping_add(m);
}
v
}

  • $m\in\text{自然数(natural number)}\mathbb{N},\quad 1\le m\lt 2^{32}$
  • $\lbrace a,b\rbrace\in\text{整数(integer)}\mathbb{Z},\quad 0\le \lbrace a, b\rbrace\lt m$
  • $\displaystyle\bar{m'} = \left\lfloor\frac{2^{64}-1}{m}\right\rfloor+1\mod2^{64}=\left\lceil\frac{2^{64}}{m}\right\rceil\mod 2^{64}$
  • $\displaystyle x=\left\lfloor\frac{ab\bar{m'}}{2^{64}}\right\rfloor$
  • $ab\mod m=ab-xm\quad(ab\ge xm)$
  • $ab\mod m=ab-xm+m\quad(ab\lt xm)$

(proof)

  1. when $m=1$, $a=b=\bar{m'}=0$, so okey
  2. when $2\le m\lt 2^{32}$,
    • $2^{32}+2=\left\lceil\frac{2^{64}}{2^{32}-1}\right\rceil\le\bar{m'}=\left\lceil\frac{2^{64}}{m}\right\rceil\le \left\lceil\frac{2^{64}}{2}\right\rceil=2^{63}$
    • $\bar{m'}\hspace{.1em}m=2^{64}+r\quad(0\le r\lt m)$
    • $z = ab = cm + d\quad(0\le\lbrace c,d\rbrace\lt m)$
    • $z\hspace{.1em}\bar{m'}=ab\hspace{.1em}\bar{m'}=(cm+d)\hspace{.1em}\bar{m'}=c(\bar{m'}\hspace{.1em}m)+d\hspace{.1em}\bar{m'}=2^{64}c+c\hspace{.1em}r+d\hspace{.1em}\bar{m'}$
    • $2^{64}c\le z\hspace{.1em}\bar{m'}\lt 2^{64}(c+2)$
      • $z\hspace{.1em}\bar{m'}=2^{64}c+c\hspace{.1em}r+d\hspace{.1em}\bar{m'}$
      • $0\le c\hspace{.1em}r\le (m-1)^2\le(2^{32}-2)^2=2^{64}-2^{34}+4$
      • $0\le d\hspace{.1em}\bar{m'}\le\bar{m'}\hspace{.1em}(m-1)=2^{64}+r-\bar{m'}\le 2^{64}+(2^{32}-2)-(2^{32}+2)=2^{64}-4$
    • $x=\left\lfloor\frac{ab\hspace{.1em}\bar{m'}}{2^{64}}\right\rfloor=\lbrace c$ or $(c+1)\rbrace$
    • $z-xm=ab-\left\lfloor\frac{ab\hspace{.1em}\bar{m'}}{2^{64}}\right\rfloor m=\lbrace d$ or $(d-m)\rbrace$

(C++: $1\le m\lt 2^{32}$ draft code)

https://godbolt.org/z/9Gz1oGrTa

#ifdef _MSC_VER
#include <intrin.h>
#endif

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_before(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned int v = (unsigned int)(z - x * _m);
    if (_m <= v) v += _m;
    return v;
}

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_after(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned long long y = x * _m;
    return (unsigned int)(z - y + (z < y ? _m : 0));
}

(Rust: $1\le m\lt 2^{32}$ draft code)

https://rust.godbolt.org/z/7P5rjahMn

/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m <= 2^31`
/// * `im` = ceil(2^64 / `m`)
#[allow(clippy::many_single_char_names)]
pub fn mul_mod_before(a: u32, b: u32, m: u32, im: u64) -> u32 {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    let mut z = a as u64;
    z *= b as u64;
    let x = (((z as u128) * (im as u128)) >> 64) as u64;
    let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
    if m <= v {
        v = v.wrapping_add(m);
    }
    v
}

/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m < 2^32`
/// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1
#[allow(clippy::many_single_char_names)]
pub fn mul_mod_after(a: u32, b: u32, m: u32, im: u64) -> u32 {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    let z = (a as u64) * (b as u64);
    let x = (((z as u128) * (im as u128)) >> 64) as u64;
    match z.overflowing_sub(x.wrapping_mul(m as u64)) {
        (v, true) => (v as u32).wrapping_add(m),
        (v, false) => v as u32,
    }
}
@mizar
Copy link
Collaborator Author

mizar commented Jan 14, 2023

The following add overflow checks and subtruct borrow checks should also be considered.

For example, use overflowing_add and overflowing_sub for addition and subtraction as well.

ac-library-rs/src/modint.rs

Lines 793 to 811 in b09e6b5

#[inline]
fn add_impl(lhs: Self, rhs: Self) -> Self {
let modulus = Self::modulus();
let mut val = lhs.val() + rhs.val();
if val >= modulus {
val -= modulus;
}
Self::raw(val)
}
#[inline]
fn sub_impl(lhs: Self, rhs: Self) -> Self {
let modulus = Self::modulus();
let mut val = lhs.val().wrapping_sub(rhs.val());
if val >= modulus {
val = val.wrapping_add(modulus)
}
Self::raw(val)
}

@mizar
Copy link
Collaborator Author

mizar commented Jan 18, 2023

Example of subtraction borrow check using built-in instruction (GCC/MSVC):

https://godbolt.org/z/P8749355T

#ifdef _MSC_VER
#include <intrin.h>
#endif

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_before(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned int v = (unsigned int)(z - x * _m);
    if (_m <= v) v += _m;
    return v;
}

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_after(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned long long y = x * _m;
#ifdef __GNUC__
    // https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html
    unsigned long long v;
    unsigned int w = __builtin_usubll_overflow(z, y, &v) ? _m : 0;
    return (unsigned int)(v + w);
#elif defined(_MSC_VER) && defined(_M_AMD64)
    // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_subborrow_u64&ig_expand=7252
    unsigned long long v;
    unsigned int w = _subborrow_u64(0, z, y, &v) ? _m : 0;
    return (unsigned int)(v + w);
#else
    return (unsigned int)((z - y) + (z < y ? _m : 0));
#endif
}

@mizar mizar linked a pull request Jan 20, 2023 that will close this issue
@mizar mizar linked a pull request Mar 27, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants