Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to symlink from remote dir in packager #122

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 42 additions & 5 deletions src/nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel.callback import Callback
from nemo_run.core.tunnel.client import LocalTunnel, SSHConfigFile, SSHTunnel, Tunnel
from nemo_run.core.tunnel.client import (
Callback,
LocalTunnel,
PackagingJob,
SSHConfigFile,
SSHTunnel,
Tunnel,
)
from nemo_run.core.tunnel.server import TunnelMetadata, server_dir
from nemo_run.devspace.base import DevSpace

Expand Down Expand Up @@ -388,7 +394,7 @@ def __post_init__(self):
self.wait_time_for_group_job = 0

def info(self) -> str:
return f"{self.__class__.__qualname__} on {self.tunnel._key}"
return f"{self.__class__.__qualname__} on {self.tunnel.key}"

def alloc(self, job_name="interactive"):
self.job_name = f"{self.job_name_prefix}{job_name}"
Expand Down Expand Up @@ -537,13 +543,39 @@ def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
return filenames

def package(self, packager: Packager, job_name: str):
if job_name in self.tunnel.packaging_jobs:
if job_name in self.tunnel.packaging_jobs and not packager.symlink_from_remote_dir:
logger.info(
f"Packaging for job {job_name} in tunnel {self.tunnel} already done. Skipping subsequent packagings.\n"
"This may cause issues if you have multiple tasks with the same name but different packagers, as only the first packager will be used."
)
return

if packager.symlink_from_remote_dir:
logger.info(
f"Packager {packager} is configured to symlink from remote dir. Skipping packaging."
)
if type(packager) is Packager:
self.tunnel.packaging_jobs[job_name] = PackagingJob(symlink=False)
return

self.tunnel.packaging_jobs[job_name] = PackagingJob(
symlink=True,
src_path=packager.symlink_from_remote_dir,
dst_path=os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"),
)

# Tunnel job dir is the directory of the experiment id, so the base job dir is two levels up
base_remote_dir = str(Path(self.tunnel.job_dir).parent.parent)
base_remote_mount = f"{base_remote_dir}:{base_remote_dir}"
if base_remote_mount not in self.container_mounts:
self.container_mounts.append(f"{base_remote_dir}:{base_remote_dir}")

for req in self.resource_group:
if base_remote_mount not in req.container_mounts:
req.container_mounts.append(base_remote_mount)

return

assert self.experiment_id, "Executor not assigned to an experiment."
if isinstance(packager, GitArchivePackager):
output = subprocess.run(
Expand Down Expand Up @@ -573,7 +605,12 @@ def package(self, packager: Packager, job_name: str):
f"tar -xvzf {local_pkg} -C {local_code_extraction_path} --ignore-zeros", hide=True
)

self.tunnel.packaging_jobs.add(job_name)
self.tunnel.packaging_jobs[job_name] = PackagingJob(
symlink=False,
dst_path=None
if type(packager) is Packager
else os.path.join(self.tunnel.job_dir, Path(self.job_dir).name, "code"),
)

def parse_deps(self) -> list[str]:
"""
Expand Down
6 changes: 5 additions & 1 deletion src/nemo_run/core/packaging/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from dataclasses import dataclass
from pathlib import Path

from typing import Optional

from nemo_run.config import ConfigurableMixin

Expand Down Expand Up @@ -45,6 +45,10 @@ class Packager(ConfigurableMixin):
#: Uses component or executor specific debug flags if set to True.
debug: bool = False

#: Symlinks the package from the provided remote dir.
#: Only applicable when using SlurmExecutor at the moment.
symlink_from_remote_dir: Optional[str] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to do the same for local runs as well. But not as important as current changes, so we can keep it in the backlog

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #126


def package(self, path: Path, job_dir: str, name: str) -> str: ...

def setup(self):
Expand Down
45 changes: 0 additions & 45 deletions src/nemo_run/core/tunnel/callback.py

This file was deleted.

53 changes: 41 additions & 12 deletions src/nemo_run/core/tunnel/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,17 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from typing import Callable, Optional

import paramiko
import paramiko.ssh_exception
from fabric import Config, Connection
from invoke.context import Context
from invoke.runners import Result as RunResult

from nemo_run.config import NEMORUN_HOME
from nemo_run.config import NEMORUN_HOME, ConfigurableMixin
from nemo_run.core.frontend.console.api import CONSOLE

if TYPE_CHECKING:
from nemo_run.core.tunnel.callback import Callback

logger: logging.Logger = logging.getLogger(__name__)
TUNNEL_DIR = ".tunnels"
TUNNEL_FILE_SUBPATH = os.path.join(NEMORUN_HOME, TUNNEL_DIR)
Expand All @@ -58,18 +55,24 @@ def authentication_handler(title, instructions, prompt_list):


@dataclass(kw_only=True)
class Tunnel(ABC):
class PackagingJob(ConfigurableMixin):
symlink: bool = False
src_path: Optional[str] = None
dst_path: Optional[str] = None

def symlink_cmd(self):
return f"ln -s {self.src_path} {self.dst_path}"


@dataclass(kw_only=True)
class Tunnel(ABC, ConfigurableMixin):
job_dir: str
host: str
user: str
packaging_jobs: dict[str, PackagingJob] = field(default_factory=dict)

def __post_init__(self):
self._key = f"{self.user}@{self.host}"
self._packaging_jobs = set()

@property
def packaging_jobs(self):
return self._packaging_jobs
self.key = f"{self.user}@{self.host}"

def _set_job_dir(self, experiment_id: str): ...

Expand Down Expand Up @@ -377,3 +380,29 @@ def remove_entry(self, name: str):
file.writelines(lines)

print(f"Removed SSH config entry for {host}.")


class Callback:
def setup(self, tunnel: "Tunnel"):
"""Called when the tunnel is setup."""
self.tunnel = tunnel

def on_start(self):
"""Called when the keep_alive loop starts."""
pass

def on_interval(self):
"""Called at each interval during the keep_alive loop."""
pass

def on_stop(self):
"""Called when the keep_alive loop stops."""
pass

def on_error(self, error: Exception):
"""Called when an error occurs during the keep_alive loop.

