Skip to content

Commit

Permalink
update handling of shared cluster creds
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 committed Jan 16, 2025
1 parent 57c1672 commit 0304718
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 91 deletions.
1 change: 1 addition & 0 deletions .github/workflows/setup_rh_config/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ runs:
run: |
mkdir ~/.rh && touch ~/.rh/config.yaml
echo "default_folder: /${{ inputs.username }}" > ~/.rh/config.yaml
echo "default_ssh_key: ssh-sky-key" > ~/.rh/config.yaml
echo "token: ${{ inputs.token }}" >> ~/.rh/config.yaml
echo "username: ${{ inputs.username }}" >> ~/.rh/config.yaml
echo "api_server_url: ${{ inputs.api_server_url }}" >> ~/.rh/config.yaml
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/api-resources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ to notify them.
cpu_cluster.share(
users=["teammate1@email.com"],
access_level="write",
access_level="read",
)
Expand Down
87 changes: 64 additions & 23 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
LOCALHOST,
NUM_PORTS_TO_TRY,
RESERVED_SYSTEM_NAMES,
SSH_SKY_SECRET_NAME,
)
from runhouse.globals import configs, obj_store, rns_client
from runhouse.logger import get_logger
Expand Down Expand Up @@ -303,7 +304,11 @@ def _setup_creds(self, ssh_creds: Union[Dict, "Secret", str]):
self._creds = ssh_creds
return
elif isinstance(ssh_creds, str):
self._creds = Secret.from_name(ssh_creds)
self._creds = (
Secret.from_name(ssh_creds)
if rns_client.base_folder(ssh_creds) == rns_client.username
else Secret.from_name(rns_client.default_ssh_key or SSH_SKY_SECRET_NAME)
)
return

if not ssh_creds:
Expand Down Expand Up @@ -433,11 +438,6 @@ def config(self, condensed: bool = True):
else None
)
if creds:
if "loaded_secret_" in creds:
# user A shares cluster with user B, with "write" permissions. If user B will save the cluster to Den, we
# would NOT like that the loaded secret will overwrite the original secret that was created and shared by
# user A.
creds = creds.replace("loaded_secret_", "")
config["creds"] = creds

