Skip to content

Commit

Permalink
update handling of shared cluster creds
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 committed Jan 28, 2025
1 parent 2817f9e commit cd1893f
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 176 deletions.
44 changes: 33 additions & 11 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
self.reqs = []

self._setup_creds(creds)
self._is_shared = self._is_cluster_shared()

if isinstance(image, dict):
# If reloading from config (ex: in Den)
Expand Down Expand Up @@ -343,7 +344,7 @@ def _setup_creds(self, ssh_creds: Union[Dict, "Secret", str]):
return

if not ssh_creds:
self._creds = self._setup_default_creds() if not self.is_shared else None
self._creds = self._setup_default_creds()

elif isinstance(ssh_creds, Dict):
creds, ssh_properties = _setup_creds_from_dict(ssh_creds, self.name)
Expand All @@ -355,6 +356,10 @@ def _setup_default_creds(self):

default_ssh_key = rns_client.default_ssh_key
if default_ssh_key is None:
logger.warning(
"No default SSH key found in the local Runhouse config. To "
"create one please run `runhouse login`"
)
return None

return Secret.from_name(default_ssh_key)
Expand Down Expand Up @@ -552,15 +557,6 @@ def server_address(self):

return self.head_ip

@property
def is_shared(self) -> bool:
rns_address = self.rns_address
if rns_address is None:
return False

# If the cluster is shared, the base directory of the rns address will differ from the current username
return rns_client.base_folder(rns_address) != rns_client.username

def _command_runner(
self, node: Optional[str] = None, use_docker_exec: Optional[bool] = False
) -> "CommandRunner":
Expand Down Expand Up @@ -628,7 +624,7 @@ def up_if_not(self, verbose: bool = True):
Example:
>>> rh.cluster("rh-cpu").up_if_not()
"""
if self.is_shared:
if self._is_shared:
logger.warning(
"Cannot up a shared cluster. Only cluster owners can perform this operation."
)
Expand All @@ -649,6 +645,32 @@ def keep_warm(self):
)
return self

def _is_cluster_shared(self) -> bool:
rns_address = self.rns_address
if rns_address is None:
return False

# If shared the creds used for launching will be attributed to another user
creds: str = (
self._resource_string_for_subconfig(self._creds, True)
if hasattr(self, "_creds") and self._creds
else None
)

if not creds or isinstance(creds, dict):
# specifying custom creds (for non on-demand clusters)
return False

cluster_folder = rns_client.base_folder(rns_address)
cluster_creds_folder = rns_client.base_folder(creds)

if cluster_folder != rns_client.username and (
creds and cluster_creds_folder != rns_client.username
):
return True

return False

def _sync_image_to_cluster(self, parallel: bool = True):
"""
Image stuff that needs to happen over SSH because the daemon won't be up yet, so we can't
Expand Down
8 changes: 7 additions & 1 deletion runhouse/resources/hardware/cluster_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional, Union

from runhouse.globals import rns_client
from runhouse.globals import configs, rns_client

