Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add get_normalized_expression to more models #3121

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@ to [Semantic Versioning]. Full commit history is available in the

- Add {class}`scvi.external.Decipher` for dimensionality reduction and interpretable
representation learning in single-cell RNA sequencing data {pr}`3015`, {pr}`3091`.
- Add {meth}`~scvi.model.SCVI.get_normalized_expression` for models: PeakVI, DestVI, AutoZI,
CellAssign. {pr}`3121`

#### Fixed

4 changes: 2 additions & 2 deletions src/scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
from scvi.dataloaders import DataSplitter
from scvi.external.cellassign._module import CellAssignModule
from scvi.model._utils import get_max_epochs_heuristic
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.model.base import BaseModelClass, RNASeqMixin, UnsupervisedTrainingMixin, VAEMixin
from scvi.train import LoudEarlyStopping, TrainingPlan, TrainRunner
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp
@@ -33,7 +33,7 @@
B = 10


class CellAssign(UnsupervisedTrainingMixin, BaseModelClass):
class CellAssign(UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin, VAEMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is CellAssign intended for these functions? I would remove it here.

"""Reimplementation of CellAssign for reference-based annotation :cite:p:`Zhang19`.

Original implementation: https://github.com/irrationone/cellassign.
40 changes: 28 additions & 12 deletions src/scvi/external/cellassign/_module.py
Original file line number Diff line number Diff line change
@@ -6,14 +6,15 @@

from scvi import REGISTRY_KEYS
from scvi.distributions import NegativeBinomial
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.module import VAE
from scvi.module.base import LossOutput, auto_move_data

LOWER_BOUND = 1e-10
THETA_LOWER_BOUND = 1e-20
B = 10


class CellAssignModule(BaseModuleClass):
class CellAssignModule(VAE):
"""Model for CellAssign.

Parameters
@@ -51,7 +52,7 @@ def __init__(
n_cats_per_cov: Iterable[int] | None = None,
n_continuous_cov: int = 0,
):
super().__init__()
super().__init__(n_genes)
self.n_genes = n_genes
self.n_labels = rho.shape[1]
self.n_batch = n_batch
@@ -103,10 +104,7 @@ def __init__(

self.register_buffer("basis_means", torch.tensor(basis_means, dtype=torch.float32))

def _get_inference_input(self, tensors):
return {}

def _get_generative_input(self, tensors, inference_outputs):
def _get_generative_input(self, tensors, inference_outputs, transform_batch=None):
x = tensors[REGISTRY_KEYS.X_KEY]
size_factor = tensors[REGISTRY_KEYS.SIZE_FACTOR_KEY]

@@ -127,19 +125,27 @@ def _get_generative_input(self, tensors, inference_outputs):

design_matrix = torch.cat(to_cat, dim=1) if len(to_cat) > 0 else None

batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch

input_dict = {
"x": x,
"size_factor": size_factor,
"design_matrix": design_matrix,
"batch_index": batch_index,
}
return input_dict

@auto_move_data
def inference(self):
return {}

@auto_move_data
def generative(self, x, size_factor, design_matrix=None):
def generative(
self,
x,
size_factor,
batch_index,
design_matrix=None,
transform_batch: torch.Tensor | None = None,
):
"""Run the generative model."""
# x has shape (n, g)
delta = torch.exp(self.delta_log) # (g, c)
@@ -193,12 +199,22 @@ def generative(self, x, size_factor, design_matrix=None):
normalizer_over_c = normalizer_over_c.unsqueeze(-1).expand(n_cells, self.n_labels)
gamma = torch.exp(p_x_c - normalizer_over_c) # (n, c)

px = torch.sum(x_log_prob_raw, -1)
normalizer_over_c2 = torch.logsumexp(px, 1)
normalizer_over_c2 = normalizer_over_c2.unsqueeze(-1).expand(n_cells, self.n_genes)
gamma2 = torch.exp(px - normalizer_over_c2) # (n, g)

if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch

return {
"mu": mu_ngc,
"phi": phi,
"gamma": gamma,
"p_x_c": p_x_c,
"px": gamma2,
"s": size_factor,
"batch_index": batch_index,
}

def loss(
4 changes: 2 additions & 2 deletions src/scvi/model/_autozi.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from scvi.module import AutoZIVAE
from scvi.utils import setup_anndata_dsp

from .base import BaseModelClass, VAEMixin
from .base import BaseModelClass, RNASeqMixin, VAEMixin

if TYPE_CHECKING:
from collections.abc import Sequence
@@ -29,7 +29,7 @@
# register buffer


class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
class AUTOZI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin):
"""Automatic identification of zero-inflated genes :cite:p:`Clivio19`.

Parameters
3 changes: 1 addition & 2 deletions src/scvi/model/_condscvi.py
Original file line number Diff line number Diff line change
@@ -297,9 +297,8 @@ def setup_anndata(
anndata_fields = [
fields.LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
fields.CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
]
if batch_key is not None:
anndata_fields.append(fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key))
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
9 changes: 6 additions & 3 deletions src/scvi/model/_destvi.py
Original file line number Diff line number Diff line change
@@ -9,8 +9,8 @@

from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import LayerField, NumericalObsField
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.data.fields import CategoricalObsField, LayerField, NumericalObsField
from scvi.model.base import BaseModelClass, RNASeqMixin, UnsupervisedTrainingMixin
from scvi.module import MRDeconv
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp
@@ -26,7 +26,7 @@
logger = logging.getLogger(__name__)


class DestVI(UnsupervisedTrainingMixin, BaseModelClass):
class DestVI(UnsupervisedTrainingMixin, BaseModelClass, RNASeqMixin):
"""Multi-resolution deconvolution of Spatial Transcriptomics data (DestVI) :cite:p:`Lopez22`.

Most users will use the alternate constructor (see example).
@@ -390,6 +390,7 @@ def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
batch_key: str | None = None,
**kwargs,
):
"""%(summary)s.
@@ -398,13 +399,15 @@ def setup_anndata(
----------
%(param_adata)s
%(param_layer)s
%(param_batch_key)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
# add index for each cell (provided to pyro plate for correct minibatching)
adata.obs["_indices"] = np.arange(adata.n_obs)
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
6 changes: 3 additions & 3 deletions src/scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
from scvi.train._callbacks import SaveBestState
from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, VAEMixin
from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin

if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
@@ -40,7 +40,7 @@
logger = logging.getLogger(__name__)


class PEAKVI(ArchesMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
class PEAKVI(ArchesMixin, RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
"""Peak Variational Inference for chromatin accessilibity analysis :cite:p:`Ashuach22`.

Parameters
@@ -393,7 +393,7 @@ def get_accessibility_estimates(
generative_kwargs=generative_kwargs,
compute_loss=False,
)
p = generative_outputs["p"].cpu()
p = generative_outputs["px"].cpu()

if normalize_cells:
p *= inference_outputs["d"].cpu()
6 changes: 5 additions & 1 deletion src/scvi/model/base/_rnamixin.py
Original file line number Diff line number Diff line change
@@ -269,7 +269,11 @@ def get_normalized_expression(
generative_kwargs=generative_kwargs,
compute_loss=False,
)
exp_ = generative_outputs["px"].get_normalized(generative_output_key)
px_generative = generative_outputs["px"]
if isinstance(px_generative, torch.Tensor):
exp_ = px_generative
else:
exp_ = px_generative.get_normalized(generative_output_key)
exp_ = exp_[..., gene_mask]
exp_ *= scaling
per_batch_exprs.append(exp_[None].cpu())
2 changes: 2 additions & 0 deletions src/scvi/module/_autozivae.py
Original file line number Diff line number Diff line change
@@ -283,6 +283,7 @@ def generative(
cat_covs=None,
n_samples: int = 1,
eps_log: float = 1e-8,
transform_batch: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Run the generative model."""
outputs = super().generative(
@@ -293,6 +294,7 @@ def generative(
cat_covs=cat_covs,
y=y,
size_factor=size_factor,
transform_batch=transform_batch,
)
# Rescale dropout
rescaled_dropout = self.rescale_dropout(outputs["px"].zi_logits, eps_log=eps_log)
63 changes: 49 additions & 14 deletions src/scvi/module/_mrdeconv.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from collections import OrderedDict
from typing import Literal
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import torch
from torch.distributions import Normal

from scvi import REGISTRY_KEYS
from scvi.distributions import NegativeBinomial
from scvi.module._constants import MODULE_KEYS
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import FCLayers

if TYPE_CHECKING:
from collections import OrderedDict
from typing import Literal

import numpy as np
from torch.distributions import Distribution


def identity(x):
"""Identity function."""
@@ -179,24 +187,47 @@ def __init__(
torch.nn.Linear(n_hidden, n_labels + 1),
)

def _get_inference_input(self, tensors):
# we perform MAP here, so we just need to subsample the variables
return {}
def _get_inference_input(
self, tensors: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor | None]:
return {
MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY],
MODULE_KEYS.BATCH_INDEX_KEY: tensors.get(REGISTRY_KEYS.BATCH_KEY, None),
}

def _get_generative_input(self, tensors, inference_outputs):
def _get_generative_input(self, tensors, inference_outputs, transform_batch=None):
x = tensors[REGISTRY_KEYS.X_KEY]
ind_x = tensors[REGISTRY_KEYS.INDICES_KEY].long().ravel()

input_dict = {"x": x, "ind_x": ind_x}
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch

input_dict = {"x": x, "ind_x": ind_x, "batch_index": batch_index}
return input_dict

@auto_move_data
def inference(self):
"""Run the inference model."""
return {}
def inference(
self,
x: torch.Tensor,
batch_index: torch.Tensor | None = None,
n_samples: int = 1,
) -> dict[str, torch.Tensor | Distribution]:
"""High level inference method.

Runs the inference (encoder) model.
"""
encoder_input = [x]

z = self.V_encoder(*encoder_input)
# z = self.gamma_encoder(*encoder_input)

return {
MODULE_KEYS.Z_KEY: z,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't really make sense. Please remove VAEmixin from DestVI

}

@auto_move_data
def generative(self, x, ind_x):
def generative(self, x, ind_x, batch_index, transform_batch: torch.Tensor | None = None):
"""Build the deconvolution model for every cell in the minibatch."""
m = x.shape[0]
library = torch.sum(x, dim=1, keepdim=True)
@@ -206,6 +237,9 @@ def generative(self, x, ind_x):
x_ = torch.log(1 + x)
# subsample parameters

if transform_batch is not None:
batch_index = torch.ones_like(batch_index) * transform_batch

if self.amortization in ["both", "latent"]:
gamma_ind = torch.transpose(self.gamma_encoder(x_), 0, 1).reshape(
(self.n_latent, self.n_labels, -1)
@@ -245,10 +279,11 @@ def generative(self, x, ind_x):

return {
"px_o": self.px_o,
"px_rate": px_rate,
"px": px_rate,
"px_scale": px_scale,
"gamma": gamma_ind,
"v": v_ind,
"batch_index": batch_index,
}

def loss(
@@ -261,7 +296,7 @@ def loss(
):
"""Compute the loss."""
x = tensors[REGISTRY_KEYS.X_KEY]
px_rate = generative_outputs["px_rate"]
px_rate = generative_outputs["px"]
px_o = generative_outputs["px_o"]
gamma = generative_outputs["gamma"]
v = generative_outputs["v"]
Loading
Loading