From ae70c7c882fd7c88a4b8735e7e53c04626bff63f Mon Sep 17 00:00:00 2001 From: Liam Gray Date: Tue, 17 Dec 2024 16:37:36 -0800 Subject: [PATCH 1/2] fix(truncate): use absolute error in bit_truncate_relative --- caput/truncate.pyx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/caput/truncate.pyx b/caput/truncate.pyx index 618173f9..02f624fd 100644 --- a/caput/truncate.pyx +++ b/caput/truncate.pyx @@ -8,6 +8,8 @@ from cython.parallel import prange import numpy as np cimport numpy as cnp +from libc.math cimport fabs + cdef extern from "truncate.hpp": inline int bit_truncate(int val, int err) nogil @@ -229,7 +231,7 @@ def bit_truncate_relative_float(float[:] val, float prec): cdef Py_ssize_t i = 0 for i in prange(n, nogil=True): - val[i] = _bit_truncate_float(val[i], prec * val[i]) + val[i] = _bit_truncate_float(val[i], fabs(prec * val[i])) return np.asarray(val) @@ -258,7 +260,7 @@ def bit_truncate_relative_double(cnp.float64_t[:] val, cnp.float64_t prec): cdef Py_ssize_t i = 0 for i in prange(n, nogil=True): - val[i] = _bit_truncate_double(val[i], prec * val[i]) + val[i] = _bit_truncate_double(val[i], fabs(prec * val[i])) return np.asarray(val, dtype=np.float64) From e2edd4a19cb7923db59ea53a0cc1a0a296016531 Mon Sep 17 00:00:00 2001 From: Liam Gray Date: Tue, 17 Dec 2024 16:38:02 -0800 Subject: [PATCH 2/2] fix(test_truncate): add test to make sure negative values are correctly truncated with relative precision --- tests/test_truncate.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_truncate.py b/tests/test_truncate.py index d92f025c..a8e6faec 100644 --- a/tests/test_truncate.py +++ b/tests/test_truncate.py @@ -126,3 +126,19 @@ def test_truncate_relative(): ) == np.asarray([32, 32], dtype=np.float64) ).all() + + # Check the case where values are negative + assert ( + truncate.bit_truncate_relative( + np.asarray([-32.121, 32.5], dtype=np.float32), + 0.1, + ) + == np.asarray([-32, 32], dtype=np.float32) + ).all() + assert ( + truncate.bit_truncate_relative( + np.asarray([-32.121, 32.5], dtype=np.float64), + 0.1, + ) + == np.asarray([-32, 32], dtype=np.float64) + ).all()