from runhouse.logger import get_logger
from runhouse.resources.hardware.cluster import Cluster
Expand Down Expand Up @@ -303,6 +303,12 @@ def ondemand_cluster(
>>> # Load cluster from above
>>> reloaded_cluster = rh.ondemand_cluster(name="rh-4-a100s")
"""
launcher = launcher.lower() if launcher else configs.launcher
if launcher not in LauncherType.strings():
raise ValueError(
f"Invalid launcher type '{launcher}'. Must be one of {LauncherType.strings()}."
)

if vpc_name and launcher == "local":
raise ValueError(
"Custom VPCs are not supported with local launching. To use a custom VPC, please use the "
Expand Down
108 changes: 17 additions & 91 deletions runhouse/resources/hardware/launcher_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import ast
import logging
import shutil
import sys
from pathlib import Path
from typing import Any, Optional

import requests

import runhouse as rh
from runhouse.constants import SSH_SKY_SECRET_NAME
from runhouse.globals import configs, rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware.utils import (
_cluster_set_autostop_command,
ClusterStatus,
SSEClient,
)
from runhouse.rns.utils.api import generate_ssh_keys, load_resp_content, read_resp_data
from runhouse.rns.utils.api import load_resp_content, read_resp_data
from runhouse.utils import ClusterLogsFormatter, ColoredFormatter, Spinner

logger = get_logger(__name__)
Expand Down Expand Up @@ -105,18 +103,6 @@ def keep_warm(cls, cluster, mins: int):
"""Abstract method for keeping a cluster warm."""
raise NotImplementedError

@classmethod
def load_creds(cls):
"""Loads the SSH credentials resource required for the launcher."""
raise NotImplementedError

@classmethod
def _create_sky_secret(cls):
"""Generate a new set of Sky SSH keys (in the default Sky path)."""
private_key_path, _ = generate_and_write_ssh_keys(force=True)
secret = rh.provider_secret(provider="sky")
return secret

@staticmethod
def supported_providers():
"""Return the base list of Sky supported providers."""
Expand Down Expand Up @@ -331,41 +317,25 @@ def teardown(cls, cluster, verbose: bool = True):

@classmethod
def load_creds(cls):
"""Loads the SSH credentials resource required for the Den launcher."""
"""Loads the SSH credentials required for the Den launcher, and for interacting with the cluster
once launched."""
default_ssh_key = rns_client.default_ssh_key
if default_ssh_key:
# try using the default SSH creds already set by the user
try:
secret = rh.Secret.from_name(default_ssh_key)
secret.write(overwrite=True)
return secret
except ValueError:
pass

if not default_ssh_key:
raise ValueError(
"No default SSH key found in the local Runhouse config, "
"please set one by running `runhouse login`"
)
try:
# Try using default Sky keys saved in Den
secret = rh.Secret.from_name(SSH_SKY_SECRET_NAME)
if secret.values:
if not default_ssh_key:
# use the Sky SSH keys as the default going forward
configs.set("default_ssh_key", secret.name)
logger.info(
f"Updated default SSH key in the local Runhouse config to {secret.name}"
)
secret.write(overwrite=True)
return secret
# Note: we still need to load them down locally to use for certain cluster operations (ex: rsync)
secret = rh.Secret.from_name(default_ssh_key)
if not Path(secret.path).expanduser().exists():
# Ensure this specific keypair is written down locally
secret.write()
logger.info(f"Saved default SSH key locally in path: {secret.path}")
except ValueError:
pass

# if none are found create a new Sky SSH pair, save it Den, and set it as the default
secret = cls._create_sky_secret()
secret.save()
logger.info(f"Saved new SSH key pair in Den with name: {secret.rns_address}")

if not default_ssh_key:
configs.set("default_ssh_key", secret.name)
logger.info(
f"Updated default SSH key in the local Runhouse config to {secret.name}"
raise ValueError(
"Failed to load default SSH key, "
"try re-saving by running `runhouse login`"
)

return secret
Expand Down Expand Up @@ -472,50 +442,6 @@ def keep_warm(cls, cluster, mins: int):
set_cluster_autostop_cmd = _cluster_set_autostop_command(mins)
cluster.run_bash_over_ssh([set_cluster_autostop_cmd], node=cluster.head_ip)

@classmethod
def load_creds(cls):
"""Loads the SSH credentials resource required for the local launcher."""
private_key_path = "~/.ssh/sky-key"
try:
# Note: doesn't require the secret to be saved in Den, just the SSH keypair existing in its local path
secret = rh.provider_secret(provider="sky")
if secret.path == private_key_path and secret.values:
return secret

# Create a new sky key pair that will be saved to its default path
secret = cls._create_sky_secret()
return secret

except ValueError:
sky_private_key = Path(private_key_path).expanduser()
sky_public_key = Path("~/.ssh/sky-key.pub").expanduser()

# if SSH creds already exist, use those for Sky too
secret = rh.provider_secret(provider="ssh")
default_private_key_path = Path(secret.path).expanduser()
default_public_key_path = Path(f"{secret.path}.pub").expanduser()

if (
secret.values
and default_private_key_path.exists()
and default_public_key_path.exists()
):
# Re-use the same SSH creds for Sky, which need to be saved in a specific path for Sky to recognize
shutil.copy(default_private_key_path, sky_private_key)
shutil.copy(default_public_key_path, sky_public_key)
logger.info(
f"Using existing SSH keys for local launching found in path: {default_private_key_path}"
)
return cls.load_creds()
else:
# if no SSH creds found create a new Sky keypair
secret = cls._create_sky_secret()
logger.info(
f"Saved new SSH key pair for local launching in path: {secret.path}"
)

return secret

@staticmethod
def _set_docker_env_vars(image, task):
"""Helper method to set Docker login environment variables."""
Expand Down
10 changes: 4 additions & 6 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _update_from_sky_status(self, dryrun: bool = False):
return

# Try to get the cluster status from SkyDB
if self.is_shared:
if self._is_shared:
# If the cluster is shared can ignore, since the sky data will only be saved on the machine where
# the cluster was initially upped
return
Expand All @@ -546,12 +546,10 @@ def _update_from_sky_status(self, dryrun: bool = False):
self._populate_connection_from_status_dict(cluster_dict)

def _setup_default_creds(self):
"""Setup the default creds used in launching. For Den launching we load the default ssh creds, and for
local launching we let Sky handle it."""
if self.launcher == LauncherType.DEN:
return DenLauncher.load_creds()
elif self.launcher == LauncherType.LOCAL:
return LocalLauncher.load_creds()
else:
raise ValueError(f"Invalid launcher '{self.launcher}'")

def get_instance_type(self):
"""Returns instance type of the cluster."""
Expand Down Expand Up @@ -643,7 +641,7 @@ def up(self, verbose: bool = True, force: bool = False, start_server: bool = Tru
if self.on_this_cluster():
return self

if self.is_shared:
if self._is_shared:
logger.warning(
"Cannot up a shared cluster. Only cluster owners can perform this operation."
)
Expand Down
4 changes: 4 additions & 0 deletions runhouse/resources/hardware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class LauncherType(str, Enum):
LOCAL = "local"
DEN = "den"

@classmethod
def strings(cls):
return [s.value for s in cls]


class RunhouseDaemonStatus(str, Enum):
RUNNING = "running"
Expand Down
59 changes: 0 additions & 59 deletions runhouse/rns/utils/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import ast
import datetime
import functools
import json
import os
import uuid
from enum import Enum
from typing import Tuple

from requests import Response

Expand Down Expand Up @@ -89,63 +87,6 @@ def relative_file_path(file_path: str):
return relative_path


def generate_and_write_ssh_keys(force: bool = False) -> Tuple[str, str]:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

# Adapted from: https://github.com/skypilot-org/skypilot/blob/2c7419cdd22adb440bd37702be6f1ddc1c287684/sky/authentication.py#L106 # noqa
private_key_path = os.path.expanduser("~/.ssh/sky-key")
public_key_path = os.path.expanduser("~/.ssh/sky-key.pub")

if not force and os.path.exists(private_key_path):
return private_key_path, public_key_path

key = rsa.generate_private_key(
backend=default_backend(), public_exponent=65537, key_size=2048
)

private_key = (
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
.decode("utf-8")
.strip()
)

public_key = (
key.public_key()
.public_bytes(
serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
)
.decode("utf-8")
.strip()
)

key_dir = os.path.dirname(private_key_path)
os.makedirs(key_dir, exist_ok=True, mode=0o700)

with open(
private_key_path,
"w",
encoding="utf-8",
opener=functools.partial(os.open, mode=0o600),
) as f:
f.write(private_key)

with open(
public_key_path,
"w",
encoding="utf-8",
opener=functools.partial(os.open, mode=0o644),
) as f:
f.write(public_key)

return private_key_path, public_key_path


class ResourceAccess(str, Enum):
WRITE = "write"
READ = "read"
Expand Down
Loading

0 comments on commit cd1893f

Please sign in to comment.