Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor connection validity #540

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 31 additions & 35 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,6 @@ def full_label(self) -> str:
"""A label combining the channel's usual label and its owner's semantic path"""
return f"{self.owner.full_label}.{self.label}"

@abstractmethod
def _valid_connection(self, other: object) -> bool:
"""
Logic for determining if a connection is valid.
"""

def connect(self, *others: ConjugateType) -> None:
"""
Form a connection between this and one or more other channels.
Expand All @@ -146,22 +140,30 @@ def connect(self, *others: ConjugateType) -> None:
for other in others:
if other in self.connections:
continue
elif self._valid_connection(other):
# Prepend new connections
# so that connection searches run newest to oldest
self.connections.insert(0, other)
other.connections.insert(0, self)
else:
if isinstance(other, self.connection_conjugate()):
elif isinstance(other, self.connection_conjugate()):
if self._valid_connection(other):
# Prepend new connections
# so that connection searches run newest to oldest
self.connections.insert(0, other)
other.connections.insert(0, self)
else:
raise ChannelConnectionError(
self._connection_conjugate_failure_message(other)
) from None
else:
raise TypeError(
f"Can only connect to {self.connection_conjugate()} "
f"objects, but {self.full_label} ({self.__class__}) "
f"got {other} ({type(other)})"
)
else:
raise TypeError(
f"Can only connect to {self.connection_conjugate()} "
f"objects, but {self.full_label} ({self.__class__}) "
f"got {other} ({type(other)})"
)

def _valid_connection(self, other: ConjugateType) -> bool:
"""
Logic for determining if a connection to a conjugate partner is valid.

Override in child classes as necessary.
"""
return True

def _connection_conjugate_failure_message(self, other: ConjugateType) -> str:
return (
Expand Down Expand Up @@ -466,21 +468,18 @@ def _value_is_data(self) -> bool:
def _has_hint(self) -> bool:
return self.type_hint is not None

def _valid_connection(self, other: object) -> bool:
if isinstance(other, self.connection_conjugate()):
if self._both_typed(other):
out, inp = self._figure_out_who_is_who(other)
if not inp.strict_hints:
return True
else:
return type_hint_is_as_or_more_specific_than(
out.type_hint, inp.type_hint
)
else:
# If either is untyped, don't do type checking
def _valid_connection(self, other: DataChannel) -> bool:
if self._both_typed(other):
out, inp = self._figure_out_who_is_who(other)
if not inp.strict_hints:
return True
else:
return type_hint_is_as_or_more_specific_than(
out.type_hint, inp.type_hint
)
else:
return False
# If either is untyped, don't do type checking
return True

def _connection_conjugate_failure_message(self, other: DataChannel) -> str:
msg = super()._connection_conjugate_failure_message(other)
Expand Down Expand Up @@ -599,9 +598,6 @@ class SignalChannel(FlavorChannel[SignalType], ABC):
def __call__(self) -> None:
pass

def _valid_connection(self, other: object) -> bool:
return isinstance(other, self.connection_conjugate())


class BadCallbackError(ValueError):
pass
Expand Down
Loading