Skip to content

Commit

Permalink
Use more specific Numba fastmath flags everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 3, 2025
1 parent c59ba2e commit 2ea42ab
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 87 deletions.
8 changes: 4 additions & 4 deletions doc/extending/creating_a_numba_jax_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`:
if mode == "add":
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit()
def cumop(x):
return np.cumsum(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand All @@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`:
else:
if axis is None or ndim == 1:
@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit()
def cumop(x):
return np.cumprod(x)
else:
@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand Down
18 changes: 15 additions & 3 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,22 @@ def global_numba_func(func):
return func


def numba_njit(*args, **kwargs):
def numba_njit(*args, fastmath=None, **kwargs):
kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
if fastmath is None and config.numba__fastmath:
# Opinionated default on fastmath flags
# https://llvm.org/docs/LangRef.html#fast-math-flags
fastmath = {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # no-signed zeros
}
else:
fastmath = False

# Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
Expand All @@ -68,9 +80,9 @@ def numba_njit(*args, **kwargs):
)

if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0])
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])

return numba.njit(*args, **kwargs)
return numba.njit(*args, fastmath=fastmath, **kwargs)


def numba_vectorize(*args, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion pytensor/link/numba/dispatch/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_op,
node=core_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
Expand Down
58 changes: 5 additions & 53 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from collections.abc import Callable
from functools import singledispatch
from textwrap import dedent, indent
from typing import Any

import numba
import numpy as np
from numba.core.extending import overload
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple

from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
Expand Down Expand Up @@ -124,42 +120,6 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
"""


def create_vectorize_func(
scalar_op_fn: Callable,
node: Apply,
use_signature: bool = False,
identity: Any | None = None,
**kwargs,
) -> Callable:
r"""Create a vectorized Numba function from a `Apply`\s Python function."""

if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)

if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []

target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)

numba_vectorized_fn = numba_basic.numba_vectorize(
signature, identity=identity, target=target, fastmath=config.numba__fastmath
)

py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)

elemwise_fn = numba_vectorized_fn(scalar_op_fn)
elemwise_fn.py_scalar_func = py_scalar_func

return elemwise_fn


def create_multiaxis_reducer(
scalar_op,
identity,
Expand Down Expand Up @@ -320,7 +280,6 @@ def jit_compile_reducer(
res = numba_basic.numba_njit(
*args,
boundscheck=False,
fastmath=config.numba__fastmath,
**kwds,
)(fn)

Expand Down Expand Up @@ -354,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs):
op.scalar_op,
node=scalar_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)

Expand Down Expand Up @@ -442,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs):

if ndim_input == len(axes):
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit(fastmath=config.numba__fastmath)
@numba_njit
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)

elif len(axes) == 0:
# These cases should be removed by rewrites!
@numba_njit(fastmath=config.numba__fastmath)
@numba_njit
def impl_sum(array):
return np.asarray(array, dtype=out_dtype)

Expand Down Expand Up @@ -607,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
Expand Down Expand Up @@ -641,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_sum = jit_fn(reduce_sum_py)
else:
reduce_sum = np.sum
Expand Down Expand Up @@ -681,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
Expand Down
9 changes: 4 additions & 5 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numba
import numpy as np

from pytensor import config
from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
Expand Down Expand Up @@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
if mode == "add":
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def cumop(x):
return np.cumsum(x)

else:

@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand All @@ -74,13 +73,13 @@ def cumop(x):
else:
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def cumop(x):
return np.cumprod(x)

else:

@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand Down
24 changes: 8 additions & 16 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np

from pytensor import config
from pytensor.compile.ops import ViewOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
Expand All @@ -23,7 +22,6 @@
Clip,
Composite,
Identity,
IsNan,
Mul,
Reciprocal,
ScalarOp,
Expand Down Expand Up @@ -138,8 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):

return numba_basic.numba_njit(
signature,
# numba always returns False if fastmath=True # https://github.com/numba/numba/issues/9383
fastmath=False if isinstance(op, IsNan) else config.numba__fastmath,
# Functions that call a function pointer can't be cached
cache=False,
)(scalar_op_fn)
Expand Down Expand Up @@ -179,19 +175,15 @@ def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
)
return numba_basic.numba_njit(signature)(nary_add_fn)


@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")

return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
)
return numba_basic.numba_njit(signature)(nary_add_fn)


@numba_funcify.register(Cast)
Expand Down Expand Up @@ -241,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):

_ = kwargs.pop("storage_map", None)

composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
composite_fn = numba_basic.numba_njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
Expand Down Expand Up @@ -269,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))

Expand All @@ -279,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)

Expand All @@ -289,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
Expand All @@ -302,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def erf(x):
return math.erf(x)

Expand All @@ -312,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)

Expand Down
17 changes: 12 additions & 5 deletions tests/link/numba/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,19 @@ def test_reciprocal(v, dtype):
)


@pytest.mark.parametrize("dtype", ("complex64", "float64", "float32"))
def test_isnan(dtype):
@pytest.mark.parametrize("composite", (False, True))
def test_isnan(composite):
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
x = tensor(shape=(2,), dtype=dtype)
out = pt.isnan(x)
x = tensor(shape=(2,), dtype="float64")

if composite:
x_scalar = psb.float64()
scalar_out = ~psb.isnan(x_scalar)
out = Elemwise(Composite([x_scalar], [scalar_out]))(x)
else:
out = pt.isnan(x)

compare_numba_and_py(
([x], [out]),
[np.array([1, 0], dtype=dtype)],
[np.array([1, 0], dtype="float64")],
)

0 comments on commit 2ea42ab

Please sign in to comment.