From cd1893f4b84ed898deb3852610ea761c70cc9a7a Mon Sep 17 00:00:00 2001 From: jlewitt1 Date: Wed, 15 Jan 2025 00:48:14 +0200 Subject: [PATCH] update handling of shared cluster creds --- runhouse/resources/hardware/cluster.py | 44 +++++-- .../resources/hardware/cluster_factory.py | 8 +- runhouse/resources/hardware/launcher_utils.py | 108 +++--------------- .../resources/hardware/on_demand_cluster.py | 10 +- runhouse/resources/hardware/utils.py | 4 + runhouse/rns/utils/api.py | 59 ---------- tests/fixtures/on_demand_cluster_fixtures.py | 23 +++- tests/fixtures/static_cluster_fixtures.py | 16 ++- tests/fixtures/utils.py | 8 ++ .../test_clusters/test_cluster.py | 2 - 10 files changed, 106 insertions(+), 176 deletions(-) diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 6fc6e2ee3..d42f14088 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -177,6 +177,7 @@ def __init__( self.reqs = [] self._setup_creds(creds) + self._is_shared = self._is_cluster_shared() if isinstance(image, dict): # If reloading from config (ex: in Den) @@ -343,7 +344,7 @@ def _setup_creds(self, ssh_creds: Union[Dict, "Secret", str]): return if not ssh_creds: - self._creds = self._setup_default_creds() if not self.is_shared else None + self._creds = self._setup_default_creds() elif isinstance(ssh_creds, Dict): creds, ssh_properties = _setup_creds_from_dict(ssh_creds, self.name) @@ -355,6 +356,10 @@ def _setup_default_creds(self): default_ssh_key = rns_client.default_ssh_key if default_ssh_key is None: + logger.warning( + "No default SSH key found in the local Runhouse config. To " + "create one please run `runhouse login`" + ) return None return Secret.from_name(default_ssh_key) @@ -552,15 +557,6 @@ def server_address(self): return self.head_ip - @property - def is_shared(self) -> bool: - rns_address = self.rns_address - if rns_address is None: - return False - - # If the cluster is shared, the base directory of the rns address will differ from the current username - return rns_client.base_folder(rns_address) != rns_client.username - def _command_runner( self, node: Optional[str] = None, use_docker_exec: Optional[bool] = False ) -> "CommandRunner": @@ -628,7 +624,7 @@ def up_if_not(self, verbose: bool = True): Example: >>> rh.cluster("rh-cpu").up_if_not() """ - if self.is_shared: + if self._is_shared: logger.warning( "Cannot up a shared cluster. Only cluster owners can perform this operation." ) @@ -649,6 +645,32 @@ def keep_warm(self): ) return self + def _is_cluster_shared(self) -> bool: + rns_address = self.rns_address + if rns_address is None: + return False + + # If shared the creds used for launching will be attributed to another user + creds: str = ( + self._resource_string_for_subconfig(self._creds, True) + if hasattr(self, "_creds") and self._creds + else None + ) + + if not creds or isinstance(creds, dict): + # specifying custom creds (for non on-demand clusters) + return False + + cluster_folder = rns_client.base_folder(rns_address) + cluster_creds_folder = rns_client.base_folder(creds) + + if cluster_folder != rns_client.username and ( + creds and cluster_creds_folder != rns_client.username + ): + return True + + return False + def _sync_image_to_cluster(self, parallel: bool = True): """ Image stuff that needs to happen over SSH because the daemon won't be up yet, so we can't diff --git a/runhouse/resources/hardware/cluster_factory.py b/runhouse/resources/hardware/cluster_factory.py index 866b5b9e3..92850c158 100644 --- a/runhouse/resources/hardware/cluster_factory.py +++ b/runhouse/resources/hardware/cluster_factory.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional, Union -from runhouse.globals import rns_client +from runhouse.globals import configs, rns_client from runhouse.logger import get_logger from runhouse.resources.hardware.cluster import Cluster @@ -303,6 +303,12 @@ def ondemand_cluster( >>> # Load cluster from above >>> reloaded_cluster = rh.ondemand_cluster(name="rh-4-a100s") """ + launcher = launcher.lower() if launcher else configs.launcher + if launcher not in LauncherType.strings(): + raise ValueError( + f"Invalid launcher type '{launcher}'. Must be one of {LauncherType.strings()}." + ) + if vpc_name and launcher == "local": raise ValueError( "Custom VPCs are not supported with local launching. To use a custom VPC, please use the " diff --git a/runhouse/resources/hardware/launcher_utils.py b/runhouse/resources/hardware/launcher_utils.py index 5052d5f25..643c4c2ea 100644 --- a/runhouse/resources/hardware/launcher_utils.py +++ b/runhouse/resources/hardware/launcher_utils.py @@ -1,6 +1,5 @@ import ast import logging -import shutil import sys from pathlib import Path from typing import Any, Optional @@ -8,7 +7,6 @@ import requests import runhouse as rh -from runhouse.constants import SSH_SKY_SECRET_NAME from runhouse.globals import configs, rns_client from runhouse.logger import get_logger from runhouse.resources.hardware.utils import ( @@ -16,7 +14,7 @@ ClusterStatus, SSEClient, ) -from runhouse.rns.utils.api import generate_ssh_keys, load_resp_content, read_resp_data +from runhouse.rns.utils.api import load_resp_content, read_resp_data from runhouse.utils import ClusterLogsFormatter, ColoredFormatter, Spinner logger = get_logger(__name__) @@ -105,18 +103,6 @@ def keep_warm(cls, cluster, mins: int): """Abstract method for keeping a cluster warm.""" raise NotImplementedError - @classmethod - def load_creds(cls): - """Loads the SSH credentials resource required for the launcher.""" - raise NotImplementedError - - @classmethod - def _create_sky_secret(cls): - """Generate a new set of Sky SSH keys (in the default Sky path).""" - private_key_path, _ = generate_and_write_ssh_keys(force=True) - secret = rh.provider_secret(provider="sky") - return secret - @staticmethod def supported_providers(): """Return the base list of Sky supported providers.""" @@ -331,41 +317,25 @@ def teardown(cls, cluster, verbose: bool = True): @classmethod def load_creds(cls): - """Loads the SSH credentials resource required for the Den launcher.""" + """Loads the SSH credentials required for the Den launcher, and for interacting with the cluster + once launched.""" default_ssh_key = rns_client.default_ssh_key - if default_ssh_key: - # try using the default SSH creds already set by the user - try: - secret = rh.Secret.from_name(default_ssh_key) - secret.write(overwrite=True) - return secret - except ValueError: - pass - + if not default_ssh_key: + raise ValueError( + "No default SSH key found in the local Runhouse config, " + "please set one by running `runhouse login`" + ) try: - # Try using default Sky keys saved in Den - secret = rh.Secret.from_name(SSH_SKY_SECRET_NAME) - if secret.values: - if not default_ssh_key: - # use the Sky SSH keys as the default going forward - configs.set("default_ssh_key", secret.name) - logger.info( - f"Updated default SSH key in the local Runhouse config to {secret.name}" - ) - secret.write(overwrite=True) - return secret + # Note: we still need to load them down locally to use for certain cluster operations (ex: rsync) + secret = rh.Secret.from_name(default_ssh_key) + if not Path(secret.path).expanduser().exists(): + # Ensure this specific keypair is written down locally + secret.write() + logger.info(f"Saved default SSH key locally in path: {secret.path}") except ValueError: - pass - - # if none are found create a new Sky SSH pair, save it Den, and set it as the default - secret = cls._create_sky_secret() - secret.save() - logger.info(f"Saved new SSH key pair in Den with name: {secret.rns_address}") - - if not default_ssh_key: - configs.set("default_ssh_key", secret.name) - logger.info( - f"Updated default SSH key in the local Runhouse config to {secret.name}" + raise ValueError( + "Failed to load default SSH key, " + "try re-saving by running `runhouse login`" ) return secret @@ -472,50 +442,6 @@ def keep_warm(cls, cluster, mins: int): set_cluster_autostop_cmd = _cluster_set_autostop_command(mins) cluster.run_bash_over_ssh([set_cluster_autostop_cmd], node=cluster.head_ip) - @classmethod - def load_creds(cls): - """Loads the SSH credentials resource required for the local launcher.""" - private_key_path = "~/.ssh/sky-key" - try: - # Note: doesn't require the secret to be saved in Den, just the SSH keypair existing in its local path - secret = rh.provider_secret(provider="sky") - if secret.path == private_key_path and secret.values: - return secret - - # Create a new sky key pair that will be saved to its default path - secret = cls._create_sky_secret() - return secret - - except ValueError: - sky_private_key = Path(private_key_path).expanduser() - sky_public_key = Path("~/.ssh/sky-key.pub").expanduser() - - # if SSH creds already exist, use those for Sky too - secret = rh.provider_secret(provider="ssh") - default_private_key_path = Path(secret.path).expanduser() - default_public_key_path = Path(f"{secret.path}.pub").expanduser() - - if ( - secret.values - and default_private_key_path.exists() - and default_public_key_path.exists() - ): - # Re-use the same SSH creds for Sky, which need to be saved in a specific path for Sky to recognize - shutil.copy(default_private_key_path, sky_private_key) - shutil.copy(default_public_key_path, sky_public_key) - logger.info( - f"Using existing SSH keys for local launching found in path: {default_private_key_path}" - ) - return cls.load_creds() - else: - # if no SSH creds found create a new Sky keypair - secret = cls._create_sky_secret() - logger.info( - f"Saved new SSH key pair for local launching in path: {secret.path}" - ) - - return secret - @staticmethod def _set_docker_env_vars(image, task): """Helper method to set Docker login environment variables.""" diff --git a/runhouse/resources/hardware/on_demand_cluster.py b/runhouse/resources/hardware/on_demand_cluster.py index 9500bf915..2b207aa6d 100644 --- a/runhouse/resources/hardware/on_demand_cluster.py +++ b/runhouse/resources/hardware/on_demand_cluster.py @@ -537,7 +537,7 @@ def _update_from_sky_status(self, dryrun: bool = False): return # Try to get the cluster status from SkyDB - if self.is_shared: + if self._is_shared: # If the cluster is shared can ignore, since the sky data will only be saved on the machine where # the cluster was initially upped return @@ -546,12 +546,10 @@ def _update_from_sky_status(self, dryrun: bool = False): self._populate_connection_from_status_dict(cluster_dict) def _setup_default_creds(self): + """Setup the default creds used in launching. For Den launching we load the default ssh creds, and for + local launching we let Sky handle it.""" if self.launcher == LauncherType.DEN: return DenLauncher.load_creds() - elif self.launcher == LauncherType.LOCAL: - return LocalLauncher.load_creds() - else: - raise ValueError(f"Invalid launcher '{self.launcher}'") def get_instance_type(self): """Returns instance type of the cluster.""" @@ -643,7 +641,7 @@ def up(self, verbose: bool = True, force: bool = False, start_server: bool = Tru if self.on_this_cluster(): return self - if self.is_shared: + if self._is_shared: logger.warning( "Cannot up a shared cluster. Only cluster owners can perform this operation." ) diff --git a/runhouse/resources/hardware/utils.py b/runhouse/resources/hardware/utils.py index 91a1990b7..a133eaaf5 100644 --- a/runhouse/resources/hardware/utils.py +++ b/runhouse/resources/hardware/utils.py @@ -50,6 +50,10 @@ class LauncherType(str, Enum): LOCAL = "local" DEN = "den" + @classmethod + def strings(cls): + return [s.value for s in cls] + class RunhouseDaemonStatus(str, Enum): RUNNING = "running" diff --git a/runhouse/rns/utils/api.py b/runhouse/rns/utils/api.py index 8db7dc111..6d8904fca 100644 --- a/runhouse/rns/utils/api.py +++ b/runhouse/rns/utils/api.py @@ -1,11 +1,9 @@ import ast import datetime -import functools import json import os import uuid from enum import Enum -from typing import Tuple from requests import Response @@ -89,63 +87,6 @@ def relative_file_path(file_path: str): return relative_path -def generate_and_write_ssh_keys(force: bool = False) -> Tuple[str, str]: - from cryptography.hazmat.backends import default_backend - from cryptography.hazmat.primitives import serialization - from cryptography.hazmat.primitives.asymmetric import rsa - - # Adapted from: https://github.com/skypilot-org/skypilot/blob/2c7419cdd22adb440bd37702be6f1ddc1c287684/sky/authentication.py#L106 # noqa - private_key_path = os.path.expanduser("~/.ssh/sky-key") - public_key_path = os.path.expanduser("~/.ssh/sky-key.pub") - - if not force and os.path.exists(private_key_path): - return private_key_path, public_key_path - - key = rsa.generate_private_key( - backend=default_backend(), public_exponent=65537, key_size=2048 - ) - - private_key = ( - key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption(), - ) - .decode("utf-8") - .strip() - ) - - public_key = ( - key.public_key() - .public_bytes( - serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH - ) - .decode("utf-8") - .strip() - ) - - key_dir = os.path.dirname(private_key_path) - os.makedirs(key_dir, exist_ok=True, mode=0o700) - - with open( - private_key_path, - "w", - encoding="utf-8", - opener=functools.partial(os.open, mode=0o600), - ) as f: - f.write(private_key) - - with open( - public_key_path, - "w", - encoding="utf-8", - opener=functools.partial(os.open, mode=0o644), - ) as f: - f.write(public_key) - - return private_key_path, public_key_path - - class ResourceAccess(str, Enum): WRITE = "write" READ = "read" diff --git a/tests/fixtures/on_demand_cluster_fixtures.py b/tests/fixtures/on_demand_cluster_fixtures.py index 28a037fa0..c6d0a9ede 100644 --- a/tests/fixtures/on_demand_cluster_fixtures.py +++ b/tests/fixtures/on_demand_cluster_fixtures.py @@ -23,9 +23,9 @@ def restart_server(request): def setup_test_cluster(args, request, test_rns_folder, setup_base=False): if request.config.getoption("--ci"): - rh.constants.SSH_SKY_SECRET_NAME = ( - f"{test_rns_folder}-{rh.constants.SSH_SKY_SECRET_NAME}" - ) + sky_key_name = f"{test_rns_folder}-{rh.constants.SSH_SKY_SECRET_NAME}" + rh.constants.SSH_SKY_SECRET_NAME = sky_key_name + cluster = rh.ondemand_cluster(**args) init_args[id(cluster)] = args cluster.up_if_not() @@ -131,6 +131,11 @@ def den_launched_ondemand_aws_docker_cluster(request, test_rns_folder): "launcher": LauncherType.DEN, } + if request.config.getoption("--ci"): + from tests.fixtures.utils import save_default_ssh_creds + + save_default_ssh_creds() + cluster = setup_test_cluster( args, request, setup_base=True, test_rns_folder=test_rns_folder ) @@ -240,6 +245,12 @@ def den_launched_ondemand_aws_k8s_cluster(request, test_rns_folder): "launcher": LauncherType.DEN, "context": os.getenv("EKS_ARN"), } + + if request.config.getoption("--ci"): + from tests.fixtures.utils import save_default_ssh_creds + + save_default_ssh_creds() + cluster = setup_test_cluster(args, request, test_rns_folder=test_rns_folder) yield cluster if not request.config.getoption("--detached"): @@ -264,6 +275,12 @@ def den_launched_ondemand_gcp_k8s_cluster(request, test_rns_folder): "launcher": LauncherType.DEN, "context": "gke_testing", } + + if request.config.getoption("--ci"): + from tests.fixtures.utils import save_default_ssh_creds + + save_default_ssh_creds() + cluster = setup_test_cluster(args, request, test_rns_folder=test_rns_folder) yield cluster if not request.config.getoption("--detached"): diff --git a/tests/fixtures/static_cluster_fixtures.py b/tests/fixtures/static_cluster_fixtures.py index 92e8a64c7..3e47e903c 100644 --- a/tests/fixtures/static_cluster_fixtures.py +++ b/tests/fixtures/static_cluster_fixtures.py @@ -21,9 +21,8 @@ def setup_static_cluster( compute_type: computeType = computeType.cpu, remote: bool = False, # whether the fixture is used on a remote-running test or on a local one. ): - rh.constants.SSH_SKY_SECRET_NAME = ( - f"{test_rns_folder}-{rh.constants.SSH_SKY_SECRET_NAME}" - ) + sky_key_name = f"{test_rns_folder}-{rh.constants.SSH_SKY_SECRET_NAME}" + rh.constants.SSH_SKY_SECRET_NAME = sky_key_name instance_type = "CPU:4" if compute_type == computeType.cpu else "g5.xlarge" launcher = launcher if launcher else LauncherType.LOCAL cluster_name = ( @@ -31,6 +30,7 @@ def setup_static_cluster( if not remote else f"{test_rns_folder}-{launcher}-aws-{compute_type}-password" ) + cluster = rh.cluster( name=cluster_name, instance_type=instance_type, @@ -97,6 +97,11 @@ def static_cpu_pwd_cluster(request, test_rns_folder): @pytest.fixture(scope="session") def static_cpu_pwd_cluster_den_launcher(request, test_rns_folder): + if request.config.getoption("--ci"): + from tests.fixtures.utils import save_default_ssh_creds + + save_default_ssh_creds() + cluster = setup_static_cluster( launcher=LauncherType.DEN, test_rns_folder=test_rns_folder, @@ -113,6 +118,11 @@ def static_cpu_pwd_cluster_den_launcher(request, test_rns_folder): @pytest.fixture(scope="session") def static_gpu_pwd_cluster_den_launcher(request, test_rns_folder): + if request.config.getoption("--ci"): + from tests.fixtures.utils import save_default_ssh_creds + + save_default_ssh_creds() + cluster = setup_static_cluster( launcher=LauncherType.DEN, compute_type=computeType.gpu, diff --git a/tests/fixtures/utils.py b/tests/fixtures/utils.py index 79dc8354d..b40728943 100644 --- a/tests/fixtures/utils.py +++ b/tests/fixtures/utils.py @@ -12,3 +12,11 @@ def create_gcs_bucket(bucket_name: str): gcs_store = GcsStore(name=bucket_name, source="") return gcs_store + + +def save_default_ssh_creds(): + """Save default creds required by the Den launcher.""" + import runhouse as rh + + sky_creds = rh.provider_secret("sky").save() + rh.configs.set("default_ssh_key", sky_creds.name) diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index 09e297f3d..4f04d3a66 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -383,9 +383,7 @@ def test_sharing(self, cluster, friend_account_logged_in_docker_cluster_pk_ssh): resource_class_name, cluster.rns_address ) - assert new_creds is None orig_creds = config.pop("creds", None) - assert sky_secret in orig_creds or generated_secret in orig_creds assert new_config == config