-
Notifications
You must be signed in to change notification settings - Fork 334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ability to split model components to different backend devices #461
base: master
Are you sure you want to change the base?
Changes from 3 commits
15d2654
11bc8c4
ead839f
9e571a5
282c371
2c9fe5c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,16 +123,16 @@ class StableDiffusionGGML { | |
} else if (rng_type == CUDA_RNG) { | ||
rng = std::make_shared<PhiloxRNG>(); | ||
} | ||
} | ||
|
||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove whitespace |
||
~StableDiffusionGGML() { | ||
if (clip_backend != backend) { | ||
ggml_backend_free(clip_backend); | ||
} | ||
if (control_net_backend != backend) { | ||
if (control_net_backend != backend && control_net_backend != clip_backend) { | ||
ggml_backend_free(control_net_backend); | ||
} | ||
if (vae_backend != backend) { | ||
if (vae_backend != backend && vae_backend != clip_backend && vae_backend != control_net_backend) { | ||
ggml_backend_free(vae_backend); | ||
} | ||
ggml_backend_free(backend); | ||
|
@@ -153,11 +153,15 @@ class StableDiffusionGGML { | |
schedule_t schedule, | ||
bool clip_on_cpu, | ||
bool control_net_cpu, | ||
bool vae_on_cpu) { | ||
bool vae_on_cpu, | ||
int model_backend_index, | ||
int clip_backend_index, | ||
int vae_backend_index) { | ||
use_tiny_autoencoder = taesd_path.size() > 0; | ||
#ifdef SD_USE_CUBLAS | ||
LOG_DEBUG("Using CUDA backend"); | ||
backend = ggml_backend_cuda_init(0); | ||
if (model_backend_index == -1) model_backend_index = 0; | ||
backend = ggml_backend_cuda_init(model_backend_index); | ||
#endif | ||
#ifdef SD_USE_METAL | ||
LOG_DEBUG("Using Metal backend"); | ||
|
@@ -166,16 +170,22 @@ class StableDiffusionGGML { | |
#endif | ||
#ifdef SD_USE_VULKAN | ||
LOG_DEBUG("Using Vulkan backend"); | ||
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { | ||
backend = ggml_backend_vk_init(device); | ||
} | ||
if (!backend) { | ||
LOG_WARN("Failed to initialize Vulkan backend"); | ||
} | ||
if (model_backend_index == -1) { | ||
// default behavior, last device selected | ||
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { | ||
backend = ggml_backend_vk_init(device); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are initializing EVERY vk device here (and leak it) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied this from the original code verbatim in order to ensure that behavior prior to my changes was preserved. I just fixed it though. Instead of iterating over all device indexes, initializing all of them, and only keeping the last one, it instead simply picks the last index and initializes the corresponding backend device. |
||
} | ||
if (!backend) { | ||
LOG_WARN("Failed to initialize Vulkan backend"); | ||
} | ||
} else { | ||
backend = ggml_backend_vk_init(model_backend_index); | ||
} | ||
#endif | ||
#ifdef SD_USE_SYCL | ||
LOG_DEBUG("Using SYCL backend"); | ||
backend = ggml_backend_sycl_init(0); | ||
if (model_backend_index == -1) model_backend_index = 0; | ||
backend = ggml_backend_sycl_init(model_backend_index); | ||
#endif | ||
|
||
if (!backend) { | ||
|
@@ -321,7 +331,29 @@ class StableDiffusionGGML { | |
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { | ||
LOG_INFO("CLIP: Using CPU backend"); | ||
clip_backend = ggml_backend_cpu_init(); | ||
} | ||
} else if (clip_backend_index > -1 && clip_backend_index != model_backend_index) { | ||
#ifdef SD_USE_CUBLAS | ||
LOG_DEBUG("CLIP: Using CUDA backend"); | ||
clip_backend = ggml_backend_cuda_init(clip_backend_index); | ||
#endif | ||
#ifdef SD_USE_VULKAN | ||
LOG_DEBUG("CLIP: Using Vulkan backend"); | ||
clip_backend = ggml_backend_vk_init(clip_backend_index); | ||
#endif | ||
#ifdef SD_USE_METAL | ||
LOG_DEBUG("CLIP: Using Metal backend"); | ||
// should be the same | ||
clip_backend = backend; | ||
#endif | ||
#ifdef SD_USE_SYCL | ||
LOG_DEBUG("CLIP: Using SYCL backend"); | ||
clip_backend = ggml_backend_sycl_init(clip_backend_index); | ||
#endif | ||
if (!clip_backend) { | ||
LOG_WARN("No backend device found for CLIP, defaulting to model device."); | ||
clip_backend = backend; | ||
} | ||
} | ||
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { | ||
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype); | ||
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version); | ||
|
@@ -342,7 +374,33 @@ class StableDiffusionGGML { | |
if (vae_on_cpu && !ggml_backend_is_cpu(backend)) { | ||
LOG_INFO("VAE Autoencoder: Using CPU backend"); | ||
vae_backend = ggml_backend_cpu_init(); | ||
} else { | ||
} else if (vae_backend_index == clip_backend_index) { | ||
vae_backend = clip_backend; | ||
} else if (vae_backend_index == model_backend_index) { | ||
vae_backend = backend; | ||
} else if (vae_backend_index > -1) { | ||
#ifdef SD_USE_CUBLAS | ||
LOG_DEBUG("VAE Autoencoder: Using CUDA backend"); | ||
vae_backend = ggml_backend_cuda_init(vae_backend_index); | ||
#endif | ||
#ifdef SD_USE_VULKAN | ||
LOG_DEBUG("VAE Autoencoder: Using Vulkan backend"); | ||
vae_backend = ggml_backend_vk_init(vae_backend_index); | ||
#endif | ||
#ifdef SD_USE_METAL | ||
LOG_DEBUG("CLIP: Using Metal backend"); | ||
// should be the same | ||
vae_backend = backend; | ||
#endif | ||
#ifdef SD_USE_SYCL | ||
LOG_DEBUG("VAE Autoencoder: Using SYCL backend"); | ||
vae_backend = ggml_backend_sycl_init(vae_backend_index); | ||
#endif | ||
if (!vae_backend) { | ||
LOG_WARN("No backend device found for VAE, defaulting to model device."); | ||
vae_backend = backend; | ||
} | ||
} else { | ||
vae_backend = backend; | ||
} | ||
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend, vae_wtype, vae_decode_only, false, version); | ||
|
@@ -1035,7 +1093,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, | |
enum schedule_t s, | ||
bool keep_clip_on_cpu, | ||
bool keep_control_net_cpu, | ||
bool keep_vae_on_cpu) { | ||
bool keep_vae_on_cpu, | ||
int model_backend_index, | ||
int clip_backend_index, | ||
int vae_backend_index) { | ||
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); | ||
if (sd_ctx == NULL) { | ||
return NULL; | ||
|
@@ -1076,7 +1137,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, | |
s, | ||
keep_clip_on_cpu, | ||
keep_control_net_cpu, | ||
keep_vae_on_cpu)) { | ||
keep_vae_on_cpu, | ||
model_backend_index, | ||
clip_backend_index, | ||
vae_backend_index)) { | ||
delete sd_ctx->sd; | ||
sd_ctx->sd = NULL; | ||
free(sd_ctx); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,7 +142,10 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path, | |
enum schedule_t s, | ||
bool keep_clip_on_cpu, | ||
bool keep_control_net_cpu, | ||
bool keep_vae_on_cpu); | ||
bool keep_vae_on_cpu, | ||
int model_backend_index = -1, | ||
int clip_backend_index = -1, | ||
int vae_backend_index = -1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. C can not have default params |
||
|
||
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no exceptions used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anything else that needs changing?