From ff3226bfbe25819c9784e4dbf57233649647fccb Mon Sep 17 00:00:00 2001 From: liamhuber Date: Tue, 7 Jan 2025 11:56:27 -0800 Subject: [PATCH 01/19] Use typing.Callable instead of callable Signed-off-by: liamhuber --- pyiron_workflow/executors/cloudpickleprocesspool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyiron_workflow/executors/cloudpickleprocesspool.py b/pyiron_workflow/executors/cloudpickleprocesspool.py index 038c4c45..cd11b072 100644 --- a/pyiron_workflow/executors/cloudpickleprocesspool.py +++ b/pyiron_workflow/executors/cloudpickleprocesspool.py @@ -1,6 +1,7 @@ from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool, _global_shutdown, _WorkItem from sys import version_info +from typing import Callable import cloudpickle @@ -14,7 +15,7 @@ def result(self, timeout=None): class _CloudPickledCallable: - def __init__(self, fnc: callable): + def __init__(self, fnc: Callable): self.fnc_serial = cloudpickle.dumps(fnc) def __call__(self, /, dumped_args, dumped_kwargs): From 5b7e9c7f3eeff046dd620ac16b5da448198e4c90 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Tue, 7 Jan 2025 12:07:53 -0800 Subject: [PATCH 02/19] Ignore erroneous error typing._UnionGenericAlias definitively _does_ exist. Signed-off-by: liamhuber --- pyiron_workflow/type_hinting.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index ccf15be2..563bfe2f 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -28,7 +28,11 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: - if isinstance(type_hint, types.UnionType | typing._UnionGenericAlias): + if isinstance( + type_hint, types.UnionType | typing._UnionGenericAlias # type: ignore + # mypy complains because it thinks typing._UnionGenericAlias doesn't exist + # It definitely does, and we may be able to remove this once mypy catches up + ): return typing.get_args(type_hint) else: return (type_hint,) From aa3c143b1df131bcf085697338e0546f650ab828 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Tue, 7 Jan 2025 14:54:02 -0800 Subject: [PATCH 03/19] Hint a tuple, don't return one Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 04e9e9f2..7dbe7be3 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -464,7 +464,9 @@ def _valid_connection(self, other: DataChannel) -> bool: def _both_typed(self, other: DataChannel) -> bool: return self._has_hint and other._has_hint - def _figure_out_who_is_who(self, other: DataChannel) -> (OutputData, InputData): + def _figure_out_who_is_who( + self, other: DataChannel + ) -> tuple[OutputData, InputData]: return (self, other) if isinstance(self, OutputData) else (other, self) def __str__(self): From 534a2c68d7c9a58333c127f89eb7b4dafc0614e4 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Tue, 7 Jan 2025 15:03:27 -0800 Subject: [PATCH 04/19] Hint typing.Callable instead of callable Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 6 +++--- pyiron_workflow/mixin/preview.py | 7 ++++--- pyiron_workflow/nodes/composite.py | 8 ++++---- pyiron_workflow/nodes/function.py | 10 +++++----- pyiron_workflow/nodes/macro.py | 10 +++++----- pyiron_workflow/nodes/standard.py | 3 ++- pyiron_workflow/nodes/transform.py | 4 ++-- pyiron_workflow/topology.py | 4 ++-- 8 files changed, 27 insertions(+), 25 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 7dbe7be3..7c0d3e90 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -569,7 +569,7 @@ def __init__( self, label: str, owner: HasIO, - callback: callable, + callback: typing.Callable, ): """ Make a new input signal channel. @@ -616,7 +616,7 @@ def _has_required_args(func): ) @property - def callback(self) -> callable: + def callback(self) -> typing.Callable: return getattr(self.owner, self._callback) def __call__(self, other: OutputSignal | None = None) -> None: @@ -639,7 +639,7 @@ def __init__( self, label: str, owner: HasIO, - callback: callable, + callback: typing.Callable, ): super().__init__(label=label, owner=owner, callback=callback) self.received_signals: set[str] = set() diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index 21463d4c..bec7dbe5 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -18,6 +18,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, ClassVar, get_args, get_type_hints, @@ -81,7 +82,7 @@ def preview_io(cls) -> DotDict[str, dict]: ) -def builds_class_io(subclass_factory: callable[..., type[HasIOPreview]]): +def builds_class_io(subclass_factory: Callable[..., type[HasIOPreview]]): """ A decorator for factories producing subclasses of `HasIOPreview` to invoke :meth:`preview_io` after the class is created, thus ensuring the IO has been @@ -129,7 +130,7 @@ class ScrapesIO(HasIOPreview, ABC): @classmethod @abstractmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: """Must return a static method.""" _output_labels: ClassVar[tuple[str] | None] = None # None: scrape them @@ -287,7 +288,7 @@ def _validate_return_count(cls): ) from type_error @staticmethod - def _io_defining_documentation(io_defining_function: callable, title: str): + def _io_defining_documentation(io_defining_function: Callable, title: str): """ A helper method for building a docstring for classes that have their IO defined by some function. diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 7f745e9b..74f4abb8 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -7,7 +7,7 @@ from abc import ABC from time import sleep -from typing import TYPE_CHECKING, Literal +from typing import Callable, Literal, TYPE_CHECKING from pyiron_snippets.colors import SeabornColors from pyiron_snippets.dotdict import DotDict @@ -450,7 +450,7 @@ def graph_as_dict(self) -> dict: return _get_graph_as_dict(self) def _get_connections_as_strings( - self, panel_getter: callable + self, panel_getter: Callable ) -> list[tuple[tuple[str, str], tuple[str, str]]]: """ Connections between children in string representation based on labels. @@ -520,8 +520,8 @@ def __setstate__(self, state): def _restore_connections_from_strings( nodes: dict[str, Node] | DotDict[str, Node], connections: list[tuple[tuple[str, str], tuple[str, str]]], - input_panel_getter: callable, - output_panel_getter: callable, + input_panel_getter: Callable, + output_panel_getter: Callable, ) -> None: """ Set connections among a dictionary of nodes. diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index cd3d9f31..484509a2 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from inspect import getsource -from typing import Any +from typing import Any, Callable from pyiron_snippets.colors import SeabornColors from pyiron_snippets.factory import classfactory @@ -300,11 +300,11 @@ class Function(StaticNode, ScrapesIO, ABC): @staticmethod @abstractmethod - def node_function(**kwargs) -> callable: + def node_function(**kwargs) -> Callable: """What the node _does_.""" @classmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: return cls.node_function @classmethod @@ -351,7 +351,7 @@ def _extra_info(cls) -> str: @classfactory def function_node_factory( - node_function: callable, + node_function: Callable, validate_output_labels: bool, use_cache: bool = True, /, @@ -429,7 +429,7 @@ def decorator(node_function): def function_node( - node_function: callable, + node_function: Callable, *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index 566605a5..b85d88ba 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -8,7 +8,7 @@ import re from abc import ABC, abstractmethod from inspect import getsource -from typing import TYPE_CHECKING +from typing import Callable, TYPE_CHECKING from pyiron_snippets.factory import classfactory @@ -271,11 +271,11 @@ def _setup_node(self) -> None: @staticmethod @abstractmethod - def graph_creator(self, *args, **kwargs) -> callable: + def graph_creator(self, *args, **kwargs) -> Callable: """Build the graph the node will run.""" @classmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: return cls.graph_creator _io_defining_function_uses_self = True @@ -466,7 +466,7 @@ def _extra_info(cls) -> str: @classfactory def macro_node_factory( - graph_creator: callable, + graph_creator: Callable, validate_output_labels: bool, use_cache: bool = True, /, @@ -536,7 +536,7 @@ def decorator(graph_creator): def macro_node( - graph_creator: callable, + graph_creator: Callable, *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, diff --git a/pyiron_workflow/nodes/standard.py b/pyiron_workflow/nodes/standard.py index 8c119944..f753a402 100644 --- a/pyiron_workflow/nodes/standard.py +++ b/pyiron_workflow/nodes/standard.py @@ -9,6 +9,7 @@ import shutil from pathlib import Path from time import sleep +from typing import Callable from pyiron_workflow.channels import NOT_DATA, OutputSignal from pyiron_workflow.nodes.function import Function, as_function_node @@ -167,7 +168,7 @@ def ChangeDirectory( @as_function_node -def PureCall(fnc: callable): +def PureCall(fnc: Callable): """ Return a call without any arguments diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 97befbbb..a1710416 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from dataclasses import MISSING from dataclasses import dataclass as as_dataclass -from typing import Any, ClassVar +from typing import Any, Callable, ClassVar from pandas import DataFrame from pyiron_snippets.colors import SeabornColors @@ -65,7 +65,7 @@ class ToManyOutputs(Transformer, ABC): # Must be commensurate with the dictionary returned by transform_to_output @abstractmethod - def _on_run(self, input_object) -> callable[..., Any | tuple]: + def _on_run(self, input_object) -> Callable[..., Any | tuple]: """Must take the single object to be transformed""" @property diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index c60c9131..08db9139 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Callable, TYPE_CHECKING from toposort import CircularDependencyError, toposort, toposort_flatten @@ -90,7 +90,7 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: def _set_new_run_connections_with_fallback_recovery( - connection_creator: callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] + connection_creator: Callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] ): """ Given a function that takes a dictionary of unconnected nodes, connects their From 85f95d28865e7ddddb748819b66aca486bdb6349 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 10:42:44 -0800 Subject: [PATCH 05/19] Expose the Self typing tool for all versions Signed-off-by: liamhuber --- pyiron_workflow/compatibility.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 pyiron_workflow/compatibility.py diff --git a/pyiron_workflow/compatibility.py b/pyiron_workflow/compatibility.py new file mode 100644 index 00000000..28b7f773 --- /dev/null +++ b/pyiron_workflow/compatibility.py @@ -0,0 +1,6 @@ +from sys import version_info + +if version_info.minor < 11: + from typing_extensions import Self as Self +else: + from typing import Self as Self From 9895187e1ce3bd6f59cbd65ed46fbe602cd1ba84 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 12:24:06 -0800 Subject: [PATCH 06/19] Add a mypy job Based on @jan-janssen's jobs for other pyiron repos Signed-off-by: liamhuber --- .github/workflows/push-pull.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/push-pull.yml b/.github/workflows/push-pull.yml index 9178f1da..b1d47dae 100644 --- a/.github/workflows/push-pull.yml +++ b/.github/workflows/push-pull.yml @@ -19,3 +19,18 @@ jobs: alternate-tests-env-files: .ci_support/lower_bound.yml alternate-tests-python-version: '3.10' alternate-tests-dir: tests/unit + + mypy: + runs-on: ubuntu-latest + steps: + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + architecture: x64 + - name: Checkout + uses: actions/checkout@v4 + - name: Install mypy + run: pip install mypy + - name: Test + run: mypy --ignore-missing-imports ${{ github.event.repository.name }} From c5947607d53c006b2c3f18f9d4487e5ac81934e2 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Wed, 8 Jan 2025 12:50:51 -0800 Subject: [PATCH 07/19] `mypy` channels (#534) * Leverage generics for connection partners Signed-off-by: liamhuber * Break apart connection error message So we only reference type hints when they're there Signed-off-by: liamhuber * Hint connections type more specifically Signed-off-by: liamhuber * Hint disconnect more specifically Signed-off-by: liamhuber * Use Self in disconnection hints Signed-off-by: liamhuber * Use Self to hint value_receiver Signed-off-by: liamhuber * Devolve responsibility for connection validity Otherwise mypy has trouble telling that data channels really are operating on a connection partner, since the `super()` call could wind up pointing anywhere. Signed-off-by: liamhuber * Fix typing in channel tests Signed-off-by: liamhuber * :bug: Return the message Signed-off-by: liamhuber * Fix typing in figuring out who is I/O Signed-off-by: liamhuber * Recast connection parters as class method mypy complained about the class-level attribute access I was using to get around circular references. This is a bit more verbose, but otherwise a fine alternative. Signed-off-by: liamhuber * Match Accumulating input signal call to parent Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 167 +++++++++++++++++++++++------------- tests/unit/test_channels.py | 63 +++++++++----- 2 files changed, 148 insertions(+), 82 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 7c0d3e90..a95f3806 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -14,6 +14,7 @@ from pyiron_snippets.singleton import Singleton +from pyiron_workflow.compatibility import Self from pyiron_workflow.mixin.display_state import HasStateDisplay from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel from pyiron_workflow.type_hinting import ( @@ -25,11 +26,24 @@ from pyiron_workflow.io import HasIO -class ChannelConnectionError(Exception): +class ChannelError(Exception): pass -class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): +class ChannelConnectionError(ChannelError): + pass + + +ConnectionPartner = typing.TypeVar("ConnectionPartner", bound="Channel") + + +class Channel( + HasChannel, + HasLabel, + HasStateDisplay, + typing.Generic[ConnectionPartner], + ABC +): """ Channels facilitate the flow of information (data or control signals) into and out of :class:`HasIO` objects (namely nodes). @@ -37,12 +51,9 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): They must have an identifier (`label: str`) and belong to an `owner: pyiron_workflow.io.HasIO`. - Non-abstract channel classes should come in input/output pairs and specify the - a necessary ancestor for instances they can connect to - (`connection_partner_type: type[Channel]`). Channels may form (:meth:`connect`/:meth:`disconnect`) and store - (:attr:`connections: list[Channel]`) connections with other channels. + (:attr:`connections`) connections with other channels. This connection information is reflexive, and is duplicated to be stored on _both_ channels in the form of a reference to their counterpart in the connection. @@ -51,10 +62,10 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): these (dis)connections is guaranteed to be handled, and new connections are subjected to a validity test. - In this abstract class the only requirement is that the connecting channels form a - "conjugate pair" of classes, i.e. they are children of each other's partner class - (:attr:`connection_partner_type: type[Channel]`) -- input/output connects to - output/input. + In this abstract class the only requirements are that the connecting channels form + a "conjugate pair" of classes, i.e. they are children of each other's partner class + and thus have the same "flavor", but are an input/output pair; and that they define + a string representation. Iterating over channels yields their connections. @@ -80,7 +91,7 @@ def __init__( """ self._label = label self.owner: HasIO = owner - self.connections: list[Channel] = [] + self.connections: list[ConnectionPartner] = [] @property def label(self) -> str: @@ -90,12 +101,12 @@ def label(self) -> str: def __str__(self): pass - @property + @classmethod @abstractmethod - def connection_partner_type(self) -> type[Channel]: + def connection_partner_type(cls) -> type[ConnectionPartner]: """ - Input and output class pairs must specify a parent class for their valid - connection partners. + The class forming a conjugate pair with this channel class -- i.e. the same + "flavor" of channel, but opposite in I/O. """ @property @@ -108,21 +119,18 @@ def full_label(self) -> str: """A label combining the channel's usual label and its owner's semantic path""" return f"{self.owner.full_label}.{self.label}" - def _valid_connection(self, other: Channel) -> bool: + @abstractmethod + def _valid_connection(self, other: object) -> bool: """ Logic for determining if a connection is valid. - - Connections only allowed to instances with the right parent type -- i.e. - connection pairs should be an input/output. """ - return isinstance(other, self.connection_partner_type) - def connect(self, *others: Channel) -> None: + def connect(self, *others: ConnectionPartner) -> None: """ Form a connection between this and one or more other channels. Connections are reflexive, and should only occur between input and output channels, i.e. they are instances of each others - :attr:`connection_partner_type`. + :meth:`connection_partner_type()`. New connections get _prepended_ to the connection lists, so they appear first when searching over connections. @@ -145,24 +153,28 @@ def connect(self, *others: Channel) -> None: self.connections.insert(0, other) other.connections.insert(0, self) else: - if isinstance(other, self.connection_partner_type): + if isinstance(other, self.connection_partner_type()): raise ChannelConnectionError( - f"The channel {other.full_label} ({other.__class__.__name__}" - f") has the correct type " - f"({self.connection_partner_type.__name__}) to connect with " - f"{self.full_label} ({self.__class__.__name__}), but is not " - f"a valid connection. Please check type hints, etc." - f"{other.full_label}.type_hint = {other.type_hint}; " - f"{self.full_label}.type_hint = {self.type_hint}" + self._connection_partner_failure_message(other) ) from None else: raise TypeError( - f"Can only connect to {self.connection_partner_type.__name__} " - f"objects, but {self.full_label} ({self.__class__.__name__}) " + f"Can only connect to {self.connection_partner_type()} " + f"objects, but {self.full_label} ({self.__class__}) " f"got {other} ({type(other)})" ) - def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: + def _connection_partner_failure_message(self, other: ConnectionPartner) -> str: + return ( + f"The channel {other.full_label} ({other.__class__}) has the " + f"correct type ({self.connection_partner_type()}) to connect with " + f"{self.full_label} ({self.__class__}), but is not a valid " + f"connection." + ) + + def disconnect( + self, *others: ConnectionPartner + ) -> list[tuple[Self, ConnectionPartner]]: """ If currently connected to any others, removes this and the other from eachothers respective connections lists. @@ -182,7 +194,9 @@ def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all(self) -> list[tuple[Channel, Channel]]: + def disconnect_all( + self + ) -> list[tuple[Self, ConnectionPartner]]: """ Disconnect from all other channels currently in the connections list. """ @@ -257,8 +271,9 @@ def __bool__(self): NOT_DATA = NotData() +DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel") -class DataChannel(Channel, ABC): +class DataChannel(Channel[DataConnectionPartner], ABC): """ Data channels control the flow of data on the graph. @@ -331,7 +346,7 @@ class DataChannel(Channel, ABC): when this channel is a value receiver. This can potentially be expensive, so consider deactivating strict hints everywhere for production runs. (Default is True, raise exceptions when type hints get violated.) - value_receiver (pyiron_workflow.channel.DataChannel|None): Another channel of + value_receiver (pyiron_workflow.compatibility.Self|None): Another channel of the same class whose value will always get updated when this channel's value gets updated. """ @@ -343,7 +358,7 @@ def __init__( default: typing.Any | None = NOT_DATA, type_hint: typing.Any | None = None, strict_hints: bool = True, - value_receiver: InputData | None = None, + value_receiver: Self | None = None, ): super().__init__(label=label, owner=owner) self._value = NOT_DATA @@ -352,7 +367,7 @@ def __init__( self.strict_hints = strict_hints self.default = default self.value = default # Implicitly type check your default by assignment - self.value_receiver = value_receiver + self.value_receiver: Self = value_receiver @property def value(self): @@ -379,7 +394,7 @@ def _type_check_new_value(self, new_value): ) @property - def value_receiver(self) -> InputData | OutputData | None: + def value_receiver(self) -> Self | None: """ Another data channel of the same type to whom new values are always pushed (without type checking of any sort, not even when forming the couple!) @@ -390,7 +405,7 @@ def value_receiver(self) -> InputData | OutputData | None: return self._value_receiver @value_receiver.setter - def value_receiver(self, new_partner: InputData | OutputData | None): + def value_receiver(self, new_partner: Self | None): if new_partner is not None: if not isinstance(new_partner, self.__class__): raise TypeError( @@ -445,8 +460,8 @@ def _value_is_data(self) -> bool: def _has_hint(self) -> bool: return self.type_hint is not None - def _valid_connection(self, other: DataChannel) -> bool: - if super()._valid_connection(other): + def _valid_connection(self, other: object) -> bool: + if isinstance(other, self.connection_partner_type()): if self._both_typed(other): out, inp = self._figure_out_who_is_who(other) if not inp.strict_hints: @@ -461,13 +476,32 @@ def _valid_connection(self, other: DataChannel) -> bool: else: return False - def _both_typed(self, other: DataChannel) -> bool: + def _connection_partner_failure_message(self, other: DataConnectionPartner) -> str: + msg = super()._connection_partner_failure_message(other) + msg += ( + f"Please check type hints, etc. {other.full_label}.type_hint = " + f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}" + ) + return msg + + def _both_typed(self, other: DataConnectionPartner | Self) -> bool: return self._has_hint and other._has_hint def _figure_out_who_is_who( - self, other: DataChannel + self, other: DataConnectionPartner ) -> tuple[OutputData, InputData]: - return (self, other) if isinstance(self, OutputData) else (other, self) + if isinstance(self, InputData) and isinstance(other, OutputData): + return other, self + elif isinstance(self, OutputData) and isinstance(other, InputData): + return self, other + else: + raise ChannelError( + f"This should be unreachable; data channel conjugate pairs should " + f"always be input/output, but got {type(self)} for {self.full_label} " + f"and {type(other)} for {other.full_label}. If you don't believe you " + f"are responsible for this error, please contact the maintainers via " + f"GitHub." + ) def __str__(self): return str(self.value) @@ -491,9 +525,10 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class InputData(DataChannel): - @property - def connection_partner_type(self): +class InputData(DataChannel["OutputData"]): + + @classmethod + def connection_partner_type(cls) -> type[OutputData]: return OutputData def fetch(self) -> None: @@ -530,13 +565,17 @@ def value(self, new_value): self._value = new_value -class OutputData(DataChannel): - @property - def connection_partner_type(self): +class OutputData(DataChannel["InputData"]): + @classmethod + def connection_partner_type(cls) -> type[InputData]: return InputData -class SignalChannel(Channel, ABC): +SignalConnectionPartner = typing.TypeVar( + "SignalConnectionPartner", bound="SignalChannel" +) + +class SignalChannel(Channel[SignalConnectionPartner], ABC): """ Signal channels give the option control execution flow by triggering callback functions when the channel is called. @@ -555,15 +594,15 @@ class SignalChannel(Channel, ABC): def __call__(self) -> None: pass + def _valid_connection(self, other: object) -> bool: + return isinstance(other, self.connection_partner_type()) + class BadCallbackError(ValueError): pass -class InputSignal(SignalChannel): - @property - def connection_partner_type(self): - return OutputSignal +class InputSignal(SignalChannel["OutputSignal"]): def __init__( self, @@ -591,6 +630,10 @@ def __init__( f"all args are optional: {self._all_args_arg_optional(callback)} " ) + @classmethod + def connection_partner_type(cls) -> type[OutputSignal]: + return OutputSignal + def _is_method_on_owner(self, callback): try: return callback == getattr(self.owner, callback.__name__) @@ -644,14 +687,15 @@ def __init__( super().__init__(label=label, owner=owner, callback=callback) self.received_signals: set[str] = set() - def __call__(self, other: OutputSignal) -> None: + def __call__(self, other: OutputSignal | None = None) -> None: """ Fire callback iff you have received at least one signal from each of your current connections. Resets the collection of received signals when firing. """ - self.received_signals.update([other.scoped_label]) + if isinstance(other, OutputSignal): + self.received_signals.update([other.scoped_label]) if ( len( set(c.scoped_label for c in self.connections).difference( @@ -675,9 +719,10 @@ def __lshift__(self, others): other._connect_accumulating_input_signal(self) -class OutputSignal(SignalChannel): - @property - def connection_partner_type(self): +class OutputSignal(SignalChannel["InputSignal"]): + + @classmethod + def connection_partner_type(cls) -> type[InputSignal]: return InputSignal def __call__(self) -> None: diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index eaeb4a85..151a97fa 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from pyiron_workflow.channels import ( @@ -6,6 +8,7 @@ BadCallbackError, Channel, ChannelConnectionError, + ConnectionPartner, InputData, InputSignal, OutputData, @@ -30,25 +33,24 @@ def data_input_locked(self): return self.locked -class InputChannel(Channel): +class DummyChannel(Channel[ConnectionPartner]): """Just to de-abstract the base class""" - def __str__(self): return "non-abstract input" - @property - def connection_partner_type(self) -> type[Channel]: - return OutputChannel + def _valid_connection(self, other: object) -> bool: + return isinstance(other, self.connection_partner_type()) -class OutputChannel(Channel): - """Just to de-abstract the base class""" +class InputChannel(DummyChannel["OutputChannel"]): + @classmethod + def connection_partner_type(cls) -> type[OutputChannel]: + return OutputChannel - def __str__(self): - return "non-abstract output" - @property - def connection_partner_type(self) -> type[Channel]: +class OutputChannel(DummyChannel["InputChannel"]): + @classmethod + def connection_partner_type(cls) -> type[InputChannel]: return InputChannel @@ -389,26 +391,44 @@ def test_aggregating_call(self): owner = DummyOwner() agg = AccumulatingInputSignal(label="agg", owner=owner, callback=owner.update) - with self.assertRaises( - TypeError, - msg="For an aggregating input signal, it _matters_ who called it, so " - "receiving an output signal is not optional", - ): - agg() - out2 = OutputSignal(label="out2", owner=DummyOwner()) agg.connect(self.out, out2) + out_unrelated = OutputSignal(label="out_unrelated", owner=DummyOwner()) + + signals_sent = 0 self.assertEqual( 2, len(agg.connections), msg="Sanity check on initial conditions" ) self.assertEqual( - 0, len(agg.received_signals), msg="Sanity check on initial conditions" + signals_sent, + len(agg.received_signals), + msg="Sanity check on initial conditions" ) self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") + agg() + signals_sent += 0 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection" + ) + agg(out_unrelated) + signals_sent += 1 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection" + ) + self.out() - self.assertEqual(1, len(agg.received_signals), msg="Signal should be received") + signals_sent += 1 + self.assertEqual( + signals_sent, + len(agg.received_signals), + msg="Signals from other channels should be received" + ) self.assertListEqual( [0], owner.foo, @@ -416,8 +436,9 @@ def test_aggregating_call(self): ) self.out() + signals_sent += 0 self.assertEqual( - 1, + signals_sent, len(agg.received_signals), msg="Repeatedly receiving the same signal should have no effect", ) From 6279797c391b4631f1bfa374008f68c8ee62aad9 Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 12:58:21 -0800 Subject: [PATCH 08/19] Move Ruff jobs into the main push-pull script This is just a little QoL thing; the current script runs the jobs twice every time I push, and it's annoying me. Signed-off-by: liamhuber --- .github/workflows/push-pull.yml | 16 ++++++++++++++++ .github/workflows/ruff.yml | 17 ----------------- 2 files changed, 16 insertions(+), 17 deletions(-) delete mode 100644 .github/workflows/ruff.yml diff --git a/.github/workflows/push-pull.yml b/.github/workflows/push-pull.yml index b1d47dae..6418e835 100644 --- a/.github/workflows/push-pull.yml +++ b/.github/workflows/push-pull.yml @@ -34,3 +34,19 @@ jobs: run: pip install mypy - name: Test run: mypy --ignore-missing-imports ${{ github.event.repository.name }} + + ruff-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: check + + ruff-sort-imports: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: check --select I --fix --diff \ No newline at end of file diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml deleted file mode 100644 index 68c0ec0d..00000000 --- a/.github/workflows/ruff.yml +++ /dev/null @@ -1,17 +0,0 @@ -name: Ruff -on: [ push, pull_request ] -jobs: - ruff-check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: check - ruff-sort-imports: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: check --select I --fix --diff From e13caf5e0683fb977a9d58036fa40de3d2e8ca0c Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 13:00:30 -0800 Subject: [PATCH 09/19] Ruff: import Callable from collections.abc Signed-off-by: liamhuber --- pyiron_workflow/executors/cloudpickleprocesspool.py | 2 +- pyiron_workflow/mixin/preview.py | 2 +- pyiron_workflow/nodes/composite.py | 3 ++- pyiron_workflow/nodes/function.py | 3 ++- pyiron_workflow/nodes/macro.py | 3 ++- pyiron_workflow/nodes/standard.py | 2 +- pyiron_workflow/nodes/transform.py | 3 ++- pyiron_workflow/topology.py | 3 ++- 8 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pyiron_workflow/executors/cloudpickleprocesspool.py b/pyiron_workflow/executors/cloudpickleprocesspool.py index cd11b072..983dc525 100644 --- a/pyiron_workflow/executors/cloudpickleprocesspool.py +++ b/pyiron_workflow/executors/cloudpickleprocesspool.py @@ -1,7 +1,7 @@ +from collections.abc import Callable from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool, _global_shutdown, _WorkItem from sys import version_info -from typing import Callable import cloudpickle diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index bec7dbe5..556af8cd 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -14,11 +14,11 @@ import inspect from abc import ABC, abstractmethod +from collections.abc import Callable from functools import lru_cache, wraps from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, get_args, get_type_hints, diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 74f4abb8..11d50583 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -6,8 +6,9 @@ from __future__ import annotations from abc import ABC +from collections.abc import Callable from time import sleep -from typing import Callable, Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from pyiron_snippets.colors import SeabornColors from pyiron_snippets.dotdict import DotDict diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index 484509a2..8000a6d9 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import getsource -from typing import Any, Callable +from typing import Any from pyiron_snippets.colors import SeabornColors from pyiron_snippets.factory import classfactory diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index b85d88ba..527bd5de 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -7,8 +7,9 @@ import re from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import getsource -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING from pyiron_snippets.factory import classfactory diff --git a/pyiron_workflow/nodes/standard.py b/pyiron_workflow/nodes/standard.py index f753a402..e9b4c683 100644 --- a/pyiron_workflow/nodes/standard.py +++ b/pyiron_workflow/nodes/standard.py @@ -7,9 +7,9 @@ import os import random import shutil +from collections.abc import Callable from pathlib import Path from time import sleep -from typing import Callable from pyiron_workflow.channels import NOT_DATA, OutputSignal from pyiron_workflow.nodes.function import Function, as_function_node diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index a1710416..8852b426 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -6,9 +6,10 @@ import itertools from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import MISSING from dataclasses import dataclass as as_dataclass -from typing import Any, Callable, ClassVar +from typing import Any, ClassVar from pandas import DataFrame from pyiron_snippets.colors import SeabornColors diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index 08db9139..bffd590f 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -6,7 +6,8 @@ from __future__ import annotations -from typing import Callable, TYPE_CHECKING +from collections.abc import Callable +from typing import TYPE_CHECKING from toposort import CircularDependencyError, toposort, toposort_flatten From 4d242b6fd8554b026917623eb0222a40a4d8fa8d Mon Sep 17 00:00:00 2001 From: liamhuber Date: Wed, 8 Jan 2025 13:03:44 -0800 Subject: [PATCH 10/19] black Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 16 ++++++---------- pyiron_workflow/type_hinting.py | 3 ++- tests/unit/test_channels.py | 9 +++++---- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index a95f3806..a7883326 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -38,11 +38,7 @@ class ChannelConnectionError(ChannelError): class Channel( - HasChannel, - HasLabel, - HasStateDisplay, - typing.Generic[ConnectionPartner], - ABC + HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConnectionPartner], ABC ): """ Channels facilitate the flow of information (data or control signals) into and @@ -173,7 +169,7 @@ def _connection_partner_failure_message(self, other: ConnectionPartner) -> str: ) def disconnect( - self, *others: ConnectionPartner + self, *others: ConnectionPartner ) -> list[tuple[Self, ConnectionPartner]]: """ If currently connected to any others, removes this and the other from eachothers @@ -194,9 +190,7 @@ def disconnect( destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all( - self - ) -> list[tuple[Self, ConnectionPartner]]: + def disconnect_all(self) -> list[tuple[Self, ConnectionPartner]]: """ Disconnect from all other channels currently in the connections list. """ @@ -273,6 +267,7 @@ def __bool__(self): DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel") + class DataChannel(Channel[DataConnectionPartner], ABC): """ Data channels control the flow of data on the graph. @@ -488,7 +483,7 @@ def _both_typed(self, other: DataConnectionPartner | Self) -> bool: return self._has_hint and other._has_hint def _figure_out_who_is_who( - self, other: DataConnectionPartner + self, other: DataConnectionPartner ) -> tuple[OutputData, InputData]: if isinstance(self, InputData) and isinstance(other, OutputData): return other, self @@ -575,6 +570,7 @@ def connection_partner_type(cls) -> type[InputData]: "SignalConnectionPartner", bound="SignalChannel" ) + class SignalChannel(Channel[SignalConnectionPartner], ABC): """ Signal channels give the option control execution flow by triggering callback diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index 563bfe2f..28af408b 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -29,7 +29,8 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: if isinstance( - type_hint, types.UnionType | typing._UnionGenericAlias # type: ignore + type_hint, + types.UnionType | typing._UnionGenericAlias, # type: ignore # mypy complains because it thinks typing._UnionGenericAlias doesn't exist # It definitely does, and we may be able to remove this once mypy catches up ): diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index 151a97fa..dce71e3c 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -35,6 +35,7 @@ def data_input_locked(self): class DummyChannel(Channel[ConnectionPartner]): """Just to de-abstract the base class""" + def __str__(self): return "non-abstract input" @@ -403,7 +404,7 @@ def test_aggregating_call(self): self.assertEqual( signals_sent, len(agg.received_signals), - msg="Sanity check on initial conditions" + msg="Sanity check on initial conditions", ) self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") @@ -412,14 +413,14 @@ def test_aggregating_call(self): self.assertListEqual( [0], owner.foo, - msg="Aggregating calls should only matter when they come from a connection" + msg="Aggregating calls should only matter when they come from a connection", ) agg(out_unrelated) signals_sent += 1 self.assertListEqual( [0], owner.foo, - msg="Aggregating calls should only matter when they come from a connection" + msg="Aggregating calls should only matter when they come from a connection", ) self.out() @@ -427,7 +428,7 @@ def test_aggregating_call(self): self.assertEqual( signals_sent, len(agg.received_signals), - msg="Signals from other channels should be received" + msg="Signals from other channels should be received", ) self.assertListEqual( [0], From 71c46da0e0f65f601e847c8a00076c5e284f722d Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Thu, 9 Jan 2025 09:42:14 -0800 Subject: [PATCH 11/19] Drop the private type hint (#535) It was necessary for python<3.10, but we dropped support for that, so we can get rid of the ugly, non-public hint. Signed-off-by: liamhuber --- pyiron_workflow/type_hinting.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index 28af408b..66ae0ce2 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -28,15 +28,9 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: - if isinstance( - type_hint, - types.UnionType | typing._UnionGenericAlias, # type: ignore - # mypy complains because it thinks typing._UnionGenericAlias doesn't exist - # It definitely does, and we may be able to remove this once mypy catches up - ): + if isinstance(type_hint, types.UnionType): return typing.get_args(type_hint) - else: - return (type_hint,) + return (type_hint,) def type_hint_is_as_or_more_specific_than(hint, other) -> bool: From 214c6e2bdd43ffaf3bfb3245f0d8806afdfba1f3 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 10 Jan 2025 09:55:49 -0800 Subject: [PATCH 12/19] `mypy` channels redux (#536) * Refactor: rename Move from "partner" language to "conjugate" language Signed-off-by: liamhuber * Explicitly decompose conjugate behaviour Into flavor and IO components Signed-off-by: liamhuber * Tidying Signed-off-by: liamhuber * Narrow hint on connection copying Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 93 ++++++++++++++++++++----------------- tests/unit/test_channels.py | 10 ++-- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index a7883326..af602a31 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -34,11 +34,14 @@ class ChannelConnectionError(ChannelError): pass -ConnectionPartner = typing.TypeVar("ConnectionPartner", bound="Channel") +ConjugateType = typing.TypeVar("ConjugateType", bound="Channel") +InputType = typing.TypeVar("InputType", bound="InputChannel") +OutputType = typing.TypeVar("OutputType", bound="OutputChannel") +FlavorType = typing.TypeVar("FlavorType", bound="FlavorChannel") class Channel( - HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConnectionPartner], ABC + HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConjugateType], ABC ): """ Channels facilitate the flow of information (data or control signals) into and @@ -58,10 +61,10 @@ class Channel( these (dis)connections is guaranteed to be handled, and new connections are subjected to a validity test. - In this abstract class the only requirements are that the connecting channels form - a "conjugate pair" of classes, i.e. they are children of each other's partner class - and thus have the same "flavor", but are an input/output pair; and that they define - a string representation. + Child classes must specify a conjugate class in order to enforce connection + conjugate pairs which have the same "flavor" (e.g. "data" or "signal"), and + opposite "direction" ("input" vs "output"). And they must define a string + representation. Iterating over channels yields their connections. @@ -87,7 +90,7 @@ def __init__( """ self._label = label self.owner: HasIO = owner - self.connections: list[ConnectionPartner] = [] + self.connections: list[ConjugateType] = [] @property def label(self) -> str: @@ -99,7 +102,7 @@ def __str__(self): @classmethod @abstractmethod - def connection_partner_type(cls) -> type[ConnectionPartner]: + def connection_conjugate(cls) -> type[ConjugateType]: """ The class forming a conjugate pair with this channel class -- i.e. the same "flavor" of channel, but opposite in I/O. @@ -121,12 +124,12 @@ def _valid_connection(self, other: object) -> bool: Logic for determining if a connection is valid. """ - def connect(self, *others: ConnectionPartner) -> None: + def connect(self, *others: ConjugateType) -> None: """ Form a connection between this and one or more other channels. Connections are reflexive, and should only occur between input and output channels, i.e. they are instances of each others - :meth:`connection_partner_type()`. + :meth:`connection_conjugate()`. New connections get _prepended_ to the connection lists, so they appear first when searching over connections. @@ -149,28 +152,26 @@ def connect(self, *others: ConnectionPartner) -> None: self.connections.insert(0, other) other.connections.insert(0, self) else: - if isinstance(other, self.connection_partner_type()): + if isinstance(other, self.connection_conjugate()): raise ChannelConnectionError( - self._connection_partner_failure_message(other) + self._connection_conjugate_failure_message(other) ) from None else: raise TypeError( - f"Can only connect to {self.connection_partner_type()} " + f"Can only connect to {self.connection_conjugate()} " f"objects, but {self.full_label} ({self.__class__}) " f"got {other} ({type(other)})" ) - def _connection_partner_failure_message(self, other: ConnectionPartner) -> str: + def _connection_conjugate_failure_message(self, other: ConjugateType) -> str: return ( f"The channel {other.full_label} ({other.__class__}) has the " - f"correct type ({self.connection_partner_type()}) to connect with " + f"correct type ({self.connection_conjugate()}) to connect with " f"{self.full_label} ({self.__class__}), but is not a valid " f"connection." ) - def disconnect( - self, *others: ConnectionPartner - ) -> list[tuple[Self, ConnectionPartner]]: + def disconnect(self, *others: ConjugateType) -> list[tuple[Self, ConjugateType]]: """ If currently connected to any others, removes this and the other from eachothers respective connections lists. @@ -190,7 +191,7 @@ def disconnect( destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all(self) -> list[tuple[Self, ConnectionPartner]]: + def disconnect_all(self) -> list[tuple[Self, ConjugateType]]: """ Disconnect from all other channels currently in the connections list. """ @@ -207,10 +208,10 @@ def __iter__(self): return self.connections.__iter__() @property - def channel(self) -> Channel: + def channel(self) -> Self: return self - def copy_connections(self, other: Channel) -> None: + def copy_connections(self, other: Self) -> None: """ Adds all the connections in another channel to this channel's connections. @@ -243,6 +244,18 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) +class FlavorChannel(Channel[FlavorType], ABC): + """Abstract base for all flavor-specific channels.""" + + +class InputChannel(Channel[OutputType], ABC): + """Mixin for input channels.""" + + +class OutputChannel(Channel[InputType], ABC): + """Mixin for output channels.""" + + class NotData(metaclass=Singleton): """ This class exists purely to initialize data channel values where no default value @@ -265,10 +278,8 @@ def __bool__(self): NOT_DATA = NotData() -DataConnectionPartner = typing.TypeVar("DataConnectionPartner", bound="DataChannel") - -class DataChannel(Channel[DataConnectionPartner], ABC): +class DataChannel(FlavorChannel["DataChannel"], ABC): """ Data channels control the flow of data on the graph. @@ -456,7 +467,7 @@ def _has_hint(self) -> bool: return self.type_hint is not None def _valid_connection(self, other: object) -> bool: - if isinstance(other, self.connection_partner_type()): + if isinstance(other, self.connection_conjugate()): if self._both_typed(other): out, inp = self._figure_out_who_is_who(other) if not inp.strict_hints: @@ -471,19 +482,19 @@ def _valid_connection(self, other: object) -> bool: else: return False - def _connection_partner_failure_message(self, other: DataConnectionPartner) -> str: - msg = super()._connection_partner_failure_message(other) + def _connection_conjugate_failure_message(self, other: DataChannel) -> str: + msg = super()._connection_conjugate_failure_message(other) msg += ( f"Please check type hints, etc. {other.full_label}.type_hint = " f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}" ) return msg - def _both_typed(self, other: DataConnectionPartner | Self) -> bool: + def _both_typed(self, other: DataChannel) -> bool: return self._has_hint and other._has_hint def _figure_out_who_is_who( - self, other: DataConnectionPartner + self, other: DataChannel ) -> tuple[OutputData, InputData]: if isinstance(self, InputData) and isinstance(other, OutputData): return other, self @@ -520,10 +531,10 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class InputData(DataChannel["OutputData"]): +class InputData(DataChannel, InputChannel["OutputData"]): @classmethod - def connection_partner_type(cls) -> type[OutputData]: + def connection_conjugate(cls) -> type[OutputData]: return OutputData def fetch(self) -> None: @@ -560,18 +571,16 @@ def value(self, new_value): self._value = new_value -class OutputData(DataChannel["InputData"]): +class OutputData(DataChannel, OutputChannel["InputData"]): @classmethod - def connection_partner_type(cls) -> type[InputData]: + def connection_conjugate(cls) -> type[InputData]: return InputData -SignalConnectionPartner = typing.TypeVar( - "SignalConnectionPartner", bound="SignalChannel" -) +SignalType = typing.TypeVar("SignalType", bound="SignalChannel") -class SignalChannel(Channel[SignalConnectionPartner], ABC): +class SignalChannel(FlavorChannel[SignalType], ABC): """ Signal channels give the option control execution flow by triggering callback functions when the channel is called. @@ -591,14 +600,14 @@ def __call__(self) -> None: pass def _valid_connection(self, other: object) -> bool: - return isinstance(other, self.connection_partner_type()) + return isinstance(other, self.connection_conjugate()) class BadCallbackError(ValueError): pass -class InputSignal(SignalChannel["OutputSignal"]): +class InputSignal(SignalChannel["OutputSignal"], InputChannel["OutputSignal"]): def __init__( self, @@ -627,7 +636,7 @@ def __init__( ) @classmethod - def connection_partner_type(cls) -> type[OutputSignal]: + def connection_conjugate(cls) -> type[OutputSignal]: return OutputSignal def _is_method_on_owner(self, callback): @@ -715,10 +724,10 @@ def __lshift__(self, others): other._connect_accumulating_input_signal(self) -class OutputSignal(SignalChannel["InputSignal"]): +class OutputSignal(SignalChannel["InputSignal"], OutputChannel["InputSignal"]): @classmethod - def connection_partner_type(cls) -> type[InputSignal]: + def connection_conjugate(cls) -> type[InputSignal]: return InputSignal def __call__(self) -> None: diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index dce71e3c..bb6c4690 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -8,7 +8,7 @@ BadCallbackError, Channel, ChannelConnectionError, - ConnectionPartner, + ConjugateType, InputData, InputSignal, OutputData, @@ -33,25 +33,25 @@ def data_input_locked(self): return self.locked -class DummyChannel(Channel[ConnectionPartner]): +class DummyChannel(Channel[ConjugateType]): """Just to de-abstract the base class""" def __str__(self): return "non-abstract input" def _valid_connection(self, other: object) -> bool: - return isinstance(other, self.connection_partner_type()) + return isinstance(other, self.connection_conjugate()) class InputChannel(DummyChannel["OutputChannel"]): @classmethod - def connection_partner_type(cls) -> type[OutputChannel]: + def connection_conjugate(cls) -> type[OutputChannel]: return OutputChannel class OutputChannel(DummyChannel["InputChannel"]): @classmethod - def connection_partner_type(cls) -> type[InputChannel]: + def connection_conjugate(cls) -> type[InputChannel]: return InputChannel From fc41dfa296f5f0046c302a810be0afd4baf018a0 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 10 Jan 2025 09:56:12 -0800 Subject: [PATCH 13/19] Apply hints to IO panels (#537) * Refactor: rename Move from "partner" language to "conjugate" language Signed-off-by: liamhuber * Explicitly decompose conjugate behaviour Into flavor and IO components Signed-off-by: liamhuber * Tidying Signed-off-by: liamhuber * Narrow hint on connection copying Signed-off-by: liamhuber * Apply hints to IO panels Signed-off-by: liamhuber * Narrow type Signed-off-by: liamhuber * Don't reuse variable Signed-off-by: liamhuber * Ruff: sort imports Signed-off-by: liamhuber * :bug: fix type hint Signed-off-by: liamhuber * Add more hints Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/io.py | 73 ++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index 0293bb6c..d1cb442d 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -9,7 +9,7 @@ import contextlib from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Generic, TypeVar from pyiron_snippets.dotdict import DotDict @@ -20,8 +20,10 @@ DataChannel, InputData, InputSignal, + InputType, OutputData, OutputSignal, + OutputType, SignalChannel, ) from pyiron_workflow.logging import logger @@ -32,8 +34,11 @@ HasRun, ) +OwnedType = TypeVar("OwnedType", bound=Channel) +OwnedConjugate = TypeVar("OwnedConjugate", bound=Channel) -class IO(HasStateDisplay, ABC): + +class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC): """ IO is a convenience layer for holding and accessing multiple input/output channels. It allows key and dot-based access to the underlying channels. @@ -52,7 +57,9 @@ class IO(HasStateDisplay, ABC): be assigned with a simple `=`. """ - def __init__(self, *channels: Channel): + channel_dict: DotDict[str, OwnedType] + + def __init__(self, *channels: OwnedType): self.__dict__["channel_dict"] = DotDict( { channel.label: channel @@ -63,15 +70,15 @@ def __init__(self, *channels: Channel): @property @abstractmethod - def _channel_class(self) -> type(Channel): + def _channel_class(self) -> type[OwnedType]: pass @abstractmethod - def _assign_a_non_channel_value(self, channel: Channel, value) -> None: + def _assign_a_non_channel_value(self, channel: OwnedType, value) -> None: """What to do when some non-channel value gets assigned to a channel""" pass - def __getattr__(self, item) -> Channel: + def __getattr__(self, item) -> OwnedType: try: return self.channel_dict[item] except KeyError as key_error: @@ -97,20 +104,20 @@ def __setattr__(self, key, value): f"attribute {key} got assigned {value} of type {type(value)}" ) - def _assign_value_to_existing_channel(self, channel: Channel, value) -> None: + def _assign_value_to_existing_channel(self, channel: OwnedType, value) -> None: if isinstance(value, HasChannel): channel.connect(value.channel) else: self._assign_a_non_channel_value(channel, value) - def __getitem__(self, item) -> Channel: + def __getitem__(self, item) -> OwnedType: return self.__getattr__(item) def __setitem__(self, key, value): self.__setattr__(key, value) @property - def connections(self) -> list[Channel]: + def connections(self) -> list[OwnedConjugate]: """All the unique connections across all channels""" return list( set([connection for channel in self for connection in channel.connections]) @@ -124,7 +131,7 @@ def connected(self): def fully_connected(self): return all([c.connected for c in self]) - def disconnect(self) -> list[tuple[Channel, Channel]]: + def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]: """ Disconnect all connections that owned channels have. @@ -173,7 +180,15 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class DataIO(IO, ABC): +class InputsIO(IO[InputType, OutputType], ABC): + pass + + +class OutputsIO(IO[OutputType, InputType], ABC): + pass + + +class DataIO(IO[DataChannel, DataChannel], ABC): def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None: channel.value = value @@ -195,9 +210,9 @@ def deactivate_strict_hints(self): [c.deactivate_strict_hints() for c in self] -class Inputs(DataIO): +class Inputs(InputsIO, DataIO): @property - def _channel_class(self) -> type(InputData): + def _channel_class(self) -> type[InputData]: return InputData def fetch(self): @@ -205,13 +220,13 @@ def fetch(self): c.fetch() -class Outputs(DataIO): +class Outputs(OutputsIO, DataIO): @property - def _channel_class(self) -> type(OutputData): + def _channel_class(self) -> type[OutputData]: return OutputData -class SignalIO(IO, ABC): +class SignalIO(IO[SignalChannel, SignalChannel], ABC): def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: raise TypeError( f"Tried to assign {value} ({type(value)} to the {channel.full_label}, " @@ -220,12 +235,12 @@ def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: ) -class InputSignals(SignalIO): +class InputSignals(InputsIO, SignalIO): @property - def _channel_class(self) -> type(InputSignal): + def _channel_class(self) -> type[InputSignal]: return InputSignal - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: """Disconnect all `run` and `accumulate_and_run` signals, if they exist.""" disconnected = [] with contextlib.suppress(AttributeError): @@ -235,9 +250,9 @@ def disconnect_run(self) -> list[tuple[Channel, Channel]]: return disconnected -class OutputSignals(SignalIO): +class OutputSignals(OutputsIO, SignalIO): @property - def _channel_class(self) -> type(OutputSignal): + def _channel_class(self) -> type[OutputSignal]: return OutputSignal @@ -254,7 +269,7 @@ def __init__(self): self.input = InputSignals() self.output = OutputSignals() - def disconnect(self) -> list[tuple[Channel, Channel]]: + def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]: """ Disconnect all connections in input and output signals. @@ -264,7 +279,7 @@ def disconnect(self) -> list[tuple[Channel, Channel]]: """ return self.input.disconnect() + self.output.disconnect() - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: return self.input.disconnect_run() @property @@ -326,14 +341,14 @@ def connected(self) -> bool: return self.inputs.connected or self.outputs.connected or self.signals.connected @property - def fully_connected(self): + def fully_connected(self) -> bool: return ( self.inputs.fully_connected and self.outputs.fully_connected and self.signals.fully_connected ) - def disconnect(self): + def disconnect(self) -> list[tuple[Channel, Channel]]: """ Disconnect all connections belonging to inputs, outputs, and signals channels. @@ -360,7 +375,7 @@ def deactivate_strict_hints(self): def _connect_output_signal(self, signal: OutputSignal): self.signals.input.run.connect(signal) - def __rshift__(self, other: InputSignal | HasIO): + def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO: """ Allows users to connect run and ran signals like: `first >> second`. """ @@ -458,8 +473,8 @@ def copy_io( try: self._copy_values(other, fail_hard=values_fail_hard) except Exception as e: - for this, other in new_connections: - this.disconnect(other) + for owned, conjugate in new_connections: + owned.disconnect(conjugate) raise e def _copy_connections( @@ -522,7 +537,7 @@ def _copy_values( self, other: HasIO, fail_hard: bool = False, - ) -> list[tuple[Channel, Any]]: + ) -> list[tuple[DataChannel, Any]]: """ Copies all data from input and output channels in the other object onto this one. From c77bcbd6483504b8278dde035edc8547d014854e Mon Sep 17 00:00:00 2001 From: liamhuber Date: Fri, 10 Jan 2025 10:04:00 -0800 Subject: [PATCH 14/19] Refactor connection validity The instance check to see if a connection candidate has the correct (conjugate) type now occurs only _once_ in the parent `Channel` class. `Channel._valid_connection` is the repurposed to check for validity inside the scope of the classes already lining up, and defaults to simply returning `True` in the base class. `DataChannel` overrides it to do the type hint comparison. Changes inspired by [conversation](https://github.com/pyiron/pyiron_workflow/pull/533#discussion_r1908526844) with @XzzX. Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 66 +++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index af602a31..f96bb48c 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -118,12 +118,6 @@ def full_label(self) -> str: """A label combining the channel's usual label and its owner's semantic path""" return f"{self.owner.full_label}.{self.label}" - @abstractmethod - def _valid_connection(self, other: object) -> bool: - """ - Logic for determining if a connection is valid. - """ - def connect(self, *others: ConjugateType) -> None: """ Form a connection between this and one or more other channels. @@ -146,22 +140,30 @@ def connect(self, *others: ConjugateType) -> None: for other in others: if other in self.connections: continue - elif self._valid_connection(other): - # Prepend new connections - # so that connection searches run newest to oldest - self.connections.insert(0, other) - other.connections.insert(0, self) - else: - if isinstance(other, self.connection_conjugate()): + elif isinstance(other, self.connection_conjugate()): + if self._valid_connection(other): + # Prepend new connections + # so that connection searches run newest to oldest + self.connections.insert(0, other) + other.connections.insert(0, self) + else: raise ChannelConnectionError( self._connection_conjugate_failure_message(other) ) from None - else: - raise TypeError( - f"Can only connect to {self.connection_conjugate()} " - f"objects, but {self.full_label} ({self.__class__}) " - f"got {other} ({type(other)})" - ) + else: + raise TypeError( + f"Can only connect to {self.connection_conjugate()} " + f"objects, but {self.full_label} ({self.__class__}) " + f"got {other} ({type(other)})" + ) + + def _valid_connection(self, other: ConjugateType) -> bool: + """ + Logic for determining if a connection to a conjugate partner is valid. + + Override in child classes as necessary. + """ + return True def _connection_conjugate_failure_message(self, other: ConjugateType) -> str: return ( @@ -466,21 +468,18 @@ def _value_is_data(self) -> bool: def _has_hint(self) -> bool: return self.type_hint is not None - def _valid_connection(self, other: object) -> bool: - if isinstance(other, self.connection_conjugate()): - if self._both_typed(other): - out, inp = self._figure_out_who_is_who(other) - if not inp.strict_hints: - return True - else: - return type_hint_is_as_or_more_specific_than( - out.type_hint, inp.type_hint - ) - else: - # If either is untyped, don't do type checking + def _valid_connection(self, other: DataChannel) -> bool: + if self._both_typed(other): + out, inp = self._figure_out_who_is_who(other) + if not inp.strict_hints: return True + else: + return type_hint_is_as_or_more_specific_than( + out.type_hint, inp.type_hint + ) else: - return False + # If either is untyped, don't do type checking + return True def _connection_conjugate_failure_message(self, other: DataChannel) -> str: msg = super()._connection_conjugate_failure_message(other) @@ -599,9 +598,6 @@ class SignalChannel(FlavorChannel[SignalType], ABC): def __call__(self) -> None: pass - def _valid_connection(self, other: object) -> bool: - return isinstance(other, self.connection_conjugate()) - class BadCallbackError(ValueError): pass From ff6a984cfec3301729c898f85dc6dbd1679335a3 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 10 Jan 2025 16:06:53 -0800 Subject: [PATCH 15/19] `mypy` run (#541) * Hint init properties Signed-off-by: liamhuber * Hint local function Signed-off-by: liamhuber * Add stricter return and hint Signed-off-by: liamhuber * :bug: Hint tuple[] not () Signed-off-by: liamhuber * Black Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/mixin/run.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index b704abc7..b7b11c7b 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -7,6 +7,7 @@ import contextlib from abc import ABC, abstractmethod +from collections.abc import Callable from concurrent.futures import Executor as StdLibExecutor from concurrent.futures import Future, ThreadPoolExecutor from functools import partial @@ -51,14 +52,14 @@ class Runnable(UsesState, HasLabel, HasRun, ABC): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.running = False - self.failed = False - self.executor = None - # We call it an executor, but it's just whether to use one. - # This is a simply stop-gap as we work out more sophisticated ways to reference - # (or create) an executor process without ever trying to pickle a `_thread.lock` + self.running: bool = False + self.failed: bool = False + self.executor: ( + StdLibExecutor | tuple[Callable[..., StdLibExecutor], tuple, dict] | None + ) = None + # We call it an executor, but it can also be instructions on making one self.future: None | Future = None - self._thread_pool_sleep_time = 1e-6 + self._thread_pool_sleep_time: float = 1e-6 @abstractmethod def on_run(self, *args, **kwargs) -> Any: # callable[..., Any | tuple]: @@ -135,7 +136,7 @@ def run( :attr:`running`. (Default is True.) """ - def _none_to_dict(inp): + def _none_to_dict(inp: dict | None) -> dict: return {} if inp is None else inp before_run_kwargs = _none_to_dict(before_run_kwargs) @@ -275,7 +276,7 @@ def _finish_run( run_exception_kwargs: dict, run_finally_kwargs: dict, **kwargs, - ) -> Any | tuple: + ) -> Any | tuple | None: """ Switch the status, then process and return the run result. """ @@ -288,6 +289,7 @@ def _finish_run( self._run_exception(**run_exception_kwargs) if raise_run_exceptions: raise e + return None finally: self._run_finally(**run_finally_kwargs) @@ -308,7 +310,7 @@ def _readiness_error_message(self) -> str: @staticmethod def _parse_executor( - executor: StdLibExecutor | (callable[..., StdLibExecutor], tuple, dict), + executor: StdLibExecutor | tuple[Callable[..., StdLibExecutor], tuple, dict], ) -> StdLibExecutor: """ If you've already got an executor, you're done. But if you get callable and From 3577158224af0707f2a6433e780482be027e6a14 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Fri, 10 Jan 2025 16:12:16 -0800 Subject: [PATCH 16/19] `mypy` topology and find (#542) * Don't overload typed variable Signed-off-by: liamhuber * Add (and more specific) return hint(s) To the one function missing one Signed-off-by: liamhuber * Add module docstring Signed-off-by: liamhuber * Catch module spec failures Signed-off-by: liamhuber * Force mypy to accept the design feature That we _want_ callers to be able to get abstract classes if they request them Signed-off-by: liamhuber * Black Signed-off-by: liamhuber * Ruff import sort Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/find.py | 17 ++++++++++++++--- pyiron_workflow/topology.py | 14 +++++++------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/pyiron_workflow/find.py b/pyiron_workflow/find.py index eea058a3..7111ef47 100644 --- a/pyiron_workflow/find.py +++ b/pyiron_workflow/find.py @@ -1,3 +1,9 @@ +""" +A utility for finding public `pyiron_workflow.node.Node` objects. + +Supports the idea of node developers writing independent node packages. +""" + from __future__ import annotations import importlib.util @@ -5,23 +11,28 @@ import sys from pathlib import Path from types import ModuleType +from typing import TypeVar, cast from pyiron_workflow.node import Node +NodeType = TypeVar("NodeType", bound=Node) + def _get_subclasses( source: str | Path | ModuleType, - base_class: type, + base_class: type[NodeType], get_private: bool = False, get_abstract: bool = False, get_imports_too: bool = False, -): +) -> list[type[NodeType]]: if isinstance(source, str | Path): source = Path(source) if source.is_file(): # Load the module from the file module_name = source.stem spec = importlib.util.spec_from_file_location(module_name, str(source)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not create a ModuleSpec for {source}") module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) @@ -54,4 +65,4 @@ def find_nodes(source: str | Path | ModuleType) -> list[type[Node]]: """ Get a list of all public, non-abstract nodes defined in the source. """ - return _get_subclasses(source, Node) + return cast(list[type[Node]], _get_subclasses(source, Node)) diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index bffd590f..a621cc20 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -12,7 +12,7 @@ from toposort import CircularDependencyError, toposort, toposort_flatten if TYPE_CHECKING: - from pyiron_workflow.channels import SignalChannel + from pyiron_workflow.channels import InputSignal, OutputSignal from pyiron_workflow.node import Node @@ -75,8 +75,8 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: ) locally_scoped_dependencies.append(upstream.owner.label) node_dependencies.extend(locally_scoped_dependencies) - node_dependencies = set(node_dependencies) - if node.label in node_dependencies: + node_dependencies_set = set(node_dependencies) + if node.label in node_dependencies_set: # the toposort library has a # [known issue](https://gitlab.com/ericvsmith/toposort/-/issues/3) # That self-dependency isn't caught, so we catch it manually here. @@ -85,14 +85,14 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: f"the execution of non-DAGs: {node.full_label} appears in its own " f"input." ) - digraph[node.label] = node_dependencies + digraph[node.label] = node_dependencies_set return digraph def _set_new_run_connections_with_fallback_recovery( connection_creator: Callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] -): +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a function that takes a dictionary of unconnected nodes, connects their execution graph, and returns the new starting nodes, this wrapper makes sure that @@ -144,7 +144,7 @@ def _set_run_connections_according_to_linear_dag(nodes: dict[str, Node]) -> list def set_run_connections_according_to_linear_dag( nodes: dict[str, Node], -) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data connections outside the nodes provided, and have acyclic data flow, disconnects all @@ -196,7 +196,7 @@ def _set_run_connections_according_to_dag(nodes: dict[str, Node]) -> list[Node]: def set_run_connections_according_to_dag( nodes: dict[str, Node], -) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data connections outside the nodes provided, and have acyclic data flow, disconnects all From 9c260ddd91168ccf137f0e70d35cdf1b4be7bf31 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Thu, 16 Jan 2025 09:40:31 -0800 Subject: [PATCH 17/19] `mypy` semantics (#538) * Initialize _label to a string Signed-off-by: liamhuber * Hint the delimiter Signed-off-by: liamhuber * Make SemanticParent a Generic Signed-off-by: liamhuber * Purge `ParentMost` If subclasses of `Semantic` want to limit their `parent` attribute beyond the standard requirement that it be a `SemanticParent`, they can handle that by overriding the `parent` setter and getter. The only place this was used was in `Workflow`, and so such handling is now exactly the case. Signed-off-by: liamhuber * Update comment Signed-off-by: liamhuber * Use generic type Signed-off-by: liamhuber * Don't use generic in static method Signed-off-by: liamhuber * Jump through mypy hoops It doesn't recognize the __set__ for fset methods on the property, so my usual routes for super'ing the setter are failing. This is annoying, but I don't see it being particularly harmful as the method is private. Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Add dev note Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/mixin/semantics.py | 87 ++++++++++++++---------------- pyiron_workflow/nodes/composite.py | 8 +-- pyiron_workflow/workflow.py | 21 +++++++- tests/unit/mixin/test_semantics.py | 25 ++++----- tests/unit/test_workflow.py | 5 +- 5 files changed, 74 insertions(+), 72 deletions(-) diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index de083b87..e207ab92 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -13,9 +13,10 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from difflib import get_close_matches from pathlib import Path +from typing import Generic, TypeVar from bidict import bidict @@ -31,12 +32,12 @@ class Semantic(UsesState, HasLabel, HasParent, ABC): accessible. """ - semantic_delimiter = "/" + semantic_delimiter: str = "/" def __init__( self, label: str, *args, parent: SemanticParent | None = None, **kwargs ): - self._label = None + self._label = "" self._parent = None self._detached_parent_path = None self.label = label @@ -61,6 +62,13 @@ def parent(self) -> SemanticParent | None: @parent.setter def parent(self, new_parent: SemanticParent | None) -> None: + self._set_parent(new_parent) + + def _set_parent(self, new_parent: SemanticParent | None): + """ + mypy is uncooperative with super calls for setters, so we pull the behaviour + out. + """ if new_parent is self._parent: # Exit early if nothing is changing return @@ -157,7 +165,10 @@ class CyclicPathError(ValueError): """ -class SemanticParent(Semantic, ABC): +ChildType = TypeVar("ChildType", bound=Semantic) + + +class SemanticParent(Semantic, Generic[ChildType], ABC): """ A semantic object with a collection of uniquely-named semantic children. @@ -182,19 +193,29 @@ def __init__( strict_naming: bool = True, **kwargs, ): - self._children = bidict() + self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming super().__init__(*args, label=label, parent=parent, **kwargs) + @classmethod + @abstractmethod + def child_type(cls) -> type[ChildType]: + # Dev note: In principle, this could be a regular attribute + # However, in other situations this is precluded (e.g. in channels) + # since it would result in circular references. + # Here we favour consistency over brevity, + # and maintain the X_type() class method pattern + pass + @property - def children(self) -> bidict[str, Semantic]: + def children(self) -> bidict[str, ChildType]: return self._children @property def child_labels(self) -> tuple[str]: return tuple(child.label for child in self) - def __getattr__(self, key): + def __getattr__(self, key) -> ChildType: try: return self._children[key] except KeyError as key_error: @@ -210,7 +231,7 @@ def __getattr__(self, key): def __iter__(self): return self.children.values().__iter__() - def __len__(self): + def __len__(self) -> int: return len(self.children) def __dir__(self): @@ -218,15 +239,15 @@ def __dir__(self): def add_child( self, - child: Semantic, + child: ChildType, label: str | None = None, strict_naming: bool | None = None, - ) -> Semantic: + ) -> ChildType: """ Add a child, optionally assigning it a new label in the process. Args: - child (Semantic): The child to add. + child (ChildType): The child to add. label (str|None): A (potentially) new label to assign the child. (Default is None, leave the child's label alone.) strict_naming (bool|None): Whether to append a suffix to the label if @@ -234,7 +255,7 @@ def add_child( use the class-level flag.) Returns: - (Semantic): The child being added. + (ChildType): The child being added. Raises: TypeError: When the child is not of an allowed class. @@ -244,18 +265,12 @@ def add_child( `strict_naming` is true. """ - if not isinstance(child, Semantic): + if not isinstance(child, self.child_type()): raise TypeError( - f"{self.label} expected a new child of type {Semantic.__name__} " + f"{self.label} expected a new child of type {self.child_type()} " f"but got {child}" ) - if isinstance(child, ParentMost): - raise ParentMostError( - f"{child.label} is {ParentMost.__name__} and may only take None as a " - f"parent but was added as a child to {self.label}" - ) - self._ensure_path_is_not_cyclic(self, child) self._ensure_child_has_no_other_parent(child) @@ -339,15 +354,15 @@ def _add_suffix_to_label(self, label): ) return new_label - def remove_child(self, child: Semantic | str) -> Semantic: + def remove_child(self, child: ChildType | str) -> ChildType: if isinstance(child, str): child = self.children.pop(child) - elif isinstance(child, Semantic): + elif isinstance(child, self.child_type()): self.children.inv.pop(child) else: raise TypeError( f"{self.label} expected to remove a child of type str or " - f"{Semantic.__name__} but got {child}" + f"{self.child_type()} but got {child}" ) child.parent = None @@ -361,7 +376,7 @@ def parent(self) -> SemanticParent | None: @parent.setter def parent(self, new_parent: SemanticParent | None) -> None: self._ensure_path_is_not_cyclic(new_parent, self) - super(SemanticParent, type(self)).parent.__set__(self, new_parent) + self._set_parent(new_parent) def __getstate__(self): state = super().__getstate__() @@ -396,27 +411,3 @@ def __setstate__(self, state): # children). So, now return their parent to them: for child in self: child.parent = self - - -class ParentMostError(TypeError): - """ - To be raised when assigning a parent to a parent-most object - """ - - -class ParentMost(SemanticParent, ABC): - """ - A semantic parent that cannot have any other parent. - """ - - @property - def parent(self) -> None: - return None - - @parent.setter - def parent(self, new_parent: None): - if new_parent is not None: - raise ParentMostError( - f"{self.label} is {ParentMost.__name__} and may only take None as a " - f"parent but got {type(new_parent)}" - ) diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 11d50583..e3a06cab 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -54,7 +54,7 @@ class FailedChildError(RuntimeError): """Raise when one or more child nodes raise exceptions.""" -class Composite(SemanticParent, HasCreator, Node, ABC): +class Composite(SemanticParent[Node], HasCreator, Node, ABC): """ A base class for nodes that have internal graph structure -- i.e. they hold a collection of child nodes and their computation is to execute that graph. @@ -154,6 +154,10 @@ def __init__( **kwargs, ) + @classmethod + def child_type(cls) -> type[Node]: + return Node + def activate_strict_hints(self): super().activate_strict_hints() for node in self: @@ -420,8 +424,6 @@ def executor_shutdown(self, wait=True, *, cancel_futures=False): def __setattr__(self, key: str, node: Node): if isinstance(node, Composite) and key in ["_parent", "parent"]: # This is an edge case for assigning a node to an attribute - # We either defer to the setter with super, or directly assign the private - # variable (as requested in the setter) super().__setattr__(key, node) elif isinstance(node, Node): self.add_child(node, label=key) diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 791e17c8..8a5ddb29 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -11,7 +11,6 @@ from bidict import bidict from pyiron_workflow.io import Inputs, Outputs -from pyiron_workflow.mixin.semantics import ParentMost from pyiron_workflow.nodes.composite import Composite if TYPE_CHECKING: @@ -20,7 +19,13 @@ from pyiron_workflow.storage import StorageInterface -class Workflow(ParentMost, Composite): +class ParentMostError(TypeError): + """ + To be raised when assigning a parent to a parent-most object + """ + + +class Workflow(Composite): """ Workflows are a dynamic composite node -- i.e. they hold and run a collection of nodes (a subgraph) which can be dynamically modified (adding and removing nodes, @@ -495,3 +500,15 @@ def replace_child( raise e return owned_node + + @property + def parent(self) -> None: + return None + + @parent.setter + def parent(self, new_parent: None): + if new_parent is not None: + raise ParentMostError( + f"{self.label} is a {self.__class__} and may only take None as a " + f"parent but got {type(new_parent)}" + ) diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index dd40c5f0..0b63b94f 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -3,18 +3,23 @@ from pyiron_workflow.mixin.semantics import ( CyclicPathError, - ParentMost, Semantic, SemanticParent, ) +class ConcreteParent(SemanticParent[Semantic]): + @classmethod + def child_type(cls) -> type[Semantic]: + return Semantic + + class TestSemantics(unittest.TestCase): def setUp(self): - self.root = ParentMost("root") + self.root = ConcreteParent("root") self.child1 = Semantic("child1", parent=self.root) - self.middle1 = SemanticParent("middle", parent=self.root) - self.middle2 = SemanticParent("middle_sub", parent=self.middle1) + self.middle1 = ConcreteParent("middle", parent=self.root) + self.middle2 = ConcreteParent("middle_sub", parent=self.middle1) self.child2 = Semantic("child2", parent=self.middle2) def test_getattr(self): @@ -58,18 +63,6 @@ def test_parent(self): self.assertEqual(self.child1.parent, self.root) self.assertEqual(self.root.parent, None) - with self.subTest(f"{ParentMost.__name__} exceptions"): - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't have parent" - ): - self.root.parent = SemanticParent(label="foo") - - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't be children" - ): - some_parent = SemanticParent(label="bar") - some_parent.add_child(self.root) - with self.subTest("Cyclicity exceptions"): with self.assertRaises(CyclicPathError): self.middle1.parent = self.middle2 diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 174013d3..f19032b7 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -9,9 +9,8 @@ from pyiron_workflow._tests import ensure_tests_in_python_path from pyiron_workflow.channels import NOT_DATA -from pyiron_workflow.mixin.semantics import ParentMostError from pyiron_workflow.storage import TypeNotFoundError, available_backends -from pyiron_workflow.workflow import Workflow +from pyiron_workflow.workflow import ParentMostError, Workflow ensure_tests_in_python_path() @@ -155,7 +154,7 @@ def test_io_map_bijectivity(self): self.assertEqual(3, len(wf.inputs_map), msg="All entries should be stored") self.assertEqual(0, len(wf.inputs), msg="No IO should be left exposed") - def test_is_parentmost(self): + def test_takes_no_parent(self): wf = Workflow("wf") wf2 = Workflow("wf2") From acc8739047721b0484e3a850fc9c7246b5744876 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Thu, 16 Jan 2025 09:58:59 -0800 Subject: [PATCH 18/19] Semantics generic parent (#544) * Make SemanticParent a Generic Signed-off-by: liamhuber * Don't use generic in static method Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Remove HasParent An interface class guaranteeing the (Any-typed) attribute is too vague to be super useful, and redundant when it's _only_ used in `Semantic`. Having a `parent` will just be a direct feature of being semantic. Signed-off-by: liamhuber * Pull out static method Signed-off-by: liamhuber * Pull cyclicity check up to Semantic Signed-off-by: liamhuber * De-parent SemanticParent from Semantic Because of the label arg vs kwarg problem, there is still a vestigial label arg in the SemanticParent init signature. Signed-off-by: liamhuber * Remove redundant type check This is handled in the super class Signed-off-by: liamhuber * Give Semantic a generic parent type Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Black Signed-off-by: liamhuber * Ruff sort imports Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Update docstrings Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/mixin/has_interface_mixins.py | 13 +-- pyiron_workflow/mixin/semantics.py | 85 ++++++++++--------- pyiron_workflow/node.py | 8 +- pyiron_workflow/nodes/composite.py | 5 -- tests/unit/mixin/test_semantics.py | 29 ++++--- 5 files changed, 70 insertions(+), 70 deletions(-) diff --git a/pyiron_workflow/mixin/has_interface_mixins.py b/pyiron_workflow/mixin/has_interface_mixins.py index 2828ce7e..72943176 100644 --- a/pyiron_workflow/mixin/has_interface_mixins.py +++ b/pyiron_workflow/mixin/has_interface_mixins.py @@ -11,7 +11,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: from pyiron_workflow.channels import Channel @@ -53,17 +53,6 @@ def full_label(self) -> str: return self.label -class HasParent(ABC): - """ - A mixin to guarantee the parent interface exists. - """ - - @property - @abstractmethod - def parent(self) -> Any: - """A parent for the object.""" - - class HasChannel(ABC): """ A mix-in class for use with the :class:`Channel` class. diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index e207ab92..8c1ef6d4 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -2,10 +2,11 @@ Classes for "semantic" reasoning. The motivation here is to be able to provide the object with a unique identifier -in the context of other semantic objects. Each object may have exactly one parent -and an arbitrary number of children, and each child's name must be unique in the -scope of that parent. In this way, the path from the parent-most object to any -child is completely unique. The typical filesystem on a computer is an excellent +in the context of other semantic objects. Each object may have at most one parent, +while semantic parents may have an arbitrary number of children, and each child's name +must be unique in the scope of that parent. In this way, when semantic parents are also +themselves semantic, we can build a path from the parent-most object to any child that +is completely unique. The typical filesystem on a computer is an excellent example and fulfills our requirements, the only reason we depart from it is so that we are free to have objects stored in different locations (possibly even on totally different drives or machines) belong to the same semantic group. @@ -21,10 +22,12 @@ from bidict import bidict from pyiron_workflow.logging import logger -from pyiron_workflow.mixin.has_interface_mixins import HasLabel, HasParent, UsesState +from pyiron_workflow.mixin.has_interface_mixins import HasLabel, UsesState +ParentType = TypeVar("ParentType", bound="SemanticParent") -class Semantic(UsesState, HasLabel, HasParent, ABC): + +class Semantic(UsesState, HasLabel, Generic[ParentType], ABC): """ An object with a unique semantic path. @@ -34,9 +37,7 @@ class Semantic(UsesState, HasLabel, HasParent, ABC): semantic_delimiter: str = "/" - def __init__( - self, label: str, *args, parent: SemanticParent | None = None, **kwargs - ): + def __init__(self, label: str, *args, parent: ParentType | None = None, **kwargs): self._label = "" self._parent = None self._detached_parent_path = None @@ -44,6 +45,11 @@ def __init__( self.parent = parent super().__init__(*args, **kwargs) + @classmethod + @abstractmethod + def parent_type(cls) -> type[ParentType]: + pass + @property def label(self) -> str: return self._label @@ -57,14 +63,14 @@ def label(self, new_label: str) -> None: self._label = new_label @property - def parent(self) -> SemanticParent | None: + def parent(self) -> ParentType | None: return self._parent @parent.setter - def parent(self, new_parent: SemanticParent | None) -> None: + def parent(self, new_parent: ParentType | None) -> None: self._set_parent(new_parent) - def _set_parent(self, new_parent: SemanticParent | None): + def _set_parent(self, new_parent: ParentType | None): """ mypy is uncooperative with super calls for setters, so we pull the behaviour out. @@ -73,12 +79,14 @@ def _set_parent(self, new_parent: SemanticParent | None): # Exit early if nothing is changing return - if new_parent is not None and not isinstance(new_parent, SemanticParent): + if new_parent is not None and not isinstance(new_parent, self.parent_type()): raise ValueError( - f"Expected None or a {SemanticParent.__name__} for the parent of " + f"Expected None or a {self.parent_type()} for the parent of " f"{self.label}, but got {new_parent}" ) + _ensure_path_is_not_cyclic(new_parent, self) + if ( self._parent is not None and new_parent is not self._parent @@ -134,7 +142,10 @@ def full_label(self) -> str: @property def semantic_root(self) -> Semantic: """The parent-most object in this semantic path; may be self.""" - return self.parent.semantic_root if isinstance(self.parent, Semantic) else self + if isinstance(self.parent, Semantic): + return self.parent.semantic_root + else: + return self def as_path(self, root: Path | str | None = None) -> Path: """ @@ -168,9 +179,9 @@ class CyclicPathError(ValueError): ChildType = TypeVar("ChildType", bound=Semantic) -class SemanticParent(Semantic, Generic[ChildType], ABC): +class SemanticParent(Generic[ChildType], ABC): """ - A semantic object with a collection of uniquely-named semantic children. + An with a collection of uniquely-named semantic children. Children should be added or removed via the :meth:`add_child` and :meth:`remove_child` methods and _not_ by direct manipulation of the @@ -187,15 +198,14 @@ class SemanticParent(Semantic, Generic[ChildType], ABC): def __init__( self, - label: str, + label: str | None, # Vestigial while the label order is broken *args, - parent: SemanticParent | None = None, strict_naming: bool = True, **kwargs, ): self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming - super().__init__(*args, label=label, parent=parent, **kwargs) + super().__init__(*args, label=label, **kwargs) @classmethod @abstractmethod @@ -271,7 +281,7 @@ def add_child( f"but got {child}" ) - self._ensure_path_is_not_cyclic(self, child) + _ensure_path_is_not_cyclic(self, child) self._ensure_child_has_no_other_parent(child) @@ -292,18 +302,6 @@ def add_child( child.parent = self return child - @staticmethod - def _ensure_path_is_not_cyclic(parent: SemanticParent | None, child: Semantic): - if parent is not None and parent.semantic_path.startswith( - child.semantic_path + child.semantic_delimiter - ): - raise CyclicPathError( - f"{parent.label} cannot be the parent of {child.label}, because its " - f"semantic path is already in {child.label}'s path and cyclic paths " - f"are not allowed. (i.e. {child.semantic_path} is in " - f"{parent.semantic_path})" - ) - def _ensure_child_has_no_other_parent(self, child: Semantic): if child.parent is not None and child.parent is not self: raise ValueError( @@ -369,15 +367,6 @@ def remove_child(self, child: ChildType | str) -> ChildType: return child - @property - def parent(self) -> SemanticParent | None: - return self._parent - - @parent.setter - def parent(self, new_parent: SemanticParent | None) -> None: - self._ensure_path_is_not_cyclic(new_parent, self) - self._set_parent(new_parent) - def __getstate__(self): state = super().__getstate__() @@ -411,3 +400,15 @@ def __setstate__(self, state): # children). So, now return their parent to them: for child in self: child.parent = self + + +def _ensure_path_is_not_cyclic(parent, child: Semantic): + if isinstance(parent, Semantic) and parent.semantic_path.startswith( + child.semantic_path + child.semantic_delimiter + ): + raise CyclicPathError( + f"{parent.label} cannot be the parent of {child.label}, because its " + f"semantic path is already in {child.label}'s path and cyclic paths " + f"are not allowed. (i.e. {child.semantic_path} is in " + f"{parent.semantic_path})" + ) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 3b86a5e4..6e19a704 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -41,7 +41,7 @@ class Node( HasIOWithInjection, - Semantic, + Semantic["Composite"], Runnable, ExploitsSingleOutput, ABC, @@ -319,6 +319,12 @@ def __init__( **kwargs, ) + @classmethod + def parent_type(cls) -> type[Composite]: + from pyiron_workflow.nodes.composite import Composite + + return Composite + def _setup_node(self) -> None: """ Called _before_ :meth:`Node.__init__` finishes. diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index e3a06cab..e5e05ed4 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -304,11 +304,6 @@ def add_child( label: str | None = None, strict_naming: bool | None = None, ) -> Node: - if not isinstance(child, Node): - raise TypeError( - f"Only new {Node.__name__} instances may be added, but got " - f"{type(child)}." - ) self._cached_inputs = None # Reset cache after graph change return super().add_child(child, label=label, strict_naming=strict_naming) diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index 0b63b94f..874928f7 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from pathlib import Path @@ -8,19 +10,25 @@ ) -class ConcreteParent(SemanticParent[Semantic]): +class ConcreteSemantic(Semantic["ConcreteParent"]): + @classmethod + def parent_type(cls) -> type[ConcreteParent]: + return ConcreteParent + + +class ConcreteParent(SemanticParent[ConcreteSemantic], ConcreteSemantic): @classmethod - def child_type(cls) -> type[Semantic]: - return Semantic + def child_type(cls) -> type[ConcreteSemantic]: + return ConcreteSemantic class TestSemantics(unittest.TestCase): def setUp(self): self.root = ConcreteParent("root") - self.child1 = Semantic("child1", parent=self.root) + self.child1 = ConcreteSemantic("child1", parent=self.root) self.middle1 = ConcreteParent("middle", parent=self.root) self.middle2 = ConcreteParent("middle_sub", parent=self.middle1) - self.child2 = Semantic("child2", parent=self.middle2) + self.child2 = ConcreteSemantic("child2", parent=self.middle2) def test_getattr(self): with self.assertRaises(AttributeError) as context: @@ -40,18 +48,19 @@ def test_getattr(self): def test_label_validity(self): with self.assertRaises(TypeError, msg="Label must be a string"): - Semantic(label=123) + ConcreteSemantic(label=123) def test_label_delimiter(self): with self.assertRaises( - ValueError, msg=f"Delimiter '{Semantic.semantic_delimiter}' not allowed" + ValueError, + msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", ): - Semantic(f"invalid{Semantic.semantic_delimiter}label") + ConcreteSemantic(f"invalid{ConcreteSemantic.semantic_delimiter}label") def test_semantic_delimiter(self): self.assertEqual( "/", - Semantic.semantic_delimiter, + ConcreteSemantic.semantic_delimiter, msg="This is just a hard-code to the current value, update it freely so " "the test passes; if it fails it's just a reminder that your change is " "not backwards compatible, and the next release number should reflect " @@ -105,7 +114,7 @@ def test_as_path(self): ) def test_detached_parent_path(self): - orphan = Semantic("orphan") + orphan = ConcreteSemantic("orphan") orphan.__setstate__(self.child2.__getstate__()) self.assertIsNone( orphan.parent, msg="We still should not explicitly have a parent" From 794291076302c0d04e58f1bfb563791eb57450e5 Mon Sep 17 00:00:00 2001 From: Liam Huber Date: Thu, 16 Jan 2025 10:06:57 -0800 Subject: [PATCH 19/19] Improvements to semantic labeling (#547) * Initialize _label to a string Signed-off-by: liamhuber * Hint the delimiter Signed-off-by: liamhuber * Make SemanticParent a Generic Signed-off-by: liamhuber * Purge `ParentMost` If subclasses of `Semantic` want to limit their `parent` attribute beyond the standard requirement that it be a `SemanticParent`, they can handle that by overriding the `parent` setter and getter. The only place this was used was in `Workflow`, and so such handling is now exactly the case. Signed-off-by: liamhuber * Update comment Signed-off-by: liamhuber * Use generic type Signed-off-by: liamhuber * Don't use generic in static method Signed-off-by: liamhuber * Jump through mypy hoops It doesn't recognize the __set__ for fset methods on the property, so my usual routes for super'ing the setter are failing. This is annoying, but I don't see it being particularly harmful as the method is private. Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Add dev note Signed-off-by: liamhuber * Remove HasParent An interface class guaranteeing the (Any-typed) attribute is too vague to be super useful, and redundant when it's _only_ used in `Semantic`. Having a `parent` will just be a direct feature of being semantic. Signed-off-by: liamhuber * Pull out static method Signed-off-by: liamhuber * Pull cyclicity check up to Semantic Signed-off-by: liamhuber * De-parent SemanticParent from Semantic Because of the label arg vs kwarg problem, there is still a vestigial label arg in the SemanticParent init signature. Signed-off-by: liamhuber * Remove redundant type check This is handled in the super class Signed-off-by: liamhuber * Give Semantic a generic parent type Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Black Signed-off-by: liamhuber * Ruff sort imports Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Update docstrings Signed-off-by: liamhuber * Guarantee that semantic parents have a label Signed-off-by: liamhuber * :bug: don't assume parents have semantic_path But we can now safely assume they have a label Signed-off-by: liamhuber * Pull label default up into Semantic This way it is allowed to be a keyword argument everywhere, except for Workflow which makes it positional and adjusts its `super().__init__` call accordingly. Signed-off-by: liamhuber * Refactor: label validity check Pull it up from semantic into an extensible method on the mixin class Signed-off-by: liamhuber * Refactor: rename class Signed-off-by: liamhuber * Add label restrictions To semantic parent based on its child type's semantic delimiter Signed-off-by: liamhuber * Improve error messages Signed-off-by: liamhuber * Make SemanticParent a Generic Signed-off-by: liamhuber * Don't use generic in static method Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Remove HasParent An interface class guaranteeing the (Any-typed) attribute is too vague to be super useful, and redundant when it's _only_ used in `Semantic`. Having a `parent` will just be a direct feature of being semantic. Signed-off-by: liamhuber * Pull out static method Signed-off-by: liamhuber * Pull cyclicity check up to Semantic Signed-off-by: liamhuber * De-parent SemanticParent from Semantic Because of the label arg vs kwarg problem, there is still a vestigial label arg in the SemanticParent init signature. Signed-off-by: liamhuber * Remove redundant type check This is handled in the super class Signed-off-by: liamhuber * Give Semantic a generic parent type Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Black Signed-off-by: liamhuber * Ruff sort imports Signed-off-by: liamhuber * Remove unused import Signed-off-by: liamhuber * Update docstrings Signed-off-by: liamhuber * Annotate some extra returns (#548) Signed-off-by: liamhuber --------- Signed-off-by: liamhuber --- pyiron_workflow/channels.py | 4 -- pyiron_workflow/mixin/has_interface_mixins.py | 16 ++++- pyiron_workflow/mixin/semantics.py | 68 +++++++++++-------- pyiron_workflow/node.py | 5 +- pyiron_workflow/nodes/composite.py | 2 +- tests/unit/mixin/test_run.py | 4 +- tests/unit/mixin/test_semantics.py | 33 ++++++--- 7 files changed, 82 insertions(+), 50 deletions(-) diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 82970878..153932fd 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -92,10 +92,6 @@ def __init__( self.owner: HasIO = owner self.connections: list[ConjugateType] = [] - @property - def label(self) -> str: - return self._label - @abstractmethod def __str__(self): pass diff --git a/pyiron_workflow/mixin/has_interface_mixins.py b/pyiron_workflow/mixin/has_interface_mixins.py index 72943176..5183c83e 100644 --- a/pyiron_workflow/mixin/has_interface_mixins.py +++ b/pyiron_workflow/mixin/has_interface_mixins.py @@ -39,10 +39,24 @@ class HasLabel(ABC): A mixin to guarantee the label interface exists. """ + _label: str + @property - @abstractmethod def label(self) -> str: """A label for the object.""" + return self._label + + @label.setter + def label(self, new_label: str): + self._check_label(new_label) + self._label = new_label + + def _check_label(self, new_label: str) -> None: + """ + Extensible checking routine for label validity. + """ + if not isinstance(new_label, str): + raise TypeError(f"Expected a string label but got {new_label}") @property def full_label(self) -> str: diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index 8c1ef6d4..1ecfd8fb 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from difflib import get_close_matches from pathlib import Path -from typing import Generic, TypeVar +from typing import ClassVar, Generic, TypeVar from bidict import bidict @@ -35,13 +35,19 @@ class Semantic(UsesState, HasLabel, Generic[ParentType], ABC): accessible. """ - semantic_delimiter: str = "/" + semantic_delimiter: ClassVar[str] = "/" - def __init__(self, label: str, *args, parent: ParentType | None = None, **kwargs): + def __init__( + self, + *args, + label: str | None = None, + parent: ParentType | None = None, + **kwargs, + ): self._label = "" self._parent = None self._detached_parent_path = None - self.label = label + self.label = self.__class__.__name__ if label is None else label self.parent = parent super().__init__(*args, **kwargs) @@ -50,17 +56,13 @@ def __init__(self, label: str, *args, parent: ParentType | None = None, **kwargs def parent_type(cls) -> type[ParentType]: pass - @property - def label(self) -> str: - return self._label - - @label.setter - def label(self, new_label: str) -> None: - if not isinstance(new_label, str): - raise TypeError(f"Expected a string label but got {new_label}") + def _check_label(self, new_label: str) -> None: + super()._check_label(new_label) if self.semantic_delimiter in new_label: - raise ValueError(f"{self.semantic_delimiter} cannot be in the label") - self._label = new_label + raise ValueError( + f"Semantic delimiter {self.semantic_delimiter} cannot be in new label " + f"{new_label}" + ) @property def parent(self) -> ParentType | None: @@ -104,18 +106,22 @@ def semantic_path(self) -> str: The path of node labels from the graph root (parent-most node) down to this node. """ + prefix: str if self.parent is None and self.detached_parent_path is None: prefix = "" elif self.parent is None and self.detached_parent_path is not None: prefix = self.detached_parent_path elif self.parent is not None and self.detached_parent_path is None: - prefix = self.parent.semantic_path + if isinstance(self.parent, Semantic): + prefix = self.parent.semantic_path + else: + prefix = self.semantic_delimiter + self.parent.label else: raise ValueError( f"The parent and detached path should not be able to take non-None " f"values simultaneously, but got {self.parent} and " - f"{self.detached_parent_path}, respectively. Please raise an issue on GitHub " - f"outlining how your reached this state." + f"{self.detached_parent_path}, respectively. Please raise an issue on " + f"GitHub outlining how your reached this state." ) return prefix + self.semantic_delimiter + self.label @@ -179,9 +185,9 @@ class CyclicPathError(ValueError): ChildType = TypeVar("ChildType", bound=Semantic) -class SemanticParent(Generic[ChildType], ABC): +class SemanticParent(HasLabel, Generic[ChildType], ABC): """ - An with a collection of uniquely-named semantic children. + A labeled object with a collection of uniquely-named semantic children. Children should be added or removed via the :meth:`add_child` and :meth:`remove_child` methods and _not_ by direct manipulation of the @@ -198,14 +204,13 @@ class SemanticParent(Generic[ChildType], ABC): def __init__( self, - label: str | None, # Vestigial while the label order is broken *args, strict_naming: bool = True, **kwargs, ): self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming - super().__init__(*args, label=label, **kwargs) + super().__init__(*args, **kwargs) @classmethod @abstractmethod @@ -225,6 +230,15 @@ def children(self) -> bidict[str, ChildType]: def child_labels(self) -> tuple[str]: return tuple(child.label for child in self) + def _check_label(self, new_label: str) -> None: + super()._check_label(new_label) + if self.child_type().semantic_delimiter in new_label: + raise ValueError( + f"Child type ({self.child_type()}) semantic delimiter " + f"{self.child_type().semantic_delimiter} cannot be in new label " + f"{new_label}" + ) + def __getattr__(self, key) -> ChildType: try: return self._children[key] @@ -302,7 +316,7 @@ def add_child( child.parent = self return child - def _ensure_child_has_no_other_parent(self, child: Semantic): + def _ensure_child_has_no_other_parent(self, child: Semantic) -> None: if child.parent is not None and child.parent is not self: raise ValueError( f"The child ({child.label}) already belongs to the parent " @@ -310,17 +324,17 @@ def _ensure_child_has_no_other_parent(self, child: Semantic): f"add it to this parent ({self.label})." ) - def _this_child_is_already_at_this_label(self, child: Semantic, label: str): + def _this_child_is_already_at_this_label(self, child: Semantic, label: str) -> bool: return ( label == child.label and label in self.child_labels and self.children[label] is child ) - def _this_child_is_already_at_a_different_label(self, child, label): + def _this_child_is_already_at_a_different_label(self, child, label) -> bool: return child.parent is self and label != child.label - def _get_unique_label(self, label: str, strict_naming: bool): + def _get_unique_label(self, label: str, strict_naming: bool) -> str: if label in self.__dir__(): if label in self.child_labels: if strict_naming: @@ -337,7 +351,7 @@ def _get_unique_label(self, label: str, strict_naming: bool): ) return label - def _add_suffix_to_label(self, label): + def _add_suffix_to_label(self, label: str) -> str: i = 0 new_label = label while new_label in self.__dir__(): @@ -402,7 +416,7 @@ def __setstate__(self, state): child.parent = self -def _ensure_path_is_not_cyclic(parent, child: Semantic): +def _ensure_path_is_not_cyclic(parent, child: Semantic) -> None: if isinstance(parent, Semantic) and parent.semantic_path.startswith( child.semantic_path + child.semantic_delimiter ): diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 6e19a704..ead473f7 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -297,10 +297,7 @@ def __init__( **kwargs: Interpreted as node input data, with keys corresponding to channel labels. """ - super().__init__( - label=self.__class__.__name__ if label is None else label, - parent=parent, - ) + super().__init__(label=label, parent=parent) self.checkpoint = checkpoint self.recovery: Literal["pickle"] | StorageInterface | None = "pickle" self._serialize_result = False # Advertised, but private to indicate diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index e5e05ed4..1689da92 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -143,8 +143,8 @@ def __init__( # empty but the running_children list is not super().__init__( - label, *args, + label=label, parent=parent, delete_existing_savefiles=delete_existing_savefiles, autoload=autoload, diff --git a/tests/unit/mixin/test_run.py b/tests/unit/mixin/test_run.py index 009e03b5..bfa32ece 100644 --- a/tests/unit/mixin/test_run.py +++ b/tests/unit/mixin/test_run.py @@ -8,9 +8,7 @@ class ConcreteRunnable(Runnable): - @property - def label(self) -> str: - return "child_class_with_all_methods_implemented" + _label = "child_class_with_all_methods_implemented" def on_run(self, **kwargs): return kwargs diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index 874928f7..fbb9bac4 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -12,23 +12,29 @@ class ConcreteSemantic(Semantic["ConcreteParent"]): @classmethod - def parent_type(cls) -> type[ConcreteParent]: - return ConcreteParent + def parent_type(cls) -> type[ConcreteSemanticParent]: + return ConcreteSemanticParent -class ConcreteParent(SemanticParent[ConcreteSemantic], ConcreteSemantic): +class ConcreteParent(SemanticParent[ConcreteSemantic]): + _label = "concrete_parent_default_label" + @classmethod def child_type(cls) -> type[ConcreteSemantic]: return ConcreteSemantic +class ConcreteSemanticParent(ConcreteParent, ConcreteSemantic): + pass + + class TestSemantics(unittest.TestCase): def setUp(self): - self.root = ConcreteParent("root") - self.child1 = ConcreteSemantic("child1", parent=self.root) - self.middle1 = ConcreteParent("middle", parent=self.root) - self.middle2 = ConcreteParent("middle_sub", parent=self.middle1) - self.child2 = ConcreteSemantic("child2", parent=self.middle2) + self.root = ConcreteSemanticParent(label="root") + self.child1 = ConcreteSemantic(label="child1", parent=self.root) + self.middle1 = ConcreteSemanticParent(label="middle", parent=self.root) + self.middle2 = ConcreteSemanticParent(label="middle_sub", parent=self.middle1) + self.child2 = ConcreteSemantic(label="child2", parent=self.middle2) def test_getattr(self): with self.assertRaises(AttributeError) as context: @@ -55,7 +61,14 @@ def test_label_delimiter(self): ValueError, msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", ): - ConcreteSemantic(f"invalid{ConcreteSemantic.semantic_delimiter}label") + ConcreteSemantic(label=f"invalid{ConcreteSemantic.semantic_delimiter}label") + + non_semantic_parent = ConcreteParent() + with self.assertRaises( + ValueError, + msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", + ): + non_semantic_parent.label = f"contains_{non_semantic_parent.child_type().semantic_delimiter}_delimiter" def test_semantic_delimiter(self): self.assertEqual( @@ -114,7 +127,7 @@ def test_as_path(self): ) def test_detached_parent_path(self): - orphan = ConcreteSemantic("orphan") + orphan = ConcreteSemantic(label="orphan") orphan.__setstate__(self.child2.__getstate__()) self.assertIsNone( orphan.parent, msg="We still should not explicitly have a parent"