diff --git a/.github/workflows/push-pull.yml b/.github/workflows/push-pull.yml index 9178f1da..6418e835 100644 --- a/.github/workflows/push-pull.yml +++ b/.github/workflows/push-pull.yml @@ -19,3 +19,34 @@ jobs: alternate-tests-env-files: .ci_support/lower_bound.yml alternate-tests-python-version: '3.10' alternate-tests-dir: tests/unit + + mypy: + runs-on: ubuntu-latest + steps: + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + architecture: x64 + - name: Checkout + uses: actions/checkout@v4 + - name: Install mypy + run: pip install mypy + - name: Test + run: mypy --ignore-missing-imports ${{ github.event.repository.name }} + + ruff-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: check + + ruff-sort-imports: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v1 + with: + args: check --select I --fix --diff \ No newline at end of file diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml deleted file mode 100644 index 68c0ec0d..00000000 --- a/.github/workflows/ruff.yml +++ /dev/null @@ -1,17 +0,0 @@ -name: Ruff -on: [ push, pull_request ] -jobs: - ruff-check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: check - ruff-sort-imports: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: astral-sh/ruff-action@v1 - with: - args: check --select I --fix --diff diff --git a/pyiron_workflow/channels.py b/pyiron_workflow/channels.py index 64084eb6..153932fd 100644 --- a/pyiron_workflow/channels.py +++ b/pyiron_workflow/channels.py @@ -14,6 +14,7 @@ from pyiron_snippets.singleton import Singleton +from pyiron_workflow.compatibility import Self from pyiron_workflow.mixin.display_state import HasStateDisplay from pyiron_workflow.mixin.has_interface_mixins import HasChannel, HasLabel from pyiron_workflow.type_hinting import ( @@ -25,11 +26,23 @@ from pyiron_workflow.io import HasIO -class ChannelConnectionError(Exception): +class ChannelError(Exception): pass -class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): +class ChannelConnectionError(ChannelError): + pass + + +ConjugateType = typing.TypeVar("ConjugateType", bound="Channel") +InputType = typing.TypeVar("InputType", bound="InputChannel") +OutputType = typing.TypeVar("OutputType", bound="OutputChannel") +FlavorType = typing.TypeVar("FlavorType", bound="FlavorChannel") + + +class Channel( + HasChannel, HasLabel, HasStateDisplay, typing.Generic[ConjugateType], ABC +): """ Channels facilitate the flow of information (data or control signals) into and out of :class:`HasIO` objects (namely nodes). @@ -37,12 +50,9 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): They must have an identifier (`label: str`) and belong to an `owner: pyiron_workflow.io.HasIO`. - Non-abstract channel classes should come in input/output pairs and specify the - a necessary ancestor for instances they can connect to - (`connection_partner_type: type[Channel]`). Channels may form (:meth:`connect`/:meth:`disconnect`) and store - (:attr:`connections: list[Channel]`) connections with other channels. + (:attr:`connections`) connections with other channels. This connection information is reflexive, and is duplicated to be stored on _both_ channels in the form of a reference to their counterpart in the connection. @@ -51,10 +61,10 @@ class Channel(HasChannel, HasLabel, HasStateDisplay, ABC): these (dis)connections is guaranteed to be handled, and new connections are subjected to a validity test. - In this abstract class the only requirement is that the connecting channels form a - "conjugate pair" of classes, i.e. they are children of each other's partner class - (:attr:`connection_partner_type: type[Channel]`) -- input/output connects to - output/input. + Child classes must specify a conjugate class in order to enforce connection + conjugate pairs which have the same "flavor" (e.g. "data" or "signal"), and + opposite "direction" ("input" vs "output"). And they must define a string + representation. Iterating over channels yields their connections. @@ -80,22 +90,18 @@ def __init__( """ self._label = label self.owner: HasIO = owner - self.connections: list[Channel] = [] - - @property - def label(self) -> str: - return self._label + self.connections: list[ConjugateType] = [] @abstractmethod def __str__(self): pass - @property + @classmethod @abstractmethod - def connection_partner_type(self) -> type[Channel]: + def connection_conjugate(cls) -> type[ConjugateType]: """ - Input and output class pairs must specify a parent class for their valid - connection partners. + The class forming a conjugate pair with this channel class -- i.e. the same + "flavor" of channel, but opposite in I/O. """ @property @@ -108,21 +114,12 @@ def full_label(self) -> str: """A label combining the channel's usual label and its owner's semantic path""" return f"{self.owner.full_label}.{self.label}" - def _valid_connection(self, other: Channel) -> bool: - """ - Logic for determining if a connection is valid. - - Connections only allowed to instances with the right parent type -- i.e. - connection pairs should be an input/output. - """ - return isinstance(other, self.connection_partner_type) - - def connect(self, *others: Channel) -> None: + def connect(self, *others: ConjugateType) -> None: """ Form a connection between this and one or more other channels. Connections are reflexive, and should only occur between input and output channels, i.e. they are instances of each others - :attr:`connection_partner_type`. + :meth:`connection_conjugate()`. New connections get _prepended_ to the connection lists, so they appear first when searching over connections. @@ -139,30 +136,40 @@ def connect(self, *others: Channel) -> None: for other in others: if other in self.connections: continue - elif self._valid_connection(other): - # Prepend new connections - # so that connection searches run newest to oldest - self.connections.insert(0, other) - other.connections.insert(0, self) - else: - if isinstance(other, self.connection_partner_type): + elif isinstance(other, self.connection_conjugate()): + if self._valid_connection(other): + # Prepend new connections + # so that connection searches run newest to oldest + self.connections.insert(0, other) + other.connections.insert(0, self) + else: raise ChannelConnectionError( - f"The channel {other.full_label} ({other.__class__.__name__}" - f") has the correct type " - f"({self.connection_partner_type.__name__}) to connect with " - f"{self.full_label} ({self.__class__.__name__}), but is not " - f"a valid connection. Please check type hints, etc." - f"{other.full_label}.type_hint = {other.type_hint}; " - f"{self.full_label}.type_hint = {self.type_hint}" + self._connection_conjugate_failure_message(other) ) from None - else: - raise TypeError( - f"Can only connect to {self.connection_partner_type.__name__} " - f"objects, but {self.full_label} ({self.__class__.__name__}) " - f"got {other} ({type(other)})" - ) + else: + raise TypeError( + f"Can only connect to {self.connection_conjugate()} " + f"objects, but {self.full_label} ({self.__class__}) " + f"got {other} ({type(other)})" + ) + + def _valid_connection(self, other: ConjugateType) -> bool: + """ + Logic for determining if a connection to a conjugate partner is valid. - def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: + Override in child classes as necessary. + """ + return True + + def _connection_conjugate_failure_message(self, other: ConjugateType) -> str: + return ( + f"The channel {other.full_label} ({other.__class__}) has the " + f"correct type ({self.connection_conjugate()}) to connect with " + f"{self.full_label} ({self.__class__}), but is not a valid " + f"connection." + ) + + def disconnect(self, *others: ConjugateType) -> list[tuple[Self, ConjugateType]]: """ If currently connected to any others, removes this and the other from eachothers respective connections lists. @@ -182,7 +189,7 @@ def disconnect(self, *others: Channel) -> list[tuple[Channel, Channel]]: destroyed_connections.append((self, other)) return destroyed_connections - def disconnect_all(self) -> list[tuple[Channel, Channel]]: + def disconnect_all(self) -> list[tuple[Self, ConjugateType]]: """ Disconnect from all other channels currently in the connections list. """ @@ -199,10 +206,10 @@ def __iter__(self): return self.connections.__iter__() @property - def channel(self) -> Channel: + def channel(self) -> Self: return self - def copy_connections(self, other: Channel) -> None: + def copy_connections(self, other: Self) -> None: """ Adds all the connections in another channel to this channel's connections. @@ -235,6 +242,18 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) +class FlavorChannel(Channel[FlavorType], ABC): + """Abstract base for all flavor-specific channels.""" + + +class InputChannel(Channel[OutputType], ABC): + """Mixin for input channels.""" + + +class OutputChannel(Channel[InputType], ABC): + """Mixin for output channels.""" + + class NotData(metaclass=Singleton): """ This class exists purely to initialize data channel values where no default value @@ -258,7 +277,7 @@ def __bool__(self): NOT_DATA = NotData() -class DataChannel(Channel, ABC): +class DataChannel(FlavorChannel["DataChannel"], ABC): """ Data channels control the flow of data on the graph. @@ -331,7 +350,7 @@ class DataChannel(Channel, ABC): when this channel is a value receiver. This can potentially be expensive, so consider deactivating strict hints everywhere for production runs. (Default is True, raise exceptions when type hints get violated.) - value_receiver (pyiron_workflow.channel.DataChannel|None): Another channel of + value_receiver (pyiron_workflow.compatibility.Self|None): Another channel of the same class whose value will always get updated when this channel's value gets updated. """ @@ -343,7 +362,7 @@ def __init__( default: typing.Any | None = NOT_DATA, type_hint: typing.Any | None = None, strict_hints: bool = True, - value_receiver: InputData | None = None, + value_receiver: Self | None = None, ): super().__init__(label=label, owner=owner) self._value = NOT_DATA @@ -352,7 +371,7 @@ def __init__( self.strict_hints = strict_hints self.default = default self.value = default # Implicitly type check your default by assignment - self.value_receiver = value_receiver + self.value_receiver: Self = value_receiver @property def value(self): @@ -379,7 +398,7 @@ def _type_check_new_value(self, new_value): ) @property - def value_receiver(self) -> InputData | OutputData | None: + def value_receiver(self) -> Self | None: """ Another data channel of the same type to whom new values are always pushed (without type checking of any sort, not even when forming the couple!) @@ -390,7 +409,7 @@ def value_receiver(self) -> InputData | OutputData | None: return self._value_receiver @value_receiver.setter - def value_receiver(self, new_partner: InputData | OutputData | None): + def value_receiver(self, new_partner: Self | None): if new_partner is not None: if not isinstance(new_partner, self.__class__): raise TypeError( @@ -446,26 +465,44 @@ def _has_hint(self) -> bool: return self.type_hint is not None def _valid_connection(self, other: DataChannel) -> bool: - if super()._valid_connection(other): - if self._both_typed(other): - out, inp = self._figure_out_who_is_who(other) - if not inp.strict_hints: - return True - else: - return type_hint_is_as_or_more_specific_than( - out.type_hint, inp.type_hint - ) - else: - # If either is untyped, don't do type checking + if self._both_typed(other): + out, inp = self._figure_out_who_is_who(other) + if not inp.strict_hints: return True + else: + return type_hint_is_as_or_more_specific_than( + out.type_hint, inp.type_hint + ) else: - return False + # If either is untyped, don't do type checking + return True + + def _connection_conjugate_failure_message(self, other: DataChannel) -> str: + msg = super()._connection_conjugate_failure_message(other) + msg += ( + f"Please check type hints, etc. {other.full_label}.type_hint = " + f"{other.type_hint}; {self.full_label}.type_hint = {self.type_hint}" + ) + return msg def _both_typed(self, other: DataChannel) -> bool: return self._has_hint and other._has_hint - def _figure_out_who_is_who(self, other: DataChannel) -> (OutputData, InputData): - return (self, other) if isinstance(self, OutputData) else (other, self) + def _figure_out_who_is_who( + self, other: DataChannel + ) -> tuple[OutputData, InputData]: + if isinstance(self, InputData) and isinstance(other, OutputData): + return other, self + elif isinstance(self, OutputData) and isinstance(other, InputData): + return self, other + else: + raise ChannelError( + f"This should be unreachable; data channel conjugate pairs should " + f"always be input/output, but got {type(self)} for {self.full_label} " + f"and {type(other)} for {other.full_label}. If you don't believe you " + f"are responsible for this error, please contact the maintainers via " + f"GitHub." + ) def __str__(self): return str(self.value) @@ -489,9 +526,10 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class InputData(DataChannel): - @property - def connection_partner_type(self): +class InputData(DataChannel, InputChannel["OutputData"]): + + @classmethod + def connection_conjugate(cls) -> type[OutputData]: return OutputData def fetch(self) -> None: @@ -528,13 +566,16 @@ def value(self, new_value): self._value = new_value -class OutputData(DataChannel): - @property - def connection_partner_type(self): +class OutputData(DataChannel, OutputChannel["InputData"]): + @classmethod + def connection_conjugate(cls) -> type[InputData]: return InputData -class SignalChannel(Channel, ABC): +SignalType = typing.TypeVar("SignalType", bound="SignalChannel") + + +class SignalChannel(FlavorChannel[SignalType], ABC): """ Signal channels give the option control execution flow by triggering callback functions when the channel is called. @@ -558,16 +599,13 @@ class BadCallbackError(ValueError): pass -class InputSignal(SignalChannel): - @property - def connection_partner_type(self): - return OutputSignal +class InputSignal(SignalChannel["OutputSignal"], InputChannel["OutputSignal"]): def __init__( self, label: str, owner: HasIO, - callback: callable, + callback: typing.Callable, ): """ Make a new input signal channel. @@ -589,6 +627,10 @@ def __init__( f"all args are optional: {self._all_args_arg_optional(callback)} " ) + @classmethod + def connection_conjugate(cls) -> type[OutputSignal]: + return OutputSignal + def _is_method_on_owner(self, callback): try: return callback == getattr(self.owner, callback.__name__) @@ -614,7 +656,7 @@ def _has_required_args(func): ) @property - def callback(self) -> callable: + def callback(self) -> typing.Callable: return getattr(self.owner, self._callback) def __call__(self, other: OutputSignal | None = None) -> None: @@ -637,19 +679,20 @@ def __init__( self, label: str, owner: HasIO, - callback: callable, + callback: typing.Callable, ): super().__init__(label=label, owner=owner, callback=callback) self.received_signals: set[str] = set() - def __call__(self, other: OutputSignal) -> None: + def __call__(self, other: OutputSignal | None = None) -> None: """ Fire callback iff you have received at least one signal from each of your current connections. Resets the collection of received signals when firing. """ - self.received_signals.update([other.scoped_label]) + if isinstance(other, OutputSignal): + self.received_signals.update([other.scoped_label]) if ( len( {c.scoped_label for c in self.connections}.difference( @@ -673,9 +716,10 @@ def __lshift__(self, others): other._connect_accumulating_input_signal(self) -class OutputSignal(SignalChannel): - @property - def connection_partner_type(self): +class OutputSignal(SignalChannel["InputSignal"], OutputChannel["InputSignal"]): + + @classmethod + def connection_conjugate(cls) -> type[InputSignal]: return InputSignal def __call__(self) -> None: diff --git a/pyiron_workflow/compatibility.py b/pyiron_workflow/compatibility.py new file mode 100644 index 00000000..28b7f773 --- /dev/null +++ b/pyiron_workflow/compatibility.py @@ -0,0 +1,6 @@ +from sys import version_info + +if version_info.minor < 11: + from typing_extensions import Self as Self +else: + from typing import Self as Self diff --git a/pyiron_workflow/executors/cloudpickleprocesspool.py b/pyiron_workflow/executors/cloudpickleprocesspool.py index 038c4c45..983dc525 100644 --- a/pyiron_workflow/executors/cloudpickleprocesspool.py +++ b/pyiron_workflow/executors/cloudpickleprocesspool.py @@ -1,3 +1,4 @@ +from collections.abc import Callable from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool, _global_shutdown, _WorkItem from sys import version_info @@ -14,7 +15,7 @@ def result(self, timeout=None): class _CloudPickledCallable: - def __init__(self, fnc: callable): + def __init__(self, fnc: Callable): self.fnc_serial = cloudpickle.dumps(fnc) def __call__(self, /, dumped_args, dumped_kwargs): diff --git a/pyiron_workflow/find.py b/pyiron_workflow/find.py index eea058a3..7111ef47 100644 --- a/pyiron_workflow/find.py +++ b/pyiron_workflow/find.py @@ -1,3 +1,9 @@ +""" +A utility for finding public `pyiron_workflow.node.Node` objects. + +Supports the idea of node developers writing independent node packages. +""" + from __future__ import annotations import importlib.util @@ -5,23 +11,28 @@ import sys from pathlib import Path from types import ModuleType +from typing import TypeVar, cast from pyiron_workflow.node import Node +NodeType = TypeVar("NodeType", bound=Node) + def _get_subclasses( source: str | Path | ModuleType, - base_class: type, + base_class: type[NodeType], get_private: bool = False, get_abstract: bool = False, get_imports_too: bool = False, -): +) -> list[type[NodeType]]: if isinstance(source, str | Path): source = Path(source) if source.is_file(): # Load the module from the file module_name = source.stem spec = importlib.util.spec_from_file_location(module_name, str(source)) + if spec is None or spec.loader is None: + raise ImportError(f"Could not create a ModuleSpec for {source}") module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) @@ -54,4 +65,4 @@ def find_nodes(source: str | Path | ModuleType) -> list[type[Node]]: """ Get a list of all public, non-abstract nodes defined in the source. """ - return _get_subclasses(source, Node) + return cast(list[type[Node]], _get_subclasses(source, Node)) diff --git a/pyiron_workflow/io.py b/pyiron_workflow/io.py index 5bb6f170..c3c8dada 100644 --- a/pyiron_workflow/io.py +++ b/pyiron_workflow/io.py @@ -9,7 +9,7 @@ import contextlib from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Generic, TypeVar from pyiron_snippets.dotdict import DotDict @@ -20,8 +20,10 @@ DataChannel, InputData, InputSignal, + InputType, OutputData, OutputSignal, + OutputType, SignalChannel, ) from pyiron_workflow.logging import logger @@ -32,8 +34,11 @@ HasRun, ) +OwnedType = TypeVar("OwnedType", bound=Channel) +OwnedConjugate = TypeVar("OwnedConjugate", bound=Channel) -class IO(HasStateDisplay, ABC): + +class IO(HasStateDisplay, Generic[OwnedType, OwnedConjugate], ABC): """ IO is a convenience layer for holding and accessing multiple input/output channels. It allows key and dot-based access to the underlying channels. @@ -52,7 +57,9 @@ class IO(HasStateDisplay, ABC): be assigned with a simple `=`. """ - def __init__(self, *channels: Channel): + channel_dict: DotDict[str, OwnedType] + + def __init__(self, *channels: OwnedType): self.__dict__["channel_dict"] = DotDict( { channel.label: channel @@ -63,15 +70,15 @@ def __init__(self, *channels: Channel): @property @abstractmethod - def _channel_class(self) -> type(Channel): + def _channel_class(self) -> type[OwnedType]: pass @abstractmethod - def _assign_a_non_channel_value(self, channel: Channel, value) -> None: + def _assign_a_non_channel_value(self, channel: OwnedType, value) -> None: """What to do when some non-channel value gets assigned to a channel""" pass - def __getattr__(self, item) -> Channel: + def __getattr__(self, item) -> OwnedType: try: return self.channel_dict[item] except KeyError as key_error: @@ -97,20 +104,20 @@ def __setattr__(self, key, value): f"attribute {key} got assigned {value} of type {type(value)}" ) - def _assign_value_to_existing_channel(self, channel: Channel, value) -> None: + def _assign_value_to_existing_channel(self, channel: OwnedType, value) -> None: if isinstance(value, HasChannel): channel.connect(value.channel) else: self._assign_a_non_channel_value(channel, value) - def __getitem__(self, item) -> Channel: + def __getitem__(self, item) -> OwnedType: return self.__getattr__(item) def __setitem__(self, key, value): self.__setattr__(key, value) @property - def connections(self) -> list[Channel]: + def connections(self) -> list[OwnedConjugate]: """All the unique connections across all channels""" return list( {connection for channel in self for connection in channel.connections} @@ -124,7 +131,7 @@ def connected(self): def fully_connected(self): return all(c.connected for c in self) - def disconnect(self) -> list[tuple[Channel, Channel]]: + def disconnect(self) -> list[tuple[OwnedType, OwnedConjugate]]: """ Disconnect all connections that owned channels have. @@ -173,7 +180,15 @@ def display_state(self, state=None, ignore_private=True): return super().display_state(state=state, ignore_private=ignore_private) -class DataIO(IO, ABC): +class InputsIO(IO[InputType, OutputType], ABC): + pass + + +class OutputsIO(IO[OutputType, InputType], ABC): + pass + + +class DataIO(IO[DataChannel, DataChannel], ABC): def _assign_a_non_channel_value(self, channel: DataChannel, value) -> None: channel.value = value @@ -195,9 +210,9 @@ def deactivate_strict_hints(self): [c.deactivate_strict_hints() for c in self] -class Inputs(DataIO): +class Inputs(InputsIO, DataIO): @property - def _channel_class(self) -> type(InputData): + def _channel_class(self) -> type[InputData]: return InputData def fetch(self): @@ -205,13 +220,13 @@ def fetch(self): c.fetch() -class Outputs(DataIO): +class Outputs(OutputsIO, DataIO): @property - def _channel_class(self) -> type(OutputData): + def _channel_class(self) -> type[OutputData]: return OutputData -class SignalIO(IO, ABC): +class SignalIO(IO[SignalChannel, SignalChannel], ABC): def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: raise TypeError( f"Tried to assign {value} ({type(value)} to the {channel.full_label}, " @@ -220,12 +235,12 @@ def _assign_a_non_channel_value(self, channel: SignalChannel, value) -> None: ) -class InputSignals(SignalIO): +class InputSignals(InputsIO, SignalIO): @property - def _channel_class(self) -> type(InputSignal): + def _channel_class(self) -> type[InputSignal]: return InputSignal - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: """Disconnect all `run` and `accumulate_and_run` signals, if they exist.""" disconnected = [] with contextlib.suppress(AttributeError): @@ -235,9 +250,9 @@ def disconnect_run(self) -> list[tuple[Channel, Channel]]: return disconnected -class OutputSignals(SignalIO): +class OutputSignals(OutputsIO, SignalIO): @property - def _channel_class(self) -> type(OutputSignal): + def _channel_class(self) -> type[OutputSignal]: return OutputSignal @@ -254,7 +269,7 @@ def __init__(self): self.input = InputSignals() self.output = OutputSignals() - def disconnect(self) -> list[tuple[Channel, Channel]]: + def disconnect(self) -> list[tuple[SignalChannel, SignalChannel]]: """ Disconnect all connections in input and output signals. @@ -264,7 +279,7 @@ def disconnect(self) -> list[tuple[Channel, Channel]]: """ return self.input.disconnect() + self.output.disconnect() - def disconnect_run(self) -> list[tuple[Channel, Channel]]: + def disconnect_run(self) -> list[tuple[InputSignal, OutputSignal]]: return self.input.disconnect_run() @property @@ -326,14 +341,14 @@ def connected(self) -> bool: return self.inputs.connected or self.outputs.connected or self.signals.connected @property - def fully_connected(self): + def fully_connected(self) -> bool: return ( self.inputs.fully_connected and self.outputs.fully_connected and self.signals.fully_connected ) - def disconnect(self): + def disconnect(self) -> list[tuple[Channel, Channel]]: """ Disconnect all connections belonging to inputs, outputs, and signals channels. @@ -360,7 +375,7 @@ def deactivate_strict_hints(self): def _connect_output_signal(self, signal: OutputSignal): self.signals.input.run.connect(signal) - def __rshift__(self, other: InputSignal | HasIO): + def __rshift__(self, other: InputSignal | HasIO) -> InputSignal | HasIO: """ Allows users to connect run and ran signals like: `first >> second`. """ @@ -456,8 +471,8 @@ def copy_io( try: self._copy_values(other, fail_hard=values_fail_hard) except Exception as e: - for this, other in new_connections: - this.disconnect(other) + for owned, conjugate in new_connections: + owned.disconnect(conjugate) raise e def _copy_connections( @@ -520,7 +535,7 @@ def _copy_values( self, other: HasIO, fail_hard: bool = False, - ) -> list[tuple[Channel, Any]]: + ) -> list[tuple[DataChannel, Any]]: """ Copies all data from input and output channels in the other object onto this one. diff --git a/pyiron_workflow/mixin/has_interface_mixins.py b/pyiron_workflow/mixin/has_interface_mixins.py index 2828ce7e..5183c83e 100644 --- a/pyiron_workflow/mixin/has_interface_mixins.py +++ b/pyiron_workflow/mixin/has_interface_mixins.py @@ -11,7 +11,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: from pyiron_workflow.channels import Channel @@ -39,10 +39,24 @@ class HasLabel(ABC): A mixin to guarantee the label interface exists. """ + _label: str + @property - @abstractmethod def label(self) -> str: """A label for the object.""" + return self._label + + @label.setter + def label(self, new_label: str): + self._check_label(new_label) + self._label = new_label + + def _check_label(self, new_label: str) -> None: + """ + Extensible checking routine for label validity. + """ + if not isinstance(new_label, str): + raise TypeError(f"Expected a string label but got {new_label}") @property def full_label(self) -> str: @@ -53,17 +67,6 @@ def full_label(self) -> str: return self.label -class HasParent(ABC): - """ - A mixin to guarantee the parent interface exists. - """ - - @property - @abstractmethod - def parent(self) -> Any: - """A parent for the object.""" - - class HasChannel(ABC): """ A mix-in class for use with the :class:`Channel` class. diff --git a/pyiron_workflow/mixin/preview.py b/pyiron_workflow/mixin/preview.py index 08c06c47..ec74f3e7 100644 --- a/pyiron_workflow/mixin/preview.py +++ b/pyiron_workflow/mixin/preview.py @@ -14,6 +14,7 @@ import inspect from abc import ABC, abstractmethod +from collections.abc import Callable from functools import lru_cache, wraps from typing import ( TYPE_CHECKING, @@ -81,7 +82,7 @@ def preview_io(cls) -> DotDict[str, dict]: ) -def builds_class_io(subclass_factory: callable[..., type[HasIOPreview]]): +def builds_class_io(subclass_factory: Callable[..., type[HasIOPreview]]): """ A decorator for factories producing subclasses of `HasIOPreview` to invoke :meth:`preview_io` after the class is created, thus ensuring the IO has been @@ -129,7 +130,7 @@ class ScrapesIO(HasIOPreview, ABC): @classmethod @abstractmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: """Must return a static method.""" _output_labels: ClassVar[tuple[str] | None] = None # None: scrape them @@ -287,7 +288,7 @@ def _validate_return_count(cls): ) from type_error @staticmethod - def _io_defining_documentation(io_defining_function: callable, title: str): + def _io_defining_documentation(io_defining_function: Callable, title: str): """ A helper method for building a docstring for classes that have their IO defined by some function. diff --git a/pyiron_workflow/mixin/run.py b/pyiron_workflow/mixin/run.py index b704abc7..b7b11c7b 100644 --- a/pyiron_workflow/mixin/run.py +++ b/pyiron_workflow/mixin/run.py @@ -7,6 +7,7 @@ import contextlib from abc import ABC, abstractmethod +from collections.abc import Callable from concurrent.futures import Executor as StdLibExecutor from concurrent.futures import Future, ThreadPoolExecutor from functools import partial @@ -51,14 +52,14 @@ class Runnable(UsesState, HasLabel, HasRun, ABC): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.running = False - self.failed = False - self.executor = None - # We call it an executor, but it's just whether to use one. - # This is a simply stop-gap as we work out more sophisticated ways to reference - # (or create) an executor process without ever trying to pickle a `_thread.lock` + self.running: bool = False + self.failed: bool = False + self.executor: ( + StdLibExecutor | tuple[Callable[..., StdLibExecutor], tuple, dict] | None + ) = None + # We call it an executor, but it can also be instructions on making one self.future: None | Future = None - self._thread_pool_sleep_time = 1e-6 + self._thread_pool_sleep_time: float = 1e-6 @abstractmethod def on_run(self, *args, **kwargs) -> Any: # callable[..., Any | tuple]: @@ -135,7 +136,7 @@ def run( :attr:`running`. (Default is True.) """ - def _none_to_dict(inp): + def _none_to_dict(inp: dict | None) -> dict: return {} if inp is None else inp before_run_kwargs = _none_to_dict(before_run_kwargs) @@ -275,7 +276,7 @@ def _finish_run( run_exception_kwargs: dict, run_finally_kwargs: dict, **kwargs, - ) -> Any | tuple: + ) -> Any | tuple | None: """ Switch the status, then process and return the run result. """ @@ -288,6 +289,7 @@ def _finish_run( self._run_exception(**run_exception_kwargs) if raise_run_exceptions: raise e + return None finally: self._run_finally(**run_finally_kwargs) @@ -308,7 +310,7 @@ def _readiness_error_message(self) -> str: @staticmethod def _parse_executor( - executor: StdLibExecutor | (callable[..., StdLibExecutor], tuple, dict), + executor: StdLibExecutor | tuple[Callable[..., StdLibExecutor], tuple, dict], ) -> StdLibExecutor: """ If you've already got an executor, you're done. But if you get callable and diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index de083b87..1ecfd8fb 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -2,10 +2,11 @@ Classes for "semantic" reasoning. The motivation here is to be able to provide the object with a unique identifier -in the context of other semantic objects. Each object may have exactly one parent -and an arbitrary number of children, and each child's name must be unique in the -scope of that parent. In this way, the path from the parent-most object to any -child is completely unique. The typical filesystem on a computer is an excellent +in the context of other semantic objects. Each object may have at most one parent, +while semantic parents may have an arbitrary number of children, and each child's name +must be unique in the scope of that parent. In this way, when semantic parents are also +themselves semantic, we can build a path from the parent-most object to any child that +is completely unique. The typical filesystem on a computer is an excellent example and fulfills our requirements, the only reason we depart from it is so that we are free to have objects stored in different locations (possibly even on totally different drives or machines) belong to the same semantic group. @@ -13,17 +14,20 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from difflib import get_close_matches from pathlib import Path +from typing import ClassVar, Generic, TypeVar from bidict import bidict from pyiron_workflow.logging import logger -from pyiron_workflow.mixin.has_interface_mixins import HasLabel, HasParent, UsesState +from pyiron_workflow.mixin.has_interface_mixins import HasLabel, UsesState +ParentType = TypeVar("ParentType", bound="SemanticParent") -class Semantic(UsesState, HasLabel, HasParent, ABC): + +class Semantic(UsesState, HasLabel, Generic[ParentType], ABC): """ An object with a unique semantic path. @@ -31,46 +35,60 @@ class Semantic(UsesState, HasLabel, HasParent, ABC): accessible. """ - semantic_delimiter = "/" + semantic_delimiter: ClassVar[str] = "/" def __init__( - self, label: str, *args, parent: SemanticParent | None = None, **kwargs + self, + *args, + label: str | None = None, + parent: ParentType | None = None, + **kwargs, ): - self._label = None + self._label = "" self._parent = None self._detached_parent_path = None - self.label = label + self.label = self.__class__.__name__ if label is None else label self.parent = parent super().__init__(*args, **kwargs) - @property - def label(self) -> str: - return self._label + @classmethod + @abstractmethod + def parent_type(cls) -> type[ParentType]: + pass - @label.setter - def label(self, new_label: str) -> None: - if not isinstance(new_label, str): - raise TypeError(f"Expected a string label but got {new_label}") + def _check_label(self, new_label: str) -> None: + super()._check_label(new_label) if self.semantic_delimiter in new_label: - raise ValueError(f"{self.semantic_delimiter} cannot be in the label") - self._label = new_label + raise ValueError( + f"Semantic delimiter {self.semantic_delimiter} cannot be in new label " + f"{new_label}" + ) @property - def parent(self) -> SemanticParent | None: + def parent(self) -> ParentType | None: return self._parent @parent.setter - def parent(self, new_parent: SemanticParent | None) -> None: + def parent(self, new_parent: ParentType | None) -> None: + self._set_parent(new_parent) + + def _set_parent(self, new_parent: ParentType | None): + """ + mypy is uncooperative with super calls for setters, so we pull the behaviour + out. + """ if new_parent is self._parent: # Exit early if nothing is changing return - if new_parent is not None and not isinstance(new_parent, SemanticParent): + if new_parent is not None and not isinstance(new_parent, self.parent_type()): raise ValueError( - f"Expected None or a {SemanticParent.__name__} for the parent of " + f"Expected None or a {self.parent_type()} for the parent of " f"{self.label}, but got {new_parent}" ) + _ensure_path_is_not_cyclic(new_parent, self) + if ( self._parent is not None and new_parent is not self._parent @@ -88,18 +106,22 @@ def semantic_path(self) -> str: The path of node labels from the graph root (parent-most node) down to this node. """ + prefix: str if self.parent is None and self.detached_parent_path is None: prefix = "" elif self.parent is None and self.detached_parent_path is not None: prefix = self.detached_parent_path elif self.parent is not None and self.detached_parent_path is None: - prefix = self.parent.semantic_path + if isinstance(self.parent, Semantic): + prefix = self.parent.semantic_path + else: + prefix = self.semantic_delimiter + self.parent.label else: raise ValueError( f"The parent and detached path should not be able to take non-None " f"values simultaneously, but got {self.parent} and " - f"{self.detached_parent_path}, respectively. Please raise an issue on GitHub " - f"outlining how your reached this state." + f"{self.detached_parent_path}, respectively. Please raise an issue on " + f"GitHub outlining how your reached this state." ) return prefix + self.semantic_delimiter + self.label @@ -126,7 +148,10 @@ def full_label(self) -> str: @property def semantic_root(self) -> Semantic: """The parent-most object in this semantic path; may be self.""" - return self.parent.semantic_root if isinstance(self.parent, Semantic) else self + if isinstance(self.parent, Semantic): + return self.parent.semantic_root + else: + return self def as_path(self, root: Path | str | None = None) -> Path: """ @@ -157,9 +182,12 @@ class CyclicPathError(ValueError): """ -class SemanticParent(Semantic, ABC): +ChildType = TypeVar("ChildType", bound=Semantic) + + +class SemanticParent(HasLabel, Generic[ChildType], ABC): """ - A semantic object with a collection of uniquely-named semantic children. + A labeled object with a collection of uniquely-named semantic children. Children should be added or removed via the :meth:`add_child` and :meth:`remove_child` methods and _not_ by direct manipulation of the @@ -176,25 +204,42 @@ class SemanticParent(Semantic, ABC): def __init__( self, - label: str, *args, - parent: SemanticParent | None = None, strict_naming: bool = True, **kwargs, ): - self._children = bidict() + self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming - super().__init__(*args, label=label, parent=parent, **kwargs) + super().__init__(*args, **kwargs) + + @classmethod + @abstractmethod + def child_type(cls) -> type[ChildType]: + # Dev note: In principle, this could be a regular attribute + # However, in other situations this is precluded (e.g. in channels) + # since it would result in circular references. + # Here we favour consistency over brevity, + # and maintain the X_type() class method pattern + pass @property - def children(self) -> bidict[str, Semantic]: + def children(self) -> bidict[str, ChildType]: return self._children @property def child_labels(self) -> tuple[str]: return tuple(child.label for child in self) - def __getattr__(self, key): + def _check_label(self, new_label: str) -> None: + super()._check_label(new_label) + if self.child_type().semantic_delimiter in new_label: + raise ValueError( + f"Child type ({self.child_type()}) semantic delimiter " + f"{self.child_type().semantic_delimiter} cannot be in new label " + f"{new_label}" + ) + + def __getattr__(self, key) -> ChildType: try: return self._children[key] except KeyError as key_error: @@ -210,7 +255,7 @@ def __getattr__(self, key): def __iter__(self): return self.children.values().__iter__() - def __len__(self): + def __len__(self) -> int: return len(self.children) def __dir__(self): @@ -218,15 +263,15 @@ def __dir__(self): def add_child( self, - child: Semantic, + child: ChildType, label: str | None = None, strict_naming: bool | None = None, - ) -> Semantic: + ) -> ChildType: """ Add a child, optionally assigning it a new label in the process. Args: - child (Semantic): The child to add. + child (ChildType): The child to add. label (str|None): A (potentially) new label to assign the child. (Default is None, leave the child's label alone.) strict_naming (bool|None): Whether to append a suffix to the label if @@ -234,7 +279,7 @@ def add_child( use the class-level flag.) Returns: - (Semantic): The child being added. + (ChildType): The child being added. Raises: TypeError: When the child is not of an allowed class. @@ -244,19 +289,13 @@ def add_child( `strict_naming` is true. """ - if not isinstance(child, Semantic): + if not isinstance(child, self.child_type()): raise TypeError( - f"{self.label} expected a new child of type {Semantic.__name__} " + f"{self.label} expected a new child of type {self.child_type()} " f"but got {child}" ) - if isinstance(child, ParentMost): - raise ParentMostError( - f"{child.label} is {ParentMost.__name__} and may only take None as a " - f"parent but was added as a child to {self.label}" - ) - - self._ensure_path_is_not_cyclic(self, child) + _ensure_path_is_not_cyclic(self, child) self._ensure_child_has_no_other_parent(child) @@ -277,19 +316,7 @@ def add_child( child.parent = self return child - @staticmethod - def _ensure_path_is_not_cyclic(parent: SemanticParent | None, child: Semantic): - if parent is not None and parent.semantic_path.startswith( - child.semantic_path + child.semantic_delimiter - ): - raise CyclicPathError( - f"{parent.label} cannot be the parent of {child.label}, because its " - f"semantic path is already in {child.label}'s path and cyclic paths " - f"are not allowed. (i.e. {child.semantic_path} is in " - f"{parent.semantic_path})" - ) - - def _ensure_child_has_no_other_parent(self, child: Semantic): + def _ensure_child_has_no_other_parent(self, child: Semantic) -> None: if child.parent is not None and child.parent is not self: raise ValueError( f"The child ({child.label}) already belongs to the parent " @@ -297,17 +324,17 @@ def _ensure_child_has_no_other_parent(self, child: Semantic): f"add it to this parent ({self.label})." ) - def _this_child_is_already_at_this_label(self, child: Semantic, label: str): + def _this_child_is_already_at_this_label(self, child: Semantic, label: str) -> bool: return ( label == child.label and label in self.child_labels and self.children[label] is child ) - def _this_child_is_already_at_a_different_label(self, child, label): + def _this_child_is_already_at_a_different_label(self, child, label) -> bool: return child.parent is self and label != child.label - def _get_unique_label(self, label: str, strict_naming: bool): + def _get_unique_label(self, label: str, strict_naming: bool) -> str: if label in self.__dir__(): if label in self.child_labels: if strict_naming: @@ -324,7 +351,7 @@ def _get_unique_label(self, label: str, strict_naming: bool): ) return label - def _add_suffix_to_label(self, label): + def _add_suffix_to_label(self, label: str) -> str: i = 0 new_label = label while new_label in self.__dir__(): @@ -339,30 +366,21 @@ def _add_suffix_to_label(self, label): ) return new_label - def remove_child(self, child: Semantic | str) -> Semantic: + def remove_child(self, child: ChildType | str) -> ChildType: if isinstance(child, str): child = self.children.pop(child) - elif isinstance(child, Semantic): + elif isinstance(child, self.child_type()): self.children.inv.pop(child) else: raise TypeError( f"{self.label} expected to remove a child of type str or " - f"{Semantic.__name__} but got {child}" + f"{self.child_type()} but got {child}" ) child.parent = None return child - @property - def parent(self) -> SemanticParent | None: - return self._parent - - @parent.setter - def parent(self, new_parent: SemanticParent | None) -> None: - self._ensure_path_is_not_cyclic(new_parent, self) - super(SemanticParent, type(self)).parent.__set__(self, new_parent) - def __getstate__(self): state = super().__getstate__() @@ -398,25 +416,13 @@ def __setstate__(self, state): child.parent = self -class ParentMostError(TypeError): - """ - To be raised when assigning a parent to a parent-most object - """ - - -class ParentMost(SemanticParent, ABC): - """ - A semantic parent that cannot have any other parent. - """ - - @property - def parent(self) -> None: - return None - - @parent.setter - def parent(self, new_parent: None): - if new_parent is not None: - raise ParentMostError( - f"{self.label} is {ParentMost.__name__} and may only take None as a " - f"parent but got {type(new_parent)}" - ) +def _ensure_path_is_not_cyclic(parent, child: Semantic) -> None: + if isinstance(parent, Semantic) and parent.semantic_path.startswith( + child.semantic_path + child.semantic_delimiter + ): + raise CyclicPathError( + f"{parent.label} cannot be the parent of {child.label}, because its " + f"semantic path is already in {child.label}'s path and cyclic paths " + f"are not allowed. (i.e. {child.semantic_path} is in " + f"{parent.semantic_path})" + ) diff --git a/pyiron_workflow/node.py b/pyiron_workflow/node.py index 3b86a5e4..ead473f7 100644 --- a/pyiron_workflow/node.py +++ b/pyiron_workflow/node.py @@ -41,7 +41,7 @@ class Node( HasIOWithInjection, - Semantic, + Semantic["Composite"], Runnable, ExploitsSingleOutput, ABC, @@ -297,10 +297,7 @@ def __init__( **kwargs: Interpreted as node input data, with keys corresponding to channel labels. """ - super().__init__( - label=self.__class__.__name__ if label is None else label, - parent=parent, - ) + super().__init__(label=label, parent=parent) self.checkpoint = checkpoint self.recovery: Literal["pickle"] | StorageInterface | None = "pickle" self._serialize_result = False # Advertised, but private to indicate @@ -319,6 +316,12 @@ def __init__( **kwargs, ) + @classmethod + def parent_type(cls) -> type[Composite]: + from pyiron_workflow.nodes.composite import Composite + + return Composite + def _setup_node(self) -> None: """ Called _before_ :meth:`Node.__init__` finishes. diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 7f745e9b..1689da92 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -6,6 +6,7 @@ from __future__ import annotations from abc import ABC +from collections.abc import Callable from time import sleep from typing import TYPE_CHECKING, Literal @@ -53,7 +54,7 @@ class FailedChildError(RuntimeError): """Raise when one or more child nodes raise exceptions.""" -class Composite(SemanticParent, HasCreator, Node, ABC): +class Composite(SemanticParent[Node], HasCreator, Node, ABC): """ A base class for nodes that have internal graph structure -- i.e. they hold a collection of child nodes and their computation is to execute that graph. @@ -142,8 +143,8 @@ def __init__( # empty but the running_children list is not super().__init__( - label, *args, + label=label, parent=parent, delete_existing_savefiles=delete_existing_savefiles, autoload=autoload, @@ -153,6 +154,10 @@ def __init__( **kwargs, ) + @classmethod + def child_type(cls) -> type[Node]: + return Node + def activate_strict_hints(self): super().activate_strict_hints() for node in self: @@ -299,11 +304,6 @@ def add_child( label: str | None = None, strict_naming: bool | None = None, ) -> Node: - if not isinstance(child, Node): - raise TypeError( - f"Only new {Node.__name__} instances may be added, but got " - f"{type(child)}." - ) self._cached_inputs = None # Reset cache after graph change return super().add_child(child, label=label, strict_naming=strict_naming) @@ -419,8 +419,6 @@ def executor_shutdown(self, wait=True, *, cancel_futures=False): def __setattr__(self, key: str, node: Node): if isinstance(node, Composite) and key in ["_parent", "parent"]: # This is an edge case for assigning a node to an attribute - # We either defer to the setter with super, or directly assign the private - # variable (as requested in the setter) super().__setattr__(key, node) elif isinstance(node, Node): self.add_child(node, label=key) @@ -450,7 +448,7 @@ def graph_as_dict(self) -> dict: return _get_graph_as_dict(self) def _get_connections_as_strings( - self, panel_getter: callable + self, panel_getter: Callable ) -> list[tuple[tuple[str, str], tuple[str, str]]]: """ Connections between children in string representation based on labels. @@ -520,8 +518,8 @@ def __setstate__(self, state): def _restore_connections_from_strings( nodes: dict[str, Node] | DotDict[str, Node], connections: list[tuple[tuple[str, str], tuple[str, str]]], - input_panel_getter: callable, - output_panel_getter: callable, + input_panel_getter: Callable, + output_panel_getter: Callable, ) -> None: """ Set connections among a dictionary of nodes. diff --git a/pyiron_workflow/nodes/function.py b/pyiron_workflow/nodes/function.py index cd3d9f31..8000a6d9 100644 --- a/pyiron_workflow/nodes/function.py +++ b/pyiron_workflow/nodes/function.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import getsource from typing import Any @@ -300,11 +301,11 @@ class Function(StaticNode, ScrapesIO, ABC): @staticmethod @abstractmethod - def node_function(**kwargs) -> callable: + def node_function(**kwargs) -> Callable: """What the node _does_.""" @classmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: return cls.node_function @classmethod @@ -351,7 +352,7 @@ def _extra_info(cls) -> str: @classfactory def function_node_factory( - node_function: callable, + node_function: Callable, validate_output_labels: bool, use_cache: bool = True, /, @@ -429,7 +430,7 @@ def decorator(node_function): def function_node( - node_function: callable, + node_function: Callable, *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, diff --git a/pyiron_workflow/nodes/macro.py b/pyiron_workflow/nodes/macro.py index 566605a5..527bd5de 100644 --- a/pyiron_workflow/nodes/macro.py +++ b/pyiron_workflow/nodes/macro.py @@ -7,6 +7,7 @@ import re from abc import ABC, abstractmethod +from collections.abc import Callable from inspect import getsource from typing import TYPE_CHECKING @@ -271,11 +272,11 @@ def _setup_node(self) -> None: @staticmethod @abstractmethod - def graph_creator(self, *args, **kwargs) -> callable: + def graph_creator(self, *args, **kwargs) -> Callable: """Build the graph the node will run.""" @classmethod - def _io_defining_function(cls) -> callable: + def _io_defining_function(cls) -> Callable: return cls.graph_creator _io_defining_function_uses_self = True @@ -466,7 +467,7 @@ def _extra_info(cls) -> str: @classfactory def macro_node_factory( - graph_creator: callable, + graph_creator: Callable, validate_output_labels: bool, use_cache: bool = True, /, @@ -536,7 +537,7 @@ def decorator(graph_creator): def macro_node( - graph_creator: callable, + graph_creator: Callable, *node_args, output_labels: str | tuple[str, ...] | None = None, validate_output_labels: bool = True, diff --git a/pyiron_workflow/nodes/standard.py b/pyiron_workflow/nodes/standard.py index 8c119944..e9b4c683 100644 --- a/pyiron_workflow/nodes/standard.py +++ b/pyiron_workflow/nodes/standard.py @@ -7,6 +7,7 @@ import os import random import shutil +from collections.abc import Callable from pathlib import Path from time import sleep @@ -167,7 +168,7 @@ def ChangeDirectory( @as_function_node -def PureCall(fnc: callable): +def PureCall(fnc: Callable): """ Return a call without any arguments diff --git a/pyiron_workflow/nodes/transform.py b/pyiron_workflow/nodes/transform.py index 97befbbb..8852b426 100644 --- a/pyiron_workflow/nodes/transform.py +++ b/pyiron_workflow/nodes/transform.py @@ -6,6 +6,7 @@ import itertools from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import MISSING from dataclasses import dataclass as as_dataclass from typing import Any, ClassVar @@ -65,7 +66,7 @@ class ToManyOutputs(Transformer, ABC): # Must be commensurate with the dictionary returned by transform_to_output @abstractmethod - def _on_run(self, input_object) -> callable[..., Any | tuple]: + def _on_run(self, input_object) -> Callable[..., Any | tuple]: """Must take the single object to be transformed""" @property diff --git a/pyiron_workflow/topology.py b/pyiron_workflow/topology.py index e96eeb76..650da2b6 100644 --- a/pyiron_workflow/topology.py +++ b/pyiron_workflow/topology.py @@ -6,12 +6,13 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING from toposort import CircularDependencyError, toposort, toposort_flatten if TYPE_CHECKING: - from pyiron_workflow.channels import SignalChannel + from pyiron_workflow.channels import InputSignal, OutputSignal from pyiron_workflow.node import Node @@ -74,8 +75,8 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: ) locally_scoped_dependencies.append(upstream.owner.label) node_dependencies.extend(locally_scoped_dependencies) - node_dependencies = set(node_dependencies) - if node.label in node_dependencies: + node_dependencies_set = set(node_dependencies) + if node.label in node_dependencies_set: # the toposort library has a # [known issue](https://gitlab.com/ericvsmith/toposort/-/issues/3) # That self-dependency isn't caught, so we catch it manually here. @@ -84,14 +85,14 @@ def nodes_to_data_digraph(nodes: dict[str, Node]) -> dict[str, set[str]]: f"the execution of non-DAGs: {node.full_label} appears in its own " f"input." ) - digraph[node.label] = node_dependencies + digraph[node.label] = node_dependencies_set return digraph def _set_new_run_connections_with_fallback_recovery( - connection_creator: callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] -): + connection_creator: Callable[[dict[str, Node]], list[Node]], nodes: dict[str, Node] +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a function that takes a dictionary of unconnected nodes, connects their execution graph, and returns the new starting nodes, this wrapper makes sure that @@ -143,7 +144,7 @@ def _set_run_connections_according_to_linear_dag(nodes: dict[str, Node]) -> list def set_run_connections_according_to_linear_dag( nodes: dict[str, Node], -) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data connections outside the nodes provided, and have acyclic data flow, disconnects all @@ -195,7 +196,7 @@ def _set_run_connections_according_to_dag(nodes: dict[str, Node]) -> list[Node]: def set_run_connections_according_to_dag( nodes: dict[str, Node], -) -> tuple[list[tuple[SignalChannel, SignalChannel]], list[Node]]: +) -> tuple[list[tuple[InputSignal, OutputSignal]], list[Node]]: """ Given a set of nodes that all have the same parent, have no upstream data connections outside the nodes provided, and have acyclic data flow, disconnects all diff --git a/pyiron_workflow/type_hinting.py b/pyiron_workflow/type_hinting.py index f8619df0..567ac7fa 100644 --- a/pyiron_workflow/type_hinting.py +++ b/pyiron_workflow/type_hinting.py @@ -28,10 +28,9 @@ def valid_value(value, type_hint) -> bool: def type_hint_to_tuple(type_hint) -> tuple: - if isinstance(type_hint, types.UnionType | typing._UnionGenericAlias): + if isinstance(type_hint, types.UnionType): return typing.get_args(type_hint) - else: - return (type_hint,) + return (type_hint,) def type_hint_is_as_or_more_specific_than(hint, other) -> bool: diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 791e17c8..8a5ddb29 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -11,7 +11,6 @@ from bidict import bidict from pyiron_workflow.io import Inputs, Outputs -from pyiron_workflow.mixin.semantics import ParentMost from pyiron_workflow.nodes.composite import Composite if TYPE_CHECKING: @@ -20,7 +19,13 @@ from pyiron_workflow.storage import StorageInterface -class Workflow(ParentMost, Composite): +class ParentMostError(TypeError): + """ + To be raised when assigning a parent to a parent-most object + """ + + +class Workflow(Composite): """ Workflows are a dynamic composite node -- i.e. they hold and run a collection of nodes (a subgraph) which can be dynamically modified (adding and removing nodes, @@ -495,3 +500,15 @@ def replace_child( raise e return owned_node + + @property + def parent(self) -> None: + return None + + @parent.setter + def parent(self, new_parent: None): + if new_parent is not None: + raise ParentMostError( + f"{self.label} is a {self.__class__} and may only take None as a " + f"parent but got {type(new_parent)}" + ) diff --git a/tests/unit/mixin/test_run.py b/tests/unit/mixin/test_run.py index 009e03b5..bfa32ece 100644 --- a/tests/unit/mixin/test_run.py +++ b/tests/unit/mixin/test_run.py @@ -8,9 +8,7 @@ class ConcreteRunnable(Runnable): - @property - def label(self) -> str: - return "child_class_with_all_methods_implemented" + _label = "child_class_with_all_methods_implemented" def on_run(self, **kwargs): return kwargs diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index dd40c5f0..fbb9bac4 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -1,21 +1,40 @@ +from __future__ import annotations + import unittest from pathlib import Path from pyiron_workflow.mixin.semantics import ( CyclicPathError, - ParentMost, Semantic, SemanticParent, ) +class ConcreteSemantic(Semantic["ConcreteParent"]): + @classmethod + def parent_type(cls) -> type[ConcreteSemanticParent]: + return ConcreteSemanticParent + + +class ConcreteParent(SemanticParent[ConcreteSemantic]): + _label = "concrete_parent_default_label" + + @classmethod + def child_type(cls) -> type[ConcreteSemantic]: + return ConcreteSemantic + + +class ConcreteSemanticParent(ConcreteParent, ConcreteSemantic): + pass + + class TestSemantics(unittest.TestCase): def setUp(self): - self.root = ParentMost("root") - self.child1 = Semantic("child1", parent=self.root) - self.middle1 = SemanticParent("middle", parent=self.root) - self.middle2 = SemanticParent("middle_sub", parent=self.middle1) - self.child2 = Semantic("child2", parent=self.middle2) + self.root = ConcreteSemanticParent(label="root") + self.child1 = ConcreteSemantic(label="child1", parent=self.root) + self.middle1 = ConcreteSemanticParent(label="middle", parent=self.root) + self.middle2 = ConcreteSemanticParent(label="middle_sub", parent=self.middle1) + self.child2 = ConcreteSemantic(label="child2", parent=self.middle2) def test_getattr(self): with self.assertRaises(AttributeError) as context: @@ -35,18 +54,26 @@ def test_getattr(self): def test_label_validity(self): with self.assertRaises(TypeError, msg="Label must be a string"): - Semantic(label=123) + ConcreteSemantic(label=123) def test_label_delimiter(self): with self.assertRaises( - ValueError, msg=f"Delimiter '{Semantic.semantic_delimiter}' not allowed" + ValueError, + msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", + ): + ConcreteSemantic(label=f"invalid{ConcreteSemantic.semantic_delimiter}label") + + non_semantic_parent = ConcreteParent() + with self.assertRaises( + ValueError, + msg=f"Delimiter '{ConcreteSemantic.semantic_delimiter}' not allowed", ): - Semantic(f"invalid{Semantic.semantic_delimiter}label") + non_semantic_parent.label = f"contains_{non_semantic_parent.child_type().semantic_delimiter}_delimiter" def test_semantic_delimiter(self): self.assertEqual( "/", - Semantic.semantic_delimiter, + ConcreteSemantic.semantic_delimiter, msg="This is just a hard-code to the current value, update it freely so " "the test passes; if it fails it's just a reminder that your change is " "not backwards compatible, and the next release number should reflect " @@ -58,18 +85,6 @@ def test_parent(self): self.assertEqual(self.child1.parent, self.root) self.assertEqual(self.root.parent, None) - with self.subTest(f"{ParentMost.__name__} exceptions"): - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't have parent" - ): - self.root.parent = SemanticParent(label="foo") - - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't be children" - ): - some_parent = SemanticParent(label="bar") - some_parent.add_child(self.root) - with self.subTest("Cyclicity exceptions"): with self.assertRaises(CyclicPathError): self.middle1.parent = self.middle2 @@ -112,7 +127,7 @@ def test_as_path(self): ) def test_detached_parent_path(self): - orphan = Semantic("orphan") + orphan = ConcreteSemantic(label="orphan") orphan.__setstate__(self.child2.__getstate__()) self.assertIsNone( orphan.parent, msg="We still should not explicitly have a parent" diff --git a/tests/unit/test_channels.py b/tests/unit/test_channels.py index eaeb4a85..bb6c4690 100644 --- a/tests/unit/test_channels.py +++ b/tests/unit/test_channels.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from pyiron_workflow.channels import ( @@ -6,6 +8,7 @@ BadCallbackError, Channel, ChannelConnectionError, + ConjugateType, InputData, InputSignal, OutputData, @@ -30,25 +33,25 @@ def data_input_locked(self): return self.locked -class InputChannel(Channel): +class DummyChannel(Channel[ConjugateType]): """Just to de-abstract the base class""" def __str__(self): return "non-abstract input" - @property - def connection_partner_type(self) -> type[Channel]: - return OutputChannel + def _valid_connection(self, other: object) -> bool: + return isinstance(other, self.connection_conjugate()) -class OutputChannel(Channel): - """Just to de-abstract the base class""" +class InputChannel(DummyChannel["OutputChannel"]): + @classmethod + def connection_conjugate(cls) -> type[OutputChannel]: + return OutputChannel - def __str__(self): - return "non-abstract output" - @property - def connection_partner_type(self) -> type[Channel]: +class OutputChannel(DummyChannel["InputChannel"]): + @classmethod + def connection_conjugate(cls) -> type[InputChannel]: return InputChannel @@ -389,26 +392,44 @@ def test_aggregating_call(self): owner = DummyOwner() agg = AccumulatingInputSignal(label="agg", owner=owner, callback=owner.update) - with self.assertRaises( - TypeError, - msg="For an aggregating input signal, it _matters_ who called it, so " - "receiving an output signal is not optional", - ): - agg() - out2 = OutputSignal(label="out2", owner=DummyOwner()) agg.connect(self.out, out2) + out_unrelated = OutputSignal(label="out_unrelated", owner=DummyOwner()) + + signals_sent = 0 self.assertEqual( 2, len(agg.connections), msg="Sanity check on initial conditions" ) self.assertEqual( - 0, len(agg.received_signals), msg="Sanity check on initial conditions" + signals_sent, + len(agg.received_signals), + msg="Sanity check on initial conditions", ) self.assertListEqual([0], owner.foo, msg="Sanity check on initial conditions") + agg() + signals_sent += 0 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection", + ) + agg(out_unrelated) + signals_sent += 1 + self.assertListEqual( + [0], + owner.foo, + msg="Aggregating calls should only matter when they come from a connection", + ) + self.out() - self.assertEqual(1, len(agg.received_signals), msg="Signal should be received") + signals_sent += 1 + self.assertEqual( + signals_sent, + len(agg.received_signals), + msg="Signals from other channels should be received", + ) self.assertListEqual( [0], owner.foo, @@ -416,8 +437,9 @@ def test_aggregating_call(self): ) self.out() + signals_sent += 0 self.assertEqual( - 1, + signals_sent, len(agg.received_signals), msg="Repeatedly receiving the same signal should have no effect", ) diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 174013d3..f19032b7 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -9,9 +9,8 @@ from pyiron_workflow._tests import ensure_tests_in_python_path from pyiron_workflow.channels import NOT_DATA -from pyiron_workflow.mixin.semantics import ParentMostError from pyiron_workflow.storage import TypeNotFoundError, available_backends -from pyiron_workflow.workflow import Workflow +from pyiron_workflow.workflow import ParentMostError, Workflow ensure_tests_in_python_path() @@ -155,7 +154,7 @@ def test_io_map_bijectivity(self): self.assertEqual(3, len(wf.inputs_map), msg="All entries should be stored") self.assertEqual(0, len(wf.inputs), msg="No IO should be left exposed") - def test_is_parentmost(self): + def test_takes_no_parent(self): wf = Workflow("wf") wf2 = Workflow("wf2")