Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch ml_dtypes to use the Python limited API #195

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ template <typename T>
Safe_PyObjectPtr PyCustomFloat_FromT(T x) {
PyTypeObject* type =
reinterpret_cast<PyTypeObject*>(TypeDescriptor<T>::type_ptr);
Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0));
Safe_PyObjectPtr ref = make_safe(PyObject_New(PyObject, type));
PyCustomFloat<T>* p = reinterpret_cast<PyCustomFloat<T>*>(ref.get());
if (p) {
p->value = x;
Expand Down Expand Up @@ -213,7 +213,9 @@ PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) {
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x + y).release();
}
return PyArray_Type.tp_as_number->nb_add(a, b);
auto array_nb_add =
reinterpret_cast<binaryfunc>(PyType_GetSlot(&PyArray_Type, Py_nb_add));
return array_nb_add(a, b);
}

template <typename T>
Expand All @@ -222,7 +224,9 @@ PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) {
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x - y).release();
}
return PyArray_Type.tp_as_number->nb_subtract(a, b);
auto array_nb_subtract = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_subtract));
return array_nb_subtract(a, b);
}

template <typename T>
Expand All @@ -231,7 +235,9 @@ PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) {
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x * y).release();
}
return PyArray_Type.tp_as_number->nb_multiply(a, b);
auto array_nb_multiply = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_multiply));
return array_nb_multiply(a, b);
}

template <typename T>
Expand All @@ -240,7 +246,9 @@ PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) {
if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
return PyCustomFloat_FromT<T>(x / y).release();
}
return PyArray_Type.tp_as_number->nb_true_divide(a, b);
auto array_nb_true_divide = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_true_divide));
return array_nb_true_divide(a, b);
}

// Constructs a new PyCustomFloat.
Expand Down Expand Up @@ -281,8 +289,7 @@ PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args,
return PyCustomFloat_FromT<T>(value).release();
}
}
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(arg)->tp_name);
PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(arg));
return nullptr;
}

Expand All @@ -291,7 +298,9 @@ template <typename T>
PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) {
T x, y;
if (!SafeCastToCustomFloat<T>(a, &x) || !SafeCastToCustomFloat<T>(b, &y)) {
return PyGenericArrType_Type.tp_richcompare(a, b, op);
auto generic_tp_richcompare = reinterpret_cast<richcmpfunc>(
PyType_GetSlot(&PyGenericArrType_Type, Py_tp_richcompare));
return generic_tp_richcompare(a, b, op);
}
bool result;
switch (op) {
Expand Down Expand Up @@ -340,25 +349,18 @@ PyObject* PyCustomFloat_Str(PyObject* self) {
return PyUnicode_FromString(s.str().c_str());
}

// _Py_HashDouble changed its prototype for Python 3.10 so we use an overload to
// handle the two possibilities.
// NOLINTNEXTLINE(clang-diagnostic-unused-function)
inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(PyObject*, double),
PyObject* self, double value) {
return hash_double(self, value);
}

// NOLINTNEXTLINE(clang-diagnostic-unused-function)
inline Py_hash_t HashImpl(Py_hash_t (*hash_double)(double), PyObject* self,
double value) {
return hash_double(value);
}

// Hash function for PyCustomFloat.
template <typename T>
Py_hash_t PyCustomFloat_Hash(PyObject* self) {
T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
return HashImpl(&_Py_HashDouble, self, static_cast<double>(x));
if (std::isnan(x)) {
// NaNs hash as the pointer hash of the object.
auto f = reinterpret_cast<hashfunc>(
PyType_GetSlot(&PyBaseObject_Type, Py_tp_hash));
return f(self);
}
Safe_PyObjectPtr f(PyFloat_FromDouble(static_cast<double>(x)));
return PyObject_Hash(f.get());
}

template <typename T>
Expand Down Expand Up @@ -428,8 +430,7 @@ template <typename T>
int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) {
T x;
if (!CastToCustomFloat<T>(item, &x)) {
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(item)->tp_name);
PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(item));
return -1;
}
memcpy(data, &x, sizeof(T));
Expand Down
42 changes: 29 additions & 13 deletions ml_dtypes/_src/intn_numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ template <typename T>
Safe_PyObjectPtr PyIntN_FromValue(T x) {
PyTypeObject* type =
reinterpret_cast<PyTypeObject*>(TypeDescriptor<T>::type_ptr);
Safe_PyObjectPtr ref = make_safe(type->tp_alloc(type, 0));
Safe_PyObjectPtr ref = make_safe(PyObject_New(PyObject, type));
PyIntN<T>* p = reinterpret_cast<PyIntN<T>*>(ref.get());
if (p) {
p->value = x;
Expand Down Expand Up @@ -214,16 +214,21 @@ PyObject* PyIntN_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
}
} else if (PyUnicode_Check(arg) || PyBytes_Check(arg)) {
// Parse float from string, then cast to T.
PyObject* f = PyLong_FromUnicodeObject(arg, /*base=*/0);
if (PyErr_Occurred()) {
Safe_PyObjectPtr bytes(PyUnicode_AsUTF8String(arg));
if (!bytes) {
return nullptr;
}
PyObject* f =
PyLong_FromString(PyBytes_AsString(bytes.get()), /*end=*/nullptr,
/*base=*/0);
if (!f) {
return nullptr;
}
if (CastToIntN<T>(f, &value)) {
return PyIntN_FromValue<T>(value).release();
}
}
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(arg)->tp_name);
PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(arg));
return nullptr;
}

