diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 23faea9465..8be08b4953 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`: if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit() def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`: else: if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit() def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 8bf827b52f..843a4dbf1f 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -49,10 +49,23 @@ def global_numba_func(func): return func -def numba_njit(*args, **kwargs): +def numba_njit(*args, fastmath=None, **kwargs): kwargs.setdefault("cache", config.numba__cache) kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True) + if fastmath is None: + if config.numba__fastmath: + # Opinionated default on fastmath flags + # https://llvm.org/docs/LangRef.html#fast-math-flags + fastmath = { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # no-signed zeros + } + else: + fastmath = False # Suppress cache warning for internal functions # We have to add an ansi escape code for optional bold text by numba @@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs): ) if len(args) > 0 and callable(args[0]): - return numba.njit(*args[1:], **kwargs)(args[0]) + return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) - return numba.njit(*args, **kwargs) + return numba.njit(*args, fastmath=fastmath, **kwargs) def numba_vectorize(*args, **kwargs): diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index 131788e843..b7481bd5a3 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): core_op, node=core_node, parent_node=node, - fastmath=_jit_options["fastmath"], **kwargs, ) core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 3559117d8a..ae5ef3dcb1 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -6,7 +6,6 @@ from numba.core.extending import overload from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple -from pytensor import config from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -281,7 +280,6 @@ def jit_compile_reducer( res = numba_basic.numba_njit( *args, boundscheck=False, - fastmath=config.numba__fastmath, **kwds, )(fn) @@ -315,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs): op.scalar_op, node=scalar_node, parent_node=node, - fastmath=_jit_options["fastmath"], **kwargs, ) @@ -403,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs): if ndim_input == len(axes): # Slightly faster than `numba_funcify_CAReduce` for this case - @numba_njit(fastmath=config.numba__fastmath) + @numba_njit def impl_sum(array): return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) elif len(axes) == 0: # These cases should be removed by rewrites! - @numba_njit(fastmath=config.numba__fastmath) + @numba_njit def impl_sum(array): return np.asarray(array, dtype=out_dtype) @@ -568,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: @@ -602,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_sum = jit_fn(reduce_sum_py) else: reduce_sum = np.sum @@ -642,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 3629b0e44c..1f0a33e595 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -4,7 +4,6 @@ import numba import numpy as np -from pytensor import config from pytensor.graph import Apply from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify @@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -74,13 +73,13 @@ def cumop(x): else: if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 82ee380029..e9b637b00f 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -2,7 +2,6 @@ import numpy as np -from pytensor import config from pytensor.compile.ops import ViewOp from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic @@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}): return numba_basic.numba_njit( signature, - fastmath=config.numba__fastmath, # Functions that call a function pointer can't be cached cache=False, )(scalar_op_fn) @@ -177,9 +175,7 @@ def numba_funcify_Add(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( - nary_add_fn - ) + return numba_basic.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Mul) @@ -187,9 +183,7 @@ def numba_funcify_Mul(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( - nary_add_fn - ) + return numba_basic.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Cast) @@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) - composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( + composite_fn = numba_basic.numba_njit(signature)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) return composite_fn @@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs): return numba_basic.global_numba_func(reciprocal) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs): return numba_basic.global_numba_func(sigmoid) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def gammaln(x): return math.lgamma(x) @@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs): return numba_basic.global_numba_func(gammaln) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def logp1mexp(x): if x < np.log(0.5): return np.log1p(-np.exp(x)) @@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs): return numba_basic.global_numba_func(logp1mexp) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def erf(x): return math.erf(x) @@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs): return numba_basic.global_numba_func(erf) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def erfc(x): return math.erfc(x) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 437956bdc0..655e507da6 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -9,6 +9,7 @@ from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import Composite +from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise from tests.link.numba.test_basic import compare_numba_and_py, set_test_value @@ -140,3 +141,21 @@ def test_reciprocal(v, dtype): if not isinstance(i, SharedVariable | Constant) ], ) + + +@pytest.mark.parametrize("composite", (False, True)) +def test_isnan(composite): + # Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath + x = tensor(shape=(2,), dtype="float64") + + if composite: + x_scalar = psb.float64() + scalar_out = ~psb.isnan(x_scalar) + out = Elemwise(Composite([x_scalar], [scalar_out]))(x) + else: + out = pt.isnan(x) + + compare_numba_and_py( + ([x], [out]), + [np.array([1, 0], dtype="float64")], + )