Skip to content

Commit

Permalink
update handling of shared cluster creds
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 authored and Alexandra Belousov committed Jan 21, 2025
1 parent f9f4421 commit ed5b09f
Show file tree
Hide file tree
Showing 19 changed files with 331 additions and 254 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/api-resources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ to notify them.
cpu_cluster.share(
users=["teammate1@email.com"],
access_level="write",
access_level="read",
)
Expand Down
3 changes: 2 additions & 1 deletion runhouse/resources/folders/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,13 @@ def _to_cluster(self, dest_cluster, path=None):
def _cluster_to_cluster(self, dest_cluster, dest_path):
src_path = self.path

cluster_ssh_properties = self.system.ssh_properties
cluster_creds = self.system.creds_values

if not cluster_creds.get("password") and not dest_cluster.creds_values.get(
"password"
):
creds_file = cluster_creds.get("ssh_private_key")
creds_file = cluster_ssh_properties.get("ssh_private_key")
creds_cmd = f"-i '{creds_file}' " if creds_file else ""

dest_cluster.run_bash([f"mkdir -p {dest_path}"])
Expand Down
157 changes: 92 additions & 65 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from runhouse.resources.hardware.utils import (
_setup_creds_from_dict,
_setup_default_creds,
ClusterStatus,
get_clusters_from_den,
get_running_and_not_running_clusters,
Expand Down Expand Up @@ -143,7 +142,6 @@ def __init__(
ssh_properties: Dict = None,
den_auth: bool = False,
dryrun: bool = False,
skip_creds: bool = False,
image: Optional["Image"] = None,
**kwargs, # We have this here to ignore extra arguments when calling from from_config
):
Expand Down Expand Up @@ -178,10 +176,7 @@ def __init__(

self.reqs = []

if skip_creds and not creds:
self._creds = None
else:
self._setup_creds(creds)
self._setup_creds(creds)

if isinstance(image, dict):
# If reloading from config (ex: in Den)
Expand Down Expand Up @@ -249,7 +244,7 @@ def creds_values(self) -> Dict:
if not self._creds:
return {}

return {**self._creds.values, **self.ssh_properties}
return self._creds.values

@property
def docker_user(self) -> Optional[str]:
Expand Down Expand Up @@ -338,22 +333,32 @@ def _setup_creds(self, ssh_creds: Union[Dict, "Secret", str]):
if isinstance(ssh_creds, Secret):
self._creds = ssh_creds
return

elif isinstance(ssh_creds, str):
self._creds = Secret.from_name(ssh_creds)
self._creds = (
Secret.from_name(ssh_creds)
if rns_client.base_folder(ssh_creds) == rns_client.username
else None
)
return

if not ssh_creds:
from runhouse.resources.hardware.on_demand_cluster import OnDemandCluster
self._creds = self._setup_default_creds() if not self.is_shared else None

cluster_subtype = (
"OnDemandCluster" if isinstance(self, OnDemandCluster) else "Cluster"
)
self._creds = _setup_default_creds(cluster_subtype)
elif isinstance(ssh_creds, Dict):
creds, ssh_properties = _setup_creds_from_dict(ssh_creds, self.name)
self._creds = creds
self.ssh_properties = ssh_properties

def _setup_default_creds(self):
from runhouse.resources.secrets import Secret

default_ssh_key = rns_client.default_ssh_key
if default_ssh_key is None:
return None

return Secret.from_name(default_ssh_key)

def _should_save_creds(self, folder: str = None) -> bool:
"""Checks whether to save the creds associated with the cluster.
Only do so as part of the save() if the user making the call is the creator"""
Expand Down Expand Up @@ -471,11 +476,6 @@ def config(self, condensed: bool = True):
else None
)
if creds:
if "loaded_secret_" in creds:
# user A shares cluster with user B, with "write" permissions. If user B will save the cluster to Den, we
# would NOT like that the loaded secret will overwrite the original secret that was created and shared by
# user A.
creds = creds.replace("loaded_secret_", "")
config["creds"] = creds

if self._use_custom_certs:
Expand Down Expand Up @@ -554,22 +554,12 @@ def server_address(self):

@property
def is_shared(self) -> bool:
from runhouse import Secret

ssh_creds = self.creds_values
if not ssh_creds:
rns_address = self.rns_address
if rns_address is None:
return False

ssh_private_key = ssh_creds.get("ssh_private_key")
if ssh_private_key:
ssh_private_key_path = Path(ssh_private_key).expanduser()
secrets_base_dir = Path(Secret.DEFAULT_DIR).expanduser()

# Check if the key path is saved down in the local .rh directory, which we only do for shared credentials
if str(ssh_private_key_path).startswith(str(secrets_base_dir)):
return True
return f"{self._creds.name}/" in ssh_private_key
return False
# If the cluster is shared, the base directory of the rns address will differ from the current username
return rns_client.base_folder(rns_address) != rns_client.username

def _command_runner(
self, node: Optional[str] = None, use_docker_exec: Optional[bool] = False
Expand All @@ -595,16 +585,23 @@ def _command_runner(
(namespace, pod_name), docker_user=self.docker_user
)
else:
ssh_credentials = copy.copy(self.creds_values) or {}
ssh_control_name = ssh_credentials.pop(
ssh_properties = copy.copy(self.ssh_properties) or {}
ssh_control_name = ssh_properties.pop(
"ssh_control_name", f"{node}:{self.ssh_port}"
)
ssh_user = ssh_properties.get("ssh_user")
ssh_private_key = (
str(Path(ssh_properties.get("ssh_private_key")).expanduser())
if ssh_properties.get("ssh_private_key")
else None
)
ssh_proxy_command = ssh_properties.get("ssh_proxy_command")

runner = SkySSHRunner(
(node, self.ssh_port),
ssh_user=ssh_credentials.get("ssh_user"),
ssh_private_key=ssh_credentials.get("ssh_private_key"),
ssh_proxy_command=ssh_credentials.get("ssh_proxy_command"),
ssh_user=ssh_user,
ssh_private_key=ssh_private_key,
ssh_proxy_command=ssh_proxy_command,
ssh_control_name=ssh_control_name,
docker_user=self.docker_user if not use_docker_exec else None,
use_docker_exec=use_docker_exec,
Expand All @@ -631,6 +628,12 @@ def up_if_not(self, verbose: bool = True):
Example:
>>> rh.cluster("rh-cpu").up_if_not()
"""
if self.is_shared:
logger.warning(
"Cannot up a shared cluster. Only cluster owners can perform this operation."
)
return self

if not self.is_up():
self.up(verbose=verbose, force=False)
return self
Expand Down Expand Up @@ -1058,7 +1061,7 @@ def ssh_tunnel(
cloud = self.compute_properties.get("cloud")
return ssh_tunnel(
address=self.head_ip,
ssh_creds=self.creds_values,
ssh_properties=self.ssh_properties,
docker_user=self.docker_user,
local_port=local_port,
ssh_port=self.ssh_port,
Expand Down Expand Up @@ -1569,11 +1572,11 @@ def rsync(
subprocess.run(cmd, check=True, capture_output=not stream_logs, text=True)
return

ssh_credentials = copy.copy(self.creds_values) or {}
ssh_credentials.pop("ssh_host", node)
pwd = ssh_credentials.pop("password", None)
ssh_credentials.pop("private_key", None)
ssh_credentials.pop("public_key", None)
ssh_properties = copy.copy(self.ssh_properties) or {}
ssh_properties.pop("ssh_host", node)

creds_values = copy.copy(self.creds_values) or {}
pwd = creds_values.pop("password", None)

# If we're syncing between nodes on the cluster, we need to use the internal ip of the destination node
if src_node:
Expand Down Expand Up @@ -1680,7 +1683,7 @@ def _local_rsync(
# use internal ip of destination node to sync between cluster nodes w/o additional creds
dest_node_idx = self.ips.index(node)
dest_node_internal_ip = self.internal_ips[dest_node_idx]
ssh_user = self.creds_values.get("ssh_user")
ssh_user = self.ssh_properties.get("ssh_user")
node_destination = f"{ssh_user}@{dest_node_internal_ip}"

rsync_cmd = ["rsync", "-Pavz", "-e", "ssh"]
Expand All @@ -1707,12 +1710,12 @@ def ssh(self):
Example:
>>> rh.cluster("rh-cpu").ssh()
"""
creds = self.creds_values
ssh_properties = self.ssh_properties
_run_ssh_command(
address=self.head_ip,
ssh_user=creds["ssh_user"],
ssh_user=ssh_properties["ssh_user"],
ssh_port=self.ssh_port,
ssh_private_key=creds["ssh_private_key"],
ssh_private_key=ssh_properties["ssh_private_key"],
docker_user=self.docker_user,
)

Expand Down Expand Up @@ -2040,10 +2043,8 @@ def _run_commands_with_runner(

return_codes = []

ssh_credentials = copy.copy(self.creds_values)
pwd = ssh_credentials.pop("password", None)
ssh_credentials.pop("private_key", None)
ssh_credentials.pop("public_key", None)
creds_values = copy.copy(self.creds_values)
pwd = creds_values.pop("password", None)

runner = self._command_runner(
node=node, use_docker_exec=self.docker_user is not None
Expand Down Expand Up @@ -2424,23 +2425,45 @@ def share(
notify_users: bool = True,
headers: Optional[Dict] = None,
) -> Tuple[Dict[str, ResourceAccess], Dict[str, ResourceAccess]]:
"""Grant access to the cluster for a single user or list of users. By default, the user(s) will
receive an email notification of access (if they have a Runhouse account) or instructions on creating
an account to access the cluster. If ``visibility`` is set to ``public``, users will not be notified.
# save cluster and creds if not saved
self.save()
Args:
users (Union[str, list], optional): Single user or list of user emails and / or Runhouse account usernames.
If none are provided and ``visibility`` is set to ``public``, cluster will be made publicly
available to all users. (Default: ``None``)
access_level (:obj:`ResourceAccess`, optional): Access level to provide for the cluster.
Note that for clusters only read access is currently supported.
visibility (:obj:`ResourceVisibility`, optional): Type of visibility to provide for the shared
resource. By default, the visibility is private. (Default: ``None``)
notify_users (bool, optional): Whether to send an email notification to users who have been given access.
(Default: ``True``)
headers (Dict, optional): Request headers to provide for the request to Den. Contains the user's auth token.
Example: ``{"Authorization": f"Bearer {token}"}``
# share creds
logger.info(
"Sharing cluster credentials, which enables the recipient to SSH into the cluster."
)
if self._creds:
self._creds.share(
users=users,
access_level=access_level,
visibility=visibility,
notify_users=notify_users,
headers=headers,
Returns:
Tuple(Dict, Dict, Set):
`added_users`:
Users who already have a Runhouse account and have been granted access to the cluster.
`new_users`:
Users who do not have Runhouse accounts and received notifications via their emails.
`valid_users`:
Set of valid usernames and emails from ``users`` parameter.
Example:
>>> # Visibility will be set to private (users can search for and view resource in Den dashboard)
>>> cluster.share(users=["username1", "user2@gmail.com"])
"""
if access_level != ResourceAccess.READ:
raise ValueError(
f"Clusters can only be shared with read access, not {access_level}."
)

# save cluster in case it's not already saved
self.save()

# share cluster
return super().share(
users=users,
Expand All @@ -2459,7 +2482,11 @@ def _check_for_child_configs(cls, config: dict):
creds = config.pop("creds", None) or config.pop("ssh_creds", None)

if isinstance(creds, str):
creds = Secret.from_config(config=load_config(name=creds))
creds = (
Secret.from_config(config=load_config(name=creds))
if rns_client.base_folder(creds) == rns_client.username
else None
)
elif isinstance(creds, dict):
creds = Secret.from_config(creds)

Expand Down
Loading

0 comments on commit ed5b09f

Please sign in to comment.