Skip to content

Commit

Permalink
REF: Fix extended tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bashtage committed Sep 18, 2024
1 parent c42b361 commit 1c6b1b6
Showing 1 changed file with 47 additions and 47 deletions.
94 changes: 47 additions & 47 deletions randomgen/tests/test_extended_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def mv_seed():

@pytest.fixture(scope="function")
def extended_gen():
pcg = PCG64(0)
pcg = PCG64(0, mode="sequence")
return ExtendedGenerator(pcg)


Expand Down Expand Up @@ -155,7 +155,7 @@ def test_multivariate_normal_method(seed, method):

@pytest.mark.parametrize("method", ["svd", "eigh", "cholesky"])
def test_multivariate_normal_basic_stats(seed, method):
random = ExtendedGenerator(MT19937(seed))
random = ExtendedGenerator(MT19937(seed, mode="sequence"))
n_s = 1000
mean = np.array([1, 2])
cov = np.array([[2, 1], [1, 2]])
Expand Down Expand Up @@ -184,24 +184,24 @@ def test_multivariate_normal_bad_size(mean, size):


def test_multivariate_normal(seed):
random.bit_generator.seed(seed)
random = ExtendedGenerator(MT19937(seed, mode="sequence"))
mean = (0.123456789, 10)
cov = [[1, 0], [0, 1]]
size = (3, 2)
actual = random.multivariate_normal(mean, cov, size)
desired = np.array(
[
[
[-3.34929721161096100, 9.891061435770858],
[-0.12250896439641100, 9.295898449738300],
[0.8032359104382388, 9.020376052272443],
[-0.549265600883053, 7.859449222248867],
],
[
[0.48355927611635563, 10.127832101772366],
[3.11093021424924300, 10.283109168794352],
[0.4159548290083927, 10.569454865174936],
[-1.5353359555824029, 9.318305704280354],
],
[
[-0.20332082341774727, 9.868532121697195],
[-1.33806889550667330, 9.813657233804179],
[0.950123388279074, 10.055774500309264],
[0.13457049786723752, 9.803457241505443],
],
]
)
Expand All @@ -210,7 +210,7 @@ def test_multivariate_normal(seed):

# Check for default size, was raising deprecation warning
actual = random.multivariate_normal(mean, cov)
desired = np.array([-1.097443117192574, 10.535787051184261])
desired = np.array([0.3668305896561763, 10.345231541085731])
assert_array_almost_equal(actual, desired, decimal=15)

# Check that non positive-semidefinite covariance warns with
Expand Down Expand Up @@ -244,21 +244,21 @@ def test_multivariate_normal(seed):


def test_complex_normal(seed):
random.bit_generator.seed(seed)
random = ExtendedGenerator(MT19937(seed, mode="sequence"))
actual = random.complex_normal(loc=1.0, gamma=1.0, relation=0.5, size=(3, 2))
desired = np.array(
[
[
-2.007493185623132 - 0.05446928211457126j,
0.7869874090977291 - 0.35205077513085050j,
1.5887059881277819 - 0.4898119738637788j,
0.41740532066669644 - 1.0702753888755665j,
],
[
1.3118579018087224 + 0.06391605088618339j,
3.5872278793967554 + 0.14155458439717636j,
1.253310733204425 + 0.2847274325874684j,
-0.4365566564216725 - 0.3408471478598232j,
],
[
0.7170022862582056 - 0.06573393915140235j,
-0.26571837106621987 - 0.0931713830979103j,
1.7159142754357688 + 0.027887250154631615j,
1.009624754209292 - 0.09827137924727837j,
],
]
)
Expand Down Expand Up @@ -321,7 +321,7 @@ def test_set_get_state(seed):


def test_complex_normal_size(mv_seed):
random = ExtendedGenerator(MT19937(mv_seed))
random = ExtendedGenerator(MT19937(mv_seed, mode="sequence"))
state = random.state
loc = np.ones((1, 2))
gamma = np.ones((3, 1))
Expand All @@ -330,16 +330,16 @@ def test_complex_normal_size(mv_seed):
desired = np.array(
[
[
1.393937478212015 - 0.31374589731830593j,
0.9474905694736895 - 0.16424530802218726j,
1.2093667943035076 + 0.13793154570353136j,
1.5769647969244578 + 0.6158732529417138j,
],
[
1.119247463119766 + 0.023956373851168843j,
0.8776366291514774 + 0.2865220655803411j,
0.5900374596315505 + 0.5238612206671325j,
1.2024921591940638 - 0.5370509227680859j,
],
[
0.5515508326417458 - 0.15986016780453596j,
-0.6803993941303332 + 1.1782711493556892j,
1.088126576683737 + 0.5544754121023469j,
0.8485995861109008 - 0.047090051743345455j,
],
]
)
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_default_pcg64():
@pytest.mark.parametrize("dim", [2, 5, 10])
@pytest.mark.parametrize("size", [None, 5, (3, 7)])
def test_standard_wishart_reproduce(df, dim, size):
pcg = PCG64(0)
pcg = PCG64(0, mode="sequence")
eg = ExtendedGenerator(pcg)
w = eg.standard_wishart(df, dim, size)
if size is not None:
Expand All @@ -378,7 +378,7 @@ def test_standard_wishart_reproduce(df, dim, size):
assert w.ndim == 2
assert w.shape[-2:] == (dim, dim)

