diff --git a/mttl/models/library/backend_engine.py b/mttl/models/library/backend_engine.py index e2fa9d464..0fae961f8 100644 --- a/mttl/models/library/backend_engine.py +++ b/mttl/models/library/backend_engine.py @@ -137,6 +137,7 @@ def __init__(self, token: Optional[str] = None, cache_dir: Optional[str] = None) self.cache_dir = cache_dir # Quiet down the azure logging logging.getLogger("azure").setLevel(logging.WARNING) + self.last_modified_cache = None @property def cache_dir(self): @@ -200,16 +201,27 @@ def _get_container_client(self, repo_id, use_async=False): storage_uri, container = self._parse_repo_id_to_storage_info(repo_id) return self._get_blob_client(repo_id, use_async).get_container_client(container) - def _last_modified(self, repo_id: str) -> datetime.datetime: + def _last_modified( + self, repo_id: str, set_cache: bool = False + ) -> datetime.datetime: """Get the last modified date of a repository.""" - try: - return ( - self._get_container_client(repo_id) - .get_container_properties() - .last_modified - ) - except ResourceNotFoundError as error: - raise ValueError(f"Repository {repo_id} not found") from error + + # if cached version exists, return cache. We want to avoid repetitive calls to the API + if self.last_modified_cache: + return self.last_modified_cache + + else: + try: + last_modified = ( + self._get_container_client(repo_id) + .get_container_properties() + .last_modified + ) + if set_cache: + self.last_modified_cache = last_modified + return last_modified + except ResourceNotFoundError as error: + raise ValueError(f"Repository {repo_id} not found") from error def get_repository_cache_dir(self, repo_id: str) -> Path: """Get the cache directory for a repository. If it doesn't exist, create it. @@ -280,23 +292,17 @@ def delete_repo(self, repo_id, repo_type=None): except ResourceNotFoundError: logger.info(f"Container {repo_id} not found.") - def create_commit(self, repo_id, operations, commit_message="", async_mode=False): + def create_commit(self, repo_id, operations, commit_message="", async_mode=True): asyncio.run( self.async_create_commit(repo_id, operations, async_mode=async_mode) ) async def async_create_commit(self, repo_id, operations, async_mode=False): tasks = [] + upload_batch = [] for op in operations: if isinstance(op, CommitOperationAdd): - tasks.append( - self._async_upload_blob( - repo_id=repo_id, - filename=op.path_in_repo, - buffer=op.path_or_fileobj, - overwrite=True, - ) - ) + upload_batch.append(op) elif isinstance(op, CommitOperationCopy): tasks.append( self._async_copy_blob( @@ -314,11 +320,13 @@ async def async_create_commit(self, repo_id, operations, async_mode=False): filename=op.path_in_repo, ) ) - if async_mode: - await asyncio.gather(*tasks) - else: - for task in tasks: - await task + + # upload blobs in batch, using async! + await self.async_upload_blobs( + repo_id, + filenames=[op.path_in_repo for op in upload_batch], + buffers=[op.path_or_fileobj for op in upload_batch], + ) def preupload_lfs_files(self, repo_id, additions): # for blob storage, these operations are done in create_commit @@ -390,28 +398,41 @@ async def async_upload_blobs( else: if len(buffers) != len(filenames): raise ValueError("Filenames and buffers must have the same length.") - tasks = [ - self._async_upload_blob(repo_id, filename, buffer, overwrite) - for filename, buffer in zip(filenames, buffers) - ] - await asyncio.gather(*tasks) - return filenames[0] if is_str else filenames - async def _async_upload_blob(self, repo_id, filename, buffer=None, overwrite=True): - storage_uri, container = self._parse_repo_id_to_storage_info(repo_id) + self._last_modified(repo_id, set_cache=True) # set the cache for last_modified async with self._get_blob_client( repo_id, use_async=True ) as blob_service_client: - blob_client = blob_service_client.get_blob_client( - container=container, blob=filename - ) - if buffer is not None: - await blob_client.upload_blob(buffer, overwrite=overwrite) - else: - local_cache = self._get_local_filepath(repo_id, filename) - with open(file=local_cache, mode="rb") as blob_file: - await blob_client.upload_blob(blob_file, overwrite=overwrite) + tasks = [ + self._async_upload_blob( + blob_service_client, repo_id, filename, buffer, overwrite + ) + for filename, buffer in zip(filenames, buffers) + ] + await asyncio.gather(*tasks) + + self.last_modified_cache = None # reset the cache + + return filenames[0] if is_str else filenames + + async def _async_upload_blob( + self, blob_service_client, repo_id, filename, buffer=None, overwrite=True + ): + storage_uri, container = self._parse_repo_id_to_storage_info(repo_id) + + blob_client = blob_service_client.get_blob_client( + container=container, blob=filename + ) + + if buffer is not None: + await blob_client.upload_blob(buffer, overwrite=overwrite) + + else: + local_cache = self._get_local_filepath(repo_id, filename) + + with open(file=local_cache, mode="rb") as blob_file: + await blob_client.upload_blob(blob_file, overwrite=overwrite) async def async_download_blobs( self, repo_id: str, filesnames: Union[List[str], str] @@ -419,33 +440,40 @@ async def async_download_blobs( is_str = isinstance(filesnames, str) if is_str: filesnames = [filesnames] - tasks = [ - self._async_download_blob(repo_id, filename) for filename in filesnames - ] - local_filenames = await asyncio.gather(*tasks) - return local_filenames[0] if is_str else local_filenames - async def _async_download_blob(self, repo_id, filename): - storage_uri, container = self._parse_repo_id_to_storage_info(repo_id) + self._last_modified(repo_id, set_cache=True) # set the cache for last_modified + async with self._get_blob_client( repo_id, use_async=True ) as blob_service_client: - # already cached! - local_filename = self._get_local_filepath(repo_id, filename) - if local_filename.exists(): - return local_filename + tasks = [ + self._async_download_blob(blob_service_client, repo_id, filename) + for filename in filesnames + ] + local_filesnames = await asyncio.gather(*tasks) - blob_client = blob_service_client.get_blob_client( - container=container, blob=filename - ) + self.last_modified_cache = None # reset the cache + + return local_filesnames[0] if is_str else local_filesnames - os.makedirs(os.path.dirname(local_filename), exist_ok=True) - with open(file=local_filename, mode="wb") as blob_file: - download_stream = await blob_client.download_blob() - data = await download_stream.readall() - blob_file.write(data) + async def _async_download_blob(self, blob_service_client, repo_id, filename): + # already cached! + local_filename = self._get_local_filepath(repo_id, filename) + if local_filename.exists(): return local_filename + storage_uri, container = self._parse_repo_id_to_storage_info(repo_id) + blob_client = blob_service_client.get_blob_client( + container=container, blob=filename + ) + + os.makedirs(os.path.dirname(local_filename), exist_ok=True) + async with open(file=local_filename, mode="wb") as blob_file: + download_stream = await blob_client.download_blob() + data = await download_stream.readall() + blob_file.write(data) + return local_filename + async def async_copy_blobs( self, source_repo_ids, diff --git a/mttl/models/library/expert_library.py b/mttl/models/library/expert_library.py index 8065dccdc..8ebdb5c7c 100644 --- a/mttl/models/library/expert_library.py +++ b/mttl/models/library/expert_library.py @@ -1,3 +1,4 @@ +import asyncio import glob import io import os @@ -141,31 +142,49 @@ def _build_lib(self): logger.error("Repository not found: %s", self.repo_id) raise e - # Function to download and process a single .meta file - def download_and_process_meta_file(file): - path_or_bytes = self.hf_hub_download(self.repo_id, file) - - metadata_entry = MetadataEntry.fromdict( - torch.load(path_or_bytes, map_location="cpu", weights_only=False) + if isinstance(self, BlobExpertLibrary): + local_filenames = asyncio.run( + self.async_download_blobs( + self.repo_id, + meta_files, + ) ) - return metadata_entry - - # Use ThreadPoolExecutor for multithreading - with ThreadPoolExecutor() as executor: - # Submit tasks to the executor - future_to_file = { - executor.submit(download_and_process_meta_file, file): file - for file in meta_files - } + # process every meta file in new local directory metadata = [] - for future in as_completed(future_to_file): - file = future_to_file[future] - try: - data = future.result() - metadata.append(data) - except Exception as exc: - logger.error("%r generated an exception: %s" % (file, exc)) + for file in local_filenames: + metadata_entry = MetadataEntry.fromdict( + torch.load(file, map_location="cpu", weights_only=False) + ) + metadata.append(metadata_entry) + + else: + + # Function to download and process a single .meta file + def download_and_process_meta_file(file): + path_or_bytes = self.hf_hub_download(self.repo_id, file) + + metadata_entry = MetadataEntry.fromdict( + torch.load(path_or_bytes, map_location="cpu", weights_only=False) + ) + return metadata_entry + + # Use ThreadPoolExecutor for multithreading + with ThreadPoolExecutor() as executor: + # Submit tasks to the executor + future_to_file = { + executor.submit(download_and_process_meta_file, file): file + for file in meta_files + } + + metadata = [] + for future in as_completed(future_to_file): + file = future_to_file[future] + try: + data = future.result() + metadata.append(data) + except Exception as exc: + logger.error("%r generated an exception: %s" % (file, exc)) for metadatum in metadata: if self.model_name is not None and metadatum.model != self.model_name: @@ -282,7 +301,11 @@ def __len__(self): return len(self.data) def add_expert( - self, expert_dump: Expert, expert_name: str = None, force: bool = False + self, + expert_dump: Expert, + expert_name: str = None, + force: bool = False, + update_readme: bool = True, ): if self.sliced: raise ValueError("Cannot add expert to sliced library.") @@ -307,7 +330,9 @@ def add_expert( self._upload_weights(metadata.expert_name, expert_dump) self._upload_metadata(metadata) self.data[metadata.expert_name] = metadata - self._update_readme() + # only update readme if requested. This is useful when adding multiple experts in a batch + if update_readme: + self._update_readme() def list_auxiliary_data(self) -> Dict[str, Tuple[int, str]]: """List auxiliary data in the library, returns a dictionary with the data type, the number of records, and a string representation of the config file.""" @@ -770,9 +795,15 @@ def clone( only_tasks = only_tasks or self.tasks with new_lib.batched_commit(): + update_readme = False for name, expert in self.items(): if expert.name not in new_lib: - new_lib.add_expert(expert, name, force=force) + new_lib.add_expert(expert, name, force=force, update_readme=False) + update_readme = True + + # only update readme if we added new experts + if update_readme: + new_lib._update_readme() # if the new_lib already exists, delete experts that # are in this lib but were deleted from the expert_lib @@ -929,7 +960,11 @@ class LocalExpertLibrary(ExpertLibrary, LocalFSEngine): """A local library stored on disk.""" def add_expert( - self, expert_dump: Expert, expert_name: str = None, force: bool = False + self, + expert_dump: Expert, + expert_name: str = None, + force: bool = False, + update_readme: bool = True, ): expert_name = expert_name or expert_dump.expert_info.expert_name if "/" in expert_name: