From bf353f129b30d4eba6a01ce0f4193ea66e2220dd Mon Sep 17 00:00:00 2001 From: Yue Shi Lai Date: Sat, 11 Jan 2025 23:30:12 -0800 Subject: [PATCH 1/3] Implement DDIM with the "trailing" timestep spacing --- denoiser.hpp | 168 ++++++++++++++++++++++++++++++++++++++++++ examples/cli/main.cpp | 3 +- stable-diffusion.cpp | 1 + stable-diffusion.h | 1 + 4 files changed, 172 insertions(+), 1 deletion(-) diff --git a/denoiser.hpp b/denoiser.hpp index 975699d2..380fd76b 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -1005,6 +1005,174 @@ static void sample_k_diffusion(sample_method_t method, } } } break; + case DDIM_TRAILING: // Denoising Diffusion Implicit Models + // with the "trailing" timestep spacing + { + // DDIM itself needs alphas_cumprod (DDPM, Ho et al., + // arXiv:2006.11239 [cs.LG] with k-diffusion's start and + // end beta) (which unfortunately k-diffusion's data + // structure hides from the denoiser), and the sigmas are + // also needed to invert the behavior of CompVisDenoiser + // (k-diffusion's LMSDiscreteScheduler) + std::vector alphas_cumprod; + std::vector compvis_sigmas; + + alphas_cumprod.reserve(TIMESTEPS); + compvis_sigmas.reserve(TIMESTEPS); + for (int i = 0; i < TIMESTEPS; i++) { + alphas_cumprod[i] = + (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * + (1.0f - + std::pow(sqrtf(0.00085f) + + (sqrtf(0.0120f) - sqrtf(0.00085f)) * + ((float)i / (TIMESTEPS - 1)), 2)); + compvis_sigmas[i] = + std::sqrt((1 - alphas_cumprod[i]) / + alphas_cumprod[i]); + } + for (int i = 0; i < steps; i++) { + // The "trailing" DDIM timestep, see S. Lin et al., + // "Common Diffusion Noise Schedules and Sample Steps + // are Flawed", arXiv:2305.08891 [cs], p. 4, Table + // 2. Most variables below follow Diffusers naming. + int timestep = + roundf(TIMESTEPS - + i * ((float)TIMESTEPS / steps)) - 1; + int prev_timestep = timestep - TIMESTEPS / steps; + // The sigma here is chosen to cause the + // CompVisDenoiser to produce t = timestep + float sigma = compvis_sigmas[timestep]; + if (i == 0) { + // The function add_noise intializes x to + // Diffusers' latents * sigma (as in Diffusers' + // pipeline) or sample * sigma (Diffusers' + // scheduler), where this sigma = init_noise_sigma + // in Diffusers. For DDPM and DDIM however, + // init_noise_sigma = 1. But the k-diffusion + // model() also evaluates F_theta(c_in(sigma) x; + // ...) instead of the bare U-net F_theta, with + // c_in = 1 / sqrt(sigma^2 + 1), as defined in + // T. Karras et al., "Elucidating the Design Space + // of Diffusion-Based Generative Models", + // arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence + // the first call has to be prescaled as x <- x / + // (c_in * sigma) with the k-diffusion pipeline + // and CompVisDenoiser. + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1) / + sigma; + } + } + else { + // For the subsequent steps after the first one, + // at this point x = latents (pipeline) or x = + // sample (scheduler), and needs to be prescaled + // with x <- latents / c_in to compensate for + // model() applying the scale c_in before the + // U-net F_theta + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1); + } + } + // Note model() is the D(x, sigma) as defined in + // T. Karras et al., arXiv:2206.00364, p. 3, Table 1 + // and p. 8 (7) + struct ggml_tensor* noise_pred = + model(x, sigma, i + 1); + // Here noise_pred is still the k-diffusion denoiser + // output, not the U-net output F_theta(c_in(sigma) x; + // ...) in Karras et al. (2022), whereas Diffusers' + // noise_pred is F_theta(...). Recover the actual + // noise_pred, which is also referred to as the + // "Karras ODE derivative" d or d_cur in several + // samplers above. + { + float* vec_x = (float*)x->data; + float* vec_noise_pred = (float*)noise_pred->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_noise_pred[j] = + (vec_x[j] - vec_noise_pred[j]) * + (1 / sigma); + } + } + // 2. compute alphas, betas + float alpha_prod_t = alphas_cumprod[timestep]; + // Note final_alpha_cumprod = alphas_cumprod[0] + float alpha_prod_t_prev = prev_timestep >= 0 ? + alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float beta_prod_t = 1 - alpha_prod_t; + // 3. compute predicted original sample from predicted + // noise also called "predicted x_0" of formula (12) + // from https://arxiv.org/pdf/2010.02502.pdf + struct ggml_tensor* pred_original_sample = + ggml_dup_tensor(work_ctx, x); + { + float* vec_x = (float*)x->data; + float* vec_noise_pred = (float*)noise_pred->data; + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + // Note the substitution of latents or sample = x + // * c_in = x / sqrt(sigma^2 + 1) + for (int j = 0; j < ggml_nelements(x); j++) { + vec_pred_original_sample[j] = + (vec_x[j] / std::sqrt(sigma * sigma + 1) - + std::sqrt(beta_prod_t) * + vec_noise_pred[j]) * + (1 / std::sqrt(alpha_prod_t)); + } + } + // Assuming the "epsilon" prediction type, where below + // pred_epsilon = noise_pred is inserted, and is not + // defined/copied explicitly. + // + // 5. compute variance: "sigma_t(eta)" -> see formula + // (16) + // + // sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) * + // sqrt(1 - alpha_t/alpha_t-1) + float beta_prod_t_prev = 1 - alpha_prod_t_prev; + float variance = (beta_prod_t_prev / beta_prod_t) * + (1 - alpha_prod_t / alpha_prod_t_prev); + float std_dev_t = 0 * std::sqrt(variance); + // 6. compute "direction pointing to x_t" of formula + // (12) from https://arxiv.org/pdf/2010.02502.pdf + struct ggml_tensor* pred_sample_direction = + ggml_dup_tensor(work_ctx, noise_pred); + { + float* vec_noise_pred = (float*)noise_pred->data; + float* vec_pred_sample_direction = + (float*)pred_sample_direction->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_pred_sample_direction[j] = + std::sqrt(1 - alpha_prod_t_prev - + std::pow(std_dev_t, 2)) * + vec_noise_pred[j]; + } + } + // 7. compute x_t without "random noise" of formula + // (12) from https://arxiv.org/pdf/2010.02502.pdf + { + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + float* vec_pred_sample_direction = + (float*)pred_sample_direction->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] = std::sqrt(alpha_prod_t_prev) * + vec_pred_original_sample[j] + + vec_pred_sample_direction[j]; + } + } + // See the note above: x = latents or sample here, and + // is not scaled by the c_in. For the final output + // this is correct, but for subsequent iterations, x + // needs to be prescaled again, since k-diffusion's + // model() differes from the bare U-net F_theta by the + // factor c_in. + } + } break; default: LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 3c35b103..7e52b3ed 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -39,6 +39,7 @@ const char* sample_method_str[] = { "ipndm", "ipndm_v", "lcm", + "ddim_trailing", }; // Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h @@ -219,7 +220,7 @@ void print_usage(int argc, const char* argv[]) { printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n"); - printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}\n"); + printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing}\n"); printf(" sampling method (default: \"euler_a\")\n"); printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index e2daf576..8046908f 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -47,6 +47,7 @@ const char* sampling_methods_str[] = { "iPNDM", "iPNDM_v", "LCM", + "DDIM \"trailing\"" }; /*================================================== Helper Functions ================================================*/ diff --git a/stable-diffusion.h b/stable-diffusion.h index 5a758df6..94a51291 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -44,6 +44,7 @@ enum sample_method_t { IPNDM, IPNDM_V, LCM, + DDIM_TRAILING, N_SAMPLE_METHODS }; From 8086045f919ca6d0ee94ef2d6907ef4aadc3b7fa Mon Sep 17 00:00:00 2001 From: Yue Shi Lai Date: Tue, 14 Jan 2025 01:50:56 -0800 Subject: [PATCH 2/3] Implement TCD, avoid repeated allocation in DDIM, implement eta parameter for DDIM and TCD, minor comment clarification --- denoiser.hpp | 278 +++++++++++++++++++++++++++++++++++------- examples/cli/main.cpp | 13 ++ stable-diffusion.cpp | 13 +- stable-diffusion.h | 3 + 4 files changed, 263 insertions(+), 44 deletions(-) diff --git a/denoiser.hpp b/denoiser.hpp index 380fd76b..619bcd36 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -474,7 +474,8 @@ static void sample_k_diffusion(sample_method_t method, ggml_context* work_ctx, ggml_tensor* x, std::vector sigmas, - std::shared_ptr rng) { + std::shared_ptr rng, + float eta) { size_t steps = sigmas.size() - 1; // sample_euler_ancestral switch (method) { @@ -1014,6 +1015,8 @@ static void sample_k_diffusion(sample_method_t method, // structure hides from the denoiser), and the sigmas are // also needed to invert the behavior of CompVisDenoiser // (k-diffusion's LMSDiscreteScheduler) + float beta_start = 0.00085f; + float beta_end = 0.0120f; std::vector alphas_cumprod; std::vector compvis_sigmas; @@ -1023,21 +1026,41 @@ static void sample_k_diffusion(sample_method_t method, alphas_cumprod[i] = (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * (1.0f - - std::pow(sqrtf(0.00085f) + - (sqrtf(0.0120f) - sqrtf(0.00085f)) * + std::pow(sqrtf(beta_start) + + (sqrtf(beta_end) - sqrtf(beta_start)) * ((float)i / (TIMESTEPS - 1)), 2)); compvis_sigmas[i] = std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); } + + struct ggml_tensor* pred_original_sample = + ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* variance_noise = + ggml_dup_tensor(work_ctx, x); + for (int i = 0; i < steps; i++) { // The "trailing" DDIM timestep, see S. Lin et al., // "Common Diffusion Noise Schedules and Sample Steps // are Flawed", arXiv:2305.08891 [cs], p. 4, Table - // 2. Most variables below follow Diffusers naming. + // 2. Most variables below follow Diffusers naming + // + // Diffuser naming vs. J. Song et al., "Denoising + // Diffusion Implicit Models", arXiv:2010.02502, p. 5, + // (12) and p. 16, (16) ( -> ): + // + // - pred_noise_t -> epsilon_theta^(t)(x_t) + // - pred_original_sample -> f_theta^(t)(x_t) or x_0 + // - std_dev_t -> sigma_t (not the LMS sigma) + // - eta -> eta (set to 0 at the moment) + // - pred_sample_direction -> "direction pointing to + // x_t" + // - pred_prev_sample -> "x_t-1" int timestep = roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps)) - 1; + // 1. get previous step value (=t-1) int prev_timestep = timestep - TIMESTEPS / steps; // The sigma here is chosen to cause the // CompVisDenoiser to produce t = timestep @@ -1066,51 +1089,53 @@ static void sample_k_diffusion(sample_method_t method, } else { // For the subsequent steps after the first one, - // at this point x = latents (pipeline) or x = - // sample (scheduler), and needs to be prescaled - // with x <- latents / c_in to compensate for - // model() applying the scale c_in before the - // U-net F_theta + // at this point x = latents or x = sample, and + // needs to be prescaled with x <- sample / c_in + // to compensate for model() applying the scale + // c_in before the U-net F_theta float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { vec_x[j] *= std::sqrt(sigma * sigma + 1); } } - // Note model() is the D(x, sigma) as defined in - // T. Karras et al., arXiv:2206.00364, p. 3, Table 1 - // and p. 8 (7) - struct ggml_tensor* noise_pred = + // Note (also noise_pred in Diffuser's pipeline) + // model_output = model() is the D(x, sigma) as + // defined in T. Karras et al., arXiv:2206.00364, + // p. 3, Table 1 and p. 8 (7), compare also p. 38 + // (226) therein. + struct ggml_tensor* model_output = model(x, sigma, i + 1); - // Here noise_pred is still the k-diffusion denoiser + // Here model_output is still the k-diffusion denoiser // output, not the U-net output F_theta(c_in(sigma) x; // ...) in Karras et al. (2022), whereas Diffusers' - // noise_pred is F_theta(...). Recover the actual - // noise_pred, which is also referred to as the + // model_output is F_theta(...). Recover the actual + // model_output, which is also referred to as the // "Karras ODE derivative" d or d_cur in several // samplers above. { float* vec_x = (float*)x->data; - float* vec_noise_pred = (float*)noise_pred->data; + float* vec_model_output = + (float*)model_output->data; for (int j = 0; j < ggml_nelements(x); j++) { - vec_noise_pred[j] = - (vec_x[j] - vec_noise_pred[j]) * + vec_model_output[j] = + (vec_x[j] - vec_model_output[j]) * (1 / sigma); } } // 2. compute alphas, betas float alpha_prod_t = alphas_cumprod[timestep]; - // Note final_alpha_cumprod = alphas_cumprod[0] + // Note final_alpha_cumprod = alphas_cumprod[0] due to + // trailing timestep spacing float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; float beta_prod_t = 1 - alpha_prod_t; // 3. compute predicted original sample from predicted // noise also called "predicted x_0" of formula (12) // from https://arxiv.org/pdf/2010.02502.pdf - struct ggml_tensor* pred_original_sample = - ggml_dup_tensor(work_ctx, x); { float* vec_x = (float*)x->data; - float* vec_noise_pred = (float*)noise_pred->data; + float* vec_model_output = + (float*)model_output->data; float* vec_pred_original_sample = (float*)pred_original_sample->data; // Note the substitution of latents or sample = x @@ -1119,12 +1144,12 @@ static void sample_k_diffusion(sample_method_t method, vec_pred_original_sample[j] = (vec_x[j] / std::sqrt(sigma * sigma + 1) - std::sqrt(beta_prod_t) * - vec_noise_pred[j]) * + vec_model_output[j]) * (1 / std::sqrt(alpha_prod_t)); } } // Assuming the "epsilon" prediction type, where below - // pred_epsilon = noise_pred is inserted, and is not + // pred_epsilon = model_output is inserted, and is not // defined/copied explicitly. // // 5. compute variance: "sigma_t(eta)" -> see formula @@ -1135,34 +1160,35 @@ static void sample_k_diffusion(sample_method_t method, float beta_prod_t_prev = 1 - alpha_prod_t_prev; float variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev); - float std_dev_t = 0 * std::sqrt(variance); + float std_dev_t = eta * std::sqrt(variance); // 6. compute "direction pointing to x_t" of formula // (12) from https://arxiv.org/pdf/2010.02502.pdf - struct ggml_tensor* pred_sample_direction = - ggml_dup_tensor(work_ctx, noise_pred); - { - float* vec_noise_pred = (float*)noise_pred->data; - float* vec_pred_sample_direction = - (float*)pred_sample_direction->data; - for (int j = 0; j < ggml_nelements(x); j++) { - vec_pred_sample_direction[j] = - std::sqrt(1 - alpha_prod_t_prev - - std::pow(std_dev_t, 2)) * - vec_noise_pred[j]; - } - } // 7. compute x_t without "random noise" of formula // (12) from https://arxiv.org/pdf/2010.02502.pdf { + float* vec_model_output = (float*)model_output->data; float* vec_pred_original_sample = (float*)pred_original_sample->data; - float* vec_pred_sample_direction = - (float*)pred_sample_direction->data; float* vec_x = (float*)x->data; for (int j = 0; j < ggml_nelements(x); j++) { + // Two step inner loop without an explicit + // tensor + float pred_sample_direction = + std::sqrt(1 - alpha_prod_t_prev - + std::pow(std_dev_t, 2)) * + vec_model_output[j]; vec_x[j] = std::sqrt(alpha_prod_t_prev) * vec_pred_original_sample[j] + - vec_pred_sample_direction[j]; + pred_sample_direction; + } + } + if (eta > 0) { + ggml_tensor_set_f32_randn(variance_noise, rng); + float* vec_variance_noise = + (float*)variance_noise->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] += std_dev_t * vec_variance_noise[j]; } } // See the note above: x = latents or sample here, and @@ -1173,6 +1199,174 @@ static void sample_k_diffusion(sample_method_t method, // factor c_in. } } break; + case TCD: // Strategic Stochastic Sampling (Algorithm 4) in + // Trajectory Consistency Distillation + { + float beta_start = 0.00085f; + float beta_end = 0.0120f; + std::vector alphas_cumprod; + std::vector compvis_sigmas; + + alphas_cumprod.reserve(TIMESTEPS); + compvis_sigmas.reserve(TIMESTEPS); + for (int i = 0; i < TIMESTEPS; i++) { + alphas_cumprod[i] = + (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * + (1.0f - + std::pow(sqrtf(beta_start) + + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), 2)); + compvis_sigmas[i] = + std::sqrt((1 - alphas_cumprod[i]) / + alphas_cumprod[i]); + } + int original_steps = 50; + + struct ggml_tensor* pred_original_sample = + ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* noise = + ggml_dup_tensor(work_ctx, x); + + for (int i = 0; i < steps; i++) { + // Analytic form for TCD timesteps + int timestep = TIMESTEPS - 1 - + (TIMESTEPS / original_steps) * + (int)floor(i * ((float)original_steps / steps)); + // 1. get previous step value + int prev_timestep = i >= steps - 1 ? 0 : + TIMESTEPS - 1 - (TIMESTEPS / original_steps) * + (int)floor((i + 1) * + ((float)original_steps / steps)); + // Here timestep_s is tau_n' in Algorithm 4. The _s + // notation appears to be that from DPM-Solver, C. Lu, + // arXiv:2206.00927 [cs.LG], but this notation is not + // continued in Algorithm 4, where _n' is used. + int timestep_s = + (int)floor((1 - eta) * prev_timestep); + // Begin k-diffusion specific workaround for + // evaluating F_theta(x; ...) from D(x, sigma), same + // as in DDIM (and see there for detailed comments) + float sigma = compvis_sigmas[timestep]; + if (i == 0) { + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1) / + sigma; + } + } + else { + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1); + } + } + struct ggml_tensor* model_output = + model(x, sigma, i + 1); + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_model_output[j] = + (vec_x[j] - vec_model_output[j]) * + (1 / sigma); + } + } + // 2. compute alphas, betas + // + // When comparing TCD with DDPM/DDIM note that Zheng + // et al. (2024) follows the DPM-Solver notation for + // alpha. One can find the following comment in the + // original DPM-Solver code + // (https://github.com/LuChengTHU/dpm-solver/): + // "**Important**: Please pay special attention for + // the args for `alphas_cumprod`: The `alphas_cumprod` + // is the \hat{alpha_n} arrays in the notations of + // DDPM. [...] Therefore, the notation \hat{alpha_n} + // is different from the notation alpha_t in + // DPM-Solver. In fact, we have alpha_{t_n} = + // \sqrt{\hat{alpha_n}}, [...]" + float alpha_prod_t = alphas_cumprod[timestep]; + float beta_prod_t = 1 - alpha_prod_t; + // Note final_alpha_cumprod = alphas_cumprod[0] since + // TCD is always "trailing" + float alpha_prod_t_prev = prev_timestep >= 0 ? + alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + // The subscript _s are the only portion in this + // section (2) unique to TCD + float alpha_prod_s = alphas_cumprod[timestep_s]; + float beta_prod_s = 1 - alpha_prod_s; + // 3. Compute the predicted noised sample x_s based on + // the model parameterization + // + // This section is also exactly the same as DDIM + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_pred_original_sample[j] = + (vec_x[j] / std::sqrt(sigma * sigma + 1) - + std::sqrt(beta_prod_t) * + vec_model_output[j]) * + (1 / std::sqrt(alpha_prod_t)); + } + } + // This consistency function step can be difficult to + // decipher from Algorithm 4, as it involves a + // difficult notation ("|->"). In Diffusers it is + // borrowed verbatim (with the same comments below for + // step (4)) from LCMScheduler's noise injection step, + // compare in S. Luo et al., arXiv:2310.04378 p. 14, + // Algorithm 3. + { + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + float* vec_model_output = + (float*)model_output->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + // Substituting x = pred_noised_sample and + // pred_epsilon = model_output + vec_x[j] = + std::sqrt(alpha_prod_s) * + vec_pred_original_sample[j] + + std::sqrt(beta_prod_s) * + vec_model_output[j]; + } + } + // 4. Sample and inject noise z ~ N(0, I) for + // MultiStep Inference Noise is not used on the final + // timestep of the timestep schedule. This also means + // that noise is not used for one-step sampling. Eta + // (referred to as "gamma" in the paper) was + // introduced to control the stochasticity in every + // step. When eta = 0, it represents deterministic + // sampling, whereas eta = 1 indicates full stochastic + // sampling. + if (eta > 0 && i != steps - 1) { + // In this case, x is still pred_noised_sample, + // continue in-place + ggml_tensor_set_f32_randn(noise, rng); + float* vec_x = (float*)x->data; + float* vec_noise = (float*)noise->data; + for (int j = 0; j < ggml_nelements(x); j++) { + // Corresponding to (35) in Zheng et + // al. (2024), substituting x = + // pred_noised_sample + vec_x[j] = + std::sqrt(alpha_prod_t_prev / + alpha_prod_s) * + vec_x[j] + + std::sqrt(1 - alpha_prod_t_prev / + alpha_prod_s) * + vec_noise[j]; + } + } + } + } break; default: LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 7e52b3ed..bbb7e03d 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -40,6 +40,7 @@ const char* sample_method_str[] = { "ipndm_v", "lcm", "ddim_trailing", + "tcd", }; // Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h @@ -94,6 +95,7 @@ struct SDParams { float min_cfg = 1.0f; float cfg_scale = 7.0f; float guidance = 3.5f; + float eta = 0.f; float style_ratio = 20.f; int clip_skip = -1; // <= 0 represents unspecified int width = 512; @@ -163,6 +165,7 @@ void print_params(SDParams params) { printf(" cfg_scale: %.2f\n", params.cfg_scale); printf(" slg_scale: %.2f\n", params.slg_scale); printf(" guidance: %.2f\n", params.guidance); + printf(" eta: %.2f\n", params.eta); printf(" clip_skip: %d\n", params.clip_skip); printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); @@ -210,6 +213,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n"); printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n"); + printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n"); printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n"); printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n"); printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n"); @@ -439,6 +443,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.guidance = std::stof(argv[i]); + } else if (arg == "--eta") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.eta = std::stof(argv[i]); } else if (arg == "--strength") { if (++i >= argc) { invalid_arg = true; @@ -718,6 +728,7 @@ std::string get_image_params(SDParams params, int64_t seed) { parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", "; } parameter_string += "Guidance: " + std::to_string(params.guidance) + ", "; + parameter_string += "Eta: " + std::to_string(params.eta) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", "; parameter_string += "Model: " + sd_basename(params.model_path) + ", "; @@ -938,6 +949,7 @@ int main(int argc, const char* argv[]) { params.clip_skip, params.cfg_scale, params.guidance, + params.eta, params.width, params.height, params.sample_method, @@ -1005,6 +1017,7 @@ int main(int argc, const char* argv[]) { params.clip_skip, params.cfg_scale, params.guidance, + params.eta, params.width, params.height, params.sample_method, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 8046908f..f5f782ab 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -47,7 +47,8 @@ const char* sampling_methods_str[] = { "iPNDM", "iPNDM_v", "LCM", - "DDIM \"trailing\"" + "DDIM \"trailing\"", + "TCD" }; /*================================================== Helper Functions ================================================*/ @@ -787,6 +788,7 @@ class StableDiffusionGGML { float min_cfg, float cfg_scale, float guidance, + float eta, sample_method_t method, const std::vector& sigmas, int start_merge_step, @@ -982,7 +984,7 @@ class StableDiffusionGGML { return denoised; }; - sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng); + sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); @@ -1188,6 +1190,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, enum sample_method_t sample_method, @@ -1451,6 +1454,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, cfg_scale, cfg_scale, guidance, + eta, sample_method, sigmas, start_merge_step, @@ -1516,6 +1520,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, enum sample_method_t sample_method, @@ -1594,6 +1599,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, clip_skip, cfg_scale, guidance, + eta, width, height, sample_method, @@ -1625,6 +1631,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, sample_method_t sample_method, @@ -1772,6 +1779,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, clip_skip, cfg_scale, guidance, + eta, width, height, sample_method, @@ -1885,6 +1893,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, min_cfg, cfg_scale, 0.f, + 0.f, sample_method, sigmas, -1, diff --git a/stable-diffusion.h b/stable-diffusion.h index 94a51291..8872bbaa 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -45,6 +45,7 @@ enum sample_method_t { IPNDM_V, LCM, DDIM_TRAILING, + TCD, N_SAMPLE_METHODS }; @@ -156,6 +157,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, enum sample_method_t sample_method, @@ -181,6 +183,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, int clip_skip, float cfg_scale, float guidance, + float eta, int width, int height, enum sample_method_t sample_method, From a2d97bba1c28f2b939f197d4dd50882cae5e1427 Mon Sep 17 00:00:00 2001 From: Yue Shi Lai Date: Sat, 18 Jan 2025 01:05:41 -0800 Subject: [PATCH 3/3] Add the missing "tcd" in help, simplification of comments and consistent reference style --- denoiser.hpp | 35 +++++++++++++++++++++-------------- examples/cli/main.cpp | 2 +- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/denoiser.hpp b/denoiser.hpp index 619bcd36..66799109 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -1009,7 +1009,10 @@ static void sample_k_diffusion(sample_method_t method, case DDIM_TRAILING: // Denoising Diffusion Implicit Models // with the "trailing" timestep spacing { - // DDIM itself needs alphas_cumprod (DDPM, Ho et al., + // See J. Song et al., "Denoising Diffusion Implicit + // Models", arXiv:2010.02502 [cs.LG] + // + // DDIM itself needs alphas_cumprod (DDPM, J. Ho et al., // arXiv:2006.11239 [cs.LG] with k-diffusion's start and // end beta) (which unfortunately k-diffusion's data // structure hides from the denoiser), and the sigmas are @@ -1045,9 +1048,8 @@ static void sample_k_diffusion(sample_method_t method, // are Flawed", arXiv:2305.08891 [cs], p. 4, Table // 2. Most variables below follow Diffusers naming // - // Diffuser naming vs. J. Song et al., "Denoising - // Diffusion Implicit Models", arXiv:2010.02502, p. 5, - // (12) and p. 16, (16) ( -> -> ): // // - pred_noise_t -> epsilon_theta^(t)(x_t) @@ -1100,9 +1102,8 @@ static void sample_k_diffusion(sample_method_t method, } // Note (also noise_pred in Diffuser's pipeline) // model_output = model() is the D(x, sigma) as - // defined in T. Karras et al., arXiv:2206.00364, - // p. 3, Table 1 and p. 8 (7), compare also p. 38 - // (226) therein. + // defined in Karras et al. (2022), p. 3, Table 1 and + // p. 8 (7), compare also p. 38 (226) therein. struct ggml_tensor* model_output = model(x, sigma, i + 1); // Here model_output is still the k-diffusion denoiser @@ -1202,6 +1203,10 @@ static void sample_k_diffusion(sample_method_t method, case TCD: // Strategic Stochastic Sampling (Algorithm 4) in // Trajectory Consistency Distillation { + // See J. Zheng et al., "Trajectory Consistency + // Distillation: Improved Latent Consistency Distillation + // by Semi-Linear Consistency Function with Trajectory + // Mapping", arXiv:2402.19159 [cs.CV] float beta_start = 0.00085f; float beta_end = 0.0120f; std::vector alphas_cumprod; @@ -1238,7 +1243,9 @@ static void sample_k_diffusion(sample_method_t method, (int)floor((i + 1) * ((float)original_steps / steps)); // Here timestep_s is tau_n' in Algorithm 4. The _s - // notation appears to be that from DPM-Solver, C. Lu, + // notation appears to be that from C. Lu, + // "DPM-Solver: A Fast ODE Solver for Diffusion + // Probabilistic Model Sampling in Around 10 Steps", // arXiv:2206.00927 [cs.LG], but this notation is not // continued in Algorithm 4, where _n' is used. int timestep_s = @@ -1315,12 +1322,12 @@ static void sample_k_diffusion(sample_method_t method, } } // This consistency function step can be difficult to - // decipher from Algorithm 4, as it involves a - // difficult notation ("|->"). In Diffusers it is - // borrowed verbatim (with the same comments below for - // step (4)) from LCMScheduler's noise injection step, - // compare in S. Luo et al., arXiv:2310.04378 p. 14, - // Algorithm 3. + // decipher from Algorithm 4, as it is simply stated + // using a consistency function. This step is the + // modified DDIM, i.e. p. 8 (32) in Zheng et + // al. (2024), with eta set to 0 (see the paragraph + // immediately thereafter that states this somewhat + // obliquely). { float* vec_pred_original_sample = (float*)pred_original_sample->data; diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index bbb7e03d..52eca4c7 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -224,7 +224,7 @@ void print_usage(int argc, const char* argv[]) { printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n"); - printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing}\n"); + printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n"); printf(" sampling method (default: \"euler_a\")\n"); printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n");