Skip to content

Commit

Permalink
TST: Add tests for broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
bashtage committed Oct 2, 2024
1 parent 883e85c commit ef39ce8
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 0 deletions.
66 changes: 66 additions & 0 deletions randomgen/tests/_shims.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@ from randomgen.common cimport (
view_little_endian,
)

from randomgen.chacha import ChaCha

from numpy.random cimport bitgen_t

import numpy as np

from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_IsValid

from randomgen.broadcasting cimport constraint_type, cont


def view_little_endian_shim(arr, dtype):
return view_little_endian(arr, dtype)
Expand All @@ -21,3 +29,61 @@ def byteswap_little_endian_shim(arr):

def object_to_int_shim(val, bits, name, default_bits=64, allowed_sizes=(64,)):
return object_to_int(val, bits, name, default_bits, allowed_sizes)

cdef double double0_func(void *state):
return 3.141592

Check warning on line 34 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L34

Added line #L34 was not covered by tests

cdef double double1_func(void *state, double a):
return a

Check warning on line 37 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L37

Added line #L37 was not covered by tests

cdef double double2_func(void *state, double a, double b):
return a+b

Check warning on line 40 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L40

Added line #L40 was not covered by tests

cdef double double3_func(void *state, double a, double b, double c):
return a+b+c

Check warning on line 43 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L43

Added line #L43 was not covered by tests

cdef bitgen_t _bitgen
cdef const char *name = "BitGenerator"

chacha = ChaCha()
capsule = chacha.capsule
_bitgen = (<bitgen_t *> PyCapsule_GetPointer(capsule, name))[0]

def cont_0(size=None, out=None):
return cont(&double0_func, &_bitgen ,size, chacha.lock, 0,

Check warning on line 53 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L53

Added line #L53 was not covered by tests
0.0, "", constraint_type.CONS_NONE,
0.0, "", constraint_type.CONS_NONE,
0.0, "", constraint_type.CONS_NONE,
out)

Check warning on line 57 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L57

Added line #L57 was not covered by tests
def cont_1(a, size=None, out=None):
return cont(&double1_func, &_bitgen ,size, chacha.lock, 1,

Check warning on line 59 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L59

Added line #L59 was not covered by tests
a, "a", constraint_type.CONS_NON_NEGATIVE,
0.0, "", constraint_type.CONS_NONE,
0.0, "", constraint_type.CONS_NONE,
out)

Check warning on line 63 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L63

Added line #L63 was not covered by tests

def cont_2(a, b, size=None, out=None):
return cont(&double2_func, &_bitgen ,size, chacha.lock, 2,

Check warning on line 66 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L66

Added line #L66 was not covered by tests
a, "a", constraint_type.CONS_POSITIVE,
b, "b", constraint_type.CONS_POSITIVE_NOT_NAN,
0.0, "", constraint_type.CONS_NONE,
out)

Check warning on line 70 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L70

Added line #L70 was not covered by tests

def cont_3(a, b, c, size=None, out=None):
return cont(&double3_func, &_bitgen ,size, chacha.lock, 3,

Check warning on line 73 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L73

Added line #L73 was not covered by tests
a, "a", constraint_type.CONS_BOUNDED_0_1,
b, "b", constraint_type.CONS_BOUNDED_GT_0_1,
c, "c", constraint_type.CONS_GT_1,
out)

Check warning on line 77 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L77

Added line #L77 was not covered by tests

def cont_3_alt_cons(a, b, c, size=None, out=None):
return cont(&double3_func, &_bitgen, size, chacha.lock, 3,

Check warning on line 80 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L80

Added line #L80 was not covered by tests
a, "a", constraint_type.CONS_GTE_1,
b, "b", constraint_type.CONS_POISSON,
c, "c", constraint_type.LEGACY_CONS_POISSON,
out)

Check warning on line 84 in randomgen/tests/_shims.pyx

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/_shims.pyx#L84

Added line #L84 was not covered by tests

# Iterations
# Scalar or vector parameters (must be broadcastable)
# size
# out, if used with size, and parameters, must be compatible
124 changes: 124 additions & 0 deletions randomgen/tests/test_broadcasting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from collections import defaultdict
from typing import NamedTuple

import numpy as np
from numpy.testing import assert_allclose
import pytest

from randomgen.tests._shims import cont_0, cont_1, cont_2, cont_3, cont_3_alt_cons


class Config(NamedTuple):
a: float | np.ndarray | None
b: float | np.ndarray | None
c: float | np.ndarray | None
size: tuple[int, ...] | None
out: np.ndarray | None


CONFIGS = defaultdict(list)


def all_scalar(*args):
return all([(arg is None or np.isscalar(arg)) for arg in args])


def get_broadcastable_size(a, b, c):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note test

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
if c is not None:
return (a + b + c).shape
if b is not None:
return (a + b).shape
if a is not None:
return a.shape


def count_params(a, b, c):
return sum((v is not None) for v in (a, b, c))


for a in (None, 0.5, 0.5 * np.ones((1, 2)), 0.5 * np.ones((3, 2))):
for b in (
None,
0.5,
0.5 * np.ones((1, 2)),
0.5 * np.ones((3, 1)),
0.5 * np.ones((3, 2)),
):
if a is None and b is not None:
continue
for c in (
None,
1.5,
1.5 * np.ones((1, 2)),
1.5 * np.ones((3, 1)),
1.5 * np.ones((3, 2)),
):
if b is None and c is not None:
continue
for size in (True, False):
for out in (True, False):
if size:
if all_scalar(a, b, c):
_size = (7, 5)
else:
_size = get_broadcastable_size(a, b, c)
else:
_size = None
if out:
if size:
_out = np.empty(_size)
elif all_scalar(a, b, c):
_out = np.empty((11, 7))
else:
_out = np.empty(get_broadcastable_size(a, b, c))
else:
_out = None
print(_size, _out.shape if isinstance(_out, np.ndarray) else _out)
CONFIGS[count_params(a, b, c)].append(Config(a, b, c, _size, _out))


@pytest.mark.parametrize("config", CONFIGS[0])
def test_cont_0(config):
res = cont_0(size=config.size, out=config.out)

Check warning on line 82 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L82

Added line #L82 was not covered by tests
if isinstance(res, np.ndarray):
assert_allclose(res, 3.141592 * np.ones_like(res))

Check warning on line 84 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L84

Added line #L84 was not covered by tests
else:
assert_allclose(res, 3.141592)

Check warning on line 86 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L86

Added line #L86 was not covered by tests


@pytest.mark.parametrize("config", CONFIGS[1])
def test_cont_1(config):
res = cont_1(config.a, size=config.size, out=config.out)

Check warning on line 91 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L91

Added line #L91 was not covered by tests
if isinstance(res, np.ndarray):
assert_allclose(res, 0.5 * np.ones_like(res))

Check warning on line 93 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L93

Added line #L93 was not covered by tests
else:
assert_allclose(res, 0.5)

Check warning on line 95 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L95

Added line #L95 was not covered by tests


@pytest.mark.parametrize("config", CONFIGS[2])
def test_cont_2(config):
res = cont_2(config.a, config.b, size=config.size, out=config.out)

Check warning on line 100 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L100

Added line #L100 was not covered by tests
if isinstance(res, np.ndarray):
assert_allclose(res, np.ones_like(res))

Check warning on line 102 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L102

Added line #L102 was not covered by tests
else:
assert_allclose(res, 1.0)

Check warning on line 104 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L104

Added line #L104 was not covered by tests


@pytest.mark.parametrize("config", CONFIGS[3])
def test_cont_3(config):
res = cont_3(config.a, config.b, config.c, size=config.size, out=config.out)

Check warning on line 109 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L109

Added line #L109 was not covered by tests
if isinstance(res, np.ndarray):
assert_allclose(res, 2.5 * np.ones_like(res))

Check warning on line 111 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L111

Added line #L111 was not covered by tests
else:
assert_allclose(res, 2.5)

Check warning on line 113 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L113

Added line #L113 was not covered by tests


@pytest.mark.parametrize("config", CONFIGS[3])
def test_cont_3_alt_cons(config):
res = cont_3_alt_cons(

Check warning on line 118 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L118

Added line #L118 was not covered by tests
1.0 + config.a, config.b, config.c, size=config.size, out=config.out
)
if isinstance(res, np.ndarray):
assert_allclose(res, 3.5 * np.ones_like(res))

Check warning on line 122 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L122

Added line #L122 was not covered by tests
else:
assert_allclose(res, 3.5)

Check warning on line 124 in randomgen/tests/test_broadcasting.py

View check run for this annotation

Codecov / codecov/patch

randomgen/tests/test_broadcasting.py#L124

Added line #L124 was not covered by tests

0 comments on commit ef39ce8

Please sign in to comment.