Skip to content

Commit

Permalink
added regular data splitting for scanvi in case of multipgu use
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Jan 29, 2025
1 parent b1758b9 commit ee0e446
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
NumericalJointObsField,
NumericalObsField,
)
from scvi.dataloaders import SemiSupervisedDataSplitter
from scvi.model._utils import _init_library_size, get_max_epochs_heuristic
from scvi.dataloaders import DataSplitter, SemiSupervisedDataSplitter
from scvi.model._utils import _init_library_size, get_max_epochs_heuristic, use_distributed_sampler
from scvi.module import SCANVAE
from scvi.train import SemiSupervisedTrainingPlan, TrainRunner
from scvi.train._callbacks import SubSampleLabels
Expand Down Expand Up @@ -408,15 +408,32 @@ def train(
# if we have labeled cells, we want to subsample labels each epoch
sampler_callback = [SubSampleLabels()] if len(self._labeled_indices) != 0 else []

data_splitter = SemiSupervisedDataSplitter(
adata_manager=self.adata_manager,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
n_samples_per_label=n_samples_per_label,
batch_size=batch_size,
**datasplitter_kwargs,
)
# TODO: Create the SCVI regular dataplitter in case of multigpu plausioble
if (len(self._unlabeled_indices) == 0) and (datasplitter_kwargs["distributed_sampler"]):
# we are in a multigpu env and its all labeled - n_samples_per_label is not used here
# also I remvoed the option for load_sparse_tensor it can be reached from kwargs
# in this way we bypass the use of concatdataloaders for now
data_splitter = DataSplitter(
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
shuffle_set_split=shuffle_set_split,
distributed_sampler=use_distributed_sampler(trainer_kwargs.get("strategy", None)),
**datasplitter_kwargs,
)
else:
# what we had so far in scanvi (concat dataloaders)
data_splitter = SemiSupervisedDataSplitter(
adata_manager=self.adata_manager,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
n_samples_per_label=n_samples_per_label,
batch_size=batch_size,
**datasplitter_kwargs,
)

training_plan = self._training_plan_cls(self.module, self.n_labels, **plan_kwargs)
if "callbacks" in trainer_kwargs.keys():
trainer_kwargs["callbacks"] + [sampler_callback]
Expand Down

0 comments on commit ee0e446

Please sign in to comment.