Skip to content

Commit

Permalink
test for callback
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Jan 30, 2025
1 parent bcee76c commit 22951d1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/scvi/train/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,23 @@ def on_exception(self, trainer, pl_module, exception) -> None:
"Saving model....Please load it back and continue training\033[0m"
)
trainer.should_stop = True

# Loads the best model state into the model after the exception.
_, _, best_state_dict, _ = _load_saved_files(
self.best_model_path,
load_adata=False,
map_location=pl_module.module.device,
)
pyro_param_store = best_state_dict.pop("pyro_param_store", None)
pl_module.module.load_state_dict(best_state_dict)
self.save_path = self.on_save_checkpoint(trainer)
if pyro_param_store is not None:
# For scArches shapes are changed and we don't want to overwrite
# these changed shapes.
pyro.get_param_store().set_state(pyro_param_store)
print(self.reason)
print(f"Model saved to {self.save_path}")
print(f"Model saved to {self.best_model_path}")
self._log_info(trainer, self.reason, False)
return

@staticmethod
def _log_info(trainer: pl.Trainer, message: str, log_rank_zero_only: bool) -> None:
Expand Down
29 changes: 29 additions & 0 deletions tests/train/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import scvi
from scvi.data import synthetic_iid


@pytest.mark.parametrize("load_best_on_end", [True, False])
Expand Down Expand Up @@ -83,3 +84,31 @@ def test_model_cls(model_cls, adata: AnnData):
test_model_cls(scvi.model.SCANVI, adata)

scvi.settings.logging_dir = old_logging_dir


def test_exception_callback():
import torch

import scvi
from scvi.model import SCVI
from scvi.train._callbacks import SaveCheckpoint

torch.set_float32_matmul_precision("high")
scvi.settings.seed = 0

# we still need to find a proper wat to simualte an adata that fail qucikly during training
adata = synthetic_iid(n_genes=1000, batch_size=2)
# change the adata to have Nan inside
# adata.X = adata.X.astype(float)
# adata.X[0,:] = 0

SCVI.setup_anndata(adata, batch_key="batch")

model = SCVI(adata)
model.train(max_epochs=5)

model.train(
max_epochs=5,
callbacks=[SaveCheckpoint(monitor="elbo_validation", load_best_on_end=True)],
enable_checkpointing=True,
)

0 comments on commit 22951d1

Please sign in to comment.