Skip to content

Commit

Permalink
Fix JIT compilation with TensorFlow >= 2.14.0
Browse files Browse the repository at this point in the history
The previous check if code is currently being compiled no longer works with new TensorFlow versions because the `Tensor` type is now called `SymbolicTensor`.

This change adds a helper function to check if code is being compiled for JAX, TensorFlow or PyTorch.
If tf.is_symbolic_tensor() is available, i.e. if the TensorFlow version is high enough,
we use this function to check if code is being compiled.

To avoid inconsistencies between backends,
the check for integration domain values is disabled if code is being compiled with PyTorch even if the check works with PyTorch.
  • Loading branch information
FHof committed Dec 25, 2023
1 parent bf01043 commit ce7cf11
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
48 changes: 35 additions & 13 deletions torchquad/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,11 @@ def _check_integration_domain(integration_domain):
raise ValueError("integration_domain.shape[0] needs to be 1 or larger.")
if num_bounds != 2:
raise ValueError("integration_domain must have 2 values per boundary")
# Skip the values check if an integrator.integrate method is JIT
# compiled with JAX
if any(
nam in type(integration_domain).__name__ for nam in ["Jaxpr", "JVPTracer"]
):
# The boundary values check does not work if the code is JIT compiled
# with JAX or TensorFlow.
if _is_compiling(integration_domain):
return dim
boundaries_are_invalid = (
anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0
)
# Skip the values check if an integrator.integrate method is
# compiled with tensorflow.function
if type(boundaries_are_invalid).__name__ == "Tensor":
return dim
if boundaries_are_invalid:
if anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0:
raise ValueError("integration_domain has invalid boundary values")
return dim

Expand Down Expand Up @@ -263,6 +254,37 @@ def wrap(*args, **kwargs):
return wrap


def _is_compiling(x):
"""
Check if code is currently being compiled with PyTorch, JAX or TensorFlow
Args:
x (backend tensor): A tensor currently used for computations
Returns:
bool: True if code is currently being compiled, False otherwise
"""
backend = infer_backend(x)
if backend == "jax":
return any(nam in type(x).__name__ for nam in ["Jaxpr", "JVPTracer"])
if backend == "torch":
import torch

if hasattr(torch.jit, "is_tracing"):
# We ignore torch.jit.is_scripting() since we do not support
# compilation to TorchScript
return torch.jit.is_tracing()
# torch.jit.is_tracing() is unavailable below PyTorch version 1.11.0
return type(x.shape[0]).__name__ == "Tensor"
if backend == "tensorflow":
import tensorflow as tf

if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x)
# tf.is_symbolic_tensor() is unavailable below TensorFlow version 2.13.0
return type(x).__name__ == "Tensor"
return False


def _torch_trace_without_warnings(*args, **kwargs):
"""Execute `torch.jit.trace` on the passed arguments and hide tracer warnings
Expand Down
38 changes: 38 additions & 0 deletions torchquad/tests/utils_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_linspace_with_grads,
_add_at_indices,
_setup_integration_domain,
_is_compiling,
)
from utils.set_precision import set_precision
from utils.enable_cuda import enable_cuda
Expand Down Expand Up @@ -196,11 +197,48 @@ def test_setup_integration_domain():
_run_tests_with_all_backends(_run_setup_integration_domain_tests)


def _run_is_compiling_tests(dtype_name, backend):
"""
Test _is_compiling with the given dtype and numerical backend
"""
dtype = to_backend_dtype(dtype_name, like=backend)
x = anp.array([[0.0, 1.0], [1.0, 2.0]], dtype=dtype, like=backend)
assert not _is_compiling(
x
), f"_is_compiling has a false positive with backend {backend}"

def check_compiling(x):
assert _is_compiling(
x
), f"_is_compiling has a false negative with backend {backend}"
return x

if backend == "jax":
import jax

jax.jit(check_compiling)(x)
elif backend == "torch":
import torch

torch.jit.trace(check_compiling, (x,), check_trace=False)(x)
elif backend == "tensorflow":
import tensorflow as tf

tf.function(check_compiling, jit_compile=True)(x)
tf.function(check_compiling, jit_compile=False)(x)


def test_is_compiling():
"""Test _is_compiling with all possible configurations"""
_run_tests_with_all_backends(_run_is_compiling_tests)


if __name__ == "__main__":
try:
# used to run this test individually
test_linspace_with_grads()
test_add_at_indices()
test_setup_integration_domain()
test_is_compiling()
except KeyboardInterrupt:
pass

0 comments on commit ce7cf11

Please sign in to comment.