From d03963a7389926fbef93bcd15052e1aebe4232c4 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Mon, 16 Dec 2024 09:23:37 -0800 Subject: [PATCH] Fix Signed-off-by: Hemil Desai --- src/nemo_run/core/execution/slurm.py | 10 +++++-- src/nemo_run/core/tunnel/callback.py | 45 ---------------------------- src/nemo_run/core/tunnel/client.py | 31 ++++++++++++++++--- src/nemo_run/devspace/base.py | 2 +- 4 files changed, 36 insertions(+), 52 deletions(-) delete mode 100644 src/nemo_run/core/tunnel/callback.py diff --git a/src/nemo_run/core/execution/slurm.py b/src/nemo_run/core/execution/slurm.py index 871a804..213cd3e 100644 --- a/src/nemo_run/core/execution/slurm.py +++ b/src/nemo_run/core/execution/slurm.py @@ -42,8 +42,14 @@ from nemo_run.core.packaging.base import Packager from nemo_run.core.packaging.git import GitArchivePackager from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer -from nemo_run.core.tunnel.callback import Callback -from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHConfigFile, SSHTunnel, Tunnel +from nemo_run.core.tunnel.client import ( + Callback, + LocalTunnel, + PackagingJob, + SSHConfigFile, + SSHTunnel, + Tunnel, +) from nemo_run.core.tunnel.server import TunnelMetadata, server_dir from nemo_run.devspace.base import DevSpace diff --git a/src/nemo_run/core/tunnel/callback.py b/src/nemo_run/core/tunnel/callback.py deleted file mode 100644 index bc6949f..0000000 --- a/src/nemo_run/core/tunnel/callback.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from nemo_run.core.tunnel.client import Tunnel - - -class Callback: - def setup(self, tunnel: "Tunnel"): - """Called when the tunnel is setup.""" - self.tunnel = tunnel - - def on_start(self): - """Called when the keep_alive loop starts.""" - pass - - def on_interval(self): - """Called at each interval during the keep_alive loop.""" - pass - - def on_stop(self): - """Called when the keep_alive loop stops.""" - pass - - def on_error(self, error: Exception): - """Called when an error occurs during the keep_alive loop. - - Args: - error (Exception): The exception that was raised. - """ - pass diff --git a/src/nemo_run/core/tunnel/client.py b/src/nemo_run/core/tunnel/client.py index 6430bb1..7abb31c 100644 --- a/src/nemo_run/core/tunnel/client.py +++ b/src/nemo_run/core/tunnel/client.py @@ -24,7 +24,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional +from typing import Callable, Optional import paramiko import paramiko.ssh_exception @@ -35,9 +35,6 @@ from nemo_run.config import NEMORUN_HOME, ConfigurableMixin from nemo_run.core.frontend.console.api import CONSOLE -if TYPE_CHECKING: - from nemo_run.core.tunnel.callback import Callback - logger: logging.Logger = logging.getLogger(__name__) TUNNEL_DIR = ".tunnels" TUNNEL_FILE_SUBPATH = os.path.join(NEMORUN_HOME, TUNNEL_DIR) @@ -383,3 +380,29 @@ def remove_entry(self, name: str): file.writelines(lines) print(f"Removed SSH config entry for {host}.") + + +class Callback: + def setup(self, tunnel: "Tunnel"): + """Called when the tunnel is setup.""" + self.tunnel = tunnel + + def on_start(self): + """Called when the keep_alive loop starts.""" + pass + + def on_interval(self): + """Called at each interval during the keep_alive loop.""" + pass + + def on_stop(self): + """Called when the keep_alive loop stops.""" + pass + + def on_error(self, error: Exception): + """Called when an error occurs during the keep_alive loop. + + Args: + error (Exception): The exception that was raised. + """ + pass diff --git a/src/nemo_run/devspace/base.py b/src/nemo_run/devspace/base.py index 03f436a..bccc99c 100644 --- a/src/nemo_run/devspace/base.py +++ b/src/nemo_run/devspace/base.py @@ -19,7 +19,7 @@ import fiddle as fdl from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer -from nemo_run.core.tunnel.callback import Callback +from nemo_run.core.tunnel.client import Callback if TYPE_CHECKING: from nemo_run.core.execution.base import Executor