Skip to content

Commit

Permalink
mypy topology and find (#542)
Browse files Browse the repository at this point in the history
* Don't overload typed variable

Signed-off-by: liamhuber <[email protected]>

* Add (and more specific) return hint(s)

To the one function missing one

Signed-off-by: liamhuber <[email protected]>

* Add module docstring

Signed-off-by: liamhuber <[email protected]>

* Catch module spec failures

Signed-off-by: liamhuber <[email protected]>

* Force mypy to accept the design feature

That we _want_ callers to be able to get abstract classes if they request them

Signed-off-by: liamhuber <[email protected]>

* Black

Signed-off-by: liamhuber <[email protected]>

* Ruff import sort

Signed-off-by: liamhuber <[email protected]>

---------

Signed-off-by: liamhuber <[email protected]>
  • Loading branch information
liamhuber authored Jan 11, 2025
1 parent ff6a984 commit 3577158
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
17 changes: 14 additions & 3 deletions pyiron_workflow/find.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
"""
A utility for finding public `pyiron_workflow.node.Node` objects.
Supports the idea of node developers writing independent node packages.
"""

from __future__ import annotations

import importlib.util
import inspect
import sys
from pathlib import Path
from types import ModuleType
from typing import TypeVar, cast

from pyiron_workflow.node import Node

NodeType = TypeVar("NodeType", bound=Node)


def _get_subclasses(
source: str | Path | ModuleType,
base_class: type,
base_class: type[NodeType],
get_private: bool = False,
get_abstract: bool = False,
get_imports_too: bool = False,
):
) -> list[type[NodeType]]:
if isinstance(source, str | Path):
source = Path(source)
if source.is_file():
# Load the module from the file
module_name = source.stem
spec = importlib.util.spec_from_file_location(module_name, str(source))
if spec is None or spec.loader is None:
raise ImportError(f"Could not create a ModuleSpec for {source}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
Expand Down Expand Up @@ -54,4 +65,4 @@ def find_nodes(source: str | Path | ModuleType) -> list[type[Node]]:
"""
Get a list of all public, non-abstract nodes defined in the source.
"""
return _get_subclasses(source, Node)
return cast(list[type[Node]], _get_subclasses(source, Node))
14 changes: 7 additions & 7 deletions pyiron_workflow/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from toposort import CircularDependencyError, toposort, toposort_flatten

if TYPE_CHECKING:
from pyiron_workflow.channels import SignalChannel
from pyiron_workflow.channels import InputSignal, OutputSignal
from pyiron_workflow.node import Node


Expand Down Expand Up @@ -75,8 +75,8 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]:
)
locally_scoped_dependencies.append(upstream.owner.label)
node_dependencies.extend(locally_scoped_dependencies)
node_dependencies = set(node_dependencies)
if node.label in node_dependencies:
node_dependencies_set = set(node_dependencies)
if node.label in node_dependencies_set:
# the toposort library has a
# [known issue](https://gitlab.com/ericvsmith/toposort/-/issues/3)
# That self-dependency isn't caught, so we catch it manually here.
Expand All @@ -85,14 +85,14 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]:
f"the execution of non-DAGs: {node.full_label} appears in its own "
f"input."
)
digraph[node.label] = node_dependencies
digraph[node.label] = node_dependencies_set

return digraph


def _set_new_run_connections_with_fallback_recovery(
connection_creator: Callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node]
):
) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]:
"""
Given a function that takes a dictionary of unconnected nodes, connects their
execution graph, and returns the new starting nodes, this wrapper makes sure that
Expand Down Expand Up @@ -144,7 +144,7 @@ def _set_run_connections_according_to_linear_dag(nodes: dict[str, Node]) -> list

def set_run_connections_according_to_linear_dag(
nodes: dict[str, Node],
) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]:
) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]:
"""
Given a set of nodes that all have the same parent, have no upstream data
connections outside the nodes provided, and have acyclic data flow, disconnects all
Expand Down Expand Up @@ -196,7 +196,7 @@ def _set_run_connections_according_to_dag(nodes: dict[str, Node]) -> list[Node]:

def set_run_connections_according_to_dag(
nodes: dict[str, Node],
) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]:
) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]:
"""
Given a set of nodes that all have the same parent, have no upstream data
connections outside the nodes provided, and have acyclic data flow, disconnects all
Expand Down

0 comments on commit 3577158

Please sign in to comment.