diff --git a/quaddtype/meson.build b/quaddtype/meson.build index 82fa4bc..ea82675 100644 --- a/quaddtype/meson.build +++ b/quaddtype/meson.build @@ -175,6 +175,10 @@ srcs = [ 'numpy_quaddtype/src/umath/matmul.h', 'numpy_quaddtype/src/umath/matmul.cpp', 'numpy_quaddtype/src/constants.hpp', + 'numpy_quaddtype/src/lock.h', + 'numpy_quaddtype/src/lock.c', + 'numpy_quaddtype/src/utilities.h', + 'numpy_quaddtype/src/utilities.c', ] py.install_sources( diff --git a/quaddtype/numpy_quaddtype/_quaddtype_main.pyi b/quaddtype/numpy_quaddtype/_quaddtype_main.pyi index 831c073..3d0d60a 100644 --- a/quaddtype/numpy_quaddtype/_quaddtype_main.pyi +++ b/quaddtype/numpy_quaddtype/_quaddtype_main.pyi @@ -9,6 +9,7 @@ _IntoQuad: TypeAlias = ( QuadPrecision | float | str + | bytes | np.floating[Any] | np.integer[Any] | np.bool_ diff --git a/quaddtype/numpy_quaddtype/src/casts.cpp b/quaddtype/numpy_quaddtype/src/casts.cpp index c43b89f..6299dcd 100644 --- a/quaddtype/numpy_quaddtype/src/casts.cpp +++ b/quaddtype/numpy_quaddtype/src/casts.cpp @@ -827,7 +827,7 @@ init_casts_internal(void) add_spec(quad2quad_spec); PyArray_DTypeMeta **void_dtypes = new PyArray_DTypeMeta *[2]{&PyArray_VoidDType, &QuadPrecDType}; - PyType_Slot *void_slots = new PyType_Slot[]{ + PyType_Slot *void_slots = new PyType_Slot[4]{ {NPY_METH_resolve_descriptors, (void *)&void_to_quad_resolve_descriptors}, {NPY_METH_strided_loop, (void *)&void_to_quad_strided_loop}, {NPY_METH_unaligned_strided_loop, (void *)&void_to_quad_strided_loop}, diff --git a/quaddtype/numpy_quaddtype/src/dtype.c b/quaddtype/numpy_quaddtype/src/dtype.c index a03e855..16e130b 100644 --- a/quaddtype/numpy_quaddtype/src/dtype.c +++ b/quaddtype/numpy_quaddtype/src/dtype.c @@ -13,11 +13,13 @@ #include "numpy/ndarraytypes.h" #include "numpy/dtype_api.h" +#include "quad_common.h" #include "scalar.h" #include "casts.h" #include "dtype.h" #include "dragon4.h" #include "constants.hpp" +#include "utilities.h" static inline int quad_load(void *x, char *data_ptr, QuadBackendType backend) @@ -353,19 +355,16 @@ quadprec_scanfunc(FILE *fp, void *dptr, char *ignore, PyArray_Descr *descr_gener /* Convert string to quad precision */ char *endptr; + quad_value val; + int err = cstring_to_quad(buffer, descr->backend, &val, &endptr, true); + if (err < 0) { + return 0; /* Return 0 on parse error (no items read) */ + } if (descr->backend == BACKEND_SLEEF) { - Sleef_quad val = Sleef_strtoq(buffer, &endptr); - if (endptr == buffer) { - return 0; /* Return 0 on parse error (no items read) */ - } - *(Sleef_quad *)dptr = val; + *(Sleef_quad *)dptr = val.sleef_value; } else { - long double val = strtold(buffer, &endptr); - if (endptr == buffer) { - return 0; /* Return 0 on parse error (no items read) */ - } - *(long double *)dptr = val; + *(long double *)dptr = val.longdouble_value; } return 1; /* Return 1 on success (1 item read) */ @@ -375,22 +374,17 @@ static int quadprec_fromstr(char *s, void *dptr, char **endptr, PyArray_Descr *descr_generic) { QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)descr_generic; - - if (descr->backend == BACKEND_SLEEF) { - Sleef_quad val = Sleef_strtoq(s, endptr); - if (*endptr == s) { - return -1; - } - *(Sleef_quad *)dptr = val; + quad_value val; + int err = cstring_to_quad(s, descr->backend, &val, endptr, false); + if (err < 0) { + return -1; + } + if(descr->backend == BACKEND_SLEEF) { + *(Sleef_quad *)dptr = val.sleef_value; } else { - long double val = strtold(s, endptr); - if (*endptr == s) { - return -1; - } - *(long double *)dptr = val; + *(long double *)dptr = val.longdouble_value; } - return 0; } diff --git a/quaddtype/numpy_quaddtype/src/lock.c b/quaddtype/numpy_quaddtype/src/lock.c new file mode 100644 index 0000000..929966b --- /dev/null +++ b/quaddtype/numpy_quaddtype/src/lock.c @@ -0,0 +1,17 @@ +#include "lock.h" + +#if PY_VERSION_HEX < 0x30d00b3 +PyThread_type_lock sleef_lock = NULL; +#else +PyMutex sleef_lock = {0}; +#endif + +void init_sleef_locks(void) +{ +#if PY_VERSION_HEX < 0x30d00b3 + sleef_lock = PyThread_allocate_lock(); + if (!sleef_lock) { + PyErr_NoMemory(); + } +#endif +} \ No newline at end of file diff --git a/quaddtype/numpy_quaddtype/src/lock.h b/quaddtype/numpy_quaddtype/src/lock.h new file mode 100644 index 0000000..6c2a970 --- /dev/null +++ b/quaddtype/numpy_quaddtype/src/lock.h @@ -0,0 +1,18 @@ +#ifndef _QUADDTYPE_LOCK_H +#define _QUADDTYPE_LOCK_H + +#include + +#if PY_VERSION_HEX < 0x30d00b3 +extern PyThread_type_lock sleef_lock; +#define LOCK_SLEEF PyThread_acquire_lock(sleef_lock, WAIT_LOCK) +#define UNLOCK_SLEEF PyThread_release_lock(sleef_lock) +#else +extern PyMutex sleef_lock; +#define LOCK_SLEEF PyMutex_Lock(&sleef_lock) +#define UNLOCK_SLEEF PyMutex_Unlock(&sleef_lock) +#endif + +void init_sleef_locks(void); + +#endif // _QUADDTYPE_LOCK_H \ No newline at end of file diff --git a/quaddtype/numpy_quaddtype/src/quad_common.h b/quaddtype/numpy_quaddtype/src/quad_common.h index bc578a2..e17ee0d 100644 --- a/quaddtype/numpy_quaddtype/src/quad_common.h +++ b/quaddtype/numpy_quaddtype/src/quad_common.h @@ -5,12 +5,20 @@ extern "C" { #endif +#include +#include + typedef enum { BACKEND_INVALID = -1, BACKEND_SLEEF, BACKEND_LONGDOUBLE } QuadBackendType; +typedef union { + Sleef_quad sleef_value; + long double longdouble_value; +} quad_value; + #ifdef __cplusplus } #endif diff --git a/quaddtype/numpy_quaddtype/src/quaddtype_main.c b/quaddtype/numpy_quaddtype/src/quaddtype_main.c index 0cbb652..b268077 100644 --- a/quaddtype/numpy_quaddtype/src/quaddtype_main.c +++ b/quaddtype/numpy_quaddtype/src/quaddtype_main.c @@ -12,6 +12,7 @@ #include "numpy/dtype_api.h" #include "numpy/ufuncobject.h" +#include "lock.h" #include "scalar.h" #include "dtype.h" #include "umath/umath.h" @@ -96,6 +97,8 @@ PyInit__quaddtype_main(void) PyUnstable_Module_SetGIL(m, Py_MOD_GIL_NOT_USED); #endif + init_sleef_locks(); + if (init_quadprecision_scalar() < 0) goto error; diff --git a/quaddtype/numpy_quaddtype/src/scalar.c b/quaddtype/numpy_quaddtype/src/scalar.c index 9b42b71..5b51d3b 100644 --- a/quaddtype/numpy_quaddtype/src/scalar.c +++ b/quaddtype/numpy_quaddtype/src/scalar.c @@ -15,6 +15,8 @@ #include "scalar_ops.h" #include "dragon4.h" #include "dtype.h" +#include "lock.h" +#include "utilities.h" // For IEEE 754 binary128 (quad precision), we need 36 decimal digits // to guarantee round-trip conversion (string -> parse -> equals original value) @@ -22,18 +24,6 @@ // src: https://en.wikipedia.org/wiki/Quadruple-precision_floating-point_format #define SLEEF_QUAD_DECIMAL_DIG 36 -#if PY_VERSION_HEX < 0x30d00b3 -static PyThread_type_lock sleef_lock; -#define LOCK_SLEEF PyThread_acquire_lock(sleef_lock, WAIT_LOCK) -#define UNLOCK_SLEEF PyThread_release_lock(sleef_lock) -#else -static PyMutex sleef_lock = {0}; -#define LOCK_SLEEF PyMutex_Lock(&sleef_lock) -#define UNLOCK_SLEEF PyMutex_Unlock(&sleef_lock) -#endif - - - QuadPrecisionObject * QuadPrecision_raw_new(QuadBackendType backend) @@ -207,14 +197,23 @@ QuadPrecision_from_object(PyObject *value, QuadBackendType backend) else if (PyUnicode_Check(value)) { const char *s = PyUnicode_AsUTF8(value); char *endptr = NULL; - if (backend == BACKEND_SLEEF) { - self->value.sleef_value = Sleef_strtoq(s, &endptr); + int err = cstring_to_quad(s, backend, &self->value, &endptr, true); + if (err < 0) { + PyErr_SetString(PyExc_ValueError, "Unable to parse string to QuadPrecision"); + Py_DECREF(self); + return NULL; } - else { - self->value.longdouble_value = strtold(s, &endptr); + } + else if (PyBytes_Check(value)) { + const char *s = PyBytes_AsString(value); + if (s == NULL) { + Py_DECREF(self); + return NULL; } - if (*endptr != '\0' || endptr == s) { - PyErr_SetString(PyExc_ValueError, "Unable to parse string to QuadPrecision"); + char *endptr = NULL; + int err = cstring_to_quad(s, backend, &self->value, &endptr, true); + if (err < 0) { + PyErr_SetString(PyExc_ValueError, "Unable to parse bytes to QuadPrecision"); Py_DECREF(self); return NULL; } @@ -242,21 +241,21 @@ QuadPrecision_from_object(PyObject *value, QuadBackendType backend) const char *type_cstr = PyUnicode_AsUTF8(type_str); if (type_cstr != NULL) { PyErr_Format(PyExc_TypeError, - "QuadPrecision value must be a quad, float, int, string, array or sequence, but got %s " + "QuadPrecision value must be a quad, float, int, string, bytes, array or sequence, but got %s " "instead", type_cstr); } else { PyErr_SetString( PyExc_TypeError, - "QuadPrecision value must be a quad, float, int, string, array or sequence, but got an " + "QuadPrecision value must be a quad, float, int, string, bytes, array or sequence, but got an " "unknown type instead"); } Py_DECREF(type_str); } else { PyErr_SetString(PyExc_TypeError, - "QuadPrecision value must be a quad, float, int, string, array or sequence, but got an " + "QuadPrecision value must be a quad, float, int, string, bytes, array or sequence, but got an " "unknown type instead"); } Py_DECREF(self); @@ -636,13 +635,6 @@ PyTypeObject QuadPrecision_Type = { int init_quadprecision_scalar(void) { -#if PY_VERSION_HEX < 0x30d00b3 - sleef_lock = PyThread_allocate_lock(); - if (sleef_lock == NULL) { - PyErr_NoMemory(); - return -1; - } -#endif QuadPrecision_Type.tp_base = &PyFloatingArrType_Type; return PyType_Ready(&QuadPrecision_Type); } \ No newline at end of file diff --git a/quaddtype/numpy_quaddtype/src/scalar.h b/quaddtype/numpy_quaddtype/src/scalar.h index 7499b1a..4afd725 100644 --- a/quaddtype/numpy_quaddtype/src/scalar.h +++ b/quaddtype/numpy_quaddtype/src/scalar.h @@ -9,11 +9,6 @@ extern "C" { #include #include "quad_common.h" -typedef union { - Sleef_quad sleef_value; - long double longdouble_value; -} quad_value; - typedef struct { PyObject_HEAD quad_value value; diff --git a/quaddtype/numpy_quaddtype/src/utilities.c b/quaddtype/numpy_quaddtype/src/utilities.c new file mode 100644 index 0000000..aa00d30 --- /dev/null +++ b/quaddtype/numpy_quaddtype/src/utilities.c @@ -0,0 +1,20 @@ +#include "utilities.h" +#include + +int cstring_to_quad(const char *str, QuadBackendType backend, quad_value *out_value, +char **endptr, bool require_full_parse) +{ + if(backend == BACKEND_SLEEF) { + out_value->sleef_value = Sleef_strtoq(str, endptr); + } else { + out_value->longdouble_value = strtold(str, endptr); + } + if(*endptr == str) + return -1; // parse error - nothing was parsed + + // If full parse is required + if(require_full_parse && **endptr != '\0') + return -1; // parse error - characters remain to be converted + + return 0; // success +} \ No newline at end of file diff --git a/quaddtype/numpy_quaddtype/src/utilities.h b/quaddtype/numpy_quaddtype/src/utilities.h new file mode 100644 index 0000000..1925046 --- /dev/null +++ b/quaddtype/numpy_quaddtype/src/utilities.h @@ -0,0 +1,11 @@ +#ifndef QUAD_UTILITIES_H +#define QUAD_UTILITIES_H + +#include "quad_common.h" +#include +#include +#include + +int cstring_to_quad(const char *str, QuadBackendType backend, quad_value *out_value, char **endptr, bool require_full_parse); + +#endif \ No newline at end of file diff --git a/quaddtype/tests/test_quaddtype.py b/quaddtype/tests/test_quaddtype.py index e967b8c..a8d297b 100644 --- a/quaddtype/tests/test_quaddtype.py +++ b/quaddtype/tests/test_quaddtype.py @@ -308,6 +308,134 @@ def test_string_roundtrip(): ) +class TestBytesSupport: + """Test suite for QuadPrecision bytes input support.""" + + @pytest.mark.parametrize("original", [ + QuadPrecision("0.417022004702574000667425480060047"), # Random value + QuadPrecision("1.23456789012345678901234567890123456789"), # Many digits + pytest.param(numpy_quaddtype.pi, id="pi"), # Mathematical constant + pytest.param(numpy_quaddtype.e, id="e"), + QuadPrecision("1e-100"), # Very small + QuadPrecision("1e100"), # Very large + QuadPrecision("-3.14159265358979323846264338327950288419"), # Negative pi + QuadPrecision("0.0"), # Zero + QuadPrecision("-0.0"), # Negative zero + QuadPrecision("1.0"), # One + QuadPrecision("-1.0"), # Negative one + ]) + def test_bytes_roundtrip(self, original): + """Test that bytes representations of quad precision values roundtrip correctly.""" + string_repr = str(original) + bytes_repr = string_repr.encode("ascii") + reconstructed = QuadPrecision(bytes_repr) + + # Values should be exactly equal (bit-for-bit identical) + assert reconstructed == original, ( + f"Bytes round-trip failed for {repr(original)}:\n" + f" Original: {repr(original)}\n" + f" Bytes: {bytes_repr}\n" + f" Reconstructed: {repr(reconstructed)}" + ) + + @pytest.mark.parametrize("bytes_val,expected_str", [ + # Simple numeric values + (b"1.0", "1.0"), + (b"-1.0", "-1.0"), + (b"0.0", "0.0"), + (b"3.14159", "3.14159"), + # Scientific notation + (b"1e10", "1e10"), + (b"1e-10", "1e-10"), + (b"2.5e100", "2.5e100"), + (b"-3.7e-50", "-3.7e-50"), + ]) + def test_bytes_creation_basic(self, bytes_val, expected_str): + """Test basic creation of QuadPrecision from bytes objects.""" + assert QuadPrecision(bytes_val) == QuadPrecision(expected_str) + + @pytest.mark.parametrize("bytes_val,check_func", [ + # Very large and very small numbers + (b"1e308", lambda x: x == QuadPrecision("1e308")), + (b"1e-308", lambda x: x == QuadPrecision("1e-308")), + # Special values + (b"inf", lambda x: np.isinf(x)), + (b"-inf", lambda x: np.isinf(x) and x < 0), + (b"nan", lambda x: np.isnan(x)), + ]) + def test_bytes_creation_edge_cases(self, bytes_val, check_func): + """Test edge cases for QuadPrecision creation from bytes.""" + val = QuadPrecision(bytes_val) + assert check_func(val) + + @pytest.mark.parametrize("invalid_bytes", [ + b"", # Empty bytes + b"not_a_number", # Invalid format + b"1.23abc", # Trailing garbage + b"abc1.23", # Leading garbage + ]) + def test_bytes_invalid_input(self, invalid_bytes): + """Test that invalid bytes input raises appropriate errors.""" + with pytest.raises(ValueError, match="Unable to parse bytes to QuadPrecision"): + QuadPrecision(invalid_bytes) + + @pytest.mark.parametrize("backend", ["sleef", "longdouble"]) + @pytest.mark.parametrize("bytes_val", [ + b"1.0", + b"-1.0", + b"3.141592653589793238462643383279502884197", + b"1e100", + b"1e-100", + b"0.0", + ]) + def test_bytes_backend_consistency(self, backend, bytes_val): + """Test that bytes parsing works consistently across backends.""" + quad_val = QuadPrecision(bytes_val, backend=backend) + str_val = QuadPrecision(bytes_val.decode("ascii"), backend=backend) + + # Bytes and string should produce identical results + assert quad_val == str_val, ( + f"Backend {backend}: bytes and string parsing differ for {bytes_val}\n" + f" From bytes: {repr(quad_val)}\n" + f" From string: {repr(str_val)}" + ) + + @pytest.mark.parametrize("bytes_val,expected_str", [ + # Leading whitespace is OK (consumed by parser) + (b" 1.0", "1.0"), + (b" 3.14", "3.14"), + ]) + def test_bytes_whitespace_valid(self, bytes_val, expected_str): + """Test handling of valid whitespace in bytes input.""" + assert QuadPrecision(bytes_val) == QuadPrecision(expected_str) + + @pytest.mark.parametrize("invalid_bytes", [ + b"1.0 ", # Trailing whitespace + b"1.0 ", # Multiple trailing spaces + b"1 .0", # Internal whitespace + b"1. 0", # Internal whitespace + ]) + def test_bytes_whitespace_invalid(self, invalid_bytes): + """Test that invalid whitespace in bytes input raises errors.""" + with pytest.raises(ValueError, match="Unable to parse bytes to QuadPrecision"): + QuadPrecision(invalid_bytes) + + @pytest.mark.parametrize("test_str", [ + "1.0", + "-3.14159265358979323846264338327950288419", + "1e100", + "2.71828182845904523536028747135266249775", + ]) + def test_bytes_encoding_compatibility(self, test_str): + """Test that bytes created from different encodings work correctly.""" + from_string = QuadPrecision(test_str) + from_bytes = QuadPrecision(test_str.encode("ascii")) + from_bytes_utf8 = QuadPrecision(test_str.encode("utf-8")) + + assert from_string == from_bytes + assert from_string == from_bytes_utf8 + + def test_string_subclass_parsing(): """Test that QuadPrecision handles string subclasses correctly.