Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add get_normalized_expression to more models #3121

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ to [Semantic Versioning]. Full commit history is available in the

- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable
representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`.
- Add {meth}`~scvi.model.SCVI.get_normalized_expression` for models: PeakVI, PoissonVI, CondSCVI,
AutoZI, CellAssign and GimVI. {pr}`3121`
- Add {class}`scvi.external.RESOLVI` for bias correction in single-cell resolved spatial
transcriptomics {pr}`3144`.

Expand Down
4 changes: 2 additions & 2 deletions src/scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from scvi.dataloaders import DataSplitter
from scvi.external.cellassign._module import CellAssignModule
from scvi.model._utils import get_max_epochs_heuristic
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.model.base import BaseModelClass, RNASeqMixin, UnsupervisedTrainingMixin
from scvi.train import LoudEarlyStopping, TrainingPlan, TrainRunner
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp
Expand All @@ -33,7 +33,7 @@
B = 10


class CellAssign(UnsupervisedTrainingMixin, BaseModelClass):
class CellAssign(UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin):
"""Reimplementation of CellAssign for reference-based annotation :cite:p:`Zhang19`.

Original implementation: https://github.com/irrationone/cellassign.
Expand Down
40 changes: 28 additions & 12 deletions src/scvi/external/cellassign/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

from scvi import REGISTRY_KEYS
from scvi.distributions import NegativeBinomial
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.module import VAE
from scvi.module.base import LossOutput, auto_move_data

LOWER_BOUND = 1e-10
THETA_LOWER_BOUND = 1e-20
B = 10


class CellAssignModule(BaseModuleClass):
class CellAssignModule(VAE):
"""Model for CellAssign.

Parameters
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
n_cats_per_cov: Iterable[int] | None = None,
n_continuous_cov: int = 0,
):
super().__init__()
super().__init__(n_genes)
self.n_genes = n_genes
self.n_labels = rho.shape[1]
self.n_batch = n_batch
Expand Down Expand Up @@ -103,10 +104,7 @@ def __init__(

self.register_buffer("basis_means", torch.tensor(basis_means, dtype=torch.float32))

def _get_inference_input(self, tensors):
return {}

def _get_generative_input(self, tensors, inference_outputs):
def _get_generative_input(self, tensors, inference_outputs, transform_batch=None):
x = tensors[REGISTRY_KEYS.X_KEY]
size_factor = tensors[REGISTRY_KEYS.SIZE_FACTOR_KEY]

Expand All @@ -127,19 +125,27 @@ def _get_generative_input(self, tensors, inference_outputs):

design_matrix = torch.cat(to_cat, dim=1) if len(to_cat) > 0 else None

batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch

input_dict = {
"x": x,
"size_factor": size_factor,
"design_matrix": design_matrix,
"batch_index": batch_index,
}
return input_dict

@auto_move_data
def inference(self):
return {}

@auto_move_data
def generative(self, x, size_factor, design_matrix=None):
def generative(
self,
x,
size_factor,
batch_index,
design_matrix=None,
transform_batch: torch.Tensor | None = None,
):
"""Run the generative model."""
# x has shape (n, g)
delta = torch.exp(self.delta_log) # (g, c)
Expand Down Expand Up @@ -193,12 +199,22 @@ def generative(self, x, size_factor, design_matrix=None):
normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand(n_cells, self.n_labels)
gamma = torch.exp(p_x_c - normalizer_over_c) # (n, c)

px = torch.sum(x_log_prob_raw, -1)
normalizer_over_c2 = torch.logsumexp(px, 1)
normalizer_over_c2 = normalizer_over_c2.unsqueeze(-1).expand(n_cells, self.n_genes)
gamma2 = torch.exp(px - normalizer_over_c2) # (n, g)

if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch

return {
"mu": mu_ngc,
"phi": phi,
"gamma": gamma,
"p_x_c": p_x_c,
"px": gamma2,
"s": size_factor,
"batch_index": batch_index,
}

def loss(
Expand Down
4 changes: 2 additions & 2 deletions src/scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from scvi.data.fields import CategoricalObsField, LayerField
from scvi.dataloaders import DataSplitter
from scvi.model._utils import _init_library_size, parse_device_args
from scvi.model.base import BaseModelClass, VAEMixin
from scvi.model.base import BaseModelClass, RNASeqMixin, VAEMixin
from scvi.train import Trainer
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp
Expand All @@ -41,7 +41,7 @@ def _unpack_tensors(tensors):
return x, batch_index, y


class GIMVI(VAEMixin, BaseModelClass):
class GIMVI(VAEMixin, BaseModelClass, RNASeqMixin):
"""Joint VAE for imputing missing genes in spatial data :cite:p:`Lopez19`.

Parameters
Expand Down
35 changes: 28 additions & 7 deletions src/scvi/external/gimvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,20 +364,31 @@ def reconstruction_loss(
reconstruction_loss = -Poisson(px_rate).log_prob(x).sum(dim=1)
return reconstruction_loss

def _get_inference_input(self, tensors):
def _get_inference_input(self, tensors) -> dict[str, torch.Tensor | None]:
"""Get the input for the inference model."""
return {"x": tensors[REGISTRY_KEYS.X_KEY]}
return {
"x": tensors[REGISTRY_KEYS.X_KEY],
"batch_index": tensors.get(REGISTRY_KEYS.BATCH_KEY, None),
}

def _get_generative_input(self, tensors, inference_outputs):
def _get_generative_input(self, tensors, inference_outputs, transform_batch=None):
"""Get the input for the generative model."""
z = inference_outputs["z"]
library = inference_outputs["library"]
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
y = tensors[REGISTRY_KEYS.LABELS_KEY]
if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch
return {"z": z, "library": library, "batch_index": batch_index, "y": y}

@auto_move_data
def inference(self, x: torch.Tensor, mode: int | None = None) -> dict:
def inference(
self,
x: torch.Tensor,
mode: int | None = 0,
n_samples: int | None = 1,
batch_index: torch.Tensor | None = None,
) -> dict:
"""Run the inference model."""
x_ = x
if self.log_variational:
Expand All @@ -390,6 +401,11 @@ def inference(self, x: torch.Tensor, mode: int | None = None) -> dict:
else:
library = torch.log(torch.sum(x, dim=1)).view(-1, 1)

if n_samples > 1:
# when z is normal, untran_z == z
untran_z = qz.sample((n_samples,))
z = self.z_encoder.z_transformation(untran_z)

return {"qz": qz, "z": z, "ql": ql, "library": library}

@auto_move_data
Expand All @@ -399,7 +415,8 @@ def generative(
library: torch.Tensor,
batch_index: torch.Tensor | None = None,
y: torch.Tensor | None = None,
mode: int | None = None,
mode: int | None = 0,
transform_batch: torch.Tensor | None = None,
) -> dict:
"""Run the generative model."""
px_scale, px_r, px_rate, px_dropout = self.decoder(
Expand All @@ -418,11 +435,15 @@ def generative(
)
px_rate = px_scale * torch.exp(library)

if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch

return {
"px_scale": px_scale,
"px_r": px_r,
"px": px_r,
"px_rate": px_rate,
"px_dropout": px_dropout,
"batch_index": batch_index,
}

def loss(
Expand Down Expand Up @@ -463,7 +484,7 @@ def loss(
qz = inference_outputs["qz"]
ql = inference_outputs["ql"]
px_rate = generative_outputs["px_rate"]
px_r = generative_outputs["px_r"]
px_r = generative_outputs["px"]
px_dropout = generative_outputs["px_dropout"]

# mask loss to observed genes
Expand Down
10 changes: 0 additions & 10 deletions src/scvi/external/poissonvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,6 @@ def get_region_factors(self):
raise RuntimeError("region factors were not included in this model")
return region_factors

def get_normalized_expression(
self,
):
# Refer to function get_accessibility_estimates
msg = (
f"differential_expression is not implemented for {self.__class__.__name__}, please "
f"use {self.__class__.__name__}.get_accessibility_estimates"
)
raise NotImplementedError(msg)

@de_dsp.dedent
def differential_accessibility(
self,
Expand Down
4 changes: 2 additions & 2 deletions src/scvi/model/_autozi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from scvi.module import AutoZIVAE
from scvi.utils import setup_anndata_dsp

from .base import BaseModelClass, VAEMixin
from .base import BaseModelClass, RNASeqMixin, VAEMixin

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -29,7 +29,7 @@
# register buffer


class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin):
"""Automatic identification of zero-inflated genes :cite:p:`Clivio19`.

Parameters
Expand Down
3 changes: 1 addition & 2 deletions src/scvi/model/_condscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,8 @@ def setup_anndata(
anndata_fields = [
fields.LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
fields.CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
]
if batch_key is not None:
anndata_fields.append(fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key))
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
12 changes: 11 additions & 1 deletion src/scvi/model/_destvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import LayerField, NumericalObsField
from scvi.data.fields import CategoricalObsField, LayerField, NumericalObsField
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.module import MRDeconv
from scvi.utils import setup_anndata_dsp
Expand Down Expand Up @@ -220,6 +220,13 @@ def get_proportions(
index=index_names,
)

def get_normalized_expression(
self,
):
# Refer to function get_accessibility_estimates
msg = f"get_normalized_expression is not implemented for {self.__class__.__name__}."
raise NotImplementedError(msg)

def get_gamma(
self,
indices: Sequence[int] | None = None,
Expand Down Expand Up @@ -390,6 +397,7 @@ def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
batch_key: str | None = None,
**kwargs,
):
"""%(summary)s.
Expand All @@ -398,13 +406,15 @@ def setup_anndata(
----------
%(param_adata)s
%(param_layer)s
%(param_batch_key)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
# add index for each cell (provided to pyro plate for correct minibatching)
adata.obs["_indices"] = np.arange(adata.n_obs)
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from scvi.train._callbacks import SaveBestState
from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, VAEMixin
from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
Expand All @@ -40,7 +40,7 @@
logger = logging.getLogger(__name__)


class PEAKVI(ArchesMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
class PEAKVI(ArchesMixin, RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
"""Peak Variational Inference for chromatin accessilibity analysis :cite:p:`Ashuach22`.

Parameters
Expand Down Expand Up @@ -393,7 +393,7 @@ def get_accessibility_estimates(
generative_kwargs=generative_kwargs,
compute_loss=False,
)
p = generative_outputs["p"].cpu()
p = generative_outputs["px"].cpu()

if normalize_cells:
p *= inference_outputs["d"].cpu()
Expand Down
6 changes: 5 additions & 1 deletion src/scvi/model/base/_rnamixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,11 @@ def get_normalized_expression(
generative_kwargs=generative_kwargs,
compute_loss=False,
)
exp_ = generative_outputs["px"].get_normalized(generative_output_key)
px_generative = generative_outputs["px"]
if isinstance(px_generative, torch.Tensor):
exp_ = px_generative
else:
exp_ = px_generative.get_normalized(generative_output_key)
exp_ = exp_[..., gene_mask]
exp_ *= scaling
per_batch_exprs.append(exp_[None].cpu())
Expand Down
2 changes: 2 additions & 0 deletions src/scvi/module/_autozivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def generative(
cat_covs=None,
n_samples: int = 1,
eps_log: float = 1e-8,
transform_batch: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Run the generative model."""
outputs = super().generative(
Expand All @@ -293,6 +294,7 @@ def generative(
cat_covs=cat_covs,
y=y,
size_factor=size_factor,
transform_batch=transform_batch,
)
# Rescale dropout
rescaled_dropout = self.rescale_dropout(outputs["px"].zi_logits, eps_log=eps_log)
Expand Down
Loading
Loading