diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index b173ab2765..7a75494aa5 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -356,34 +356,36 @@ def numba_const_convert(data, dtype=None, **kwargs): def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable: """Convert `obj` to a Numba-JITable object.""" - return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) + return cast( + Callable, _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs) + ) numba_cache_index = pathlib.PurePath(config.compiledir, "numba_cache_index") numba_db = shelve.open(numba_cache_index.as_posix()) -def make_node_key(node: "Apply") -> Optional[str]: +def make_node_key(node): """Create a cache key for `node`. TODO: Currently this works only with Apply Node """ if not isinstance(node, Apply): return None key = (node.op,) - key += tuple(inp.type for inp in node.inputs) + key = tuple(inp.type for inp in node.inputs) key += tuple(inp.type for inp in node.outputs) - key = hashlib.sha256(pickle.dumps(key)).hexdigest() + hash_key = hashlib.sha256(pickle.dumps(key)).hexdigest() - return key + return hash_key -def check_cache(node_key: str): +def check_cache(node_key): """Check disk-backed cache.""" return numba_db.get(node_key) -def add_to_cache(node_key: str, numba_py_fn) -> Callable: +def add_to_cache(node_key: str, numba_py_fn: Callable) -> Callable: """Add the numba generated function to the cache.""" module_file_base = ( pathlib.PurePath(config.compiledir, node_key).with_suffix(".py").as_posix() @@ -391,7 +393,7 @@ def add_to_cache(node_key: str, numba_py_fn) -> Callable: cache_module = ModuleType(node_key) # Create a temporary module for the generated source - cache_module.source = numba_py_fn + cache_module.source = numba_py_fn # type: ignore dill.dump_module(module_file_base, module=cache_module) # Load the function from the persisted module @@ -403,7 +405,7 @@ def add_to_cache(node_key: str, numba_py_fn) -> Callable: return numba_py_fn -def persist_py_code(func) -> Callable: +def persist_py_code(func: Callable) -> Callable: """Persist a Numba JIT-able Python function. Parameters ========== @@ -413,12 +415,11 @@ def persist_py_code(func) -> Callable: """ @wraps(func) - def _func(obj, node, **kwargs): + def _func(obj, node, **kwargs) -> Callable: node_key = make_node_key(node) numba_py_fn = None if node_key: numba_py_fn = check_cache(node_key) - if node_key is None or numba_py_fn is None: # We could only ever return the function source in our dispatch # implementations. That way, we can compile directly to the on-disk @@ -434,9 +435,9 @@ def _func(obj, node, **kwargs): # TODO: Presently numba_py_fn is already jitted. # numba_fn = numba_njit(numba_py_fn) - return numba_py_fn + return cast(Callable, numba_py_fn) - return _func + return cast(Callable, _func) @persist_py_code