Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adding multiGPU option to SCANVI #3125

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
42c56e8
1st commit
ori-kron-wis Jan 7, 2025
c8995cc
Merge remote-tracking branch 'origin/main' into Ori-scanvi-MultiGPU-s…
ori-kron-wis Jan 8, 2025
02c01f2
Merge branch 'main' into Ori-scanvi-MultiGPU-support
ori-kron-wis Jan 16, 2025
c8ac786
Merge branch 'main' into Ori-scanvi-MultiGPU-support
ori-kron-wis Jan 20, 2025
b1758b9
Merge branch 'main' into Ori-scanvi-MultiGPU-support
ori-kron-wis Jan 26, 2025
ee0e446
added regular data splitting for scanvi in case of multipgu use
ori-kron-wis Jan 29, 2025
167ed5b
update changlog
ori-kron-wis Jan 29, 2025
69188f0
treat scanvi all labeled with regualr anndataloder and not with conca…
ori-kron-wis Jan 30, 2025
1c7273d
treat scanvi all labeled with regualr anndataloder and not with conca…
ori-kron-wis Jan 30, 2025
6d39d18
Merge remote-tracking branch 'origin/main' into Ori-scanvi-MultiGPU-s…
ori-kron-wis Jan 30, 2025
0cd6f35
merge with main
ori-kron-wis Jan 30, 2025
939c246
Merge remote-tracking branch 'origin/main' into Ori-scanvi-MultiGPU-s…
ori-kron-wis Jan 30, 2025
d25e56d
Merge remote-tracking branch 'origin/Ori-scanvi-MultiGPU-support' int…
ori-kron-wis Jan 30, 2025
b15d5c7
fix gimvi
ori-kron-wis Jan 30, 2025
92f0b0c
change tests
ori-kron-wis Jan 30, 2025
7f6ac72
change tests
ori-kron-wis Jan 30, 2025
a97efc7
Merge remote-tracking branch 'origin/Ori-scanvi-MultiGPU-support' int…
ori-kron-wis Feb 2, 2025
8530004
re think the scanvi multi gpu problem: adding support for anndatalode…
ori-kron-wis Feb 2, 2025
fa0a77f
limit the anndataloader fix to scanvi only (other models will need mo…
ori-kron-wis Feb 2, 2025
1c2ea3c
cleaning code
ori-kron-wis Feb 3, 2025
fad1daf
Merge remote-tracking branch 'origin/main' into Ori-scanvi-MultiGPU-s…
ori-kron-wis Feb 3, 2025
60de475
fix circular bug
ori-kron-wis Feb 3, 2025
9c55311
fix test
ori-kron-wis Feb 3, 2025
28ef89b
update notebooks link (?))
ori-kron-wis Feb 4, 2025
dd75476
validate more models to run with multi GPU: condscvi, linearscvi, pea…
ori-kron-wis Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ to [Semantic Versioning]. Full commit history is available in the

- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable
representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`.
- Add support for {class}`~scvi.model.SCANVI`, {class}`~scvi.model.CondSCVI` and
{class}`~scvi.model.LinearSCVI` multiGPU training {pr}`3125`. Also added this support for
{class}`~scvi.model.TOTALVI`, {class}`~scvi.model.MULTIVI` and {class}`~scvi.model.PEAKVI`
but need to disable early_stopping first in order to use multiGPU for those models.
- Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial
transcriptomics {pr}`3144`.

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
77 changes: 76 additions & 1 deletion src/scvi/dataloaders/_ann_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,55 @@
SequentialSampler,
)

from scvi import settings
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._utils import get_anndata_attribute

from ._samplers import BatchDistributedSampler

logger = logging.getLogger(__name__)


def subsample_labels(labeled_locs, n_samples_per_label):
"""Subsamples each label class by taking up to n_samples_per_label samples per class."""
if n_samples_per_label is None:
if len(labeled_locs) == 0:
return labeled_locs
else:
return np.concatenate(labeled_locs)

sample_idx = []
for loc in labeled_locs:
if len(loc) < n_samples_per_label:
sample_idx.append(loc)
else:
label_subset = np.random.choice(loc, n_samples_per_label, replace=False)
sample_idx.append(label_subset)
sample_idx = np.concatenate(sample_idx)
return sample_idx


def labelled_indices_generator(adata_manager, indices, indices_asarray, n_samples_per_label):
"""Generates indices for each label class"""
labelled_idx = []
labeled_locs = []
labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)
labels = get_anndata_attribute(
adata_manager.adata,
adata_manager.data_registry.labels.attr_name,
labels_state_registry.original_key,
).ravel()
if hasattr(labels_state_registry, "unlabeled_category"):
# save a nested list of the indices per labeled category (if exists)
for label in np.unique(labels):
if label != labels_state_registry.unlabeled_category:
label_loc_idx = np.where(labels[indices] == label)[0]
label_loc = indices_asarray[label_loc_idx]
labeled_locs.append(label_loc)
labelled_idx = subsample_labels(labeled_locs, n_samples_per_label)
return labeled_locs, labelled_idx


class AnnDataLoader(DataLoader):
"""DataLoader for loading tensors from AnnData objects.

Expand Down Expand Up @@ -56,6 +97,8 @@ class AnnDataLoader(DataLoader):
distributed_sampler
``EXPERIMENTAL`` Whether to use :class:`~scvi.dataloaders.BatchDistributedSampler` as the
sampler. If `True`, `sampler` must be `None`.
n_samples_per_label
Number of subsamples for each label class to sample per epoch
load_sparse_tensor
``EXPERIMENTAL`` If ``True``, loads data with sparse CSR or CSC layout as a
:class:`~torch.Tensor` with the same layout. Can lead to speedups in data transfers to
Expand Down Expand Up @@ -83,6 +126,7 @@ def __init__(
data_and_attributes: list[str] | dict[str, np.dtype] | None = None,
iter_ndarray: bool = False,
distributed_sampler: bool = False,
n_samples_per_label: int | None = None,
load_sparse_tensor: bool = False,
**kwargs,
):
Expand All @@ -93,6 +137,7 @@ def __init__(
indices = np.where(indices)[0].ravel()
indices = np.asarray(indices)
self.indices = indices
self.n_samples_per_label = n_samples_per_label
self.dataset = adata_manager.create_torch_dataset(
indices=indices,
data_and_attributes=data_and_attributes,
Expand All @@ -104,10 +149,24 @@ def __init__(
kwargs["persistent_workers"] = settings.dl_persistent_workers

self.kwargs = copy.deepcopy(kwargs)
self.adata_manager = adata_manager
self.data_and_attributes = data_and_attributes
self._shuffle = shuffle
self._batch_size = batch_size
self._drop_last = drop_last
self.load_sparse_tensor = load_sparse_tensor

if sampler is not None and distributed_sampler:
raise ValueError("Cannot specify both `sampler` and `distributed_sampler`.")

# Next block of code is for the case of labeled anndataloder used in scanvi multigpu:
self.labeled_locs, labelled_idx = [], []
if adata_manager.registry["model_name"] == "SCANVI":
# Next block of code is for the case of labeled anndataloder used in scanvi multigpu:
self.labeled_locs, labelled_idx = labelled_indices_generator(
adata_manager, indices, self.indices, self.n_samples_per_label
)

# custom sampler for efficient minibatching on sparse matrices
if sampler is None:
if not distributed_sampler:
Expand Down Expand Up @@ -136,3 +195,19 @@ def __init__(
self.kwargs.update({"collate_fn": lambda x: x})

super().__init__(self.dataset, **self.kwargs)

def resample_labels(self):
"""Resamples the labeled data."""
self.kwargs.pop("batch_size", None)
self.kwargs.pop("shuffle", None)
self.kwargs.pop("sampler", None)
self.kwargs.pop("collate_fn", None)
AnnDataLoader(
self.adata_manager,
indices=subsample_labels(self.labeled_locs, self.n_samples_per_label),
shuffle=self._shuffle,
batch_size=self._batch_size,
data_and_attributes=self.data_and_attributes,
drop_last=self._drop_last,
**self.kwargs,
)
17 changes: 12 additions & 5 deletions src/scvi/dataloaders/_concat_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
**data_loader_kwargs,
):
self.adata_manager = adata_manager
self.dataloader_kwargs = data_loader_kwargs
self.data_loader_kwargs = data_loader_kwargs
self.data_and_attributes = data_and_attributes
self._shuffle = shuffle
self._batch_size = batch_size
Expand All @@ -57,6 +57,7 @@ def __init__(

self.dataloaders = []
for indices in indices_list:
self.data_loader_kwargs.pop("sampler", None)
self.dataloaders.append(
AnnDataLoader(
adata_manager,
Expand All @@ -66,12 +67,13 @@ def __init__(
data_and_attributes=data_and_attributes,
drop_last=drop_last,
distributed_sampler=distributed_sampler,
**self.dataloader_kwargs,
**self.data_loader_kwargs,
)
)
lens = [len(dl) for dl in self.dataloaders]
self.largest_dl = self.dataloaders[np.argmax(lens)]
super().__init__(self.largest_dl, **data_loader_kwargs)
self.data_loader_kwargs.pop("drop_dataset_tail", None)
super().__init__(self.largest_dl, **self.data_loader_kwargs)

def __len__(self):
return len(self.largest_dl)
Expand All @@ -83,5 +85,10 @@ def __iter__(self):
the data in the other dataloaders. The order of data in returned iter_list
is the same as indices_list.
"""
iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders]
return zip(*iter_list, strict=True)
if not self._distributed_sampler:
iter_list = [cycle(dl) if dl != self.largest_dl else dl for dl in self.dataloaders]
strict = True
else:
iter_list = self.dataloaders
strict = False
return zip(*iter_list, strict=strict)
7 changes: 6 additions & 1 deletion src/scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,11 +516,16 @@ def setup(self, stage: str | None = None):
self.test_idx = indices_test.astype(int)

if len(self._labeled_indices) != 0:
self.data_loader_class = SemiSupervisedDataLoader
if len(self._unlabeled_indices) != 0:
self.data_loader_class = SemiSupervisedDataLoader
else:
# data is all labeled
self.data_loader_class = AnnDataLoader
dl_kwargs = {
"n_samples_per_label": self.n_samples_per_label,
}
else:
# data is all unlabeled
self.data_loader_class = AnnDataLoader
dl_kwargs = {}

Expand Down
43 changes: 8 additions & 35 deletions src/scvi/dataloaders/_semi_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import numpy as np

from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data._utils import get_anndata_attribute

from ._ann_dataloader import AnnDataLoader
from ._ann_dataloader import AnnDataLoader, labelled_indices_generator, subsample_labels
from ._concat_dataloader import ConcatDataLoader


Expand Down Expand Up @@ -53,22 +51,11 @@ def __init__(
return None

self.n_samples_per_label = n_samples_per_label
self.data_loader_kwargs = data_loader_kwargs

labels_state_registry = adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY)
labels = get_anndata_attribute(
adata_manager.adata,
adata_manager.data_registry.labels.attr_name,
labels_state_registry.original_key,
).ravel()

# save a nested list of the indices per labeled category
self.labeled_locs = []
for label in np.unique(labels):
if label != labels_state_registry.unlabeled_category:
label_loc_idx = np.where(labels[indices] == label)[0]
label_loc = self.indices[label_loc_idx]
self.labeled_locs.append(label_loc)
labelled_idx = self.subsample_labels()
self.labeled_locs, labelled_idx = labelled_indices_generator(
adata_manager, indices, self.indices, self.n_samples_per_label
)

super().__init__(
adata_manager=adata_manager,
Expand All @@ -77,12 +64,12 @@ def __init__(
batch_size=batch_size,
data_and_attributes=data_and_attributes,
drop_last=drop_last,
**data_loader_kwargs,
**self.data_loader_kwargs,
)

def resample_labels(self):
"""Resamples the labeled data."""
labelled_idx = self.subsample_labels()
labelled_idx = subsample_labels(self.labeled_locs, self.n_samples_per_label)
# self.dataloaders[0] iterates over full_indices
# self.dataloaders[1] iterates over the labelled_indices
# change the indices of the labelled set
Expand All @@ -93,19 +80,5 @@ def resample_labels(self):
batch_size=self._batch_size,
data_and_attributes=self.data_and_attributes,
drop_last=self._drop_last,
**self.data_loader_kwargs,
)

def subsample_labels(self):
"""Subsamples each label class by taking up to n_samples_per_label samples per class."""
if self.n_samples_per_label is None:
return np.concatenate(self.labeled_locs)

sample_idx = []
for loc in self.labeled_locs:
if len(loc) < self.n_samples_per_label:
sample_idx.append(loc)
else:
label_subset = np.random.choice(loc, self.n_samples_per_label, replace=False)
sample_idx.append(label_subset)
sample_idx = np.concatenate(sample_idx)
return sample_idx
2 changes: 2 additions & 0 deletions src/scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_get_batch_code_from_category,
scatac_raw_counts_properties,
scrna_raw_counts_properties,
use_distributed_sampler,
)
from scvi.model.base import (
ArchesMixin,
Expand Down Expand Up @@ -348,6 +349,7 @@ def train(
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
distributed_sampler=use_distributed_sampler(kwargs.get("strategy", None)),
batch_size=batch_size,
**datasplitter_kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
NumericalObsField,
)
from scvi.dataloaders import SemiSupervisedDataSplitter
from scvi.model._utils import _init_library_size, get_max_epochs_heuristic
from scvi.model._utils import _init_library_size, get_max_epochs_heuristic, use_distributed_sampler
from scvi.module import SCANVAE
from scvi.train import SemiSupervisedTrainingPlan, TrainRunner
from scvi.train._callbacks import SubSampleLabels
Expand Down Expand Up @@ -414,6 +414,7 @@ def train(
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
n_samples_per_label=n_samples_per_label,
distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)),
batch_size=batch_size,
**datasplitter_kwargs,
)
Expand Down
2 changes: 2 additions & 0 deletions src/scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_init_library_size,
cite_seq_raw_counts_properties,
get_max_epochs_heuristic,
use_distributed_sampler,
)
from scvi.model.base._de_core import _de_core
from scvi.module import TOTALVAE
Expand Down Expand Up @@ -322,6 +323,7 @@ def train(
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
batch_size=batch_size,
distributed_sampler=use_distributed_sampler(kwargs.get("strategy", None)),
external_indexing=external_indexing,
**datasplitter_kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions src/scvi/model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def use_distributed_sampler(strategy: str | Strategy) -> bool:
"""
if isinstance(strategy, str):
# ["ddp", "ddp_spawn", "ddp_find_unused_parameters_true"]
# ["ddp_notebook","ddp_notebook_find_unused_parameters_true"] - for jupyter nb run
return "ddp" in strategy
return isinstance(strategy, DDPStrategy)

Expand Down
Loading
Loading