Expand Down Expand Up @@ -257,7 +262,9 @@ PyObject* PyIntN_nb_add(PyObject* a, PyObject* b) {
if (PyIntN_Value<T>(a, &x) && PyIntN_Value<T>(b, &y)) {
return PyIntN_FromValue<T>(x + y).release();
}
return PyArray_Type.tp_as_number->nb_add(a, b);
auto array_nb_add =
reinterpret_cast<binaryfunc>(PyType_GetSlot(&PyArray_Type, Py_nb_add));
return array_nb_add(a, b);
}

template <typename T>
Expand All @@ -266,7 +273,9 @@ PyObject* PyIntN_nb_subtract(PyObject* a, PyObject* b) {
if (PyIntN_Value<T>(a, &x) && PyIntN_Value<T>(b, &y)) {
return PyIntN_FromValue<T>(x - y).release();
}
return PyArray_Type.tp_as_number->nb_subtract(a, b);
auto array_nb_subtract = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_subtract));
return array_nb_subtract(a, b);
}

template <typename T>
Expand All @@ -275,7 +284,9 @@ PyObject* PyIntN_nb_multiply(PyObject* a, PyObject* b) {
if (PyIntN_Value<T>(a, &x) && PyIntN_Value<T>(b, &y)) {
return PyIntN_FromValue<T>(x * y).release();
}
return PyArray_Type.tp_as_number->nb_multiply(a, b);
auto array_nb_multiply = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_multiply));
return array_nb_multiply(a, b);
}

template <typename T>
Expand All @@ -292,7 +303,9 @@ PyObject* PyIntN_nb_remainder(PyObject* a, PyObject* b) {
}
return PyIntN_FromValue<T>(v).release();
}
return PyArray_Type.tp_as_number->nb_remainder(a, b);
auto array_nb_remainder = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_remainder));
return array_nb_remainder(a, b);
}

template <typename T>
Expand All @@ -309,7 +322,9 @@ PyObject* PyIntN_nb_floor_divide(PyObject* a, PyObject* b) {
}
return PyIntN_FromValue<T>(v).release();
}
return PyArray_Type.tp_as_number->nb_floor_divide(a, b);
auto array_nb_floor_divide = reinterpret_cast<binaryfunc>(
PyType_GetSlot(&PyArray_Type, Py_nb_floor_divide));
return array_nb_floor_divide(a, b);
}

