Skip to content

Commit

Permalink
remove my_bootstrap
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyduran committed Nov 13, 2023
1 parent d076361 commit 6d81ece
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 108 deletions.
170 changes: 77 additions & 93 deletions R/cell_count_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ new_cell_count_model <- function(ccs,
pseudocount=0,
pln_min_ratio=0.001,
pln_num_penalties=30,
vhat_method = c("bootstrap", "variational_var", "jackknife", "my_bootstrap"),
vhat_method = c("bootstrap", "variational_var", "jackknife"),
covariance_type = c("spherical", "diagonal"),
num_bootstraps = 10,
inception = NULL,
Expand Down Expand Up @@ -459,34 +459,34 @@ new_cell_count_model <- function(ccs,

# Check the whitelist and blacklist for expected values.
assertthat::assert_that(is.null(whitelist) || (is.numeric(whitelist[[1]][[1]]) &&
range(whitelist[[1]])[[1]] >= 0 &&
range(whitelist[[1]])[[2]] <= ccs_num_cell_group) ||
(is.character(whitelist[[1]][[1]])))
range(whitelist[[1]])[[1]] >= 0 &&
range(whitelist[[1]])[[2]] <= ccs_num_cell_group) ||
(is.character(whitelist[[1]][[1]])))
if (!all(whitelist[[1]] %in% ccs_cell_group_names)){
message ("Warning: whitelist refers to cell groups missing from cell count set")
}

assertthat::assert_that(is.null(whitelist) || (is.numeric(whitelist[[2]][[1]]) &&
range(whitelist[[2]])[[1]] >= 0 &&
range(whitelist[[2]])[[2]] <= ccs_num_cell_group) ||
(is.character(whitelist[[2]][[1]])))
range(whitelist[[2]])[[1]] >= 0 &&
range(whitelist[[2]])[[2]] <= ccs_num_cell_group) ||
(is.character(whitelist[[2]][[1]])))
if (!all(whitelist[[2]] %in% ccs_cell_group_names)){
message ("Warning: whitelist refers to cell groups missing from cell count set")
}

assertthat::assert_that(is.null(blacklist) || (is.numeric(blacklist[[1]][[1]]) &&
range(blacklist[[1]])[[1]] >= 0 &&
range(blacklist[[1]])[[2]] <= ccs_num_cell_group) ||
(is.character(blacklist[[1]][[1]])))
range(blacklist[[1]])[[1]] >= 0 &&
range(blacklist[[1]])[[2]] <= ccs_num_cell_group) ||
(is.character(blacklist[[1]][[1]])))
if (!all(blacklist[[1]] %in% ccs_cell_group_names)){
message ("Warning: blacklist refers to cell groups missing from cell count set")
}


assertthat::assert_that(is.null(blacklist) || (is.numeric(blacklist[[2]][[1]]) &&
range(blacklist[[2]])[[1]] >= 0 &&
range(blacklist[[2]])[[2]] <= ccs_num_cell_group) ||
(is.character(blacklist[[2]][[1]])))
range(blacklist[[2]])[[1]] >= 0 &&
range(blacklist[[2]])[[2]] <= ccs_num_cell_group) ||
(is.character(blacklist[[2]][[1]])))
if (!all(blacklist[[2]] %in% ccs_cell_group_names)){
message ("Warning: blacklist refers to cell groups missing from cell count set")
}
Expand All @@ -504,7 +504,7 @@ new_cell_count_model <- function(ccs,
tryCatch(expr = ifelse(match.arg(vhat_method) == "", TRUE, TRUE),
error = function(e) FALSE),
msg = paste( 'Argument vhat_method must be one of "variational_var",',
'"jackknife","bootstrap", or "my_bootstrap".'))
'"jackknife", or "bootstrap".'))
vhat_method <- match.arg(vhat_method)

assertthat::assert_that(
Expand All @@ -525,28 +525,28 @@ new_cell_count_model <- function(ccs,
full_model_formula_str = paste("Abundance~", main_model_formula_str, "+", nuisance_model_formula_str, " + offset(log(Offset))")
# full_model_formula = as.formula(full_model_formula_str)
full_model_formula <- tryCatch(
{
as.formula(full_model_formula_str)
},
error = function(condition) {
message(paste('Bad full_model_formula string', full_model_formula_str), ': ', condition, '.')
},
warn = function(condition) {
message(paste('Bad full_model_formula string', full_model_formula_str), ': ', condition, '.')
})
{
as.formula(full_model_formula_str)
},
error = function(condition) {
message(paste('Bad full_model_formula string', full_model_formula_str), ': ', condition, '.')
},
warn = function(condition) {
message(paste('Bad full_model_formula string', full_model_formula_str), ': ', condition, '.')
})

reduced_model_formula_str = paste("Abundance~", nuisance_model_formula_str, " + offset(log(Offset))")
# reduced_model_formula = as.formula(reduced_model_formula_str)
# reduced_model_formula = as.formula(reduced_model_formula_str)
reduced_model_formula <- tryCatch(
{
as.formula(reduced_model_formula_str)
},
error = function(condition) {
message(paste('Bad reduced_model_formula string', reduced_model_formula_str), ': ', condition, '.')
},
warn = function(condition) {
message(paste('Bad reduced_model_formula string', reduced_model_formula_str), ': ', condition, '.')
})
{
as.formula(reduced_model_formula_str)
},
error = function(condition) {
message(paste('Bad reduced_model_formula string', reduced_model_formula_str), ': ', condition, '.')
},
warn = function(condition) {
message(paste('Bad reduced_model_formula string', reduced_model_formula_str), ': ', condition, '.')
})

