Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Hemil Desai <[email protected]>
  • Loading branch information
hemildesai committed Dec 16, 2024
1 parent 9f852c3 commit d03963a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 52 deletions.
10 changes: 8 additions & 2 deletions src/nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel.callback import Callback
from nemo_run.core.tunnel.client import LocalTunnel, PackagingJob, SSHConfigFile, SSHTunnel, Tunnel
from nemo_run.core.tunnel.client import (
Callback,
LocalTunnel,
PackagingJob,
SSHConfigFile,
SSHTunnel,
Tunnel,
)
from nemo_run.core.tunnel.server import TunnelMetadata, server_dir
from nemo_run.devspace.base import DevSpace

Expand Down
45 changes: 0 additions & 45 deletions src/nemo_run/core/tunnel/callback.py

This file was deleted.

31 changes: 27 additions & 4 deletions src/nemo_run/core/tunnel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from typing import Callable, Optional

import paramiko
import paramiko.ssh_exception
Expand All @@ -35,9 +35,6 @@
from nemo_run.config import NEMORUN_HOME, ConfigurableMixin
from nemo_run.core.frontend.console.api import CONSOLE

if TYPE_CHECKING:
from nemo_run.core.tunnel.callback import Callback

logger: logging.Logger = logging.getLogger(__name__)
TUNNEL_DIR = ".tunnels"
TUNNEL_FILE_SUBPATH = os.path.join(NEMORUN_HOME, TUNNEL_DIR)
Expand Down Expand Up @@ -383,3 +380,29 @@ def remove_entry(self, name: str):
file.writelines(lines)

print(f"Removed SSH config entry for {host}.")


class Callback:
def setup(self, tunnel: "Tunnel"):
"""Called when the tunnel is setup."""
self.tunnel = tunnel

def on_start(self):
"""Called when the keep_alive loop starts."""
pass

def on_interval(self):
"""Called at each interval during the keep_alive loop."""
pass

def on_stop(self):
"""Called when the keep_alive loop stops."""
pass

def on_error(self, error: Exception):
"""Called when an error occurs during the keep_alive loop.
Args:
error (Exception): The exception that was raised.
"""
pass
2 changes: 1 addition & 1 deletion src/nemo_run/devspace/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import fiddle as fdl

from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel.callback import Callback
from nemo_run.core.tunnel.client import Callback

if TYPE_CHECKING:
from nemo_run.core.execution.base import Executor
Expand Down

0 comments on commit d03963a

Please sign in to comment.