if self._use_custom_certs:
Expand Down Expand Up @@ -561,12 +561,24 @@ def _command_runner(
ssh_control_name = ssh_credentials.pop(
"ssh_control_name", f"{node}:{self.ssh_port}"
)
ssh_user = ssh_credentials.get("ssh_user") or self.ssh_properties.get(
"ssh_user"
)
ssh_private_key = (
ssh_credentials.get("ssh_private_key")
or str(Path(self.ssh_properties.get("ssh_private_key")).expanduser())
if self.ssh_properties.get("ssh_private_key")
else None
)
ssh_proxy_command = ssh_credentials.get(
"ssh_proxy_command"
) or self.ssh_properties.get("ssh_proxy_command")

runner = SkySSHRunner(
(node, self.ssh_port),
ssh_user=ssh_credentials.get("ssh_user"),
ssh_private_key=ssh_credentials.get("ssh_private_key"),
ssh_proxy_command=ssh_credentials.get("ssh_proxy_command"),
ssh_user=ssh_user,
ssh_private_key=ssh_private_key,
ssh_proxy_command=ssh_proxy_command,
ssh_control_name=ssh_control_name,
docker_user=self.docker_user if not use_docker_exec else None,
use_docker_exec=use_docker_exec,
Expand Down Expand Up @@ -2291,23 +2303,48 @@ def share(
notify_users: bool = True,
headers: Optional[Dict] = None,
) -> Tuple[Dict[str, ResourceAccess], Dict[str, ResourceAccess]]:
"""Grant access to the cluster for a list of users (or a single user). By default, the user(s) will
receive an email notification of access (if they have a Runhouse account) or instructions on creating
an account to access the cluster. If ``visibility`` is set to ``public``, users will not be notified.
# save cluster and creds if not saved
self.save()
.. note::
You can only grant access to other users if you have write access to the cluster.
# share creds
logger.info(
"Sharing cluster credentials, which enables the recipient to SSH into the cluster."
)
if self._creds:
self._creds.share(
users=users,
access_level=access_level,
visibility=visibility,
notify_users=notify_users,
headers=headers,
Args:
users (Union[str, list], optional): Single user or list of user emails and / or Runhouse account usernames.
If none are provided and ``visibility`` is set to ``public``, cluster will be made publicly
available to all users. (Default: ``None``)
access_level (:obj:`ResourceAccess`, optional): Access level to provide for the resource.
Note that for clusters only read access is currently supported.
visibility (:obj:`ResourceVisibility`, optional): Type of visibility to provide for the shared
resource. By default, the visibility is private. (Default: ``None``)
notify_users (bool, optional): Whether to send an email notification to users who have been given access.
(Default: ``True``)
headers (Dict, optional): Request headers to provide for the request to Den. Contains the user's auth token.
Example: ``{"Authorization": f"Bearer {token}"}``
Returns:
Tuple(Dict, Dict, Set):
`added_users`:
Users who already have a Runhouse account and have been granted access to the resource.
`new_users`:
Users who do not have Runhouse accounts and received notifications via their emails.
`valid_users`:
Set of valid usernames and emails from ``users`` parameter.
Example:
>>> # Visibility will be set to private (users can search for and view resource in Den dashboard)
>>> cluster.share(users=["username1", "user2@gmail.com"])
"""
if access_level != ResourceAccess.READ:
raise ValueError(
f"Clusters can only be shared with read access, not {access_level}."
)

# save cluster in case it's not already saved
self.save()

# share cluster
return super().share(
users=users,
Expand All @@ -2326,7 +2363,11 @@ def _check_for_child_configs(cls, config: dict):
creds = config.pop("creds", None) or config.pop("ssh_creds", None)

if isinstance(creds, str):
creds = Secret.from_config(config=load_config(name=creds))
creds = (
Secret.from_config(config=load_config(name=creds))
if rns_client.base_folder(creds) == rns_client.username
else Secret.from_name(rns_client.default_ssh_key or SSH_SKY_SECRET_NAME)
)
elif isinstance(creds, dict):
creds = Secret.from_config(creds)

Expand Down
85 changes: 50 additions & 35 deletions runhouse/resources/hardware/launcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import requests

import runhouse as rh
from runhouse.constants import SSH_SKY_SECRET_NAME
from runhouse.globals import configs, rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware.utils import (
Expand Down Expand Up @@ -98,41 +99,18 @@ def keep_warm(cls, cluster, mins: int):
"""Abstract method for keeping a cluster warm."""
raise NotImplementedError

@classmethod
def load_creds(cls):
"""Loads the SSH credentials resource required for the launcher."""
raise NotImplementedError

@staticmethod
def supported_providers():
"""Return the base list of Sky supported providers."""
import sky

return list(sky.clouds.CLOUD_REGISTRY)

@classmethod
def sky_secret(cls):
from runhouse.constants import SSH_SKY_SECRET_NAME

try:
sky_secret = rh.secret(SSH_SKY_SECRET_NAME)
except ValueError:
# Create a new default key pair required for the Den launcher and save it to Den
from runhouse import provider_secret

default_ssh_path, _ = generate_ssh_keys()
logger.info(f"Saved new SSH key to path: {default_ssh_path} ")
sky_secret = provider_secret(
provider="sky", path=default_ssh_path, name=SSH_SKY_SECRET_NAME
)
sky_secret.save()

secret_values = sky_secret.values
if (
not secret_values
or "public_key" not in secret_values
or "private_key" not in secret_values
):
raise ValueError(
f"Public key and private key values not found in secret {sky_secret.name}"
)
return sky_secret

@classmethod
def run_verbose(
cls,
Expand Down Expand Up @@ -244,17 +222,11 @@ def keep_warm(cls, cluster, mins: int):
@classmethod
def up(cls, cluster, verbose: bool = True, force: bool = False):
"""Launch the cluster via Den."""
sky_secret = cls.sky_secret()
cluster._setup_creds(sky_secret)
cluster.save()

cluster_config = cluster.config()

payload = {
"cluster_config": {
**cluster_config,
"ssh_creds": sky_secret.rns_address,
},
"cluster_config": cluster_config,
"force": force,
"verbose": verbose,
"observability": configs.observability_enabled,
Expand Down Expand Up @@ -332,6 +304,32 @@ def teardown(cls, cluster, verbose: bool = True):
)
cluster._cluster_status = ClusterStatus.TERMINATED

@classmethod
def load_creds(cls):
"""Loads the SSH credentials resource required for the launcher."""
default_ssh_key = rns_client.default_ssh_key
if default_ssh_key:
try:
secret = rh.Secret.from_name(default_ssh_key)
return secret
except ValueError:
pass

# try loading the default SSH key
secret = rh.provider_secret("ssh", path="~/.ssh/id_rsa")
if not secret.values:
raise ValueError(
"No default SSH key found locally or in Den. Please run `runhouse login` to save one."
)

if not default_ssh_key:
configs.set("default_ssh_key", secret.name)
logger.info(
f"Updated default SSH key in the local Runhouse config to {secret.name}"
)

return secret


class LocalLauncher(Launcher):
"""Launcher APIs for operations handled locally via Sky."""
Expand Down Expand Up @@ -434,6 +432,23 @@ def keep_warm(cls, cluster, mins: int):
set_cluster_autostop_cmd = _cluster_set_autostop_command(mins)
cluster.run_bash_over_ssh([set_cluster_autostop_cmd], node=cluster.head_ip)

@classmethod
def load_creds(cls):
"""Loads the SSH credentials resource required for the launcher."""
try:
secret = rh.provider_secret("sky")
return secret

except ValueError:
# Create a new key pair required by Sky
path, _ = generate_ssh_keys()
secret = rh.provider_secret(
provider="sky", path=path, name=SSH_SKY_SECRET_NAME
)
secret.save()
logger.info(f"Saved new Sky key pair locally in path: {path}")
return secret

@staticmethod
def _set_docker_env_vars(image, task):
"""Helper method to set Docker login environment variables."""
Expand Down
9 changes: 9 additions & 0 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ def image_id(self) -> str:
return self.image.image_id
return None

@property
def creds_values(self) -> Dict:
if self._creds is None:
if self.launcher == LauncherType.DEN:
self._creds = DenLauncher.load_creds()
elif self.launcher == LauncherType.LOCAL:
self._creds = LocalLauncher.load_creds()
return self._creds.values

@property
def docker_user(self) -> str:
if self._docker_user:
Expand Down
4 changes: 2 additions & 2 deletions runhouse/resources/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,15 @@ def share(
notify_users: bool = True,
headers: Dict = None,
) -> Tuple[Dict[str, ResourceAccess], Dict[str, ResourceAccess]]:
"""Grant access to the resource for a list of users (or a single user). By default, the user will
"""Grant access to the resource for a list of users (or a single user). By default, the user(s) will
receive an email notification of access (if they have a Runhouse account) or instructions on creating
an account to access the resource. If ``visibility`` is set to ``public``, users will not be notified.
.. note::
You can only grant access to other users if you have write access to the resource.
Args:
users (Union[str, list], optional): Single user or list of user emails and / or runhouse account usernames.
users (Union[str, list], optional): Single user or list of user emails and / or Runhouse account usernames.
If none are provided and ``visibility`` is set to ``public``, resource will be made publicly
available to all users. (Default: ``None``)
access_level (:obj:`ResourceAccess`, optional): Access level to provide for the resource.
Expand Down
11 changes: 1 addition & 10 deletions runhouse/resources/secrets/secret.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def _write_shared_secret_to_local(config):
"ssh_private_key": str(private_key_path / "ssh-key"),
"ssh_user": new_creds_values.get("ssh_user"),
}
return rh.secret(
values=new_creds_values, name=f"loaded_secret_{config['name']}"
)
return rh.secret(values=new_creds_values, name=config["name"])

@staticmethod
def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = True):
Expand All @@ -101,13 +99,6 @@ def from_config(config: dict, dryrun: bool = False, _resolve_children: bool = Tr
provider_class = _get_provider_class(config["provider"])
return provider_class.from_config(config, dryrun=dryrun)

# checks if the config is a of a shared secret
current_user = configs.username
owner_user = config["owner"]["username"] if "owner" in config.keys() else None

if owner_user and current_user != owner_user and config["values"]:
return Secret._write_shared_secret_to_local(config)

return Secret(**config, dryrun=dryrun)

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions runhouse/rns/rns_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def grant_resource_access(
)
if resp.status_code != 200:
raise Exception(
f"Received [{resp.status_code}] from Den PUT '{uri}': Failed to grant access to: {load_resp_content(resp)}"
f"Received [{resp.status_code}] from Den PUT '{uri}': Failed to grant access to resource {rns_address}"
)

resp_data: dict = read_resp_data(resp)
Expand Down Expand Up @@ -524,11 +524,11 @@ def _save_config_in_rns(self, config, resource_name):
)
if resp.status_code != 200:
raise Exception(
f"Received [{resp.status_code}] from Den POST '{post_uri}': Failed to create new resource '{resource_uri}': {load_resp_content(resp)}"
f"Received [{resp.status_code}] from Den POST '{post_uri}': Failed to create new resource: {load_resp_content(resp)}"
)
else:
raise Exception(
f"Received [{resp.status_code}] from Den PUT '{put_uri}': Failed to save resource '{resource_uri}': {load_resp_content(resp)}"
f"Received [{resp.status_code}] from Den PUT '{put_uri}': Failed to save resource: {load_resp_content(resp)}"
)

def delete_configs(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_resources/test_clusters/cluster_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_cluster_config(local_launched_ondemand_aws_docker_cluster):
def test_cluster_sharing(local_launched_ondemand_aws_docker_cluster):
local_launched_ondemand_aws_docker_cluster.share(
users=["donny@run.house", "josh@run.house"],
access_level="write",
access_level="read",
notify_users=False,
)
assert True
Expand Down
Loading

0 comments on commit 0304718

Please sign in to comment.