Args:
error (Exception): The exception that was raised.
"""
pass
2 changes: 1 addition & 1 deletion src/nemo_run/devspace/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import fiddle as fdl

from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.core.tunnel.callback import Callback
from nemo_run.core.tunnel.client import Callback

if TYPE_CHECKING:
from nemo_run.core.execution.base import Executor
Expand Down
33 changes: 32 additions & 1 deletion src/nemo_run/run/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class Experiment(ConfigurableMixin):
_VERSION_FILE = "_VERSION"
_TASK_FILE = "_TASKS"
_DONE_FILE = "_DONE"
_TUNNELS_FILE = "_TUNNELS"
_current_experiment_token: Optional[contextvars.Token]

@classmethod
Expand Down Expand Up @@ -221,6 +222,12 @@ def _from_config(cls: Type["Experiment"], exp_dir: str) -> "Experiment":

exp: "Experiment" = fdl.build(cfg)
exp._jobs = exp._load_jobs()
try:
exp.tunnels = exp._load_tunnels()
except Exception as e:
exp.console.log(
f"Exception {e} loading tunnels for experiment {id}, will continue without loading tunnels."
)

return exp

Expand Down Expand Up @@ -327,6 +334,20 @@ def _save_config(self):
with open(os.path.join(self._exp_dir, self.__class__._VERSION_FILE), "w+") as f:
f.write(f"{run.__version__}\n")

def _save_tunnels(self):
serializer = ZlibJSONSerializer()
serialized_tunnels = {
k: serializer.serialize(v.to_config()) for k, v in self.tunnels.items()
}
with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE), "w+") as f:
json.dump(serialized_tunnels, f)

def _load_tunnels(self) -> dict[str, Tunnel]:
with open(os.path.join(self._exp_dir, self.__class__._TUNNELS_FILE)) as f:
serialized_tunnels = json.load(f)
serializer = ZlibJSONSerializer()
return {k: fdl.build(serializer.deserialize(v)) for k, v in serialized_tunnels.items()}

def _save_jobs(self):
serialized_jobs = list(map(lambda job: job.serialize(), self.jobs))
with open(os.path.join(self._exp_dir, self.__class__._TASK_FILE), "w+") as f:
Expand Down Expand Up @@ -645,9 +666,19 @@ def run(
for tunnel in self.tunnels.values():
if isinstance(tunnel, SSHTunnel):
tunnel.connect()
assert tunnel.session, f"SSH tunnel {tunnel._key} failed to connect."
assert tunnel.session, f"SSH tunnel {tunnel.key} failed to connect."
rsync(tunnel.session, self._exp_dir, os.path.dirname(tunnel.job_dir))

symlink_cmds = []
for packaging_job in tunnel.packaging_jobs.values():
if packaging_job.symlink:
symlink_cmds.append(packaging_job.symlink_cmd())

if symlink_cmds:
tunnel.run(" && ".join(symlink_cmds))

self._save_tunnels()

return self._run_dag(detach=detach, tail_logs=tail_logs, executors=executors)

def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]):
Expand Down
5 changes: 0 additions & 5 deletions src/nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ def package(

args.append(fn_or_script_filename)
else:
args += [
"-p",
_serialize(executor.packager.to_config()),
]

args.append(_serialize(fn_or_script))

role_args = default_cmd + args
Expand Down
10 changes: 5 additions & 5 deletions src/nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def _initialize_tunnel(self, tunnel: SSHTunnel | LocalTunnel):
return

experiment = run_experiment._current_experiment.get(None)
if experiment and tunnel._key in experiment.tunnels:
self.tunnel = experiment.tunnels[tunnel._key]
if experiment and tunnel.key in experiment.tunnels:
self.tunnel = experiment.tunnels[tunnel.key]
return

self.tunnel = tunnel

if experiment:
experiment.tunnels[tunnel._key] = self.tunnel
experiment.tunnels[tunnel.key] = self.tunnel

def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # type: ignore
assert isinstance(cfg, SlurmExecutor), f"{cfg.__class__} not supported for slurm scheduler."
Expand All @@ -96,6 +96,8 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
partition = executor.partition
assert partition is None or isinstance(partition, str), "partition must be str"

executor.package(packager=executor.packager, job_name=Path(job_dir).name)

srun_cmds: list[list[str]] = []
jobs = []
envs = {}
Expand Down Expand Up @@ -137,8 +139,6 @@ def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[Any]: # t
with open(path, "w") as f:
f.write(script)

executor.package(packager=executor.packager, job_name=Path(job_dir).name)

return AppDryRunInfo(req, repr)

def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str: # type: ignore
Expand Down
2 changes: 0 additions & 2 deletions test/run/torchx_backend/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def test_package_partial(mock_executor):
"nemo_run.core.runners.fdl_runner",
"-n",
"test",
"-p",
"eJzdVE1Lw0AQ_StlL60goelRrKCePAgevJUSNtlJunazG_ajGEr_uzvbRNPYtBU9eQnZYea9N59bopWy5Ga0JbauwP8QDTm5HpE11PiqaLamBegkJjtvVekbZNbsA1wlwNs7wV8oVd3glIo5EUyp48JyadAqaRlsASMgcwsl4i4W5EkyeJ9w_M6nV-jOIHUFWS4RDyxl1FLvKp0QGMp4Zn-pAyEOZfS4_jTFY3l8JjL7B4lgOKOpHw-r6Qa08QPU-v2gUzlnTECUGJ1FmZI5L7p6HlqS15bjuZXSG6h7a_UEw-bjXCZKJ5kwYxysk-wSSpVoJz21hmi_CFwWUUoNdHW8NCtCdr4cB2RUF64EaRN89hkP96xdpmEMS4vTEM0aDCOsuLFK1-dBZh5kaGz-tk_YqM6JuXQwOq3psz3uLcMTEG5JrwYCaGDYUOHQkFNhIEizq4BAbvFO3kXNITpRnsNynlmE3REOj45mj5MpJ7_d2rh78OLu0YgvWby4X_AYydCTK4mKp9E08sI-AMy28Wg=",
"eJzdlEtPwzAMgP_KlMuGhKp1RwRIcOOAxIHbNEVuk21haVKlzrRq2n8nDuto9wSxE5eqduLPdvxYM2ctsrvemmFdyvDDnJyy2x5byJok4Yui5iAET9kmqG32IXOsvix8qWXQt6y_MWW9BRVWeB1VmVcalalIa6CIusiIZIWyIO54zF6MkKuBou_D8IauA5vc9roHaTzI2GRCTiSCAIRgb7zWxBMqxz8GR4hubHu-rpr3sTx2iYz-QSJkLiALPYMOltJV0vHm3i8qNVVCaJnwyuVJbs1UzdrxPDdO3hsfr00oe132hOgGZPbQnxpuHc911aemOusdrcvnK55BvpBGJCgr5GUQYKZMJ5Dd5LBN7N2WO3AzX0iDnMR9n935a2bsNANhdh6xHYTThLmqQlb1ZcgoQE41znUrFfu-tXp-2htGFpY7b464ewOHCvSZLoC9F9ASIn4J2pMiDf8l4DxasnvanI9J2EwHL5tdAI2OgTICnXzdbjUuTNLmCD_QSR04ufXmYIOn7Y25E0Zb4eLkpgf1SskbXVXWUMjDZJiEyD4BcFMKTQ==",
]

Expand Down
8 changes: 6 additions & 2 deletions test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from dataclasses import dataclass
from unittest.mock import Mock

import nemo_run as run
import pytest

import nemo_run as run
from nemo_run.api import dryrun_fn


Expand Down Expand Up @@ -117,7 +118,10 @@ def test_dryrun_fn_with_executor(self, capsys, configured_fn):

captured = capsys.readouterr()
assert "Dry run for task test.test_api:some_fn" in captured.out
assert "LocalExecutor(packager=Packager(debug=False)" in captured.out
assert (
"LocalExecutor(packager=Packager(debug=False, symlink_from_remote_dir=None)"
in captured.out
)

def test_dryrun_fn_with_build(self, mocker, configured_fn):
build_mock = Mock()
Expand Down
Loading