Skip to content

Commit

Permalink
Merge pull request #540 from pyiron/valid_connection_refactor
Browse files Browse the repository at this point in the history
Refactor connection validity
  • Loading branch information
XzzX authored Jan 13, 2025
2 parents 3577158 + c77bcbd commit 3e8c92f
Showing 1 changed file with 31 additions and 35 deletions.
66 changes: 31 additions & 35 deletions pyiron_workflow/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,6 @@ def full_label(self) -> str:
"""A label combining the channel's usual label and its owner's semantic path"""
return f"{self.owner.full_label}.{self.label}"

@abstractmethod
def _valid_connection(self, other: object) -> bool:
"""
Logic for determining if a connection is valid.
"""

def connect(self, *others: ConjugateType) -> None:
"""
Form a connection between this and one or more other channels.
Expand All @@ -146,22 +140,30 @@ def connect(self, *others: ConjugateType) -> None:
for other in others:
if other in self.connections:
continue
elif self._valid_connection(other):
# Prepend new connections
# so that connection searches run newest to oldest
self.connections.insert(0, other)
other.connections.insert(0, self)
else:
if isinstance(other, self.connection_conjugate()):
elif isinstance(other, self.connection_conjugate()):
if self._valid_connection(other):
# Prepend new connections
# so that connection searches run newest to oldest
self.connections.insert(0, other)
other.connections.insert(0, self)
else:
raise ChannelConnectionError(
self._connection_conjugate_failure_message(other)
) from None
else:
raise TypeError(
f"Can only connect to {self.connection_conjugate()} "
f"objects, but {self.full_label} ({self.__class__}) "
f"got {other} ({type(other)})"
)
else:
raise TypeError(
f"Can only connect to {self.connection_conjugate()} "
f"objects, but {self.full_label} ({self.__class__}) "
f"got {other} ({type(other)})"
)

def _valid_connection(self, other: ConjugateType) -> bool:
"""
Logic for determining if a connection to a conjugate partner is valid.
Override in child classes as necessary.
"""
return True

def _connection_conjugate_failure_message(self, other: ConjugateType) -> str:
return (
Expand Down Expand Up @@ -466,21 +468,18 @@ def _value_is_data(self) -> bool:
def _has_hint(self) -> bool:
return self.type_hint is not None

def _valid_connection(self, other: object) -> bool:
if isinstance(other, self.connection_conjugate()):
if self._both_typed(other):
out, inp = self._figure_out_who_is_who(other)
if not inp.strict_hints:
return True
else:
return type_hint_is_as_or_more_specific_than(
out.type_hint, inp.type_hint
)
else:
# If either is untyped, don't do type checking
def _valid_connection(self, other: DataChannel) -> bool:
if self._both_typed(other):
out, inp = self._figure_out_who_is_who(other)
if not inp.strict_hints:
return True
else:
return type_hint_is_as_or_more_specific_than(
out.type_hint, inp.type_hint
)
else:
return False
# If either is untyped, don't do type checking
return True

def _connection_conjugate_failure_message(self, other: DataChannel) -> str:
msg = super()._connection_conjugate_failure_message(other)
Expand Down Expand Up @@ -599,9 +598,6 @@ class SignalChannel(FlavorChannel[SignalType], ABC):
def __call__(self) -> None:
pass

def _valid_connection(self, other: object) -> bool:
return isinstance(other, self.connection_conjugate())


class BadCallbackError(ValueError):
pass
Expand Down

0 comments on commit 3e8c92f

Please sign in to comment.