From 6d81ece556bcda465a5b32b2eaee0b370519c742 Mon Sep 17 00:00:00 2001 From: Madeleine Duran Date: Sun, 12 Nov 2023 20:22:00 -0800 Subject: [PATCH] remove my_bootstrap --- R/cell_count_model.R | 170 ++++++++++++++++++++----------------------- examples/csg.R | 16 ++-- examples/silicosis.R | 21 +++--- 3 files changed, 99 insertions(+), 108 deletions(-) diff --git a/R/cell_count_model.R b/R/cell_count_model.R index e93088f..ddd56b6 100644 --- a/R/cell_count_model.R +++ b/R/cell_count_model.R @@ -416,7 +416,7 @@ new_cell_count_model <- function(ccs, pseudocount=0, pln_min_ratio=0.001, pln_num_penalties=30, - vhat_method = c("bootstrap", "variational_var", "jackknife", "my_bootstrap"), + vhat_method = c("bootstrap", "variational_var", "jackknife"), covariance_type = c("spherical", "diagonal"), num_bootstraps = 10, inception = NULL, @@ -459,34 +459,34 @@ new_cell_count_model <- function(ccs, # Check the whitelist and blacklist for expected values. assertthat::assert_that(is.null(whitelist) || (is.numeric(whitelist[[1]][[1]]) && - range(whitelist[[1]])[[1]] >= 0 && - range(whitelist[[1]])[[2]] <= ccs_num_cell_group) || - (is.character(whitelist[[1]][[1]]))) + range(whitelist[[1]])[[1]] >= 0 && + range(whitelist[[1]])[[2]] <= ccs_num_cell_group) || + (is.character(whitelist[[1]][[1]]))) if (!all(whitelist[[1]] %in% ccs_cell_group_names)){ message ("Warning: whitelist refers to cell groups missing from cell count set") } assertthat::assert_that(is.null(whitelist) || (is.numeric(whitelist[[2]][[1]]) && - range(whitelist[[2]])[[1]] >= 0 && - range(whitelist[[2]])[[2]] <= ccs_num_cell_group) || - (is.character(whitelist[[2]][[1]]))) + range(whitelist[[2]])[[1]] >= 0 && + range(whitelist[[2]])[[2]] <= ccs_num_cell_group) || + (is.character(whitelist[[2]][[1]]))) if (!all(whitelist[[2]] %in% ccs_cell_group_names)){ message ("Warning: whitelist refers to cell groups missing from cell count set") } assertthat::assert_that(is.null(blacklist) || (is.numeric(blacklist[[1]][[1]]) && - range(blacklist[[1]])[[1]] >= 0 && - range(blacklist[[1]])[[2]] <= ccs_num_cell_group) || - (is.character(blacklist[[1]][[1]]))) + range(blacklist[[1]])[[1]] >= 0 && + range(blacklist[[1]])[[2]] <= ccs_num_cell_group) || + (is.character(blacklist[[1]][[1]]))) if (!all(blacklist[[1]] %in% ccs_cell_group_names)){ message ("Warning: blacklist refers to cell groups missing from cell count set") } assertthat::assert_that(is.null(blacklist) || (is.numeric(blacklist[[2]][[1]]) && - range(blacklist[[2]])[[1]] >= 0 && - range(blacklist[[2]])[[2]] <= ccs_num_cell_group) || - (is.character(blacklist[[2]][[1]]))) + range(blacklist[[2]])[[1]] >= 0 && + range(blacklist[[2]])[[2]] <= ccs_num_cell_group) || + (is.character(blacklist[[2]][[1]]))) if (!all(blacklist[[2]] %in% ccs_cell_group_names)){ message ("Warning: blacklist refers to cell groups missing from cell count set") } @@ -504,7 +504,7 @@ new_cell_count_model <- function(ccs, tryCatch(expr = ifelse(match.arg(vhat_method) == "", TRUE, TRUE), error = function(e) FALSE), msg = paste( 'Argument vhat_method must be one of "variational_var",', - '"jackknife","bootstrap", or "my_bootstrap".')) + '"jackknife", or "bootstrap".')) vhat_method <- match.arg(vhat_method) assertthat::assert_that( @@ -525,28 +525,28 @@ new_cell_count_model <- function(ccs, full_model_formula_str = paste("Abundance~", main_model_formula_str, "+", nuisance_model_formula_str, " + offset(log(Offset))") # full_model_formula = as.formula(full_model_formula_str) full_model_formula <- tryCatch( - { - as.formula(full_model_formula_str) - }, - error = function(condition) { - message(paste('Bad full_model_formula string', full_model_formula_str), ': ', condition, '.') - }, - warn = function(condition) { - message(paste('Bad full_model_formula string', full_model_formula_str), ': ', condition, '.') - }) + { + as.formula(full_model_formula_str) + }, + error = function(condition) { + message(paste('Bad full_model_formula string', full_model_formula_str), ': ', condition, '.') + }, + warn = function(condition) { + message(paste('Bad full_model_formula string', full_model_formula_str), ': ', condition, '.') + }) reduced_model_formula_str = paste("Abundance~", nuisance_model_formula_str, " + offset(log(Offset))") -# reduced_model_formula = as.formula(reduced_model_formula_str) + # reduced_model_formula = as.formula(reduced_model_formula_str) reduced_model_formula <- tryCatch( - { - as.formula(reduced_model_formula_str) - }, - error = function(condition) { - message(paste('Bad reduced_model_formula string', reduced_model_formula_str), ': ', condition, '.') - }, - warn = function(condition) { - message(paste('Bad reduced_model_formula string', reduced_model_formula_str), ': ', condition, '.') - }) + { + as.formula(reduced_model_formula_str) + }, + error = function(condition) { + message(paste('Bad reduced_model_formula string', reduced_model_formula_str), ': ', condition, '.') + }, + warn = function(condition) { + message(paste('Bad reduced_model_formula string', reduced_model_formula_str), ': ', condition, '.') + }) #pln_data <- as.name(deparse(substitute(pln_data))) @@ -555,23 +555,23 @@ new_cell_count_model <- function(ccs, # arguments in the call to init_penalty_matrix() are unused there. # This is so that the whitelist and blacklist penalties are applied # to the user supplied penalty matrix. - if (is.null(penalty_matrix)){ - if (penalize_by_distance){ - initial_penalties = init_penalty_matrix(ccs, - whitelist=whitelist, - blacklist=blacklist, - base_penalty=base_penalty, - min_penalty=min_penalty, - max_penalty=max_penalty, - penalty_scale_exponent=penalty_scale_exponent, - reduction_method=reduction_method) - initial_penalties = initial_penalties[colnames(pln_data$Abundance), colnames(pln_data$Abundance)] - }else{ - initial_penalties = NULL - } + if (is.null(penalty_matrix)){ + if (penalize_by_distance){ + initial_penalties = init_penalty_matrix(ccs, + whitelist=whitelist, + blacklist=blacklist, + base_penalty=base_penalty, + min_penalty=min_penalty, + max_penalty=max_penalty, + penalty_scale_exponent=penalty_scale_exponent, + reduction_method=reduction_method) + initial_penalties = initial_penalties[colnames(pln_data$Abundance), colnames(pln_data$Abundance)] }else{ - initial_penalties = penalty_matrix + initial_penalties = NULL } + }else{ + initial_penalties = penalty_matrix + } # FIXME: This might only actually work when grouping cells by clusters and cluster names are # integers. We should make sure this generalizes when making white/black lists of cell groups @@ -630,9 +630,8 @@ new_cell_count_model <- function(ccs, sandwich_var = FALSE jackknife = FALSE bootstrap = FALSE - my_bootstrap = FALSE - if (vhat_method == "variational_var" | vhat_method == "my_bootstrap") { + if (vhat_method == "variational_var") { variational_var = TRUE }else{ variational_var = FALSE # Don't compute the variational variance unless we have to, because it sometimes throws exceptions @@ -651,9 +650,9 @@ new_cell_count_model <- function(ccs, } -# bge (20221227): notes: -# o I am trying to track the code in the PLNmodels master branch at Github -# o I revert to the original because the PLNmodels changes break hooke. + # bge (20221227): notes: + # o I am trying to track the code in the PLNmodels master branch at Github + # o I revert to the original because the PLNmodels changes break hooke. reduced_pln_model <- do.call(PLNmodels::PLNnetwork, args=list(reduced_model_formula_str, data=pln_data, control = PLNmodels::PLNnetwork_param(backend = backend, @@ -671,29 +670,29 @@ new_cell_count_model <- function(ccs, ...),) full_pln_model <- do.call(PLNmodels::PLN, args=list(full_model_formula_str, - data=pln_data, - control = PLNmodels::PLN_param(backend = backend, - covariance = covariance_type, - trace = ifelse(verbose, 2, 0), - config_post = list(jackknife = jackknife, - bootstrap = bootstrap, - variational_var = variational_var, - sandwich_var = sandwich_var, - rsquared = FALSE), - config_optim = control_optim_args), - ...),) - - -# bge (20221227): notes: -# o the previous version of PLNmodels was PLNmodels * 0.11.7-9600 2022-11-29 [1] Github (PLN-team/PLNmodels@022d59d) -# full_pln_model <- do.call(PLNmodels::PLNnetwork, args=list(full_model_formula_str, -# data=pln_data, -# penalties = reduced_pln_model$penalties, -# control_init=list(min.ratio=pln_min_ratio, -# nPenalties=pln_num_penalties, -# penalty_weights=initial_penalties), -# control_main=list(trace = ifelse(verbose, 2, 0)), -# ...),) + data=pln_data, + control = PLNmodels::PLN_param(backend = backend, + covariance = covariance_type, + trace = ifelse(verbose, 2, 0), + config_post = list(jackknife = jackknife, + bootstrap = bootstrap, + variational_var = variational_var, + sandwich_var = sandwich_var, + rsquared = FALSE), + config_optim = control_optim_args), + ...),) + + + # bge (20221227): notes: + # o the previous version of PLNmodels was PLNmodels * 0.11.7-9600 2022-11-29 [1] Github (PLN-team/PLNmodels@022d59d) + # full_pln_model <- do.call(PLNmodels::PLNnetwork, args=list(full_model_formula_str, + # data=pln_data, + # penalties = reduced_pln_model$penalties, + # control_init=list(min.ratio=pln_min_ratio, + # nPenalties=pln_num_penalties, + # penalty_weights=initial_penalties), + # control_main=list(trace = ifelse(verbose, 2, 0)), + # ...),) }, finally = { RhpcBLASctl::omp_set_num_threads(1) RhpcBLASctl::blas_set_num_threads(1) @@ -712,22 +711,7 @@ new_cell_count_model <- function(ccs, # best_full_model <- PLNmodels::getModel(full_pln_model, var=best_reduced_model$penalty) best_full_model <- full_pln_model - if (vhat_method == "my_bootstrap") { - vhat = bootstrap_vhat(ccs, - full_model_formula_str, - best_full_model, - best_reduced_model, - reduced_pln_model, - pseudocount, - initial_penalties, - pln_min_ratio, - pln_num_penalties, - verbose, - num_bootstraps, - backend, - covariance_type) - - } else if (vhat_method == "jackknife" | vhat_method == "bootstrap") { + if (vhat_method == "jackknife" | vhat_method == "bootstrap") { vhat_coef = coef(best_full_model, type = "main") var_jack_mat = attributes(vhat_coef)[[paste0("vcov_", vhat_method)]] @@ -762,7 +746,7 @@ new_cell_count_model <- function(ccs, model_aux = SimpleList(model_frame=model_frame, xlevels=xlevels), vhat = vhat, vhat_method = vhat_method - ) + ) # # metadata(cds)$cds_version <- Biobase::package.version("monocle3") # clusters <- stats::setNames(SimpleList(), character(0)) diff --git a/examples/csg.R b/examples/csg.R index 2deb1e2..e0bd6ae 100644 --- a/examples/csg.R +++ b/examples/csg.R @@ -1,7 +1,10 @@ library(monocle3) library(hooke) +library(ggplot2) +library(splines) +library(tidyverse) -cds = readRDS("~/OneDrive/UW/Trapnell/hooke/examples/R_objects/all-geno_sensory-cranial-ganglion_neuron_29k_cds.RDS") +cds = readRDS("all-geno_sensory-cranial-ganglion_neuron_29k_cds.RDS") ganglia_colors = c("cranial ganglion progenitor" = "#A29ADE", @@ -33,8 +36,7 @@ stop_time = 72 time_formula = build_interval_formula(ccs, num_breaks = 3, interval_start = 18, interval_stop = 72) ccm = new_cell_count_model(ccs, - main_model_formula_str = paste0("perturbation +", time_formula), - nuissance_model_formula_str = "~ expt") + main_model_formula_str = paste0("perturbation +", time_formula)) # predict for 48 hpf cond_wt = estimate_abundances(ccm, tibble(timepoint = 48, perturbation = "control")) @@ -45,8 +47,8 @@ wt_v_phox2a_tbl = compare_abundances(ccm, cond_wt, cond_phox2a) wt_v_foxi1_tbl = compare_abundances(ccm, cond_wt, cond_foxi1) -plot_contrast(ccm, wt_v_phox2a_tbl, x=1, y=3, q_value_threshold = 0.05) -plot_contrast(ccm, wt_v_foxi1_tbl, x=1, y=3, q_value_threshold = 0.05) +plot_contrast(ccm, wt_v_phox2a_tbl, x=1, y=3, q_value_thresh = 0.05) +plot_contrast(ccm, wt_v_foxi1_tbl, x=1, y=3, q_value_thresh = 0.05) @@ -79,6 +81,8 @@ wt_expt_ccm = new_cell_count_model(wt_ccs, main_model_formula_str = "ns(timepoint, df=3)", nuisance_model_formula_str = "~ expt") + +batches = data.frame(batch = unique(colData(wt_ccs)$expt)) batches = batches %>% mutate(tp_preds = purrr::map(.f = function(batch) { estimate_abundances_over_interval(wt_expt_ccm, start_time, @@ -113,7 +117,7 @@ stop_time = 72 time_formula = build_interval_formula(foxi1_ccs, num_breaks = 3, interval_start = 18, interval_stop = 72) foxi1_ccm = new_cell_count_model(foxi1_ccs, - main_model_formula_str = past0("perturbation" + time_formula)) + main_model_formula_str = paste0("perturbation +", time_formula)) wt_timepoint_pred_df = estimate_abundances_over_interval(foxi1_ccm, interval_start=start_time, diff --git a/examples/silicosis.R b/examples/silicosis.R index e231b21..e93e294 100644 --- a/examples/silicosis.R +++ b/examples/silicosis.R @@ -1,9 +1,12 @@ library(monocle3) library(hooke) +library(ggplot2) +library(splines) +library(tidyverse) +cds = readRDS("~/OneDrive/UW/Trapnell/hooke_manuscript/main_figures_v1/R_objects/silicosis_cds.rds") +cds = readRDS("silicosis_cds.rds") -cds = readRDS("silicosis_cds.cds") - # for simplicity, we are lumping together pre and post i.t. silica colData(cds)$exposed = ifelse(colData(cds)$Timepoint == 0, "not exposed", "exposed") colData(cds)$Rep = as.factor(colData(cds)$Rep) @@ -23,20 +26,20 @@ cond_not_exp = estimate_abundances(ccm, tibble::tibble(exposed = "not exposed")) cond_ne_v_e_tbl = compare_abundances(ccm, cond_not_exp, cond_exp) -cond_ne_v_e_tbl %>% select(cell_group, perturbation_x, perturbation_y, +cond_ne_v_e_tbl %>% select(cell_group, exposed_x, exposed_y, delta_log_abund, delta_log_abund_se, delta_q_value) -plot_contrast(ccm, cond_ne_v_e_tbl, q_value_threshold = 0.05) +plot_contrast(ccm, cond_ne_v_e_tbl, q_value_thresh = 0.05) # controlling for batch -ccm = new_cell_count_model(ccs, +ccm_rep = new_cell_count_model(ccs, main_model_formula_str = "~ exposed", nuisance_model_formula_str = "~ Rep") -cond_exp = estimate_abundances(ccm, tibble::tibble(exposed = "exposed", Rep = "3")) -cond_not_exp = estimate_abundances(ccm, tibble::tibble(exposed = "not exposed", Rep = "3")) -cond_ne_v_e_tbl = compare_abundances(ccm, cond_not_exp, cond_exp) +cond_exp_rep = estimate_abundances(ccm_rep, tibble::tibble(exposed = "exposed", Rep = "1")) +cond_not_exp_rep = estimate_abundances(ccm_rep, tibble::tibble(exposed = "not exposed", Rep = "1")) +cond_ne_v_e_tbl_rep = compare_abundances(ccm_rep, cond_not_exp_rep, cond_exp_rep) -plot_contrast(ccm, cond_ne_v_e_tbl, q_value_thresh = 0.05) \ No newline at end of file +plot_contrast(ccm_rep, cond_ne_v_e_tbl_rep, q_value_thresh = 0.05)