Skip to content

Commit

Permalink
Merge pull request #156 from fverac/fverac/blob_engine_speedup
Browse files Browse the repository at this point in the history
optimizations to BlobStorageEngine, ExpertLibrary
  • Loading branch information
sordonia authored Jan 23, 2025
2 parents a561f27 + 739b249 commit 00882d6
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 85 deletions.
146 changes: 87 additions & 59 deletions mttl/models/library/backend_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(self, token: Optional[str] = None, cache_dir: Optional[str] = None)
self.cache_dir = cache_dir
# Quiet down the azure logging
logging.getLogger("azure").setLevel(logging.WARNING)
self.last_modified_cache = None

@property
def cache_dir(self):
Expand Down Expand Up @@ -200,16 +201,27 @@ def _get_container_client(self, repo_id, use_async=False):
storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)
return self._get_blob_client(repo_id, use_async).get_container_client(container)

def _last_modified(self, repo_id: str) -> datetime.datetime:
def _last_modified(
self, repo_id: str, set_cache: bool = False
) -> datetime.datetime:
"""Get the last modified date of a repository."""
try:
return (
self._get_container_client(repo_id)
.get_container_properties()
.last_modified
)
except ResourceNotFoundError as error:
raise ValueError(f"Repository {repo_id} not found") from error

# if cached version exists, return cache. We want to avoid repetitive calls to the API
if self.last_modified_cache:
return self.last_modified_cache

else:
try:
last_modified = (
self._get_container_client(repo_id)
.get_container_properties()
.last_modified
)
if set_cache:
self.last_modified_cache = last_modified
return last_modified
except ResourceNotFoundError as error:
raise ValueError(f"Repository {repo_id} not found") from error