// Implementation of repr() for PyIntN.
Expand Down Expand Up @@ -342,7 +357,9 @@ template <typename T>
PyObject* PyIntN_RichCompare(PyObject* a, PyObject* b, int op) {
T x, y;
if (!PyIntN_Value<T>(a, &x) || !PyIntN_Value<T>(b, &y)) {
return PyGenericArrType_Type.tp_richcompare(a, b, op);
auto generic_tp_richcompare = reinterpret_cast<richcmpfunc>(
PyType_GetSlot(&PyGenericArrType_Type, Py_tp_richcompare));
return generic_tp_richcompare(a, b, op);
}
bool result;
switch (op) {
Expand Down Expand Up @@ -440,8 +457,7 @@ template <typename T>
int NPyIntN_SetItem(PyObject* item, void* data, void* arr) {
T x;
if (!CastToIntN<T>(item, &x)) {
PyErr_Format(PyExc_TypeError, "expected number, got %s",
Py_TYPE(item)->tp_name);
PyErr_Format(PyExc_TypeError, "expected number, got %R", Py_TYPE(item));
return -1;
}
memcpy(data, &x, sizeof(T));
Expand Down
59 changes: 33 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,50 @@

import fnmatch
import platform
import sysconfig

import numpy as np
from setuptools import Extension
from setuptools import setup
from setuptools.command.build_py import build_py as build_py_orig

free_threading = sysconfig.get_config_var("Py_GIL_DISABLED")

if platform.system() == "Windows":
COMPILE_ARGS = [
"/std:c++17",
"/DEIGEN_MPL2_ONLY",
"/EHsc",
"/bigobj",
]
COMPILE_ARGS = [
"/std:c++17",
"/DEIGEN_MPL2_ONLY",
"/EHsc",
"/bigobj",
]
else:
COMPILE_ARGS = [
"-std=c++17",
"-DEIGEN_MPL2_ONLY",
"-fvisibility=hidden",
# -ftrapping-math is necessary because NumPy looks at floating point
# exception state to determine whether to emit, e.g., invalid value
# warnings. Without this setting, on Mac ARM we see spurious "invalid
# value" warnings when running the tests.
"-ftrapping-math",
]
COMPILE_ARGS = [
"-std=c++17",
"-DEIGEN_MPL2_ONLY",
"-fvisibility=hidden",
# -ftrapping-math is necessary because NumPy looks at floating point
# exception state to determine whether to emit, e.g., invalid value
# warnings. Without this setting, on Mac ARM we see spurious "invalid
# value" warnings when running the tests.
"-ftrapping-math",
]
if not free_threading:
COMPILE_ARGS.append("-DPy_LIMITED_API=0x03090000")

exclude = ["third_party*"]


class build_py(build_py_orig): # pylint: disable=invalid-name

def find_package_modules(self, package, package_dir):
modules = super().find_package_modules(package, package_dir)
return [ # pylint: disable=g-complex-comprehension
(pkg, mod, file)
for (pkg, mod, file) in modules
if not any(
fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern)
for pattern in exclude
)
]
def find_package_modules(self, package, package_dir):
modules = super().find_package_modules(package, package_dir)
return [ # pylint: disable=g-complex-comprehension
(pkg, mod, file)
for (pkg, mod, file) in modules
if not any(
fnmatch.fnmatchcase(pkg + "." + mod, pat=pattern) for pattern in exclude
)
]


setup(
Expand All @@ -71,7 +76,9 @@ def find_package_modules(self, package, package_dir):
np.get_include(),
],
extra_compile_args=COMPILE_ARGS,
py_limited_api=not free_threading,
)
],
cmdclass={"build_py": build_py},
options={} if free_threading else {"bdist_wheel": {"py_limited_api": "cp39"}},
)
Loading