From 42c56e89204b361415713f5e06edd40917e3be20 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Tue, 7 Jan 2025 09:41:59 +0200 Subject: [PATCH 01/16] 1st commit --- CHANGELOG.md | 1 + src/scvi/dataloaders/_ann_dataloader.py | 9 +- src/scvi/dataloaders/_concat_dataloader.py | 2 +- src/scvi/dataloaders/_data_splitting.py | 12 ++- src/scvi/external/gimvi/_model.py | 1 + tests/model/test_multigpu.py | 119 +++++++++++++++++++++ 6 files changed, 137 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e362b329f5..a9e2f3f773 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ to [Semantic Versioning]. Full commit history is available in the - Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`. +- Add support for {class}`~scvi.model.SCANVI` multiGPU training, {pr}`30XX`. #### Fixed diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 27e17302d5..90aed95dbb 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -106,7 +106,7 @@ def __init__( self.kwargs = copy.deepcopy(kwargs) if sampler is not None and distributed_sampler: - raise ValueError("Cannot specify both `sampler` and `distributed_sampler`.") + Warning("Cannot specify both `sampler` and `distributed_sampler`.") # custom sampler for efficient minibatching on sparse matrices if sampler is None: @@ -135,4 +135,11 @@ def __init__( if iter_ndarray: self.kwargs.update({"collate_fn": lambda x: x}) + # Special patch for scanvi multigpu + # if adata_manager.registry['model_name']=="SCANVI" and sampler is None + # and distributed_sampler: + # self.kwargs.update({"batch_size": batch_size, "shuffle": False}) + # if adata_manager.registry['model_name']=="SCANVI" and sampler is not None: + self.kwargs.update({"batch_size": batch_size}) + super().__init__(self.dataset, **self.kwargs) diff --git a/src/scvi/dataloaders/_concat_dataloader.py b/src/scvi/dataloaders/_concat_dataloader.py index fdcdea4aa8..031f2d6ba9 100644 --- a/src/scvi/dataloaders/_concat_dataloader.py +++ b/src/scvi/dataloaders/_concat_dataloader.py @@ -71,7 +71,7 @@ def __init__( ) lens = [len(dl) for dl in self.dataloaders] self.largest_dl = self.dataloaders[np.argmax(lens)] - super().__init__(self.largest_dl, **data_loader_kwargs) + super().__init__(self.largest_dl, batch_size=batch_size, **data_loader_kwargs) def __len__(self): return len(self.largest_dl) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index be7c557468..053cf3dcf2 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -404,6 +404,7 @@ def __init__( self.drop_last = kwargs.pop("drop_last", False) self.data_loader_kwargs = kwargs self.n_samples_per_label = n_samples_per_label + self.batch_size = self.data_loader_kwargs.pop("batch_size", settings.batch_size) labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) labels = get_anndata_attribute( @@ -434,7 +435,7 @@ def setup(self, stage: str | None = None): n_labeled_train, n_labeled_val = validate_data_split_with_external_indexing( n_labeled_idx, [labeled_idx_train, labeled_idx_val, labeled_idx_test], - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.batch_size, self.drop_last, ) else: @@ -442,7 +443,7 @@ def setup(self, stage: str | None = None): n_labeled_idx, self.train_size, self.validation_size, - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.batch_size, self.drop_last, self.train_size_is_none, ) @@ -475,7 +476,7 @@ def setup(self, stage: str | None = None): n_unlabeled_train, n_unlabeled_val = validate_data_split_with_external_indexing( n_unlabeled_idx, [unlabeled_idx_train, unlabeled_idx_val, unlabeled_idx_test], - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.batch_size, self.drop_last, ) else: @@ -483,7 +484,7 @@ def setup(self, stage: str | None = None): n_unlabeled_idx, self.train_size, self.validation_size, - self.data_loader_kwargs.pop("batch_size", settings.batch_size), + self.batch_size, self.drop_last, self.train_size_is_none, ) @@ -531,6 +532,7 @@ def train_dataloader(self): return self.data_loader_class( self.adata_manager, indices=self.train_idx, + batch_size=self.batch_size, shuffle=True, drop_last=self.drop_last, pin_memory=self.pin_memory, @@ -677,7 +679,7 @@ def _make_dataloader(self, tensor_dict: dict[str, torch.Tensor], shuffle): batch_size=bs, drop_last=False, ) - return DataLoader(dataset, sampler=sampler, batch_size=None) + return DataLoader(dataset, sampler=sampler, batch_size=bs) def train_dataloader(self): """Create the train data loader.""" diff --git a/src/scvi/external/gimvi/_model.py b/src/scvi/external/gimvi/_model.py index 8bbde7326b..0309329f4e 100644 --- a/src/scvi/external/gimvi/_model.py +++ b/src/scvi/external/gimvi/_model.py @@ -690,6 +690,7 @@ def __init__(self, data_loader_list, **kwargs): self.data_loader_list = data_loader_list self.largest_train_dl_idx = np.argmax([len(dl.indices) for dl in data_loader_list]) self.largest_dl = self.data_loader_list[self.largest_train_dl_idx] + self.kwargs.update({"batch_size": self.batch_size}) super().__init__(self.largest_dl, **kwargs) def __len__(self): diff --git a/tests/model/test_multigpu.py b/tests/model/test_multigpu.py index 301502a27e..bfdb23fcb5 100644 --- a/tests/model/test_multigpu.py +++ b/tests/model/test_multigpu.py @@ -4,6 +4,125 @@ import pytest import torch +import scvi +from scvi.model import SCANVI, SCVI + + +@pytest.mark.multigpu +# SCANVI FROM SCVI - reminder: its impossible to debug pytest multigpu work like this +def test_scanvi_from_scvi_multigpu(): + if torch.cuda.is_available(): + adata = scvi.data.synthetic_iid() + + SCVI.setup_anndata(adata) + + datasplitter_kwargs = {} + datasplitter_kwargs["drop_dataset_tail"] = True + datasplitter_kwargs["drop_last"] = False + + model = SCVI(adata) + + print("multi GPU SCVI train") + model.train( + max_epochs=1, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + datasplitter_kwargs=datasplitter_kwargs, + strategy="ddp_find_unused_parameters_true", + ) + print("done") + torch.distributed.destroy_process_group() + + assert model.is_trained + adata.obsm["scVI"] = model.get_latent_representation() + + datasplitter_kwargs = {} + datasplitter_kwargs["distributed_sampler"] = True + datasplitter_kwargs["drop_last"] = False + + print("multi GPU scanvi load from scvi model") + model_scanvi = scvi.model.SCANVI.from_scvi_model( + model, + adata=adata, + labels_key="labels", + unlabeled_category="label_0", + ) + print("done") + print("multi GPU scanvi train from scvi") + model_scanvi.train( + max_epochs=1, + train_size=0.5, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + datasplitter_kwargs=datasplitter_kwargs, + ) + print("done") + adata.obsm["scANVI"] = model_scanvi.get_latent_representation() + + torch.distributed.destroy_process_group() + + assert model_scanvi.is_trained + + +@pytest.mark.multigpu +# SCANVI FROM SCRATCH - reminder: its impossible to debug pytest multigpu work like this +def test_scanvi_from_scratch_multigpu(): + if torch.cuda.is_available(): + adata = scvi.data.synthetic_iid() + + # SCVI.setup_anndata(adata) + # + # datasplitter_kwargs = {} + # datasplitter_kwargs["drop_dataset_tail"] = True + # datasplitter_kwargs["drop_last"] = False + # + # print("multi GPU train") + # model.train( + # max_epochs=1, + # check_val_every_n_epoch=1, + # accelerator="gpu", + # devices=-1, + # datasplitter_kwargs=datasplitter_kwargs, + # strategy='ddp_find_unused_parameters_true' + # ) + # + # torch.distributed.destroy_process_group() + # + # assert model.is_trained + + SCANVI.setup_anndata( + adata, + "labels", + "label_0", + batch_key="batch", + ) + + datasplitter_kwargs = {} + datasplitter_kwargs["distributed_sampler"] = True + datasplitter_kwargs["drop_dataset_tail"] = True + datasplitter_kwargs["drop_last"] = False + + model = SCANVI(adata, n_latent=10) + + print("multi GPU scanvi train from scracth") + model.train( + max_epochs=1, + train_size=0.5, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + datasplitter_kwargs=datasplitter_kwargs, + strategy="ddp_find_unused_parameters_true", + ) + print("done") + + torch.distributed.destroy_process_group() + + assert model.is_trained + @pytest.mark.multigpu def test_scvi_train_ddp(save_path: str): From ee0e446171cf4b3ed136764123a07289862981fa Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 29 Jan 2025 22:08:49 +0200 Subject: [PATCH 02/16] added regular data splitting for scanvi in case of multipgu use --- src/scvi/model/_scanvi.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 87a14dbd54..4bbc2b879c 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -24,8 +24,8 @@ NumericalJointObsField, NumericalObsField, ) -from scvi.dataloaders import SemiSupervisedDataSplitter -from scvi.model._utils import _init_library_size, get_max_epochs_heuristic +from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter +from scvi.model._utils import _init_library_size, get_max_epochs_heuristic, use_distributed_sampler from scvi.module import SCANVAE from scvi.train import SemiSupervisedTrainingPlan, TrainRunner from scvi.train._callbacks import SubSampleLabels @@ -408,15 +408,32 @@ def train( # if we have labeled cells, we want to subsample labels each epoch sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] - data_splitter = SemiSupervisedDataSplitter( - adata_manager=self.adata_manager, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - n_samples_per_label=n_samples_per_label, - batch_size=batch_size, - **datasplitter_kwargs, - ) + # TODO: Create the SCVI regular dataplitter in case of multigpu plausioble + if (len(self._unlabeled_indices) == 0) and (datasplitter_kwargs["distributed_sampler"]): + # we are in a multigpu env and its all labeled - n_samples_per_label is not used here + # also I remvoed the option for load_sparse_tensor it can be reached from kwargs + # in this way we bypass the use of concatdataloaders for now + data_splitter = DataSplitter( + self.adata_manager, + train_size=train_size, + validation_size=validation_size, + batch_size=batch_size, + shuffle_set_split=shuffle_set_split, + distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), + **datasplitter_kwargs, + ) + else: + # what we had so far in scanvi (concat dataloaders) + data_splitter = SemiSupervisedDataSplitter( + adata_manager=self.adata_manager, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + n_samples_per_label=n_samples_per_label, + batch_size=batch_size, + **datasplitter_kwargs, + ) + training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) if "callbacks" in trainer_kwargs.keys(): trainer_kwargs["callbacks"] + [sampler_callback] From 167ed5beadc81fd673bdc6019db77e761addd1ec Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Wed, 29 Jan 2025 22:14:48 +0200 Subject: [PATCH 03/16] update changlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 27c5067176..2b23d3c460 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ to [Semantic Versioning]. Full commit history is available in the - Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`. -- Add support for {class}`~scvi.model.SCANVI` multiGPU training, {pr}`30XX`. +- Add support for {class}`~scvi.model.SCANVI` multiGPU training, {pr}`3125`. - Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial transcriptomics {pr}`3144`. From 69188f02a46558ad9c3b05719468280f9c8224a5 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Thu, 30 Jan 2025 14:23:55 +0200 Subject: [PATCH 04/16] treat scanvi all labeled with regualr anndataloder and not with concat dataloader --- src/scvi/dataloaders/_ann_dataloader.py | 9 +-------- src/scvi/dataloaders/_concat_dataloader.py | 2 +- src/scvi/dataloaders/_data_splitting.py | 4 +--- src/scvi/model/_scanvi.py | 20 +++++++++++++------- tests/model/test_multigpu.py | 14 +++++++++----- 5 files changed, 25 insertions(+), 24 deletions(-) diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 90aed95dbb..27e17302d5 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -106,7 +106,7 @@ def __init__( self.kwargs = copy.deepcopy(kwargs) if sampler is not None and distributed_sampler: - Warning("Cannot specify both `sampler` and `distributed_sampler`.") + raise ValueError("Cannot specify both `sampler` and `distributed_sampler`.") # custom sampler for efficient minibatching on sparse matrices if sampler is None: @@ -135,11 +135,4 @@ def __init__( if iter_ndarray: self.kwargs.update({"collate_fn": lambda x: x}) - # Special patch for scanvi multigpu - # if adata_manager.registry['model_name']=="SCANVI" and sampler is None - # and distributed_sampler: - # self.kwargs.update({"batch_size": batch_size, "shuffle": False}) - # if adata_manager.registry['model_name']=="SCANVI" and sampler is not None: - self.kwargs.update({"batch_size": batch_size}) - super().__init__(self.dataset, **self.kwargs) diff --git a/src/scvi/dataloaders/_concat_dataloader.py b/src/scvi/dataloaders/_concat_dataloader.py index 031f2d6ba9..fdcdea4aa8 100644 --- a/src/scvi/dataloaders/_concat_dataloader.py +++ b/src/scvi/dataloaders/_concat_dataloader.py @@ -71,7 +71,7 @@ def __init__( ) lens = [len(dl) for dl in self.dataloaders] self.largest_dl = self.dataloaders[np.argmax(lens)] - super().__init__(self.largest_dl, batch_size=batch_size, **data_loader_kwargs) + super().__init__(self.largest_dl, **data_loader_kwargs) def __len__(self): return len(self.largest_dl) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 0b98b4c559..9ea0146acb 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -404,7 +404,6 @@ def __init__( self.drop_last = kwargs.pop("drop_last", False) self.data_loader_kwargs = kwargs self.n_samples_per_label = n_samples_per_label - self.batch_size = self.data_loader_kwargs.get("batch_size", settings.batch_size) labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) labels = get_anndata_attribute( @@ -532,7 +531,6 @@ def train_dataloader(self): return self.data_loader_class( self.adata_manager, indices=self.train_idx, - batch_size=self.batch_size, shuffle=True, drop_last=self.drop_last, pin_memory=self.pin_memory, @@ -679,7 +677,7 @@ def _make_dataloader(self, tensor_dict: dict[str, torch.Tensor], shuffle): batch_size=bs, drop_last=False, ) - return DataLoader(dataset, sampler=sampler, batch_size=bs) + return DataLoader(dataset, sampler=sampler, batch_size=None) def train_dataloader(self): """Create the train data loader.""" diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 4bbc2b879c..72d9e8f70a 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -408,22 +408,26 @@ def train( # if we have labeled cells, we want to subsample labels each epoch sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] - # TODO: Create the SCVI regular dataplitter in case of multigpu plausioble - if (len(self._unlabeled_indices) == 0) and (datasplitter_kwargs["distributed_sampler"]): + # Create the SCVI regular dataplitter in case of multigpu plausioble + if (len(self._unlabeled_indices) == 0) and use_distributed_sampler( + trainer_kwargs.get("strategy", None) + ): # we are in a multigpu env and its all labeled - n_samples_per_label is not used here # also I remvoed the option for load_sparse_tensor it can be reached from kwargs # in this way we bypass the use of concatdataloaders for now + run_multi_gpu = True data_splitter = DataSplitter( self.adata_manager, train_size=train_size, validation_size=validation_size, batch_size=batch_size, shuffle_set_split=shuffle_set_split, - distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), + distributed_sampler=True, **datasplitter_kwargs, ) else: # what we had so far in scanvi (concat dataloaders) + run_multi_gpu = False data_splitter = SemiSupervisedDataSplitter( adata_manager=self.adata_manager, train_size=train_size, @@ -435,10 +439,12 @@ def train( ) training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) - if "callbacks" in trainer_kwargs.keys(): - trainer_kwargs["callbacks"] + [sampler_callback] - else: - trainer_kwargs["callbacks"] = sampler_callback + if not run_multi_gpu: + # TODO: how do we generate subsamples per class in multigpu? + if "callbacks" in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] + [sampler_callback] + else: + trainer_kwargs["callbacks"] = sampler_callback runner = TrainRunner( self, diff --git a/tests/model/test_multigpu.py b/tests/model/test_multigpu.py index bfdb23fcb5..082aa60590 100644 --- a/tests/model/test_multigpu.py +++ b/tests/model/test_multigpu.py @@ -46,7 +46,7 @@ def test_scanvi_from_scvi_multigpu(): model, adata=adata, labels_key="labels", - unlabeled_category="label_0", + unlabeled_category="unknown", ) print("done") print("multi GPU scanvi train from scvi") @@ -96,12 +96,11 @@ def test_scanvi_from_scratch_multigpu(): SCANVI.setup_anndata( adata, "labels", - "label_0", + "unknown", batch_key="batch", ) datasplitter_kwargs = {} - datasplitter_kwargs["distributed_sampler"] = True datasplitter_kwargs["drop_dataset_tail"] = True datasplitter_kwargs["drop_last"] = False @@ -188,19 +187,24 @@ def test_scanvi_train_ddp(save_path: str): SCANVI.setup_anndata( adata, "labels", - "label_0", + "unknown", batch_key="batch", ) model = SCANVI(adata, n_latent=10) +datasplitter_kwargs = {} +datasplitter_kwargs["drop_dataset_tail"] = True +datasplitter_kwargs["drop_last"] = False + model.train( - max_epochs=100, + max_epochs=1, train_size=0.5, check_val_every_n_epoch=1, accelerator="gpu", devices=-1, strategy="ddp_find_unused_parameters_true", + datasplitter_kwargs=datasplitter_kwargs, ) torch.distributed.destroy_process_group() From 1c7273d07204634fce368c4b00f15acb3cb9dec9 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 30 Jan 2025 14:43:29 +0200 Subject: [PATCH 05/16] treat scanvi all labeled with regualr anndataloder and not with concat dataloader --- CHANGELOG.md | 2 +- src/scvi/dataloaders/_ann_dataloader.py | 9 +-------- src/scvi/dataloaders/_concat_dataloader.py | 2 +- src/scvi/dataloaders/_data_splitting.py | 4 +--- src/scvi/model/_scanvi.py | 20 +++++++++++++------- tests/model/test_multigpu.py | 14 +++++++++----- 6 files changed, 26 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b23d3c460..b839befe3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ to [Semantic Versioning]. Full commit history is available in the - Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`. -- Add support for {class}`~scvi.model.SCANVI` multiGPU training, {pr}`3125`. +- Add support for {class}`~scvi.model.SCANVI` multiGPU training, only for the full labeled case {pr}`3125`. - Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial transcriptomics {pr}`3144`. diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 90aed95dbb..27e17302d5 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -106,7 +106,7 @@ def __init__( self.kwargs = copy.deepcopy(kwargs) if sampler is not None and distributed_sampler: - Warning("Cannot specify both `sampler` and `distributed_sampler`.") + raise ValueError("Cannot specify both `sampler` and `distributed_sampler`.") # custom sampler for efficient minibatching on sparse matrices if sampler is None: @@ -135,11 +135,4 @@ def __init__( if iter_ndarray: self.kwargs.update({"collate_fn": lambda x: x}) - # Special patch for scanvi multigpu - # if adata_manager.registry['model_name']=="SCANVI" and sampler is None - # and distributed_sampler: - # self.kwargs.update({"batch_size": batch_size, "shuffle": False}) - # if adata_manager.registry['model_name']=="SCANVI" and sampler is not None: - self.kwargs.update({"batch_size": batch_size}) - super().__init__(self.dataset, **self.kwargs) diff --git a/src/scvi/dataloaders/_concat_dataloader.py b/src/scvi/dataloaders/_concat_dataloader.py index 031f2d6ba9..fdcdea4aa8 100644 --- a/src/scvi/dataloaders/_concat_dataloader.py +++ b/src/scvi/dataloaders/_concat_dataloader.py @@ -71,7 +71,7 @@ def __init__( ) lens = [len(dl) for dl in self.dataloaders] self.largest_dl = self.dataloaders[np.argmax(lens)] - super().__init__(self.largest_dl, batch_size=batch_size, **data_loader_kwargs) + super().__init__(self.largest_dl, **data_loader_kwargs) def __len__(self): return len(self.largest_dl) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 0b98b4c559..9ea0146acb 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -404,7 +404,6 @@ def __init__( self.drop_last = kwargs.pop("drop_last", False) self.data_loader_kwargs = kwargs self.n_samples_per_label = n_samples_per_label - self.batch_size = self.data_loader_kwargs.get("batch_size", settings.batch_size) labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) labels = get_anndata_attribute( @@ -532,7 +531,6 @@ def train_dataloader(self): return self.data_loader_class( self.adata_manager, indices=self.train_idx, - batch_size=self.batch_size, shuffle=True, drop_last=self.drop_last, pin_memory=self.pin_memory, @@ -679,7 +677,7 @@ def _make_dataloader(self, tensor_dict: dict[str, torch.Tensor], shuffle): batch_size=bs, drop_last=False, ) - return DataLoader(dataset, sampler=sampler, batch_size=bs) + return DataLoader(dataset, sampler=sampler, batch_size=None) def train_dataloader(self): """Create the train data loader.""" diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 4bbc2b879c..72d9e8f70a 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -408,22 +408,26 @@ def train( # if we have labeled cells, we want to subsample labels each epoch sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] - # TODO: Create the SCVI regular dataplitter in case of multigpu plausioble - if (len(self._unlabeled_indices) == 0) and (datasplitter_kwargs["distributed_sampler"]): + # Create the SCVI regular dataplitter in case of multigpu plausioble + if (len(self._unlabeled_indices) == 0) and use_distributed_sampler( + trainer_kwargs.get("strategy", None) + ): # we are in a multigpu env and its all labeled - n_samples_per_label is not used here # also I remvoed the option for load_sparse_tensor it can be reached from kwargs # in this way we bypass the use of concatdataloaders for now + run_multi_gpu = True data_splitter = DataSplitter( self.adata_manager, train_size=train_size, validation_size=validation_size, batch_size=batch_size, shuffle_set_split=shuffle_set_split, - distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), + distributed_sampler=True, **datasplitter_kwargs, ) else: # what we had so far in scanvi (concat dataloaders) + run_multi_gpu = False data_splitter = SemiSupervisedDataSplitter( adata_manager=self.adata_manager, train_size=train_size, @@ -435,10 +439,12 @@ def train( ) training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) - if "callbacks" in trainer_kwargs.keys(): - trainer_kwargs["callbacks"] + [sampler_callback] - else: - trainer_kwargs["callbacks"] = sampler_callback + if not run_multi_gpu: + # TODO: how do we generate subsamples per class in multigpu? + if "callbacks" in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] + [sampler_callback] + else: + trainer_kwargs["callbacks"] = sampler_callback runner = TrainRunner( self, diff --git a/tests/model/test_multigpu.py b/tests/model/test_multigpu.py index bfdb23fcb5..082aa60590 100644 --- a/tests/model/test_multigpu.py +++ b/tests/model/test_multigpu.py @@ -46,7 +46,7 @@ def test_scanvi_from_scvi_multigpu(): model, adata=adata, labels_key="labels", - unlabeled_category="label_0", + unlabeled_category="unknown", ) print("done") print("multi GPU scanvi train from scvi") @@ -96,12 +96,11 @@ def test_scanvi_from_scratch_multigpu(): SCANVI.setup_anndata( adata, "labels", - "label_0", + "unknown", batch_key="batch", ) datasplitter_kwargs = {} - datasplitter_kwargs["distributed_sampler"] = True datasplitter_kwargs["drop_dataset_tail"] = True datasplitter_kwargs["drop_last"] = False @@ -188,19 +187,24 @@ def test_scanvi_train_ddp(save_path: str): SCANVI.setup_anndata( adata, "labels", - "label_0", + "unknown", batch_key="batch", ) model = SCANVI(adata, n_latent=10) +datasplitter_kwargs = {} +datasplitter_kwargs["drop_dataset_tail"] = True +datasplitter_kwargs["drop_last"] = False + model.train( - max_epochs=100, + max_epochs=1, train_size=0.5, check_val_every_n_epoch=1, accelerator="gpu", devices=-1, strategy="ddp_find_unused_parameters_true", + datasplitter_kwargs=datasplitter_kwargs, ) torch.distributed.destroy_process_group() From 0cd6f3598b20c3aaa6bce7a1a98fc719eef6bbeb Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 30 Jan 2025 14:44:42 +0200 Subject: [PATCH 06/16] merge with main --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 943703f938..a93b59ad4d 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 943703f938c43ddc681e01c013d704db37fa3193 +Subproject commit a93b59ad4d1e38477ab0bce4890a6ba7587bf4d0 From b15d5c7d3eb409f6e74a42b271d0041aed94494c Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 30 Jan 2025 15:04:03 +0200 Subject: [PATCH 07/16] fix gimvi --- src/scvi/external/gimvi/_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/scvi/external/gimvi/_model.py b/src/scvi/external/gimvi/_model.py index 0309329f4e..8bbde7326b 100644 --- a/src/scvi/external/gimvi/_model.py +++ b/src/scvi/external/gimvi/_model.py @@ -690,7 +690,6 @@ def __init__(self, data_loader_list, **kwargs): self.data_loader_list = data_loader_list self.largest_train_dl_idx = np.argmax([len(dl.indices) for dl in data_loader_list]) self.largest_dl = self.data_loader_list[self.largest_train_dl_idx] - self.kwargs.update({"batch_size": self.batch_size}) super().__init__(self.largest_dl, **kwargs) def __len__(self): From 92f0b0c696104b021efbe240b4dd5d4a8a7099bb Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 30 Jan 2025 16:45:07 +0200 Subject: [PATCH 08/16] change tests --- tests/dataloaders/test_dataloaders.py | 2 +- tests/model/test_multigpu.py | 230 +++++++++++++------------- 2 files changed, 114 insertions(+), 118 deletions(-) diff --git a/tests/dataloaders/test_dataloaders.py b/tests/dataloaders/test_dataloaders.py index c1e96c6786..a224b65d8e 100644 --- a/tests/dataloaders/test_dataloaders.py +++ b/tests/dataloaders/test_dataloaders.py @@ -134,7 +134,7 @@ def test_anndataloader_distributed_sampler(num_processes: int, save_path: str): @pytest.mark.multigpu -@pytest.mark.parametrize("num_processes", [1, 2]) +@pytest.mark.parametrize("num_processes", [1]) def test_scanvi_with_distributed_sampler(num_processes: int, save_path: str): adata = scvi.data.synthetic_iid() SCANVI.setup_anndata( diff --git a/tests/model/test_multigpu.py b/tests/model/test_multigpu.py index 082aa60590..dac3730d79 100644 --- a/tests/model/test_multigpu.py +++ b/tests/model/test_multigpu.py @@ -4,123 +4,119 @@ import pytest import torch -import scvi -from scvi.model import SCANVI, SCVI - - -@pytest.mark.multigpu -# SCANVI FROM SCVI - reminder: its impossible to debug pytest multigpu work like this -def test_scanvi_from_scvi_multigpu(): - if torch.cuda.is_available(): - adata = scvi.data.synthetic_iid() - - SCVI.setup_anndata(adata) - - datasplitter_kwargs = {} - datasplitter_kwargs["drop_dataset_tail"] = True - datasplitter_kwargs["drop_last"] = False - - model = SCVI(adata) - - print("multi GPU SCVI train") - model.train( - max_epochs=1, - check_val_every_n_epoch=1, - accelerator="gpu", - devices=-1, - datasplitter_kwargs=datasplitter_kwargs, - strategy="ddp_find_unused_parameters_true", - ) - print("done") - torch.distributed.destroy_process_group() - - assert model.is_trained - adata.obsm["scVI"] = model.get_latent_representation() - - datasplitter_kwargs = {} - datasplitter_kwargs["distributed_sampler"] = True - datasplitter_kwargs["drop_last"] = False - - print("multi GPU scanvi load from scvi model") - model_scanvi = scvi.model.SCANVI.from_scvi_model( - model, - adata=adata, - labels_key="labels", - unlabeled_category="unknown", - ) - print("done") - print("multi GPU scanvi train from scvi") - model_scanvi.train( - max_epochs=1, - train_size=0.5, - check_val_every_n_epoch=1, - accelerator="gpu", - devices=-1, - strategy="ddp_find_unused_parameters_true", - datasplitter_kwargs=datasplitter_kwargs, - ) - print("done") - adata.obsm["scANVI"] = model_scanvi.get_latent_representation() - - torch.distributed.destroy_process_group() - - assert model_scanvi.is_trained - - -@pytest.mark.multigpu -# SCANVI FROM SCRATCH - reminder: its impossible to debug pytest multigpu work like this -def test_scanvi_from_scratch_multigpu(): - if torch.cuda.is_available(): - adata = scvi.data.synthetic_iid() - - # SCVI.setup_anndata(adata) - # - # datasplitter_kwargs = {} - # datasplitter_kwargs["drop_dataset_tail"] = True - # datasplitter_kwargs["drop_last"] = False - # - # print("multi GPU train") - # model.train( - # max_epochs=1, - # check_val_every_n_epoch=1, - # accelerator="gpu", - # devices=-1, - # datasplitter_kwargs=datasplitter_kwargs, - # strategy='ddp_find_unused_parameters_true' - # ) - # - # torch.distributed.destroy_process_group() - # - # assert model.is_trained - - SCANVI.setup_anndata( - adata, - "labels", - "unknown", - batch_key="batch", - ) - - datasplitter_kwargs = {} - datasplitter_kwargs["drop_dataset_tail"] = True - datasplitter_kwargs["drop_last"] = False - - model = SCANVI(adata, n_latent=10) - - print("multi GPU scanvi train from scracth") - model.train( - max_epochs=1, - train_size=0.5, - check_val_every_n_epoch=1, - accelerator="gpu", - devices=-1, - datasplitter_kwargs=datasplitter_kwargs, - strategy="ddp_find_unused_parameters_true", - ) - print("done") - - torch.distributed.destroy_process_group() - - assert model.is_trained +# @pytest.mark.multigpu +# # SCANVI FROM SCVI - reminder: its impossible to debug pytest multigpu work like this +# def test_scanvi_from_scvi_multigpu(): +# if torch.cuda.is_available(): +# adata = scvi.data.synthetic_iid() +# +# SCVI.setup_anndata(adata) +# +# datasplitter_kwargs = {} +# datasplitter_kwargs["drop_dataset_tail"] = True +# datasplitter_kwargs["drop_last"] = False +# +# model = SCVI(adata) +# +# print("multi GPU SCVI train") +# model.train( +# max_epochs=1, +# check_val_every_n_epoch=1, +# accelerator="gpu", +# devices=-1, +# datasplitter_kwargs=datasplitter_kwargs, +# strategy="ddp_find_unused_parameters_true", +# ) +# print("done") +# torch.distributed.destroy_process_group() +# +# assert model.is_trained +# adata.obsm["scVI"] = model.get_latent_representation() +# +# datasplitter_kwargs = {} +# datasplitter_kwargs["distributed_sampler"] = True +# datasplitter_kwargs["drop_last"] = False +# +# print("multi GPU scanvi load from scvi model") +# model_scanvi = scvi.model.SCANVI.from_scvi_model( +# model, +# adata=adata, +# labels_key="labels", +# unlabeled_category="unknown", +# ) +# print("done") +# print("multi GPU scanvi train from scvi") +# model_scanvi.train( +# max_epochs=1, +# train_size=0.5, +# check_val_every_n_epoch=1, +# accelerator="gpu", +# devices=-1, +# strategy="ddp_find_unused_parameters_true", +# datasplitter_kwargs=datasplitter_kwargs, +# ) +# print("done") +# adata.obsm["scANVI"] = model_scanvi.get_latent_representation() +# +# torch.distributed.destroy_process_group() +# +# assert model_scanvi.is_trained +# +# +# @pytest.mark.multigpu +# # SCANVI FROM SCRATCH - reminder: its impossible to debug pytest multigpu work like this +# def test_scanvi_from_scratch_multigpu(): +# if torch.cuda.is_available(): +# adata = scvi.data.synthetic_iid() +# +# # SCVI.setup_anndata(adata) +# # +# # datasplitter_kwargs = {} +# # datasplitter_kwargs["drop_dataset_tail"] = True +# # datasplitter_kwargs["drop_last"] = False +# # +# # print("multi GPU train") +# # model.train( +# # max_epochs=1, +# # check_val_every_n_epoch=1, +# # accelerator="gpu", +# # devices=-1, +# # datasplitter_kwargs=datasplitter_kwargs, +# # strategy='ddp_find_unused_parameters_true' +# # ) +# # +# # torch.distributed.destroy_process_group() +# # +# # assert model.is_trained +# +# SCANVI.setup_anndata( +# adata, +# "labels", +# "unknown", +# batch_key="batch", +# ) +# +# datasplitter_kwargs = {} +# datasplitter_kwargs["drop_dataset_tail"] = True +# datasplitter_kwargs["drop_last"] = False +# +# model = SCANVI(adata, n_latent=10) +# +# print("multi GPU scanvi train from scracth") +# model.train( +# max_epochs=1, +# train_size=0.5, +# check_val_every_n_epoch=1, +# accelerator="gpu", +# devices=-1, +# datasplitter_kwargs=datasplitter_kwargs, +# strategy="ddp_find_unused_parameters_true", +# ) +# print("done") +# +# torch.distributed.destroy_process_group() +# +# assert model.is_trained @pytest.mark.multigpu From 7f6ac72be2489235d19b05308c03a5ed8e3fcc35 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Thu, 30 Jan 2025 16:59:07 +0200 Subject: [PATCH 09/16] change tests --- tests/model/test_multigpu.py | 217 ++++++++++++++++------------------- 1 file changed, 102 insertions(+), 115 deletions(-) diff --git a/tests/model/test_multigpu.py b/tests/model/test_multigpu.py index dac3730d79..0e930d06ea 100644 --- a/tests/model/test_multigpu.py +++ b/tests/model/test_multigpu.py @@ -4,123 +4,110 @@ import pytest import torch -# @pytest.mark.multigpu -# # SCANVI FROM SCVI - reminder: its impossible to debug pytest multigpu work like this -# def test_scanvi_from_scvi_multigpu(): -# if torch.cuda.is_available(): -# adata = scvi.data.synthetic_iid() -# -# SCVI.setup_anndata(adata) -# -# datasplitter_kwargs = {} -# datasplitter_kwargs["drop_dataset_tail"] = True -# datasplitter_kwargs["drop_last"] = False -# -# model = SCVI(adata) -# -# print("multi GPU SCVI train") -# model.train( -# max_epochs=1, -# check_val_every_n_epoch=1, -# accelerator="gpu", -# devices=-1, -# datasplitter_kwargs=datasplitter_kwargs, -# strategy="ddp_find_unused_parameters_true", -# ) -# print("done") -# torch.distributed.destroy_process_group() -# -# assert model.is_trained -# adata.obsm["scVI"] = model.get_latent_representation() -# -# datasplitter_kwargs = {} -# datasplitter_kwargs["distributed_sampler"] = True -# datasplitter_kwargs["drop_last"] = False -# -# print("multi GPU scanvi load from scvi model") -# model_scanvi = scvi.model.SCANVI.from_scvi_model( -# model, -# adata=adata, -# labels_key="labels", -# unlabeled_category="unknown", -# ) -# print("done") -# print("multi GPU scanvi train from scvi") -# model_scanvi.train( -# max_epochs=1, -# train_size=0.5, -# check_val_every_n_epoch=1, -# accelerator="gpu", -# devices=-1, -# strategy="ddp_find_unused_parameters_true", -# datasplitter_kwargs=datasplitter_kwargs, -# ) -# print("done") -# adata.obsm["scANVI"] = model_scanvi.get_latent_representation() -# -# torch.distributed.destroy_process_group() -# -# assert model_scanvi.is_trained -# -# -# @pytest.mark.multigpu -# # SCANVI FROM SCRATCH - reminder: its impossible to debug pytest multigpu work like this -# def test_scanvi_from_scratch_multigpu(): -# if torch.cuda.is_available(): -# adata = scvi.data.synthetic_iid() -# -# # SCVI.setup_anndata(adata) -# # -# # datasplitter_kwargs = {} -# # datasplitter_kwargs["drop_dataset_tail"] = True -# # datasplitter_kwargs["drop_last"] = False -# # -# # print("multi GPU train") -# # model.train( -# # max_epochs=1, -# # check_val_every_n_epoch=1, -# # accelerator="gpu", -# # devices=-1, -# # datasplitter_kwargs=datasplitter_kwargs, -# # strategy='ddp_find_unused_parameters_true' -# # ) -# # -# # torch.distributed.destroy_process_group() -# # -# # assert model.is_trained -# -# SCANVI.setup_anndata( -# adata, -# "labels", -# "unknown", -# batch_key="batch", -# ) -# -# datasplitter_kwargs = {} -# datasplitter_kwargs["drop_dataset_tail"] = True -# datasplitter_kwargs["drop_last"] = False -# -# model = SCANVI(adata, n_latent=10) -# -# print("multi GPU scanvi train from scracth") -# model.train( -# max_epochs=1, -# train_size=0.5, -# check_val_every_n_epoch=1, -# accelerator="gpu", -# devices=-1, -# datasplitter_kwargs=datasplitter_kwargs, -# strategy="ddp_find_unused_parameters_true", -# ) -# print("done") -# -# torch.distributed.destroy_process_group() -# -# assert model.is_trained + +@pytest.mark.multigpu +# SCANVI FROM SCVI - reminder: its impossible to debug pytest multigpu work like this +def test_scanvi_from_scvi_multigpu(): + if torch.cuda.is_available(): + import scvi + from scvi.model import SCVI + + adata = scvi.data.synthetic_iid() + + SCVI.setup_anndata(adata) + + datasplitter_kwargs = {} + datasplitter_kwargs["drop_dataset_tail"] = True + datasplitter_kwargs["drop_last"] = False + + model = SCVI(adata) + + print("multi GPU SCVI train") + model.train( + max_epochs=1, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + datasplitter_kwargs=datasplitter_kwargs, + strategy="ddp_find_unused_parameters_true", + ) + print("done") + torch.distributed.destroy_process_group() + + assert model.is_trained + adata.obsm["scVI"] = model.get_latent_representation() + + datasplitter_kwargs = {} + datasplitter_kwargs["drop_dataset_tail"] = True + datasplitter_kwargs["drop_last"] = False + + print("multi GPU scanvi load from scvi model") + model_scanvi = scvi.model.SCANVI.from_scvi_model( + model, + adata=adata, + labels_key="labels", + unlabeled_category="unknown", + ) + print("done") + print("multi GPU scanvi train from scvi") + model_scanvi.train( + max_epochs=1, + train_size=0.5, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + datasplitter_kwargs=datasplitter_kwargs, + ) + print("done") + adata.obsm["scANVI"] = model_scanvi.get_latent_representation() + + torch.distributed.destroy_process_group() + + assert model_scanvi.is_trained + + +@pytest.mark.multigpu +# SCANVI FROM SCRATCH - reminder: its impossible to debug pytest multigpu work like this +def test_scanvi_from_scratch_multigpu(): + if torch.cuda.is_available(): + import scvi + from scvi.model import SCANVI + + adata = scvi.data.synthetic_iid() + + SCANVI.setup_anndata( + adata, + "labels", + "unknown", + batch_key="batch", + ) + + datasplitter_kwargs = {} + datasplitter_kwargs["drop_dataset_tail"] = True + datasplitter_kwargs["drop_last"] = False + + model = SCANVI(adata, n_latent=10) + + print("multi GPU scanvi train from scracth") + model.train( + max_epochs=1, + train_size=0.5, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + datasplitter_kwargs=datasplitter_kwargs, + strategy="ddp_find_unused_parameters_true", + ) + print("done") + + torch.distributed.destroy_process_group() + + assert model.is_trained @pytest.mark.multigpu -def test_scvi_train_ddp(save_path: str): +def test_scvi_train_ddp(save_path: str = "."): training_code = """ import torch import scvi @@ -173,7 +160,7 @@ def launch_ddp(world_size, temp_file_path): @pytest.mark.multigpu -def test_scanvi_train_ddp(save_path: str): +def test_scanvi_train_ddp(save_path: str = "."): training_code = """ import torch import scvi From 85300040d30b44400b2c7bd9558cc507bac5e537 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Sun, 2 Feb 2025 16:34:59 +0200 Subject: [PATCH 10/16] re think the scanvi multi gpu problem: adding support for anndataloder to deal with labels as well as resample them as scanvi needs in the labelled case. adding the support for scanvi model distributed sampler. reuplaod the test and added a changlog. also validated tutorials for biological correctness. --- CHANGELOG.md | 2 +- src/scvi/dataloaders/_ann_dataloader.py | 60 +++++++++++++++++++++- src/scvi/dataloaders/_concat_dataloader.py | 15 ++++-- src/scvi/dataloaders/_data_splitting.py | 7 ++- src/scvi/dataloaders/_semi_dataloader.py | 4 +- src/scvi/model/_scanvi.py | 52 ++++++------------- src/scvi/model/_utils.py | 1 + 7 files changed, 96 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b839befe3c..d458b6755c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ to [Semantic Versioning]. Full commit history is available in the - Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`. -- Add support for {class}`~scvi.model.SCANVI` multiGPU training, only for the full labeled case {pr}`3125`. +- Add support for {class}`~scvi.model.SCANVI` multiGPU training {pr}`3125`. - Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial transcriptomics {pr}`3144`. diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 27e17302d5..1673c8f260 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -10,8 +10,9 @@ SequentialSampler, ) -from scvi import settings +from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager +from scvi.data._utils import get_anndata_attribute from ._samplers import BatchDistributedSampler @@ -56,6 +57,8 @@ class AnnDataLoader(DataLoader): distributed_sampler ``EXPERIMENTAL`` Whether to use :class:`~scvi.dataloaders.BatchDistributedSampler` as the sampler. If `True`, `sampler` must be `None`. + n_samples_per_label + Number of subsamples for each label class to sample per epoch load_sparse_tensor ``EXPERIMENTAL`` If ``True``, loads data with sparse CSR or CSC layout as a :class:`~torch.Tensor` with the same layout. Can lead to speedups in data transfers to @@ -83,6 +86,7 @@ def __init__( data_and_attributes: list[str] | dict[str, np.dtype] | None = None, iter_ndarray: bool = False, distributed_sampler: bool = False, + n_samples_per_label: int | None = None, load_sparse_tensor: bool = False, **kwargs, ): @@ -93,6 +97,7 @@ def __init__( indices = np.where(indices)[0].ravel() indices = np.asarray(indices) self.indices = indices + self.n_samples_per_label = n_samples_per_label self.dataset = adata_manager.create_torch_dataset( indices=indices, data_and_attributes=data_and_attributes, @@ -104,10 +109,32 @@ def __init__( kwargs["persistent_workers"] = settings.dl_persistent_workers self.kwargs = copy.deepcopy(kwargs) + self.adata_manager = adata_manager + self.data_and_attributes = data_and_attributes + self._shuffle = shuffle + self._batch_size = batch_size + self._drop_last = drop_last + self.load_sparse_tensor = load_sparse_tensor if sampler is not None and distributed_sampler: raise ValueError("Cannot specify both `sampler` and `distributed_sampler`.") + # Next block of code is for the case of fully labeled anndataloder + labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + labels = get_anndata_attribute( + adata_manager.adata, + adata_manager.data_registry.labels.attr_name, + labels_state_registry.original_key, + ).ravel() + if hasattr(labels_state_registry, "unlabeled_category"): + # save a nested list of the indices per labeled category (if exists) + self.labeled_locs = [] + for label in np.unique(labels): + if label != labels_state_registry.unlabeled_category: + label_loc_idx = np.where(labels[indices] == label)[0] + label_loc = self.indices[label_loc_idx] + self.labeled_locs.append(label_loc) + # custom sampler for efficient minibatching on sparse matrices if sampler is None: if not distributed_sampler: @@ -136,3 +163,34 @@ def __init__( self.kwargs.update({"collate_fn": lambda x: x}) super().__init__(self.dataset, **self.kwargs) + + def resample_labels(self): + """Resamples the labeled data.""" + self.kwargs.pop("batch_size", None) + self.kwargs.pop("shuffle", None) + self.kwargs.pop("sampler", None) + self.kwargs.pop("collate_fn", None) + AnnDataLoader( + self.adata_manager, + indices=self.subsample_labels(), + shuffle=self._shuffle, + batch_size=self._batch_size, + data_and_attributes=self.data_and_attributes, + drop_last=self._drop_last, + **self.kwargs, + ) + + def subsample_labels(self): + """Subsamples each label class by taking up to n_samples_per_label samples per class.""" + if self.n_samples_per_label is None: + return np.concatenate(self.labeled_locs) + + sample_idx = [] + for loc in self.labeled_locs: + if len(loc) < self.n_samples_per_label: + sample_idx.append(loc) + else: + label_subset = np.random.choice(loc, self.n_samples_per_label, replace=False) + sample_idx.append(label_subset) + sample_idx = np.concatenate(sample_idx) + return sample_idx diff --git a/src/scvi/dataloaders/_concat_dataloader.py b/src/scvi/dataloaders/_concat_dataloader.py index fdcdea4aa8..7242d49108 100644 --- a/src/scvi/dataloaders/_concat_dataloader.py +++ b/src/scvi/dataloaders/_concat_dataloader.py @@ -48,7 +48,7 @@ def __init__( **data_loader_kwargs, ): self.adata_manager = adata_manager - self.dataloader_kwargs = data_loader_kwargs + self.data_loader_kwargs = data_loader_kwargs self.data_and_attributes = data_and_attributes self._shuffle = shuffle self._batch_size = batch_size @@ -57,6 +57,7 @@ def __init__( self.dataloaders = [] for indices in indices_list: + self.data_loader_kwargs.pop("sampler", None) self.dataloaders.append( AnnDataLoader( adata_manager, @@ -66,11 +67,12 @@ def __init__( data_and_attributes=data_and_attributes, drop_last=drop_last, distributed_sampler=distributed_sampler, - **self.dataloader_kwargs, + **self.data_loader_kwargs, ) ) lens = [len(dl) for dl in self.dataloaders] self.largest_dl = self.dataloaders[np.argmax(lens)] + self.data_loader_kwargs.pop("drop_dataset_tail", None) super().__init__(self.largest_dl, **data_loader_kwargs) def __len__(self): @@ -83,5 +85,10 @@ def __iter__(self): the data in the other dataloaders. The order of data in returned iter_list is the same as indices_list. """ - iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders] - return zip(*iter_list, strict=True) + if not self._distributed_sampler: + iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders] + strict = True + else: + iter_list = self.dataloaders + strict = False + return zip(*iter_list, strict=strict) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 9ea0146acb..371de8fab8 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -516,11 +516,16 @@ def setup(self, stage: str | None = None): self.test_idx = indices_test.astype(int) if len(self._labeled_indices) != 0: - self.data_loader_class = SemiSupervisedDataLoader + if len(self._unlabeled_indices) != 0: + self.data_loader_class = SemiSupervisedDataLoader + else: + # data is all labeled + self.data_loader_class = AnnDataLoader dl_kwargs = { "n_samples_per_label": self.n_samples_per_label, } else: + # data is all unlabeled self.data_loader_class = AnnDataLoader dl_kwargs = {} diff --git a/src/scvi/dataloaders/_semi_dataloader.py b/src/scvi/dataloaders/_semi_dataloader.py index 545f3c6d9b..6e8b5d937f 100644 --- a/src/scvi/dataloaders/_semi_dataloader.py +++ b/src/scvi/dataloaders/_semi_dataloader.py @@ -53,6 +53,7 @@ def __init__( return None self.n_samples_per_label = n_samples_per_label + self.data_loader_kwargs = data_loader_kwargs labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) labels = get_anndata_attribute( @@ -77,7 +78,7 @@ def __init__( batch_size=batch_size, data_and_attributes=data_and_attributes, drop_last=drop_last, - **data_loader_kwargs, + **self.data_loader_kwargs, ) def resample_labels(self): @@ -93,6 +94,7 @@ def resample_labels(self): batch_size=self._batch_size, data_and_attributes=self.data_and_attributes, drop_last=self._drop_last, + **self.data_loader_kwargs, ) def subsample_labels(self): diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 72d9e8f70a..5a43075754 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -24,7 +24,7 @@ NumericalJointObsField, NumericalObsField, ) -from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter +from scvi.dataloaders import SemiSupervisedDataSplitter from scvi.model._utils import _init_library_size, get_max_epochs_heuristic, use_distributed_sampler from scvi.module import SCANVAE from scvi.train import SemiSupervisedTrainingPlan, TrainRunner @@ -408,43 +408,21 @@ def train( # if we have labeled cells, we want to subsample labels each epoch sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else [] - # Create the SCVI regular dataplitter in case of multigpu plausioble - if (len(self._unlabeled_indices) == 0) and use_distributed_sampler( - trainer_kwargs.get("strategy", None) - ): - # we are in a multigpu env and its all labeled - n_samples_per_label is not used here - # also I remvoed the option for load_sparse_tensor it can be reached from kwargs - # in this way we bypass the use of concatdataloaders for now - run_multi_gpu = True - data_splitter = DataSplitter( - self.adata_manager, - train_size=train_size, - validation_size=validation_size, - batch_size=batch_size, - shuffle_set_split=shuffle_set_split, - distributed_sampler=True, - **datasplitter_kwargs, - ) - else: - # what we had so far in scanvi (concat dataloaders) - run_multi_gpu = False - data_splitter = SemiSupervisedDataSplitter( - adata_manager=self.adata_manager, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - n_samples_per_label=n_samples_per_label, - batch_size=batch_size, - **datasplitter_kwargs, - ) - + data_splitter = SemiSupervisedDataSplitter( + adata_manager=self.adata_manager, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + n_samples_per_label=n_samples_per_label, + distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)), + batch_size=batch_size, + **datasplitter_kwargs, + ) training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs) - if not run_multi_gpu: - # TODO: how do we generate subsamples per class in multigpu? - if "callbacks" in trainer_kwargs.keys(): - trainer_kwargs["callbacks"] + [sampler_callback] - else: - trainer_kwargs["callbacks"] = sampler_callback + if "callbacks" in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] + [sampler_callback] + else: + trainer_kwargs["callbacks"] = sampler_callback runner = TrainRunner( self, diff --git a/src/scvi/model/_utils.py b/src/scvi/model/_utils.py index 20dafdb1da..51db4f098e 100644 --- a/src/scvi/model/_utils.py +++ b/src/scvi/model/_utils.py @@ -29,6 +29,7 @@ def use_distributed_sampler(strategy: str | Strategy) -> bool: """ if isinstance(strategy, str): # ["ddp", "ddp_spawn", "ddp_find_unused_parameters_true"] + # ["ddp_notebook","ddp_notebook_find_unused_parameters_true"] - for jupyter nb run return "ddp" in strategy return isinstance(strategy, DDPStrategy) From fa0a77ff4f3bc7b126cccae1a938ec8aa4182ddb Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Sun, 2 Feb 2025 17:13:01 +0200 Subject: [PATCH 11/16] limit the anndataloader fix to scanvi only (other models will need more work) --- src/scvi/dataloaders/_ann_dataloader.py | 31 +++++++++++++------------ 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 1673c8f260..14a71c6657 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -119,21 +119,22 @@ def __init__( if sampler is not None and distributed_sampler: raise ValueError("Cannot specify both `sampler` and `distributed_sampler`.") - # Next block of code is for the case of fully labeled anndataloder - labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) - labels = get_anndata_attribute( - adata_manager.adata, - adata_manager.data_registry.labels.attr_name, - labels_state_registry.original_key, - ).ravel() - if hasattr(labels_state_registry, "unlabeled_category"): - # save a nested list of the indices per labeled category (if exists) - self.labeled_locs = [] - for label in np.unique(labels): - if label != labels_state_registry.unlabeled_category: - label_loc_idx = np.where(labels[indices] == label)[0] - label_loc = self.indices[label_loc_idx] - self.labeled_locs.append(label_loc) + # Next block of code is for the case of labeled anndataloder used in scanvi multigpu: + if adata_manager.registry["model_name"] == "SCANVI": + labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + labels = get_anndata_attribute( + adata_manager.adata, + adata_manager.data_registry.labels.attr_name, + labels_state_registry.original_key, + ).ravel() + if hasattr(labels_state_registry, "unlabeled_category"): + # save a nested list of the indices per labeled category (if exists) + self.labeled_locs = [] + for label in np.unique(labels): + if label != labels_state_registry.unlabeled_category: + label_loc_idx = np.where(labels[indices] == label)[0] + label_loc = self.indices[label_loc_idx] + self.labeled_locs.append(label_loc) # custom sampler for efficient minibatching on sparse matrices if sampler is None: From 1c2ea3c582b031686858ded8a5825c4ad907d6d0 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 4 Feb 2025 00:41:44 +0200 Subject: [PATCH 12/16] cleaning code --- docs/tutorials/notebooks | 2 +- src/scvi/dataloaders/_ann_dataloader.py | 39 +++--------- src/scvi/dataloaders/_concat_dataloader.py | 2 +- src/scvi/dataloaders/_semi_dataloader.py | 72 ++++++++++++---------- tests/dataloaders/test_dataloaders.py | 2 +- 5 files changed, 51 insertions(+), 66 deletions(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index a93b59ad4d..feb577cecc 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit a93b59ad4d1e38477ab0bce4890a6ba7587bf4d0 +Subproject commit feb577cecc2f1873027724a4e56af2b6ff31921e diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 14a71c6657..993429eef9 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -10,11 +10,11 @@ SequentialSampler, ) -from scvi import REGISTRY_KEYS, settings +from scvi import settings from scvi.data import AnnDataManager -from scvi.data._utils import get_anndata_attribute from ._samplers import BatchDistributedSampler +from ._semi_dataloader import labelled_indices_generator, subsample_labels logger = logging.getLogger(__name__) @@ -120,21 +120,11 @@ def __init__( raise ValueError("Cannot specify both `sampler` and `distributed_sampler`.") # Next block of code is for the case of labeled anndataloder used in scanvi multigpu: + self.labeled_locs, labelled_idx = [], [] if adata_manager.registry["model_name"] == "SCANVI": - labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) - labels = get_anndata_attribute( - adata_manager.adata, - adata_manager.data_registry.labels.attr_name, - labels_state_registry.original_key, - ).ravel() - if hasattr(labels_state_registry, "unlabeled_category"): - # save a nested list of the indices per labeled category (if exists) - self.labeled_locs = [] - for label in np.unique(labels): - if label != labels_state_registry.unlabeled_category: - label_loc_idx = np.where(labels[indices] == label)[0] - label_loc = self.indices[label_loc_idx] - self.labeled_locs.append(label_loc) + self.labeled_locs, labelled_idx = labelled_indices_generator( + adata_manager, indices, self.indices, self.n_samples_per_label + ) # custom sampler for efficient minibatching on sparse matrices if sampler is None: @@ -173,25 +163,10 @@ def resample_labels(self): self.kwargs.pop("collate_fn", None) AnnDataLoader( self.adata_manager, - indices=self.subsample_labels(), + indices=subsample_labels(self.labeled_locs, self.n_samples_per_label), shuffle=self._shuffle, batch_size=self._batch_size, data_and_attributes=self.data_and_attributes, drop_last=self._drop_last, **self.kwargs, ) - - def subsample_labels(self): - """Subsamples each label class by taking up to n_samples_per_label samples per class.""" - if self.n_samples_per_label is None: - return np.concatenate(self.labeled_locs) - - sample_idx = [] - for loc in self.labeled_locs: - if len(loc) < self.n_samples_per_label: - sample_idx.append(loc) - else: - label_subset = np.random.choice(loc, self.n_samples_per_label, replace=False) - sample_idx.append(label_subset) - sample_idx = np.concatenate(sample_idx) - return sample_idx diff --git a/src/scvi/dataloaders/_concat_dataloader.py b/src/scvi/dataloaders/_concat_dataloader.py index 7242d49108..5553627d41 100644 --- a/src/scvi/dataloaders/_concat_dataloader.py +++ b/src/scvi/dataloaders/_concat_dataloader.py @@ -73,7 +73,7 @@ def __init__( lens = [len(dl) for dl in self.dataloaders] self.largest_dl = self.dataloaders[np.argmax(lens)] self.data_loader_kwargs.pop("drop_dataset_tail", None) - super().__init__(self.largest_dl, **data_loader_kwargs) + super().__init__(self.largest_dl, **self.data_loader_kwargs) def __len__(self): return len(self.largest_dl) diff --git a/src/scvi/dataloaders/_semi_dataloader.py b/src/scvi/dataloaders/_semi_dataloader.py index 6e8b5d937f..a09bf79356 100644 --- a/src/scvi/dataloaders/_semi_dataloader.py +++ b/src/scvi/dataloaders/_semi_dataloader.py @@ -8,6 +8,43 @@ from ._concat_dataloader import ConcatDataLoader +def subsample_labels(labeled_locs, n_samples_per_label): + """Subsamples each label class by taking up to n_samples_per_label samples per class.""" + if n_samples_per_label is None: + return np.concatenate(labeled_locs) + + sample_idx = [] + for loc in labeled_locs: + if len(loc) < n_samples_per_label: + sample_idx.append(loc) + else: + label_subset = np.random.choice(loc, n_samples_per_label, replace=False) + sample_idx.append(label_subset) + sample_idx = np.concatenate(sample_idx) + return sample_idx + + +def labelled_indices_generator(adata_manager, indices, indices_asarray, n_samples_per_label): + # Next block of code is for the case of labeled anndataloder used in scanvi multigpu: + labelled_idx = [] + labeled_locs = [] + labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + labels = get_anndata_attribute( + adata_manager.adata, + adata_manager.data_registry.labels.attr_name, + labels_state_registry.original_key, + ).ravel() + if hasattr(labels_state_registry, "unlabeled_category"): + # save a nested list of the indices per labeled category (if exists) + for label in np.unique(labels): + if label != labels_state_registry.unlabeled_category: + label_loc_idx = np.where(labels[indices] == label)[0] + label_loc = indices_asarray[label_loc_idx] + labeled_locs.append(label_loc) + labelled_idx = subsample_labels(labeled_locs, n_samples_per_label) + return labeled_locs, labelled_idx + + class SemiSupervisedDataLoader(ConcatDataLoader): """DataLoader that supports semisupervised training. @@ -55,21 +92,9 @@ def __init__( self.n_samples_per_label = n_samples_per_label self.data_loader_kwargs = data_loader_kwargs - labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) - labels = get_anndata_attribute( - adata_manager.adata, - adata_manager.data_registry.labels.attr_name, - labels_state_registry.original_key, - ).ravel() - - # save a nested list of the indices per labeled category - self.labeled_locs = [] - for label in np.unique(labels): - if label != labels_state_registry.unlabeled_category: - label_loc_idx = np.where(labels[indices] == label)[0] - label_loc = self.indices[label_loc_idx] - self.labeled_locs.append(label_loc) - labelled_idx = self.subsample_labels() + self.labeled_locs, labelled_idx = labelled_indices_generator( + adata_manager, indices, self.indices, self.n_samples_per_label + ) super().__init__( adata_manager=adata_manager, @@ -83,7 +108,7 @@ def __init__( def resample_labels(self): """Resamples the labeled data.""" - labelled_idx = self.subsample_labels() + labelled_idx = subsample_labels(self.labeled_locs, self.n_samples_per_label) # self.dataloaders[0] iterates over full_indices # self.dataloaders[1] iterates over the labelled_indices # change the indices of the labelled set @@ -96,18 +121,3 @@ def resample_labels(self): drop_last=self._drop_last, **self.data_loader_kwargs, ) - - def subsample_labels(self): - """Subsamples each label class by taking up to n_samples_per_label samples per class.""" - if self.n_samples_per_label is None: - return np.concatenate(self.labeled_locs) - - sample_idx = [] - for loc in self.labeled_locs: - if len(loc) < self.n_samples_per_label: - sample_idx.append(loc) - else: - label_subset = np.random.choice(loc, self.n_samples_per_label, replace=False) - sample_idx.append(label_subset) - sample_idx = np.concatenate(sample_idx) - return sample_idx diff --git a/tests/dataloaders/test_dataloaders.py b/tests/dataloaders/test_dataloaders.py index a224b65d8e..c1e96c6786 100644 --- a/tests/dataloaders/test_dataloaders.py +++ b/tests/dataloaders/test_dataloaders.py @@ -134,7 +134,7 @@ def test_anndataloader_distributed_sampler(num_processes: int, save_path: str): @pytest.mark.multigpu -@pytest.mark.parametrize("num_processes", [1]) +@pytest.mark.parametrize("num_processes", [1, 2]) def test_scanvi_with_distributed_sampler(num_processes: int, save_path: str): adata = scvi.data.synthetic_iid() SCANVI.setup_anndata( From 60de4756a7674b8f167b41c1226fdebf622af3f7 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 4 Feb 2025 00:52:41 +0200 Subject: [PATCH 13/16] fix circular bug --- src/scvi/dataloaders/_ann_dataloader.py | 42 ++++++++++++++++++++++-- src/scvi/dataloaders/_semi_dataloader.py | 41 +---------------------- 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 993429eef9..83669d1eab 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -10,15 +10,52 @@ SequentialSampler, ) -from scvi import settings +from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager +from scvi.data._utils import get_anndata_attribute from ._samplers import BatchDistributedSampler -from ._semi_dataloader import labelled_indices_generator, subsample_labels logger = logging.getLogger(__name__) +def subsample_labels(labeled_locs, n_samples_per_label): + """Subsamples each label class by taking up to n_samples_per_label samples per class.""" + if n_samples_per_label is None: + return np.concatenate(labeled_locs) + + sample_idx = [] + for loc in labeled_locs: + if len(loc) < n_samples_per_label: + sample_idx.append(loc) + else: + label_subset = np.random.choice(loc, n_samples_per_label, replace=False) + sample_idx.append(label_subset) + sample_idx = np.concatenate(sample_idx) + return sample_idx + + +def labelled_indices_generator(adata_manager, indices, indices_asarray, n_samples_per_label): + """Generates indices for each label class""" + labelled_idx = [] + labeled_locs = [] + labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + labels = get_anndata_attribute( + adata_manager.adata, + adata_manager.data_registry.labels.attr_name, + labels_state_registry.original_key, + ).ravel() + if hasattr(labels_state_registry, "unlabeled_category"): + # save a nested list of the indices per labeled category (if exists) + for label in np.unique(labels): + if label != labels_state_registry.unlabeled_category: + label_loc_idx = np.where(labels[indices] == label)[0] + label_loc = indices_asarray[label_loc_idx] + labeled_locs.append(label_loc) + labelled_idx = subsample_labels(labeled_locs, n_samples_per_label) + return labeled_locs, labelled_idx + + class AnnDataLoader(DataLoader): """DataLoader for loading tensors from AnnData objects. @@ -122,6 +159,7 @@ def __init__( # Next block of code is for the case of labeled anndataloder used in scanvi multigpu: self.labeled_locs, labelled_idx = [], [] if adata_manager.registry["model_name"] == "SCANVI": + # Next block of code is for the case of labeled anndataloder used in scanvi multigpu: self.labeled_locs, labelled_idx = labelled_indices_generator( adata_manager, indices, self.indices, self.n_samples_per_label ) diff --git a/src/scvi/dataloaders/_semi_dataloader.py b/src/scvi/dataloaders/_semi_dataloader.py index a09bf79356..2241fb82df 100644 --- a/src/scvi/dataloaders/_semi_dataloader.py +++ b/src/scvi/dataloaders/_semi_dataloader.py @@ -1,50 +1,11 @@ import numpy as np -from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data._utils import get_anndata_attribute -from ._ann_dataloader import AnnDataLoader +from ._ann_dataloader import AnnDataLoader, labelled_indices_generator, subsample_labels from ._concat_dataloader import ConcatDataLoader -def subsample_labels(labeled_locs, n_samples_per_label): - """Subsamples each label class by taking up to n_samples_per_label samples per class.""" - if n_samples_per_label is None: - return np.concatenate(labeled_locs) - - sample_idx = [] - for loc in labeled_locs: - if len(loc) < n_samples_per_label: - sample_idx.append(loc) - else: - label_subset = np.random.choice(loc, n_samples_per_label, replace=False) - sample_idx.append(label_subset) - sample_idx = np.concatenate(sample_idx) - return sample_idx - - -def labelled_indices_generator(adata_manager, indices, indices_asarray, n_samples_per_label): - # Next block of code is for the case of labeled anndataloder used in scanvi multigpu: - labelled_idx = [] - labeled_locs = [] - labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) - labels = get_anndata_attribute( - adata_manager.adata, - adata_manager.data_registry.labels.attr_name, - labels_state_registry.original_key, - ).ravel() - if hasattr(labels_state_registry, "unlabeled_category"): - # save a nested list of the indices per labeled category (if exists) - for label in np.unique(labels): - if label != labels_state_registry.unlabeled_category: - label_loc_idx = np.where(labels[indices] == label)[0] - label_loc = indices_asarray[label_loc_idx] - labeled_locs.append(label_loc) - labelled_idx = subsample_labels(labeled_locs, n_samples_per_label) - return labeled_locs, labelled_idx - - class SemiSupervisedDataLoader(ConcatDataLoader): """DataLoader that supports semisupervised training. From 9c55311e9565640a1ec23ac5667fdc8016ef5e89 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Tue, 4 Feb 2025 01:12:25 +0200 Subject: [PATCH 14/16] fix test --- src/scvi/dataloaders/_ann_dataloader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/scvi/dataloaders/_ann_dataloader.py b/src/scvi/dataloaders/_ann_dataloader.py index 83669d1eab..1105638db5 100644 --- a/src/scvi/dataloaders/_ann_dataloader.py +++ b/src/scvi/dataloaders/_ann_dataloader.py @@ -22,7 +22,10 @@ def subsample_labels(labeled_locs, n_samples_per_label): """Subsamples each label class by taking up to n_samples_per_label samples per class.""" if n_samples_per_label is None: - return np.concatenate(labeled_locs) + if len(labeled_locs) == 0: + return labeled_locs + else: + return np.concatenate(labeled_locs) sample_idx = [] for loc in labeled_locs: From 28ef89b4b2bb788637a2e1847d34719b87ae67ec Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Tue, 4 Feb 2025 09:15:50 +0200 Subject: [PATCH 15/16] update notebooks link (?)) --- docs/tutorials/notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index feb577cecc..4f851184e3 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit feb577cecc2f1873027724a4e56af2b6ff31921e +Subproject commit 4f851184e3db1665c8ffa6295c111418d32cc59f From dd7547652c1d9dda469e41b32ec1c48ad98eb1c3 Mon Sep 17 00:00:00 2001 From: Ori Kronfeld Date: Tue, 4 Feb 2025 18:00:07 +0200 Subject: [PATCH 16/16] validate more models to run with multi GPU: condscvi, linearscvi, peakvi, totalvi and multivi (without early stopping callbacks for those last3) --- CHANGELOG.md | 5 +- src/scvi/model/_multivi.py | 2 + src/scvi/model/_totalvi.py | 2 + tests/model/test_multigpu.py | 149 ++++++++++++++++++++++++++++++++--- 4 files changed, 146 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d458b6755c..e5d9d01958 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,10 @@ to [Semantic Versioning]. Full commit history is available in the - Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`. -- Add support for {class}`~scvi.model.SCANVI` multiGPU training {pr}`3125`. +- Add support for {class}`~scvi.model.SCANVI`, {class}`~scvi.model.CondSCVI` and + {class}`~scvi.model.LinearSCVI` multiGPU training {pr}`3125`. Also added this support for + {class}`~scvi.model.TOTALVI`, {class}`~scvi.model.MULTIVI` and {class}`~scvi.model.PEAKVI` + but need to disable early_stopping first in order to use multiGPU for those models. - Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial transcriptomics {pr}`3144`. diff --git a/src/scvi/model/_multivi.py b/src/scvi/model/_multivi.py index 61ffb04556..3a657b1eac 100644 --- a/src/scvi/model/_multivi.py +++ b/src/scvi/model/_multivi.py @@ -27,6 +27,7 @@ _get_batch_code_from_category, scatac_raw_counts_properties, scrna_raw_counts_properties, + use_distributed_sampler, ) from scvi.model.base import ( ArchesMixin, @@ -348,6 +349,7 @@ def train( train_size=train_size, validation_size=validation_size, shuffle_set_split=shuffle_set_split, + distributed_sampler=use_distributed_sampler(kwargs.get("strategy", None)), batch_size=batch_size, **datasplitter_kwargs, ) diff --git a/src/scvi/model/_totalvi.py b/src/scvi/model/_totalvi.py index d23c533da6..46af6ac7f5 100644 --- a/src/scvi/model/_totalvi.py +++ b/src/scvi/model/_totalvi.py @@ -22,6 +22,7 @@ _init_library_size, cite_seq_raw_counts_properties, get_max_epochs_heuristic, + use_distributed_sampler, ) from scvi.model.base._de_core import _de_core from scvi.module import TOTALVAE @@ -322,6 +323,7 @@ def train( validation_size=validation_size, shuffle_set_split=shuffle_set_split, batch_size=batch_size, + distributed_sampler=use_distributed_sampler(kwargs.get("strategy", None)), external_indexing=external_indexing, **datasplitter_kwargs, ) diff --git a/tests/model/test_multigpu.py b/tests/model/test_multigpu.py index 0e930d06ea..0f1bf0b2a3 100644 --- a/tests/model/test_multigpu.py +++ b/tests/model/test_multigpu.py @@ -3,6 +3,10 @@ import pytest import torch +from mudata import MuData + +import scvi +from scvi.model import MULTIVI, PEAKVI, TOTALVI, CondSCVI, LinearSCVI @pytest.mark.multigpu @@ -32,7 +36,6 @@ def test_scanvi_from_scvi_multigpu(): strategy="ddp_find_unused_parameters_true", ) print("done") - torch.distributed.destroy_process_group() assert model.is_trained adata.obsm["scVI"] = model.get_latent_representation() @@ -62,8 +65,6 @@ def test_scanvi_from_scvi_multigpu(): print("done") adata.obsm["scANVI"] = model_scanvi.get_latent_representation() - torch.distributed.destroy_process_group() - assert model_scanvi.is_trained @@ -101,13 +102,143 @@ def test_scanvi_from_scratch_multigpu(): ) print("done") - torch.distributed.destroy_process_group() - assert model.is_trained @pytest.mark.multigpu -def test_scvi_train_ddp(save_path: str = "."): +def test_totalvi_multigpu(): + adata = scvi.data.synthetic_iid() + protein_adata = scvi.data.synthetic_iid(n_genes=50) + mdata = MuData({"rna": adata, "protein": protein_adata}) + # TOTALVI.setup_anndata( + # adata, + # batch_key="batch", + # protein_expression_obsm_key="protein_expression", + # protein_names_uns_key="protein_names", + # ) + TOTALVI.setup_mudata( + mdata, + batch_key="batch", + modalities={"rna_layer": "rna", "batch_key": "rna", "protein_layer": "protein"}, + ) + n_latent = 10 + model = TOTALVI(mdata, n_latent=n_latent) + model.train( + 100, + train_size=0.5, + early_stopping=False, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + ) + assert model.is_trained is True + + +@pytest.mark.multigpu +def test_multivi_multigpu(): + # adata = scvi.data.synthetic_iid() + mdata = scvi.data.synthetic_iid(return_mudata=True) + # MULTIVI.setup_anndata( + # adata, + # batch_key="batch", + # ) + MULTIVI.setup_mudata( + mdata, + batch_key="batch", + modalities={ + "rna_layer": "rna", + "protein_layer": "protein_expression", + "atac_layer": "accessibility", + }, + ) + n_latent = 10 + model = MULTIVI( + mdata, + n_latent=n_latent, + n_genes=50, + n_regions=50, + ) + model.train( + 100, + train_size=0.5, + early_stopping=False, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + ) + assert model.is_trained is True + + +@pytest.mark.multigpu +def test_peakvi_multigpu(): + adata = scvi.data.synthetic_iid() + PEAKVI.setup_anndata( + adata, + batch_key="batch", + ) + + model = PEAKVI( + adata, + model_depth=False, + ) + + model.train( + max_epochs=100, + train_size=0.5, + check_val_every_n_epoch=1, + early_stopping=False, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + save_best=False, + ) + assert model.is_trained + + +@pytest.mark.multigpu +def test_condscvi_multigpu(): + adata = scvi.data.synthetic_iid() + adata.obs["overclustering_vamp"] = list(range(adata.n_obs)) + CondSCVI.setup_anndata( + adata, + labels_key="labels", + ) + + model = CondSCVI(adata) + + model.train( + max_epochs=100, + train_size=0.9, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + ) + assert model.is_trained + + +@pytest.mark.multigpu +def test_linearcvi_multigpu(): + adata = scvi.data.synthetic_iid() + adata = adata[:, :10].copy() + LinearSCVI.setup_anndata(adata) + model = LinearSCVI(adata, n_latent=10) + + model.train( + max_epochs=100, + train_size=0.5, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + ) + assert model.is_trained + + +@pytest.mark.multigpu +def test_scvi_train_ddp(save_path: str): training_code = """ import torch import scvi @@ -126,8 +257,6 @@ def test_scvi_train_ddp(save_path: str = "."): strategy="ddp_find_unused_parameters_true", ) -torch.distributed.destroy_process_group() - assert model.is_trained """ # Define the file path for the temporary script in the current working directory @@ -160,7 +289,7 @@ def launch_ddp(world_size, temp_file_path): @pytest.mark.multigpu -def test_scanvi_train_ddp(save_path: str = "."): +def test_scanvi_train_ddp(save_path: str): training_code = """ import torch import scvi @@ -190,8 +319,6 @@ def test_scanvi_train_ddp(save_path: str = "."): datasplitter_kwargs=datasplitter_kwargs, ) -torch.distributed.destroy_process_group() - assert model.is_trained """ # Define the file path for the temporary script in the current working directory