def get_repository_cache_dir(self, repo_id: str) -> Path:
"""Get the cache directory for a repository. If it doesn't exist, create it.
Expand Down Expand Up @@ -280,23 +292,17 @@ def delete_repo(self, repo_id, repo_type=None):
except ResourceNotFoundError:
logger.info(f"Container {repo_id} not found.")

def create_commit(self, repo_id, operations, commit_message="", async_mode=False):
def create_commit(self, repo_id, operations, commit_message="", async_mode=True):
asyncio.run(
self.async_create_commit(repo_id, operations, async_mode=async_mode)
)

async def async_create_commit(self, repo_id, operations, async_mode=False):
tasks = []
upload_batch = []
for op in operations:
if isinstance(op, CommitOperationAdd):
tasks.append(
self._async_upload_blob(
repo_id=repo_id,
filename=op.path_in_repo,
buffer=op.path_or_fileobj,
overwrite=True,
)
)
upload_batch.append(op)
elif isinstance(op, CommitOperationCopy):
tasks.append(
self._async_copy_blob(
Expand All @@ -314,11 +320,13 @@ async def async_create_commit(self, repo_id, operations, async_mode=False):
filename=op.path_in_repo,
)
)
if async_mode:
await asyncio.gather(*tasks)
else:
for task in tasks:
await task

# upload blobs in batch, using async!
await self.async_upload_blobs(
repo_id,
filenames=[op.path_in_repo for op in upload_batch],
buffers=[op.path_or_fileobj for op in upload_batch],
)

def preupload_lfs_files(self, repo_id, additions):
# for blob storage, these operations are done in create_commit
Expand Down Expand Up @@ -390,62 +398,82 @@ async def async_upload_blobs(
else:
if len(buffers) != len(filenames):
raise ValueError("Filenames and buffers must have the same length.")
tasks = [
self._async_upload_blob(repo_id, filename, buffer, overwrite)
for filename, buffer in zip(filenames, buffers)
]
await asyncio.gather(*tasks)
return filenames[0] if is_str else filenames

async def _async_upload_blob(self, repo_id, filename, buffer=None, overwrite=True):
storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)
self._last_modified(repo_id, set_cache=True) # set the cache for last_modified

async with self._get_blob_client(
repo_id, use_async=True
) as blob_service_client:
blob_client = blob_service_client.get_blob_client(
container=container, blob=filename
)
if buffer is not None:
await blob_client.upload_blob(buffer, overwrite=overwrite)
else:
local_cache = self._get_local_filepath(repo_id, filename)
with open(file=local_cache, mode="rb") as blob_file:
await blob_client.upload_blob(blob_file, overwrite=overwrite)
tasks = [
self._async_upload_blob(
blob_service_client, repo_id, filename, buffer, overwrite
)
for filename, buffer in zip(filenames, buffers)
]
await asyncio.gather(*tasks)

self.last_modified_cache = None # reset the cache

return filenames[0] if is_str else filenames

async def _async_upload_blob(
self, blob_service_client, repo_id, filename, buffer=None, overwrite=True
):
storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)

blob_client = blob_service_client.get_blob_client(
container=container, blob=filename
)

if buffer is not None:
await blob_client.upload_blob(buffer, overwrite=overwrite)

else:
local_cache = self._get_local_filepath(repo_id, filename)

with open(file=local_cache, mode="rb") as blob_file:
await blob_client.upload_blob(blob_file, overwrite=overwrite)

async def async_download_blobs(
self, repo_id: str, filesnames: Union[List[str], str]
) -> str:
is_str = isinstance(filesnames, str)
if is_str:
filesnames = [filesnames]
tasks = [
self._async_download_blob(repo_id, filename) for filename in filesnames
]
local_filenames = await asyncio.gather(*tasks)
return local_filenames[0] if is_str else local_filenames

async def _async_download_blob(self, repo_id, filename):
storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)
self._last_modified(repo_id, set_cache=True) # set the cache for last_modified

async with self._get_blob_client(
repo_id, use_async=True
) as blob_service_client:
# already cached!
local_filename = self._get_local_filepath(repo_id, filename)
if local_filename.exists():
return local_filename
tasks = [
self._async_download_blob(blob_service_client, repo_id, filename)
for filename in filesnames
]
local_filesnames = await asyncio.gather(*tasks)

blob_client = blob_service_client.get_blob_client(
container=container, blob=filename
)
self.last_modified_cache = None # reset the cache

return local_filesnames[0] if is_str else local_filesnames

os.makedirs(os.path.dirname(local_filename), exist_ok=True)
with open(file=local_filename, mode="wb") as blob_file:
download_stream = await blob_client.download_blob()
data = await download_stream.readall()
blob_file.write(data)
async def _async_download_blob(self, blob_service_client, repo_id, filename):
# already cached!
local_filename = self._get_local_filepath(repo_id, filename)
if local_filename.exists():
return local_filename

storage_uri, container = self._parse_repo_id_to_storage_info(repo_id)
blob_client = blob_service_client.get_blob_client(
container=container, blob=filename
)

os.makedirs(os.path.dirname(local_filename), exist_ok=True)
async with open(file=local_filename, mode="wb") as blob_file:
download_stream = await blob_client.download_blob()
data = await download_stream.readall()
blob_file.write(data)
return local_filename

async def async_copy_blobs(
self,
source_repo_ids,
Expand Down
87 changes: 61 additions & 26 deletions mttl/models/library/expert_library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import glob
import io
import os
Expand Down Expand Up @@ -141,31 +142,49 @@ def _build_lib(self):
logger.error("Repository not found: %s", self.repo_id)
raise e

# Function to download and process a single .meta file
def download_and_process_meta_file(file):
path_or_bytes = self.hf_hub_download(self.repo_id, file)

metadata_entry = MetadataEntry.fromdict(
torch.load(path_or_bytes, map_location="cpu", weights_only=False)
if isinstance(self, BlobExpertLibrary):
local_filenames = asyncio.run(
self.async_download_blobs(
self.repo_id,
meta_files,
)
)
return metadata_entry

# Use ThreadPoolExecutor for multithreading
with ThreadPoolExecutor() as executor:
# Submit tasks to the executor
future_to_file = {
executor.submit(download_and_process_meta_file, file): file
for file in meta_files
}

# process every meta file in new local directory
metadata = []
for future in as_completed(future_to_file):
file = future_to_file[future]
try:
data = future.result()
metadata.append(data)
except Exception as exc:
logger.error("%r generated an exception: %s" % (file, exc))
for file in local_filenames:
metadata_entry = MetadataEntry.fromdict(
torch.load(file, map_location="cpu", weights_only=False)
)
metadata.append(metadata_entry)

else:

# Function to download and process a single .meta file
def download_and_process_meta_file(file):
path_or_bytes = self.hf_hub_download(self.repo_id, file)

metadata_entry = MetadataEntry.fromdict(
torch.load(path_or_bytes, map_location="cpu", weights_only=False)
)
return metadata_entry

# Use ThreadPoolExecutor for multithreading
with ThreadPoolExecutor() as executor:
# Submit tasks to the executor
future_to_file = {
executor.submit(download_and_process_meta_file, file): file
for file in meta_files
}

metadata = []
for future in as_completed(future_to_file):
file = future_to_file[future]
try:
data = future.result()
metadata.append(data)
except Exception as exc:
logger.error("%r generated an exception: %s" % (file, exc))

for metadatum in metadata:
if self.model_name is not None and metadatum.model != self.model_name:
Expand Down Expand Up @@ -282,7 +301,11 @@ def __len__(self):
return len(self.data)

def add_expert(
self, expert_dump: Expert, expert_name: str = None, force: bool = False
self,
expert_dump: Expert,
expert_name: str = None,
force: bool = False,
update_readme: bool = True,
):
if self.sliced:
raise ValueError("Cannot add expert to sliced library.")
Expand All @@ -307,7 +330,9 @@ def add_expert(
self._upload_weights(metadata.expert_name, expert_dump)
self._upload_metadata(metadata)
self.data[metadata.expert_name] = metadata
self._update_readme()
# only update readme if requested. This is useful when adding multiple experts in a batch
if update_readme:
self._update_readme()

def list_auxiliary_data(self) -> Dict[str, Tuple[int, str]]:
"""List auxiliary data in the library, returns a dictionary with the data type, the number of records, and a string representation of the config file."""
Expand Down Expand Up @@ -770,9 +795,15 @@ def clone(

only_tasks = only_tasks or self.tasks
with new_lib.batched_commit():
update_readme = False
for name, expert in self.items():
if expert.name not in new_lib:
new_lib.add_expert(expert, name, force=force)
new_lib.add_expert(expert, name, force=force, update_readme=False)
update_readme = True

# only update readme if we added new experts
if update_readme:
new_lib._update_readme()

# if the new_lib already exists, delete experts that
# are in this lib but were deleted from the expert_lib
Expand Down Expand Up @@ -929,7 +960,11 @@ class LocalExpertLibrary(ExpertLibrary, LocalFSEngine):
"""A local library stored on disk."""

def add_expert(
self, expert_dump: Expert, expert_name: str = None, force: bool = False
self,
expert_dump: Expert,
expert_name: str = None,
force: bool = False,
update_readme: bool = True,
):
expert_name = expert_name or expert_dump.expert_info.expert_name
if "/" in expert_name:
Expand Down

0 comments on commit 00882d6

Please sign in to comment.