Skip to content

Commit

Permalink
Use more specific Numba fastmath flags everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 3, 2025
1 parent 8f07714 commit d480f5d
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 41 deletions.
8 changes: 4 additions & 4 deletions doc/extending/creating_a_numba_jax_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`:
if mode == "add":
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit()
def cumop(x):
return np.cumsum(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand All @@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`:
else:
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit()
def cumop(x):
return np.cumprod(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand Down
19 changes: 16 additions & 3 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,23 @@ def global_numba_func(func):
return func


def numba_njit(*args, **kwargs):
def numba_njit(*args, fastmath=None, **kwargs):
kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
if fastmath is None:
if config.numba__fastmath:
# Opinionated default on fastmath flags
# https://llvm.org/docs/LangRef.html#fast-math-flags
fastmath = {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # no-signed zeros
}
else:
fastmath = False

# Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
Expand All @@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs):
)

if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0])
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])

return numba.njit(*args, **kwargs)
return numba.njit(*args, fastmath=fastmath, **kwargs)


def numba_vectorize(*args, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion pytensor/link/numba/dispatch/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_op,
node=core_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
Expand Down
19 changes: 5 additions & 14 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from numba.core.extending import overload
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple

from pytensor import config
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
Expand Down Expand Up @@ -281,7 +280,6 @@ def jit_compile_reducer(
res = numba_basic.numba_njit(
*args,
boundscheck=False,
fastmath=config.numba__fastmath,
**kwds,
)(fn)

Expand Down Expand Up @@ -315,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
op.scalar_op,
node=scalar_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)

Expand Down Expand Up @@ -403,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs):

if ndim_input == len(axes):
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit(fastmath=config.numba__fastmath)
@numba_njit
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)

elif len(axes) == 0:
# These cases should be removed by rewrites!
@numba_njit(fastmath=config.numba__fastmath)
@numba_njit
def impl_sum(array):
return np.asarray(array, dtype=out_dtype)

Expand Down Expand Up @@ -568,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
Expand Down Expand Up @@ -602,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_sum = jit_fn(reduce_sum_py)
else:
reduce_sum = np.sum
Expand Down Expand Up @@ -642,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
Expand Down
9 changes: 4 additions & 5 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numba
import numpy as np

from pytensor import config
from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
Expand Down Expand Up @@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
if mode == "add":
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def cumop(x):
return np.cumsum(x)

else:

@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand All @@ -74,13 +73,13 @@ def cumop(x):
else:
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def cumop(x):
return np.cumprod(x)

else:

@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand Down
22 changes: 8 additions & 14 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np

from pytensor import config
from pytensor.compile.ops import ViewOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
Expand Down Expand Up @@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):

return numba_basic.numba_njit(
signature,
fastmath=config.numba__fastmath,
# Functions that call a function pointer can't be cached
cache=False,
)(scalar_op_fn)
Expand Down Expand Up @@ -177,19 +175,15 @@ def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
)
return numba_basic.numba_njit(signature)(nary_add_fn)


@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")

return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
)
return numba_basic.numba_njit(signature)(nary_add_fn)


@numba_funcify.register(Cast)
Expand Down Expand Up @@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):

_ = kwargs.pop("storage_map", None)

composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
composite_fn = numba_basic.numba_njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
Expand Down Expand Up @@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))

Expand All @@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)

Expand All @@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
Expand All @@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def erf(x):
return math.erf(x)

Expand All @@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)

Expand Down
19 changes: 19 additions & 0 deletions tests/link/numba/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value

Expand Down Expand Up @@ -140,3 +141,21 @@ def test_reciprocal(v, dtype):
if not isinstance(i, SharedVariable | Constant)
],
)


@pytest.mark.parametrize("composite", (False, True))
def test_isnan(composite):
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
x = tensor(shape=(2,), dtype="float64")

if composite:
x_scalar = psb.float64()
scalar_out = ~psb.isnan(x_scalar)
out = Elemwise(Composite([x_scalar], [scalar_out]))(x)
else:
out = pt.isnan(x)

compare_numba_and_py(
([x], [out]),
[np.array([1, 0], dtype="float64")],
)

0 comments on commit d480f5d

Please sign in to comment.