Skip to content

Commit

Permalink
re think the scanvi multi gpu problem: adding support for anndatalode…
Browse files Browse the repository at this point in the history
…r to deal with labels as well as resample them as scanvi needs in the labelled case. adding the support for scanvi model distributed sampler. reuplaod the test and added a changlog. also validated tutorials for biological correctness.
  • Loading branch information
ori-kron-wis committed Feb 2, 2025
1 parent a97efc7 commit 8530004
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 45 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ to [Semantic Versioning]. Full commit history is available in the

- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable
representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`.
- Add support for {class}`~scvi.model.SCANVI` multiGPU training, only for the full labeled case {pr}`3125`.
- Add support for {class}`~scvi.model.SCANVI` multiGPU training {pr}`3125`.
- Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial
transcriptomics {pr}`3144`.

Expand Down
60 changes: 59 additions & 1 deletion src/scvi/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
SequentialSampler,
)

from scvi import settings
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._utils import get_anndata_attribute

from ._samplers import BatchDistributedSampler

Expand Down Expand Up @@ -56,6 +57,8 @@ class AnnDataLoader(DataLoader):
distributed_sampler
``EXPERIMENTAL`` Whether to use :class:`~scvi.dataloaders.BatchDistributedSampler` as the
sampler. If `True`, `sampler` must be `None`.
n_samples_per_label
Number of subsamples for each label class to sample per epoch
load_sparse_tensor
``EXPERIMENTAL`` If ``True``, loads data with sparse CSR or CSC layout as a
:class:`~torch.Tensor` with the same layout. Can lead to speedups in data transfers to
Expand Down Expand Up @@ -83,6 +86,7 @@ def __init__(
data_and_attributes: list[str] | dict[str, np.dtype] | None = None,
iter_ndarray: bool = False,
distributed_sampler: bool = False,
n_samples_per_label: int | None = None,
load_sparse_tensor: bool = False,
**kwargs,
):
Expand All @@ -93,6 +97,7 @@ def __init__(
indices = np.where(indices)[0].ravel()
indices = np.asarray(indices)
self.indices = indices
self.n_samples_per_label = n_samples_per_label
self.dataset = adata_manager.create_torch_dataset(
indices=indices,
data_and_attributes=data_and_attributes,
Expand All @@ -104,10 +109,32 @@ def __init__(
kwargs["persistent_workers"] = settings.dl_persistent_workers

self.kwargs = copy.deepcopy(kwargs)
self.adata_manager = adata_manager
self.data_and_attributes = data_and_attributes
self._shuffle = shuffle
self._batch_size = batch_size
self._drop_last = drop_last
self.load_sparse_tensor = load_sparse_tensor

if sampler is not None and distributed_sampler:
raise ValueError("Cannot specify both `sampler` and `distributed_sampler`.")

# Next block of code is for the case of fully labeled anndataloder
labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)
labels = get_anndata_attribute(
adata_manager.adata,
adata_manager.data_registry.labels.attr_name,
labels_state_registry.original_key,
).ravel()
if hasattr(labels_state_registry, "unlabeled_category"):
# save a nested list of the indices per labeled category (if exists)
self.labeled_locs = []
for label in np.unique(labels):
if label != labels_state_registry.unlabeled_category:
label_loc_idx = np.where(labels[indices] == label)[0]
label_loc = self.indices[label_loc_idx]
self.labeled_locs.append(label_loc)

# custom sampler for efficient minibatching on sparse matrices
if sampler is None:
if not distributed_sampler:
Expand Down Expand Up @@ -136,3 +163,34 @@ def __init__(
self.kwargs.update({"collate_fn": lambda x: x})

super().__init__(self.dataset, **self.kwargs)

def resample_labels(self):
"""Resamples the labeled data."""
self.kwargs.pop("batch_size", None)
self.kwargs.pop("shuffle", None)
self.kwargs.pop("sampler", None)
self.kwargs.pop("collate_fn", None)
AnnDataLoader(
self.adata_manager,
indices=self.subsample_labels(),
shuffle=self._shuffle,
batch_size=self._batch_size,
data_and_attributes=self.data_and_attributes,
drop_last=self._drop_last,
**self.kwargs,
)

def subsample_labels(self):
"""Subsamples each label class by taking up to n_samples_per_label samples per class."""
if self.n_samples_per_label is None:
return np.concatenate(self.labeled_locs)

sample_idx = []
for loc in self.labeled_locs:
if len(loc) < self.n_samples_per_label:
sample_idx.append(loc)
else:
label_subset = np.random.choice(loc, self.n_samples_per_label, replace=False)
sample_idx.append(label_subset)
sample_idx = np.concatenate(sample_idx)
return sample_idx
15 changes: 11 additions & 4 deletions src/scvi/dataloaders/_concat_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
**data_loader_kwargs,
):
self.adata_manager = adata_manager
self.dataloader_kwargs = data_loader_kwargs
self.data_loader_kwargs = data_loader_kwargs
self.data_and_attributes = data_and_attributes
self._shuffle = shuffle
self._batch_size = batch_size
Expand All @@ -57,6 +57,7 @@ def __init__(

self.dataloaders = []
for indices in indices_list:
self.data_loader_kwargs.pop("sampler", None)
self.dataloaders.append(
AnnDataLoader(
adata_manager,
Expand All @@ -66,11 +67,12 @@ def __init__(
data_and_attributes=data_and_attributes,
drop_last=drop_last,
distributed_sampler=distributed_sampler,
**self.dataloader_kwargs,
**self.data_loader_kwargs,
)
)
lens = [len(dl) for dl in self.dataloaders]
self.largest_dl = self.dataloaders[np.argmax(lens)]
self.data_loader_kwargs.pop("drop_dataset_tail", None)
super().__init__(self.largest_dl, **data_loader_kwargs)

def __len__(self):
Expand All @@ -83,5 +85,10 @@ def __iter__(self):
the data in the other dataloaders. The order of data in returned iter_list
is the same as indices_list.
"""
iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders]
return zip(*iter_list, strict=True)
if not self._distributed_sampler:
iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders]
strict = True
else:
iter_list = self.dataloaders
strict = False
return zip(*iter_list, strict=strict)
7 changes: 6 additions & 1 deletion src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,11 +516,16 @@ def setup(self, stage: str | None = None):
self.test_idx = indices_test.astype(int)

if len(self._labeled_indices) != 0:
self.data_loader_class = SemiSupervisedDataLoader
if len(self._unlabeled_indices) != 0:
self.data_loader_class = SemiSupervisedDataLoader
else:
# data is all labeled
self.data_loader_class = AnnDataLoader
dl_kwargs = {
"n_samples_per_label": self.n_samples_per_label,
}
else:
# data is all unlabeled
self.data_loader_class = AnnDataLoader
dl_kwargs = {}

Expand Down
4 changes: 3 additions & 1 deletion src/scvi/dataloaders/_semi_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
return None

self.n_samples_per_label = n_samples_per_label
self.data_loader_kwargs = data_loader_kwargs

labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)
labels = get_anndata_attribute(
Expand All @@ -77,7 +78,7 @@ def __init__(
batch_size=batch_size,
data_and_attributes=data_and_attributes,
drop_last=drop_last,
**data_loader_kwargs,
**self.data_loader_kwargs,
)

def resample_labels(self):
Expand All @@ -93,6 +94,7 @@ def resample_labels(self):
batch_size=self._batch_size,
data_and_attributes=self.data_and_attributes,
drop_last=self._drop_last,
**self.data_loader_kwargs,
)

def subsample_labels(self):
Expand Down
52 changes: 15 additions & 37 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
NumericalJointObsField,
NumericalObsField,
)
from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter
from scvi.dataloaders import SemiSupervisedDataSplitter
from scvi.model._utils import _init_library_size, get_max_epochs_heuristic, use_distributed_sampler
from scvi.module import SCANVAE
from scvi.train import SemiSupervisedTrainingPlan, TrainRunner
Expand Down Expand Up @@ -408,43 +408,21 @@ def train(
# if we have labeled cells, we want to subsample labels each epoch
sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else []

# Create the SCVI regular dataplitter in case of multigpu plausioble
if (len(self._unlabeled_indices) == 0) and use_distributed_sampler(
trainer_kwargs.get("strategy", None)
):
# we are in a multigpu env and its all labeled - n_samples_per_label is not used here
# also I remvoed the option for load_sparse_tensor it can be reached from kwargs
# in this way we bypass the use of concatdataloaders for now
run_multi_gpu = True
data_splitter = DataSplitter(
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
shuffle_set_split=shuffle_set_split,
distributed_sampler=True,
**datasplitter_kwargs,
)
else:
# what we had so far in scanvi (concat dataloaders)
run_multi_gpu = False
data_splitter = SemiSupervisedDataSplitter(
adata_manager=self.adata_manager,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
n_samples_per_label=n_samples_per_label,
batch_size=batch_size,
**datasplitter_kwargs,
)

data_splitter = SemiSupervisedDataSplitter(
adata_manager=self.adata_manager,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
n_samples_per_label=n_samples_per_label,
distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)),
batch_size=batch_size,
**datasplitter_kwargs,
)
training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs)
if not run_multi_gpu:
# TODO: how do we generate subsamples per class in multigpu?
if "callbacks" in trainer_kwargs.keys():
trainer_kwargs["callbacks"] + [sampler_callback]
else:
trainer_kwargs["callbacks"] = sampler_callback
if "callbacks" in trainer_kwargs.keys():
trainer_kwargs["callbacks"] + [sampler_callback]
else:
trainer_kwargs["callbacks"] = sampler_callback

runner = TrainRunner(
self,
Expand Down
1 change: 1 addition & 0 deletions src/scvi/model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def use_distributed_sampler(strategy: str | Strategy) -> bool:
"""
if isinstance(strategy, str):
# ["ddp", "ddp_spawn", "ddp_find_unused_parameters_true"]
# ["ddp_notebook","ddp_notebook_find_unused_parameters_true"] - for jupyter nb run
return "ddp" in strategy
return isinstance(strategy, DDPStrategy)

Expand Down

0 comments on commit 8530004

Please sign in to comment.