#pln_data <- as.name(deparse(substitute(pln_data)))

Expand All @@ -555,23 +555,23 @@ new_cell_count_model <- function(ccs,
# arguments in the call to init_penalty_matrix() are unused there.
# This is so that the whitelist and blacklist penalties are applied
# to the user supplied penalty matrix.
if (is.null(penalty_matrix)){
if (penalize_by_distance){
initial_penalties = init_penalty_matrix(ccs,
whitelist=whitelist,
blacklist=blacklist,
base_penalty=base_penalty,
min_penalty=min_penalty,
max_penalty=max_penalty,
penalty_scale_exponent=penalty_scale_exponent,
reduction_method=reduction_method)
initial_penalties = initial_penalties[colnames(pln_data$Abundance), colnames(pln_data$Abundance)]
}else{
initial_penalties = NULL
}
if (is.null(penalty_matrix)){
if (penalize_by_distance){
initial_penalties = init_penalty_matrix(ccs,
whitelist=whitelist,
blacklist=blacklist,
base_penalty=base_penalty,
min_penalty=min_penalty,
max_penalty=max_penalty,
penalty_scale_exponent=penalty_scale_exponent,
reduction_method=reduction_method)
initial_penalties = initial_penalties[colnames(pln_data$Abundance), colnames(pln_data$Abundance)]
}else{
initial_penalties = penalty_matrix
initial_penalties = NULL
}
}else{
initial_penalties = penalty_matrix
}

# FIXME: This might only actually work when grouping cells by clusters and cluster names are
# integers. We should make sure this generalizes when making white/black lists of cell groups
Expand Down Expand Up @@ -630,9 +630,8 @@ new_cell_count_model <- function(ccs,
sandwich_var = FALSE
jackknife = FALSE
bootstrap = FALSE
my_bootstrap = FALSE

if (vhat_method == "variational_var" | vhat_method == "my_bootstrap") {
if (vhat_method == "variational_var") {
variational_var = TRUE
}else{
variational_var = FALSE # Don't compute the variational variance unless we have to, because it sometimes throws exceptions
Expand All @@ -651,9 +650,9 @@ new_cell_count_model <- function(ccs,
}


# bge (20221227): notes:
# o I am trying to track the code in the PLNmodels master branch at Github
# o I revert to the original because the PLNmodels changes break hooke.
# bge (20221227): notes:
# o I am trying to track the code in the PLNmodels master branch at Github
# o I revert to the original because the PLNmodels changes break hooke.
reduced_pln_model <- do.call(PLNmodels::PLNnetwork, args=list(reduced_model_formula_str,
data=pln_data,
control = PLNmodels::PLNnetwork_param(backend = backend,
Expand All @@ -671,29 +670,29 @@ new_cell_count_model <- function(ccs,
...),)

full_pln_model <- do.call(PLNmodels::PLN, args=list(full_model_formula_str,
data=pln_data,
control = PLNmodels::PLN_param(backend = backend,
covariance = covariance_type,
trace = ifelse(verbose, 2, 0),
config_post = list(jackknife = jackknife,
bootstrap = bootstrap,
variational_var = variational_var,
sandwich_var = sandwich_var,
rsquared = FALSE),
config_optim = control_optim_args),
...),)


# bge (20221227): notes:
# o the previous version of PLNmodels was PLNmodels * 0.11.7-9600 2022-11-29 [1] Github (PLN-team/PLNmodels@022d59d)
# full_pln_model <- do.call(PLNmodels::PLNnetwork, args=list(full_model_formula_str,
# data=pln_data,
# penalties = reduced_pln_model$penalties,
# control_init=list(min.ratio=pln_min_ratio,
# nPenalties=pln_num_penalties,
# penalty_weights=initial_penalties),
# control_main=list(trace = ifelse(verbose, 2, 0)),
# ...),)
data=pln_data,
control = PLNmodels::PLN_param(backend = backend,
covariance = covariance_type,
trace = ifelse(verbose, 2, 0),
config_post = list(jackknife = jackknife,
bootstrap = bootstrap,
variational_var = variational_var,
sandwich_var = sandwich_var,
rsquared = FALSE),
config_optim = control_optim_args),
...),)


# bge (20221227): notes:
# o the previous version of PLNmodels was PLNmodels * 0.11.7-9600 2022-11-29 [1] Github (PLN-team/PLNmodels@022d59d)
# full_pln_model <- do.call(PLNmodels::PLNnetwork, args=list(full_model_formula_str,
# data=pln_data,
# penalties = reduced_pln_model$penalties,
# control_init=list(min.ratio=pln_min_ratio,
# nPenalties=pln_num_penalties,
# penalty_weights=initial_penalties),
# control_main=list(trace = ifelse(verbose, 2, 0)),
# ...),)
}, finally = {
RhpcBLASctl::omp_set_num_threads(1)
RhpcBLASctl::blas_set_num_threads(1)
Expand All @@ -712,22 +711,7 @@ new_cell_count_model <- function(ccs,
# best_full_model <- PLNmodels::getModel(full_pln_model, var=best_reduced_model$penalty)
best_full_model <- full_pln_model

if (vhat_method == "my_bootstrap") {
vhat = bootstrap_vhat(ccs,
full_model_formula_str,
best_full_model,
best_reduced_model,
reduced_pln_model,
pseudocount,
initial_penalties,
pln_min_ratio,
pln_num_penalties,
verbose,
num_bootstraps,
backend,
covariance_type)

} else if (vhat_method == "jackknife" | vhat_method == "bootstrap") {
if (vhat_method == "jackknife" | vhat_method == "bootstrap") {

vhat_coef = coef(best_full_model, type = "main")
var_jack_mat = attributes(vhat_coef)[[paste0("vcov_", vhat_method)]]
Expand Down Expand Up @@ -762,7 +746,7 @@ new_cell_count_model <- function(ccs,
model_aux = SimpleList(model_frame=model_frame, xlevels=xlevels),
vhat = vhat,
vhat_method = vhat_method
)
)
#
# metadata(cds)$cds_version <- Biobase::package.version("monocle3")
# clusters <- stats::setNames(SimpleList(), character(0))
Expand Down
16 changes: 10 additions & 6 deletions examples/csg.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
library(monocle3)
library(hooke)
library(ggplot2)
library(splines)
library(tidyverse)

cds = readRDS("~/OneDrive/UW/Trapnell/hooke/examples/R_objects/all-geno_sensory-cranial-ganglion_neuron_29k_cds.RDS")
cds = readRDS("all-geno_sensory-cranial-ganglion_neuron_29k_cds.RDS")

ganglia_colors =
c("cranial ganglion progenitor" = "#A29ADE",
Expand Down Expand Up @@ -33,8 +36,7 @@ stop_time = 72
time_formula = build_interval_formula(ccs, num_breaks = 3, interval_start = 18, interval_stop = 72)

ccm = new_cell_count_model(ccs,
main_model_formula_str = paste0("perturbation +", time_formula),
nuissance_model_formula_str = "~ expt")
main_model_formula_str = paste0("perturbation +", time_formula))

# predict for 48 hpf
cond_wt = estimate_abundances(ccm, tibble(timepoint = 48, perturbation = "control"))
Expand All @@ -45,8 +47,8 @@ wt_v_phox2a_tbl = compare_abundances(ccm, cond_wt, cond_phox2a)
wt_v_foxi1_tbl = compare_abundances(ccm, cond_wt, cond_foxi1)


plot_contrast(ccm, wt_v_phox2a_tbl, x=1, y=3, q_value_threshold = 0.05)
plot_contrast(ccm, wt_v_foxi1_tbl, x=1, y=3, q_value_threshold = 0.05)
plot_contrast(ccm, wt_v_phox2a_tbl, x=1, y=3, q_value_thresh = 0.05)
plot_contrast(ccm, wt_v_foxi1_tbl, x=1, y=3, q_value_thresh = 0.05)



Expand Down Expand Up @@ -79,6 +81,8 @@ wt_expt_ccm = new_cell_count_model(wt_ccs,
main_model_formula_str = "ns(timepoint, df=3)",
nuisance_model_formula_str = "~ expt")


batches = data.frame(batch = unique(colData(wt_ccs)$expt))
batches = batches %>% mutate(tp_preds = purrr::map(.f = function(batch) {
estimate_abundances_over_interval(wt_expt_ccm,
start_time,
Expand Down Expand Up @@ -113,7 +117,7 @@ stop_time = 72
time_formula = build_interval_formula(foxi1_ccs, num_breaks = 3, interval_start = 18, interval_stop = 72)

foxi1_ccm = new_cell_count_model(foxi1_ccs,
main_model_formula_str = past0("perturbation" + time_formula))
main_model_formula_str = paste0("perturbation +", time_formula))

wt_timepoint_pred_df = estimate_abundances_over_interval(foxi1_ccm,
interval_start=start_time,
Expand Down
Loading

0 comments on commit 6d81ece

Please sign in to comment.