Skip to content

Commit

Permalink
Merge pull request #71 from qryxip/atomic-barrett-seqcst
Browse files Browse the repository at this point in the history
Make `modint::Barrett` `(AtomicU32, AtomicU64)`
  • Loading branch information
qryxip authored Oct 18, 2020
2 parents 006d353 + b09e6b5 commit a8c306d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 33 deletions.
47 changes: 29 additions & 18 deletions src/internal_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,36 @@ impl Barrett {
/// a * b % m
#[allow(clippy::many_single_char_names)]
pub(crate) fn mul(&self, a: u32, b: u32) -> u32 {
// [1] m = 1
// a = b = im = 0, so okay

// [2] m >= 2
// im = ceil(2^64 / m)
// -> im * m = 2^64 + r (0 <= r < m)
// let z = a*b = c*m + d (0 <= c, d < m)
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
// ((ab * im) >> 64) == c or c + 1
let mut z = a as u64;
z *= b as u64;
let x = (((z as u128) * (self.im as u128)) >> 64) as u64;
let mut v = z.wrapping_sub(x.wrapping_mul(self._m as u64)) as u32;
if self._m <= v {
v = v.wrapping_add(self._m);
}
v
mul_mod(a, b, self._m, self.im)
}
}

/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m <= 2^31`
/// * `im` = ceil(2^64 / `m`)
#[allow(clippy::many_single_char_names)]
pub(crate) fn mul_mod(a: u32, b: u32, m: u32, im: u64) -> u32 {
// [1] m = 1
// a = b = im = 0, so okay

// [2] m >= 2
// im = ceil(2^64 / m)
// -> im * m = 2^64 + r (0 <= r < m)
// let z = a*b = c*m + d (0 <= c, d < m)
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
// ((ab * im) >> 64) == c or c + 1
let mut z = a as u64;
z *= b as u64;
let x = (((z as u128) * (im as u128)) >> 64) as u64;
let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
if m <= v {
v = v.wrapping_add(m);
}
v
}

/// # Parameters
Expand Down
48 changes: 33 additions & 15 deletions src/modint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ use std::{
marker::PhantomData,
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
str::FromStr,
sync::atomic::{self, AtomicU32, AtomicU64},
thread::LocalKey,
};

Expand Down Expand Up @@ -330,7 +331,7 @@ impl<I: Id> DynamicModInt<I> {
/// ```
#[inline]
pub fn modulus() -> u32 {
I::companion_barrett().with(|bt| bt.borrow().umod())
I::companion_barrett().umod()
}

/// Sets a modulus.
Expand All @@ -354,7 +355,7 @@ impl<I: Id> DynamicModInt<I> {
if modulus == 0 {
panic!("the modulus must not be 0");
}
I::companion_barrett().with(|bt| *bt.borrow_mut() = Barrett::new(modulus))
I::companion_barrett().update(modulus);
}

/// Creates a new `DynamicModInt`.
Expand Down Expand Up @@ -442,47 +443,64 @@ impl<I: Id> ModIntBase for DynamicModInt<I> {
}

pub trait Id: 'static + Copy + Eq {
// TODO: Make `internal_math::Barret` `Copy`.
fn companion_barrett() -> &'static LocalKey<RefCell<Barrett>>;
fn companion_barrett() -> &'static Barrett;
}

#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub enum DefaultId {}

impl Id for DefaultId {
fn companion_barrett() -> &'static LocalKey<RefCell<Barrett>> {
thread_local! {
static BARRETT: RefCell<Barrett> = RefCell::default();
}
fn companion_barrett() -> &'static Barrett {
static BARRETT: Barrett = Barrett::default();
&BARRETT
}
}

/// Pair of _m_ and _ceil(2⁶⁴/m)_.
pub struct Barrett(internal_math::Barrett);
pub struct Barrett {
m: AtomicU32,
im: AtomicU64,
}

impl Barrett {
/// Creates a new `Barrett`.
#[inline]
pub fn new(m: u32) -> Self {
Self(internal_math::Barrett::new(m))
pub const fn new(m: u32) -> Self {
Self {
m: AtomicU32::new(m),
im: AtomicU64::new((-1i64 as u64 / m as u64).wrapping_add(1)),
}
}

#[inline]
const fn default() -> Self {
Self::new(998_244_353)
}

#[inline]
fn update(&self, m: u32) {
let im = (-1i64 as u64 / m as u64).wrapping_add(1);
self.m.store(m, atomic::Ordering::SeqCst);
self.im.store(im, atomic::Ordering::SeqCst);
}

#[inline]
fn umod(&self) -> u32 {
self.0.umod()
self.m.load(atomic::Ordering::SeqCst)
}

#[inline]
fn mul(&self, a: u32, b: u32) -> u32 {
self.0.mul(a, b)
let m = self.m.load(atomic::Ordering::SeqCst);
let im = self.im.load(atomic::Ordering::SeqCst);
internal_math::mul_mod(a, b, m, im)
}
}

impl Default for Barrett {
#[inline]
fn default() -> Self {
Self(internal_math::Barrett::new(998_244_353))
Self::default()
}
}

Expand Down Expand Up @@ -810,7 +828,7 @@ impl<M: Modulus> InternalImplementations for StaticModInt<M> {
impl<I: Id> InternalImplementations for DynamicModInt<I> {
#[inline]
fn mul_impl(lhs: Self, rhs: Self) -> Self {
I::companion_barrett().with(|bt| Self::raw(bt.borrow().mul(lhs.val, rhs.val)))
Self::raw(I::companion_barrett().mul(lhs.val, rhs.val))
}
}

Expand Down

0 comments on commit a8c306d

Please sign in to comment.