Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to split model components to different backend devices #461

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <random>
#include <string>
#include <vector>
#include <stdexcept>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no exceptions used.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything else that needs changing?


// #include "preprocessing.hpp"
#include "flux.hpp"
Expand Down Expand Up @@ -119,6 +120,10 @@ struct SDParams {
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;

int model_backend_index = -1;
int clip_backend_index = -1;
int vae_backend_index = -1;
};

void print_params(SDParams params) {
Expand Down Expand Up @@ -164,6 +169,9 @@ void print_params(SDParams params) {
printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
printf(" upscale_repeats: %d\n", params.upscale_repeats);
printf(" model_backend_index %d\n", params.model_backend_index);
printf(" clip_backend_index %d\n", params.clip_backend_index);
printf(" vae_backend_index %d\n", params.vae_backend_index);
}

void print_usage(int argc, const char* argv[]) {
Expand Down Expand Up @@ -219,6 +227,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color Colors the logging tags according to level\n");
printf(" -v, --verbose print extra info\n");
printf(" --model-backend-index specify which device the model defaults to using\n");
printf(" --clip-backend-index specify which device the CLIP model uses\n");
printf(" --vae-backend-index specify which device the VAE model uses\n");
}

void parse_args(int argc, const char** argv, SDParams& params) {
Expand Down Expand Up @@ -534,7 +545,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.verbose = true;
} else if (arg == "--color") {
params.color = true;
} else {
}
else if (arg == "--model-backend-index") {
params.model_backend_index = std::stoi(argv[++i]);
} else if (arg == "--clip-backend-index") {
params.clip_backend_index = std::stoi(argv[++i]);
} else if (arg == "--vae-backend-index") {
params.vae_backend_index = std::stoi(argv[++i]);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
exit(1);
Expand Down Expand Up @@ -791,7 +809,10 @@ int main(int argc, const char* argv[]) {
params.schedule,
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);
params.vae_on_cpu,
params.model_backend_index,
params.clip_backend_index,
params.vae_backend_index);

if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
Expand Down
98 changes: 81 additions & 17 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,16 @@ class StableDiffusionGGML {
} else if (rng_type == CUDA_RNG) {
rng = std::make_shared<PhiloxRNG>();
}
}

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove whitespace

