From 9bba0db066666b24a7537fad583bc4e734c382fc Mon Sep 17 00:00:00 2001 From: Nathan Date: Mon, 3 Feb 2025 11:03:11 +0200 Subject: [PATCH] fix test folder name --- tests/external/nichevi/test_nichevi.py | 188 +++++++++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 tests/external/nichevi/test_nichevi.py diff --git a/tests/external/nichevi/test_nichevi.py b/tests/external/nichevi/test_nichevi.py new file mode 100644 index 0000000000..72eafba45e --- /dev/null +++ b/tests/external/nichevi/test_nichevi.py @@ -0,0 +1,188 @@ +import numpy as np +import pytest +from anndata import AnnData + +from scvi.data import synthetic_iid +from scvi.external import nicheSCVI + +N_LATENT = 10 +K_NN = 5 +N_EPOCHS_NICHEVI = 2 + +setup_kwargs = { + "sample_key": "batch", + "labels_key": "labels", + "cell_coordinates_key": "coordinates", + "expression_embedding_key": "qz1_m", + "expression_embedding_niche_key": "qz1_m_niche_ct", + "niche_composition_key": "neighborhood_composition", + "niche_indexes_key": "niche_indexes", + "niche_distances_key": "niche_distances", +} + + +@pytest.fixture(scope="session") +def adata(): + adata = synthetic_iid( + batch_size=256, + n_genes=100, + n_proteins=0, + n_regions=0, + n_batches=2, + n_labels=3, + dropout_ratio=0.5, + generate_coordinates=True, + sparse_format=None, + return_mudata=False, + ) + + adata.obsm["qz1_m"] = np.random.normal(size=(adata.shape[0], N_LATENT)) + adata.layers["counts"] = adata.X.copy() + + return adata + + +def test_nichevi_train(adata: AnnData): + nicheSCVI.preprocessing_anndata( + adata, + k_nn=K_NN, + **setup_kwargs, + ) + + nicheSCVI.setup_anndata( + adata, + layer="counts", + batch_key="batch", + **setup_kwargs, + ) + nichevae = nicheSCVI( + adata, + prior_mixture=False, + semisupervised=True, + linear_classifier=True, + ) + + nichevae.train( + max_epochs=N_EPOCHS_NICHEVI, + train_size=0.8, + validation_size=0.2, + early_stopping=True, + check_val_every_n_epoch=1, + accelerator="cpu", + ) + + +def test_nichevi_save_load(adata): + nicheSCVI.preprocessing_anndata( + adata, + k_nn=K_NN, + **setup_kwargs, + ) + + nicheSCVI.setup_anndata( + adata, + layer="counts", + batch_key="batch", + **setup_kwargs, + ) + nichevae = nicheSCVI( + adata, + prior_mixture=False, + semisupervised=True, + linear_classifier=True, + ) + + nichevae.train( + max_epochs=N_EPOCHS_NICHEVI, + train_size=0.8, + validation_size=0.2, + early_stopping=True, + check_val_every_n_epoch=1, + accelerator="cpu", + ) + hist_elbo = nichevae.history["elbo_train"] + latent = nichevae.get_latent_representation() + assert latent.shape == (adata.n_obs, nichevae.module.n_latent) + nichevae.save("test_nichevi", save_anndata=True, overwrite=True) + model2 = nichevae.load("test_nichevi") + np.testing.assert_array_equal(model2.history_["elbo_train"], hist_elbo) + latent2 = model2.get_latent_representation() + assert np.allclose(latent, latent2, atol=1e-5) + + nichevae.get_elbo(indices=nichevae.validation_indices) + nichevae.get_composition_error(return_mean=False, indices=nichevae.validation_indices) + nichevae.get_niche_error(return_mean=False, indices=nichevae.validation_indices) + nichevae.get_normalized_expression() + nichevae.predict_neighborhood() # specific to nicheSCVI + nichevae.predict_niche_activation() # specific to nicheSCVI + + +def test_nichevi_differential(adata): + nicheSCVI.preprocessing_anndata( + adata, + k_nn=K_NN, + **setup_kwargs, + ) + + nicheSCVI.setup_anndata( + adata, + layer="counts", + batch_key="batch", + **setup_kwargs, + ) + nichevae = nicheSCVI( + adata, + prior_mixture=False, + semisupervised=True, + linear_classifier=True, + ) + + nichevae.train( + max_epochs=N_EPOCHS_NICHEVI, + train_size=0.8, + validation_size=0.2, + early_stopping=True, + check_val_every_n_epoch=1, + accelerator="cpu", + ) + + nichevae.differential_expression( + groupby="labels", + group1="label_1", + group2="label_2", + batch_correction=False, + niche_mode=False, + fdr_target=1, + delta=0.5, + ) + + nichevae.differential_expression( + groupby="labels", + group1="label_1", + # group2="label_2", + batch_correction=False, + radius=None, + k_nn=5, + fdr_target=1, + delta=0.5, + ) + nichevae.differential_expression( + groupby="labels", + group1="label_1", + group2="label_2", + batch_correction=False, + radius=50, + k_nn=None, + fdr_target=[1, 1, 1, 1], + delta=[0.5, 0.5, 0.5, 0.5], + ) + nichevae.differential_expression( + groupby="labels", + group1="label_1", + group2="label_2", + batch_correction=False, + radius=50, + k_nn=None, + fdr_target=[1, 1, 1, 1], + delta=[0.5, 0.5, 0.5, 0.5], + )