Skip to content

Commit

Permalink
Merge pull request #386 from bashtage/more-test-broadcast
Browse files Browse the repository at this point in the history
Add test for float 1
  • Loading branch information
bashtage authored Oct 2, 2024
2 parents 6f1cc8d + 3338f62 commit b7a6d9a
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 19 deletions.
24 changes: 23 additions & 1 deletion randomgen/tests/_shim_dist.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,26 @@ double double2_func(bitgen_t *state, double a, double b) { return a + b; }

double double3_func(bitgen_t *state, double a, double b, double c) {
return a + b + c;
}
}

float float_0(bitgen_t *state) { return 3.141592; }

float float_1(bitgen_t *state, float a) { return a; }

int64_t int_0(void *state) { return 3; }

int64_t int_d(void *state, double a) { return (int64_t)(10 * a); };

int64_t int_dd(void *state, double a, double b) {
return (int64_t)(10 * a * b);
};

int64_t int_di(void *state, double a, uint64_t b) {
return (int64_t)2 * a * b;
};

int64_t int_i(void *state, int64_t a) { return a; };

int64_t int_iii(void *state, int64_t a, int64_t b, int64_t c) {
return a + b + c;
};
19 changes: 19 additions & 0 deletions randomgen/tests/_shim_dist.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
#include <inttypes.h>
#include "numpy/random/distributions.h"

extern double double0_func(bitgen_t *state);
extern double double1_func(bitgen_t *state, double a);
extern double double2_func(bitgen_t *state, double a, double b);
extern double double3_func(bitgen_t *state, double a, double b, double c);

extern float float_0(bitgen_t *state);
extern float float_1(bitgen_t *state, float a);

extern int64_t int_0(void *state);
extern int64_t int_d(void *state, double a);
extern int64_t int_dd(void *state, double a, double b);
extern int64_t int_di(void *state, double a, uint64_t b);
extern int64_t int_i(void *state, int64_t a);
extern int64_t int_iii(void *state, int64_t a, int64_t b, int64_t c);

/*
extern uint32_t uint_0_32(bitgen_t *state);
extern uint32_t uint_1_i_32(bitgen_t *state, uint32_t a);
extern int32_t int_2_i_32(bitgen_t *state, int32_t a, int32_t b);
extern int64_t int_2_i(bitgen_t *state, int64_t a, int64_t b);
*/
31 changes: 26 additions & 5 deletions randomgen/tests/_shims.pxd
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_IsValid
from numpy.random cimport bitgen_t

from randomgen.tests.data.compute_hashes import bit_gen
from randomgen.broadcasting cimport constraint_type, cont, cont_f
from randomgen.common cimport (
byteswap_little_endian,
int_to_array,
object_to_int,
view_little_endian,
)

from Cython.Includes.cpython.datetime import noexcept

from libc.stdint cimport int64_t, uint64_t


cdef extern from "_shim_dist.h":
double double0_func(bitgen_t *state);
double double1_func(bitgen_t *state, double a);
double double2_func(bitgen_t *state, double a, double b);
double double3_func(bitgen_t *state, double a, double b, double c);
double double0_func(bitgen_t *state) noexcept nogil
double double1_func(bitgen_t *state, double a) noexcept nogil
double double2_func(bitgen_t *state, double a, double b) noexcept nogil
double double3_func(bitgen_t *state, double a, double b, double c) noexcept nogil

float float_0(bitgen_t *state) noexcept nogil
float float_1(bitgen_t *state, float a) noexcept nogil

int64_t int_0(bitgen_t *state) noexcept nogil
int64_t int_d(bitgen_t *state, double a) noexcept nogil
int64_t int_dd(bitgen_t *state, double a, double b) noexcept nogil
int64_t int_di(bitgen_t *state, double a, uint64_t b) noexcept nogil
int64_t int_i(bitgen_t *state, int64_t a) noexcept nogil
int64_t int_iii(bitgen_t *state, int64_t a, int64_t b, int64_t c) noexcept nogil
17 changes: 4 additions & 13 deletions randomgen/tests/_shims.pyx
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
from numpy.random cimport bitgen_t

from randomgen.common cimport (
byteswap_little_endian,
int_to_array,
object_to_int,
view_little_endian,
)

import numpy as np

from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_IsValid

from randomgen.broadcasting cimport constraint_type, cont


def view_little_endian_shim(arr, dtype):
return view_little_endian(arr, dtype)
Expand Down Expand Up @@ -79,3 +66,7 @@ cdef class ShimGenerator:
b, "b", constraint_type.CONS_POISSON,
c, "c", constraint_type.LEGACY_CONS_POISSON,
out)

def cont_1_float(self, a, size=None, out=None):
return cont_f(&float_1, &self._bitgen, size, self.lock,
a, "a", constraint_type.CONS_NONE, out)
17 changes: 17 additions & 0 deletions randomgen/tests/test_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,20 @@ def test_cont_3_alt_cons(config):
assert_allclose(res, 3.5 * np.ones_like(res))
else:
assert_allclose(res, 3.5)


@pytest.mark.parametrize("config", CONFIGS[1])
def test_cont_1(config):
if isinstance(config.a, np.ndarray):
a = config.a.astype(np.float32)
else:
a = config.a
out = None
if config.out is not None:
out = np.empty(config.out.shape, dtype=np.float32)

res = generator.cont_1_float(a, size=config.size, out=out)
if isinstance(res, np.ndarray):
assert_allclose(res, 0.5 * np.ones_like(res))
else:
assert_allclose(res, 0.5)

0 comments on commit b7a6d9a

Please sign in to comment.