diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 23faea9465..8be08b4953 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`: if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit() def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`: else: if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit() def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 8bf827b52f..12a5f9d62a 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -49,10 +49,22 @@ def global_numba_func(func): return func -def numba_njit(*args, **kwargs): +def numba_njit(*args, fastmath=None, **kwargs): kwargs.setdefault("cache", config.numba__cache) kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True) + if fastmath is None and config.numba__fastmath: + # Opinionated default on fastmath flags + # https://llvm.org/docs/LangRef.html#fast-math-flags + fastmath = { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # no-signed zeros + } + else: + fastmath = False # Suppress cache warning for internal functions # We have to add an ansi escape code for optional bold text by numba @@ -68,9 +80,9 @@ def numba_njit(*args, **kwargs): ) if len(args) > 0 and callable(args[0]): - return numba.njit(*args[1:], **kwargs)(args[0]) + return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) - return numba.njit(*args, **kwargs) + return numba.njit(*args, fastmath=fastmath, **kwargs) def numba_vectorize(*args, **kwargs): diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index 131788e843..b7481bd5a3 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): core_op, node=core_node, parent_node=node, - fastmath=_jit_options["fastmath"], **kwargs, ) core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2759422bf6..ae5ef3dcb1 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,15 +1,11 @@ -from collections.abc import Callable from functools import singledispatch from textwrap import dedent, indent -from typing import Any import numba import numpy as np from numba.core.extending import overload from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple -from pytensor import config -from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -124,42 +120,6 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): """ -def create_vectorize_func( - scalar_op_fn: Callable, - node: Apply, - use_signature: bool = False, - identity: Any | None = None, - **kwargs, -) -> Callable: - r"""Create a vectorized Numba function from a `Apply`\s Python function.""" - - if len(node.outputs) > 1: - raise NotImplementedError( - "Multi-output Elemwise Ops are not supported by the Numba backend" - ) - - if use_signature: - signature = [create_numba_signature(node, force_scalar=True)] - else: - signature = [] - - target = ( - getattr(node.tag, "numba__vectorize_target", None) - or config.numba__vectorize_target - ) - - numba_vectorized_fn = numba_basic.numba_vectorize( - signature, identity=identity, target=target, fastmath=config.numba__fastmath - ) - - py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn) - - elemwise_fn = numba_vectorized_fn(scalar_op_fn) - elemwise_fn.py_scalar_func = py_scalar_func - - return elemwise_fn - - def create_multiaxis_reducer( scalar_op, identity, @@ -320,7 +280,6 @@ def jit_compile_reducer( res = numba_basic.numba_njit( *args, boundscheck=False, - fastmath=config.numba__fastmath, **kwds, )(fn) @@ -354,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs): op.scalar_op, node=scalar_node, parent_node=node, - fastmath=_jit_options["fastmath"], **kwargs, ) @@ -442,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs): if ndim_input == len(axes): # Slightly faster than `numba_funcify_CAReduce` for this case - @numba_njit(fastmath=config.numba__fastmath) + @numba_njit def impl_sum(array): return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) elif len(axes) == 0: # These cases should be removed by rewrites! - @numba_njit(fastmath=config.numba__fastmath) + @numba_njit def impl_sum(array): return np.asarray(array, dtype=out_dtype) @@ -607,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: @@ -641,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_sum = jit_fn(reduce_sum_py) else: reduce_sum = np.sum @@ -681,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 3629b0e44c..1f0a33e595 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -4,7 +4,6 @@ import numba import numpy as np -from pytensor import config from pytensor.graph import Apply from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify @@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -74,13 +73,13 @@ def cumop(x): else: if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index d9342d5694..e9b637b00f 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -2,7 +2,6 @@ import numpy as np -from pytensor import config from pytensor.compile.ops import ViewOp from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic @@ -23,7 +22,6 @@ Clip, Composite, Identity, - IsNan, Mul, Reciprocal, ScalarOp, @@ -138,8 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}): return numba_basic.numba_njit( signature, - # numba always returns False if fastmath=True # https://github.com/numba/numba/issues/9383 - fastmath=False if isinstance(op, IsNan) else config.numba__fastmath, # Functions that call a function pointer can't be cached cache=False, )(scalar_op_fn) @@ -179,9 +175,7 @@ def numba_funcify_Add(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( - nary_add_fn - ) + return numba_basic.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Mul) @@ -189,9 +183,7 @@ def numba_funcify_Mul(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( - nary_add_fn - ) + return numba_basic.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Cast) @@ -241,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) - composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( + composite_fn = numba_basic.numba_njit(signature)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) return composite_fn @@ -269,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs): return numba_basic.global_numba_func(reciprocal) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -279,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs): return numba_basic.global_numba_func(sigmoid) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def gammaln(x): return math.lgamma(x) @@ -289,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs): return numba_basic.global_numba_func(gammaln) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def logp1mexp(x): if x < np.log(0.5): return np.log1p(-np.exp(x)) @@ -302,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs): return numba_basic.global_numba_func(logp1mexp) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def erf(x): return math.erf(x) @@ -312,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs): return numba_basic.global_numba_func(erf) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def erfc(x): return math.erfc(x) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index d9a85ee0e3..655e507da6 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -143,12 +143,19 @@ def test_reciprocal(v, dtype): ) -@pytest.mark.parametrize("dtype", ("complex64", "float64", "float32")) -def test_isnan(dtype): +@pytest.mark.parametrize("composite", (False, True)) +def test_isnan(composite): # Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath - x = tensor(shape=(2,), dtype=dtype) - out = pt.isnan(x) + x = tensor(shape=(2,), dtype="float64") + + if composite: + x_scalar = psb.float64() + scalar_out = ~psb.isnan(x_scalar) + out = Elemwise(Composite([x_scalar], [scalar_out]))(x) + else: + out = pt.isnan(x) + compare_numba_and_py( ([x], [out]), - [np.array([1, 0], dtype=dtype)], + [np.array([1, 0], dtype="float64")], )