From 6b4b2f65b089e768ea3553045698399693aa804f Mon Sep 17 00:00:00 2001 From: Jonas Rembser Date: Thu, 7 Nov 2024 09:29:57 +0100 Subject: [PATCH] [PyROOT] Skip numba test also if numba is not available So far, the test was only skipped if the numba version was too low, but it might also happen that numba is not installed at all. For example if one uses Python 3.13, which doesn't support numba yet. --- python/numba/PyROOT_numbatests.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/python/numba/PyROOT_numbatests.py b/python/numba/PyROOT_numbatests.py index 960170c625..26afaaf1f4 100644 --- a/python/numba/PyROOT_numbatests.py +++ b/python/numba/PyROOT_numbatests.py @@ -1,17 +1,25 @@ -import sys +import importlib import os import pytest +import sys import time import ROOT -import numba sys.path.append(os.path.dirname(os.path.dirname(__file__))) +def has_required_numba(): + """Check if numba is available and that the version matches requirement.""" + if not importlib.util.find_spec("numba"): + return False + import numba + + # With fallback in case it's an older numba version that doesn't have the + # version_info attribute yet: + return getattr(numba, "version_info", (0, 0)) >= (0, 54) -@pytest.mark.skipif( - not hasattr(numba, 'version_info') or numba.version_info < (0, 54), - reason="Numba version 0.54 or more required") + +@pytest.mark.skipif(not has_required_numba(), reason="Numba version >=0.54 required") class TestClasNumba: """Tests numba support for PyROOT""" @@ -35,6 +43,7 @@ def test01_simple_free_func(self): import ROOT.NumbaExt import math import numpy as np + import numba def go_slow(a): trace = 0.0 @@ -57,6 +66,7 @@ def go_fast(a): def test02_member_function(self): import ROOT.NumbaExt import math + import numba # Obtain a vector of ROOT::Math::LorentzVector from the sample # .root file @@ -78,8 +88,9 @@ def numba_calc_pt_vec(vec_lv): def test03_inheritance(self): """This test shows one of the limitations of the current support""" - + import numba from numba.core.errors import TypingError + errtyp = TypingError if numba.version_info < (0, 60) else KeyError ROOT.gInterpreter.Declare("""