~StableDiffusionGGML() {
if (clip_backend != backend) {
ggml_backend_free(clip_backend);
}
if (control_net_backend != backend) {
if (control_net_backend != backend && control_net_backend != clip_backend) {
ggml_backend_free(control_net_backend);
}
if (vae_backend != backend) {
if (vae_backend != backend && vae_backend != clip_backend && vae_backend != control_net_backend) {
ggml_backend_free(vae_backend);
}
ggml_backend_free(backend);
Expand All @@ -153,11 +153,15 @@ class StableDiffusionGGML {
schedule_t schedule,
bool clip_on_cpu,
bool control_net_cpu,
bool vae_on_cpu) {
bool vae_on_cpu,
int model_backend_index,
int clip_backend_index,
int vae_backend_index) {
use_tiny_autoencoder = taesd_path.size() > 0;
#ifdef SD_USE_CUBLAS
LOG_DEBUG("Using CUDA backend");
backend = ggml_backend_cuda_init(0);
if (model_backend_index == -1) model_backend_index = 0;
backend = ggml_backend_cuda_init(model_backend_index);
#endif
#ifdef SD_USE_METAL
LOG_DEBUG("Using Metal backend");
Expand All @@ -166,16 +170,22 @@ class StableDiffusionGGML {
#endif
#ifdef SD_USE_VULKAN
LOG_DEBUG("Using Vulkan backend");
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
backend = ggml_backend_vk_init(device);
}
if (!backend) {
LOG_WARN("Failed to initialize Vulkan backend");
}
if (model_backend_index == -1) {
// default behavior, last device selected
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
backend = ggml_backend_vk_init(device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are initializing EVERY vk device here (and leak it)

Copy link
Author

@softcookiepp softcookiepp Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from the original code verbatim in order to ensure that behavior prior to my changes was preserved. I just fixed it though. Instead of iterating over all device indexes, initializing all of them, and only keeping the last one, it instead simply picks the last index and initializes the corresponding backend device.

}
if (!backend) {
LOG_WARN("Failed to initialize Vulkan backend");
}
} else {
backend = ggml_backend_vk_init(model_backend_index);
}
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("Using SYCL backend");
backend = ggml_backend_sycl_init(0);
if (model_backend_index == -1) model_backend_index = 0;
backend = ggml_backend_sycl_init(model_backend_index);
#endif

if (!backend) {
Expand Down Expand Up @@ -321,7 +331,29 @@ class StableDiffusionGGML {
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
}
} else if (clip_backend_index > -1 && clip_backend_index != model_backend_index) {
#ifdef SD_USE_CUBLAS
LOG_DEBUG("CLIP: Using CUDA backend");
clip_backend = ggml_backend_cuda_init(clip_backend_index);
#endif
#ifdef SD_USE_VULKAN
LOG_DEBUG("CLIP: Using Vulkan backend");
clip_backend = ggml_backend_vk_init(clip_backend_index);
#endif
#ifdef SD_USE_METAL
LOG_DEBUG("CLIP: Using Metal backend");
// should be the same
clip_backend = backend;
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("CLIP: Using SYCL backend");
clip_backend = ggml_backend_sycl_init(clip_backend_index);
#endif
if (!clip_backend) {
LOG_WARN("No backend device found for CLIP, defaulting to model device.");
clip_backend = backend;
}
}
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
Expand All @@ -342,7 +374,33 @@ class StableDiffusionGGML {
if (vae_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("VAE Autoencoder: Using CPU backend");
vae_backend = ggml_backend_cpu_init();
} else {
} else if (vae_backend_index == clip_backend_index) {
vae_backend = clip_backend;
} else if (vae_backend_index == model_backend_index) {
vae_backend = backend;
} else if (vae_backend_index > -1) {
#ifdef SD_USE_CUBLAS
LOG_DEBUG("VAE Autoencoder: Using CUDA backend");
vae_backend = ggml_backend_cuda_init(vae_backend_index);
#endif
#ifdef SD_USE_VULKAN
LOG_DEBUG("VAE Autoencoder: Using Vulkan backend");
vae_backend = ggml_backend_vk_init(vae_backend_index);
#endif
#ifdef SD_USE_METAL
LOG_DEBUG("CLIP: Using Metal backend");
// should be the same
vae_backend = backend;
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("VAE Autoencoder: Using SYCL backend");
vae_backend = ggml_backend_sycl_init(vae_backend_index);
#endif
if (!vae_backend) {
LOG_WARN("No backend device found for VAE, defaulting to model device.");
vae_backend = backend;
}
} else {
vae_backend = backend;
}
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend, vae_wtype, vae_decode_only, false, version);
Expand Down Expand Up @@ -1035,7 +1093,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
enum schedule_t s,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu) {
bool keep_vae_on_cpu,
int model_backend_index,
int clip_backend_index,
int vae_backend_index) {
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
if (sd_ctx == NULL) {
return NULL;
Expand Down Expand Up @@ -1076,7 +1137,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
s,
keep_clip_on_cpu,
keep_control_net_cpu,
keep_vae_on_cpu)) {
keep_vae_on_cpu,
model_backend_index,
clip_backend_index,
vae_backend_index)) {
delete sd_ctx->sd;
sd_ctx->sd = NULL;
free(sd_ctx);
Expand Down
5 changes: 4 additions & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
enum schedule_t s,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu);
bool keep_vae_on_cpu,
int model_backend_index = -1,
int clip_backend_index = -1,
int vae_backend_index = -1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C can not have default params


SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);

Expand Down