pcg = PCG64(0)
pcg = PCG64(0, mode="sequence")
eg = ExtendedGenerator(pcg)
w2 = eg.standard_wishart(df, dim, size)
assert_allclose(w, w2)
Expand All @@ -388,7 +388,7 @@ def test_standard_wishart_reproduce(df, dim, size):
@pytest.mark.parametrize("df", [8, [10], [[5], [6]]])
def test_wishart_broadcast(df, scale_dim):
dim = 5
pcg = PCG64(0)
pcg = PCG64(0, mode="sequence")
eg = ExtendedGenerator(pcg)
scale = np.eye(dim)
for _ in range(scale_dim):
Expand All @@ -402,7 +402,7 @@ def test_wishart_broadcast(df, scale_dim):
assert w.shape[:-2] == np.broadcast(df, z).shape

size = w.shape[:-2]
pcg = PCG64(0)
pcg = PCG64(0, mode="sequence")
eg = ExtendedGenerator(pcg)
w2 = eg.wishart(df, scale, size=size)
assert_allclose(w, w2)
Expand All @@ -417,7 +417,7 @@ def test_wishart_broadcast(df, scale_dim):
def test_wishart_reduced_rank(method):
scale = np.eye(3)
scale[0, 1] = scale[1, 0] = 1.0
pcg = PCG64(0)
pcg = PCG64(0, mode="sequence")
eg = ExtendedGenerator(pcg)
w = eg.wishart(10, scale, method=method, rank=2)
assert w.shape == (3, 3)
Expand All @@ -428,7 +428,7 @@ def test_wishart_reduced_rank(method):
def test_missing_scipy_exception():
scale = np.eye(3)
scale[0, 1] = scale[1, 0] = 1.0
pcg = PCG64(0)
pcg = PCG64(0, mode="sequence")
eg = ExtendedGenerator(pcg)
with pytest.raises(ImportError):
eg.wishart(10, scale, method="cholesky", rank=2)
Expand Down Expand Up @@ -490,7 +490,7 @@ def test_broadcast_both_paths():


def test_factor_wishart():
pcg = PCG64(0)
pcg = PCG64(0, mode="sequence")
eg = ExtendedGenerator(pcg)
w = eg.wishart([3, 5], 2 * np.eye(4), size=(10000, 2), method="factor")
assert_allclose(np.diag((w[:, 0] / 3).mean(0)).mean(), 4, rtol=1e-2)
Expand Down Expand Up @@ -648,36 +648,36 @@ def test_random_other_type(extended_gen):
extended_gen.random(dtype=f16)


def test_random(extended_gen_legacy):
extended_gen_legacy.bit_generator.seed(SEED)
actual = extended_gen_legacy.random((3, 2))
def test_random():
random = ExtendedGenerator(MT19937(0, mode="sequence"))
actual = random.random((3, 2))
desired = np.array(
[
[0.61879477158567997, 0.59162362775974664],
[0.88868358904449662, 0.89165480011560816],
[0.4575674820298663, 0.7781880808593471],
[0.840045643478751, 0.526812305612705],
[0.390476663670696, 0.373221178199718],
[0.144160402211855, 0.255532529630851],
]
)
assert_array_almost_equal(actual, desired, decimal=15)

extended_gen_legacy.bit_generator.seed(SEED)
actual = extended_gen_legacy.random()
random = ExtendedGenerator(MT19937(0, mode="sequence"))
actual = random.random()
assert_array_almost_equal(actual, desired[0, 0], decimal=15)


def test_random_float(extended_gen_legacy):
extended_gen_legacy.bit_generator.seed(SEED)
actual = extended_gen_legacy.random((3, 2))
def test_random_float(seed):
random = ExtendedGenerator(MT19937(seed, mode="sequence"))
actual = random.random((3, 2))
desired = np.array(
[[0.6187948, 0.5916236], [0.8886836, 0.8916548], [0.4575675, 0.7781881]]
[[0.7165936, 0.6035045], [0.4473828, 0.359537], [0.7954794, 0.1942982]]
)
assert_array_almost_equal(actual, desired, decimal=7)


def test_random_float_scalar(extended_gen_legacy):
extended_gen_legacy.bit_generator.seed(SEED)
actual = extended_gen_legacy.random(dtype=np.float32)
desired = 0.6187948
def test_random_float_scalar(seed):
random = ExtendedGenerator(MT19937(seed, mode="sequence"))
actual = random.random(dtype=np.float32)
desired = 0.7165936
assert_array_almost_equal(actual, desired, decimal=7)


Expand Down

0 comments on commit 1c6b1b6

Please sign in to comment.