Skip to content

Commit

Permalink
Fix some mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create committed Mar 1, 2023
1 parent 3b5ef75 commit ae4f989
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions aesara/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,42 +356,44 @@ def numba_const_convert(data, dtype=None, **kwargs):

def numba_funcify(obj, node=None, storage_map=None, **kwargs) -> Callable:
"""Convert `obj` to a Numba-JITable object."""
return _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs)
return cast(
Callable, _numba_funcify(obj, node=node, storage_map=storage_map, **kwargs)
)


numba_cache_index = pathlib.PurePath(config.compiledir, "numba_cache_index")
numba_db = shelve.open(numba_cache_index.as_posix())


def make_node_key(node: "Apply") -> Optional[str]:
def make_node_key(node):
"""Create a cache key for `node`.
TODO: Currently this works only with Apply Node
"""
if not isinstance(node, Apply):
return None
key = (node.op,)
key += tuple(inp.type for inp in node.inputs)
key = tuple(inp.type for inp in node.inputs)
key += tuple(inp.type for inp in node.outputs)

key = hashlib.sha256(pickle.dumps(key)).hexdigest()
hash_key = hashlib.sha256(pickle.dumps(key)).hexdigest()

return key
return hash_key


def check_cache(node_key: str):
def check_cache(node_key):
"""Check disk-backed cache."""
return numba_db.get(node_key)


def add_to_cache(node_key: str, numba_py_fn) -> Callable:
def add_to_cache(node_key: str, numba_py_fn: Callable) -> Callable:
"""Add the numba generated function to the cache."""
module_file_base = (
pathlib.PurePath(config.compiledir, node_key).with_suffix(".py").as_posix()
)
cache_module = ModuleType(node_key)

# Create a temporary module for the generated source
cache_module.source = numba_py_fn
cache_module.source = numba_py_fn # type: ignore
dill.dump_module(module_file_base, module=cache_module)

# Load the function from the persisted module
Expand All @@ -403,7 +405,7 @@ def add_to_cache(node_key: str, numba_py_fn) -> Callable:
return numba_py_fn


def persist_py_code(func) -> Callable:
def persist_py_code(func: Callable) -> Callable:
"""Persist a Numba JIT-able Python function.
Parameters
==========
Expand All @@ -413,12 +415,11 @@ def persist_py_code(func) -> Callable:
"""

@wraps(func)
def _func(obj, node, **kwargs):
def _func(obj, node, **kwargs) -> Callable:
node_key = make_node_key(node)
numba_py_fn = None
if node_key:
numba_py_fn = check_cache(node_key)

if node_key is None or numba_py_fn is None:
# We could only ever return the function source in our dispatch
# implementations. That way, we can compile directly to the on-disk
Expand All @@ -434,9 +435,9 @@ def _func(obj, node, **kwargs):

# TODO: Presently numba_py_fn is already jitted.
# numba_fn = numba_njit(numba_py_fn)
return numba_py_fn
return cast(Callable, numba_py_fn)

return _func
return cast(Callable, _func)


@persist_py_code
Expand Down

0 comments on commit ae4f989

Please sign in to comment.