Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: nan failure during training #3159

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open

fix: nan failure during training #3159

wants to merge 16 commits into from

Conversation

ori-kron-wis
Copy link
Collaborator

@ori-kron-wis ori-kron-wis commented Jan 21, 2025

Using SaveCheckpoint callback with on_exception can save the best optimal model up to the point it crashed due to Nan's in loss or gradients.
See an example (using Michal's data):

import scvi
from scvi.train._callbacks import SaveCheckpoint
from scvi.model import SCANVI
import pandas as pd
import numpy as np
import scanpy as sc
import torch
torch.set_float32_matmul_precision('high')

pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000)
scvi.settings.seed = 0

early_stopping_kwargs = {
    'early_stopping': True,
    'early_stopping_monitor': 'elbo_validation', #'train_loss'
    'early_stopping_patience': 50,
    'early_stopping_mode': "min",
    'early_stopping_min_delta': 0.0,
    #'check_val_every_n_epoch': 1,
    #'check_finite': True,
}

ScviM = scvi.model.SCVI.load("/home/access/scvi_forScanVI4")

lvae = scvi.model.SCANVI.from_scvi_model(
                ScviM,
                unlabeled_category='unlabeled',
                labels_key="celltypes_steven2",
                linear_classifier=True,
            )
lvae.train(batch_size=1024,n_samples_per_label=100, max_epochs=500, gradient_clip_val=0,
           **early_stopping_kwargs , detect_anomaly=False, enable_checkpointing=True,
           callbacks=[SaveCheckpoint(monitor="elbo_validation", load_best_on_end=True)]) #breaks at epoch 58

#WE now want to laod this model and continue to train it
model = SCANVI.load("/home/access/.config/JetBrains/PyCharmCE2024.2/scratches/scvi_log/"
                    "2025-01-23_13-37-44_elbo_validation/"
                    "epoch=54-step=53295-elbo_validation=1255.7066650390625/",adata=ScviM.adata)
model.train(batch_size=2048,n_samples_per_label=50, max_epochs=500, gradient_clip_val=1,
           **early_stopping_kwargs , detect_anomaly=False, enable_checkpointing=True, plan_kwargs={"lr": 1e-2},
           callbacks=[SaveCheckpoint(monitor="elbo_validation", load_best_on_end=True)])

#running with detect_anomlay=True really slows down the whole thing
print("done")

We can then load it and continue training it (with or without parameters twicking)

@ori-kron-wis ori-kron-wis self-assigned this Jan 21, 2025
@ori-kron-wis ori-kron-wis added the on-merge: backport to 1.3.x on-merge: backport to 1.3.x label Jan 21, 2025
@ori-kron-wis ori-kron-wis added this to the scvi-tools 1.3 milestone Jan 21, 2025
@ori-kron-wis ori-kron-wis changed the title Ori nan crash fix fix: nan failure during training Jan 21, 2025
Copy link

codecov bot commented Jan 21, 2025

Codecov Report

Attention: Patch coverage is 38.46154% with 16 lines in your changes missing coverage. Please review.

Project coverage is 82.60%. Comparing base (abfcbfc) to head (21943dc).

Files with missing lines Patch % Lines
src/scvi/train/_callbacks.py 38.46% 16 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (abfcbfc) and HEAD (21943dc). Click for more details.

HEAD has 18 uploads less than BASE
Flag BASE (abfcbfc) HEAD (21943dc)
21 3
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3159      +/-   ##
==========================================
- Coverage   89.43%   82.60%   -6.84%     
==========================================
  Files         185      185              
  Lines       16182    16208      +26     
==========================================
- Hits        14473    13389    -1084     
- Misses       1709     2819    +1110     
Files with missing lines Coverage Δ
src/scvi/train/_trainer.py 100.00% <ø> (ø)
src/scvi/train/_callbacks.py 77.97% <38.46%> (-7.24%) ⬇️

... and 27 files with indirect coverage changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
on-merge: backport to 1.3.x on-merge: backport to 1.3.x
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant