Skip to content

Commit

Permalink
Fixed bug. Before this commit, it was possible to use contexts which …
Browse files Browse the repository at this point in the history
…had no more attempts
  • Loading branch information
LanderOtto committed Jan 17, 2025
1 parent 7c29d2b commit c017f7e
Showing 1 changed file with 24 additions and 27 deletions.
51 changes: 24 additions & 27 deletions streamflow/deployment/connector/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,12 @@ def __init__(
max_concurrent_sessions: int,
):
self._streamflow_config_dir: str = streamflow_config_dir
self._closing: bool = False
self._config: SSHConfig = config
self._max_concurrent_sessions: int = max_concurrent_sessions
self._ssh_connection: asyncssh.SSHClientConnection | None = None
self._connecting: bool = False
self._connect_event: asyncio.Event = asyncio.Event()
self.ssh_attempts: int = 0
self.connection_attempts: int = 0

async def _get_connection(
self, config: SSHConfig
Expand Down Expand Up @@ -93,34 +92,24 @@ def _get_param_from_file(self, file_path: str):
return f.read().strip()

async def close(self):
self._closing = True
if self._ssh_connection is not None:
max_times = 0
while len(self._ssh_connection._channels) > 0:
await asyncio.sleep(5)
max_times += 1
if max_times > 5:
logger.warning(
f"Closing the SSH connection {self.get_hostname()} is running, but the connection "
f"has had open channels for too long. Forcing closure."
)
break
if len(self._ssh_connection._channels) > 0:
logger.warning(
f"Closing the SSH connection {self.get_hostname()} is running, but the connection "
f"has had open channels for too long. Forcing closure."
)
self._ssh_connection.close()
await self._ssh_connection.wait_closed()
self._ssh_connection = None
self._connecting = False

def full(self) -> bool:
return self._closing or (
return (
self._ssh_connection
and len(self._ssh_connection._channels) >= self._max_concurrent_sessions
)

async def get_connection(self) -> asyncssh.SSHClientConnection:
if self._closing:
raise WorkflowExecutionException(
f"Connecting to a closed SSH context {self.get_hostname()}"
)
if self._ssh_connection is None:
if not self._connecting:
self._connecting = True
Expand All @@ -146,13 +135,9 @@ async def get_connection(self) -> asyncssh.SSHClientConnection:
def get_hostname(self) -> str:
return self._config.hostname

def is_closed(self) -> bool:
return self._closing

async def reset(self):
await self.close()
self.ssh_attempts += 1
self._closing = False
self.connection_attempts += 1
self._connect_event.clear()


Expand Down Expand Up @@ -185,14 +170,26 @@ def __init__(

async def __aenter__(self) -> asyncssh.SSHClientProcess:
async with self._condition:
available_contexts = self._contexts
while True:
if all(c.ssh_attempts > self._retries for c in self._contexts):
if (
len(
available_contexts := [
c
for c in available_contexts
if c.connection_attempts < self._retries
]
)
== 0
):
raise WorkflowExecutionException(
f"Hosts {[c.get_hostname() for c in self._contexts]} have no "
f"more available contexts: terminating."
)
elif (
len(free_contexts := [c for c in self._contexts if not c.full()])
len(
free_contexts := [c for c in available_contexts if not c.full()]
)
== 0
):
await self._condition.wait()
Expand All @@ -210,7 +207,7 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
encoding=self.encoding,
)
await self._proc.__aenter__()
self._selected_context.ssh_attempts = 0
self._selected_context.connection_attempts = 0
return self._proc
except (
ChannelOpenError,
Expand All @@ -226,7 +223,7 @@ async def __aenter__(self) -> asyncssh.SSHClientProcess:
if not isinstance(exc, ChannelOpenError):
if logger.isEnabledFor(logging.WARNING):
logger.warning(
f"Connection to {context.get_hostname()} attempts: {context.ssh_attempts} "
f"Connection to {context.get_hostname()} attempts: {context.connection_attempts} "
)
self._selected_context = None
await context.reset()
Expand Down

0 comments on commit c017f7e

Please sign in to comment.