From 5c95362398ee3924272fbe3ddffc344b3d829475 Mon Sep 17 00:00:00 2001 From: Ivan Kharin Date: Mon, 20 Apr 2026 10:43:38 +0500 Subject: [PATCH] fix: build with GGML_BACKEND=ON --- src/convert.cpp | 2 +- src/flux.hpp | 2 +- src/ggml_extend.hpp | 30 ++++++++++++++++++++--------- src/llm.hpp | 2 +- src/lora.hpp | 4 ++-- src/mmdit.hpp | 2 +- src/qwen_image.hpp | 2 +- src/stable-diffusion.cpp | 41 ++++++++++++++++++++++++---------------- src/t5.hpp | 2 +- src/upscaler.cpp | 2 +- src/util.cpp | 30 ++++++++++++++++------------- src/wan.hpp | 4 ++-- src/z_image.hpp | 2 +- 13 files changed, 75 insertions(+), 50 deletions(-) diff --git a/src/convert.cpp b/src/convert.cpp index 7cae8df0f..f723b70bb 100644 --- a/src/convert.cpp +++ b/src/convert.cpp @@ -103,7 +103,7 @@ bool convert(const char* input_path, bool output_is_safetensors = ends_with(output_path, ".safetensors"); TensorTypeRules type_rules = parse_tensor_type_rules(tensor_type_rules); - auto backend = ggml_backend_cpu_init(); + auto backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); size_t mem_size = 1 * 1024 * 1024; // for padding mem_size += model_loader.get_tensor_storage_map().size() * ggml_tensor_overhead(); mem_size += model_loader.get_params_mem_size(backend, type); diff --git a/src/flux.hpp b/src/flux.hpp index e6bf002fb..a175e992f 100644 --- a/src/flux.hpp +++ b/src/flux.hpp @@ -1539,7 +1539,7 @@ namespace Flux { static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); ggml_type model_data_type = GGML_TYPE_COUNT; ModelLoader model_loader; diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 859270cbd..359d0b887 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -64,6 +64,10 @@ #define SD_UNUSED(x) (void)(x) #endif +bool inline sd_ggml_backend_is_cpu(ggml_backend_t backend) noexcept { + return std::string_view{"CPU"} == ggml_backend_name(backend); +} + __STATIC_INLINE__ int align_up_offset(int n, int multiple) { return (multiple - n % multiple) % multiple; } @@ -1497,7 +1501,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_group_norm(ggml_context* ctx, __STATIC_INLINE__ void ggml_ext_backend_tensor_get_and_sync(ggml_backend_t backend, const ggml_tensor* tensor, void* data, size_t offset, size_t size) { #if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) - if (!ggml_backend_is_cpu(backend)) { + if (!sd_ggml_backend_is_cpu(backend)) { ggml_backend_tensor_get_async(backend, tensor, data, offset, size); ggml_backend_synchronize(backend); } else { @@ -1859,7 +1863,7 @@ struct GGMLRunner { LOG_DEBUG("%s compute buffer size: %.2f MB(%s)", get_desc().c_str(), compute_buffer_size / 1024.0 / 1024.0, - ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM"); + sd_ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM"); return true; } @@ -1895,7 +1899,7 @@ struct GGMLRunner { LOG_DEBUG("%s cache backend buffer size = % 6.2f MB(%s) (%i tensors)", get_desc().c_str(), cache_buffer_size / (1024.f * 1024.f), - ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM", + sd_ggml_backend_is_cpu(runtime_backend) ? "RAM" : "VRAM", num_tensors); } @@ -1998,8 +2002,8 @@ struct GGMLRunner { GGMLRunner(ggml_backend_t backend, bool offload_params_to_cpu = false) : runtime_backend(backend) { alloc_params_ctx(); - if (!ggml_backend_is_cpu(runtime_backend) && offload_params_to_cpu) { - params_backend = ggml_backend_cpu_init(); + if (!sd_ggml_backend_is_cpu(runtime_backend) && offload_params_to_cpu) { + params_backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); } else { params_backend = runtime_backend; } @@ -2046,7 +2050,7 @@ struct GGMLRunner { LOG_DEBUG("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)", get_desc().c_str(), params_buffer_size / (1024.f * 1024.f), - ggml_backend_is_cpu(params_backend) ? "RAM" : "VRAM", + sd_ggml_backend_is_cpu(params_backend) ? "RAM" : "VRAM", num_tensors); return true; } @@ -2112,7 +2116,7 @@ struct GGMLRunner { return nullptr; } // it's performing a compute, check if backend isn't cpu - if (!ggml_backend_is_cpu(runtime_backend) && (tensor->buffer == nullptr || ggml_backend_buffer_is_host(tensor->buffer))) { + if (!sd_ggml_backend_is_cpu(runtime_backend) && (tensor->buffer == nullptr || ggml_backend_buffer_is_host(tensor->buffer))) { // pass input tensors to gpu memory auto backend_tensor = ggml_dup_tensor(compute_ctx, tensor); @@ -2154,8 +2158,16 @@ struct GGMLRunner { return std::nullopt; } copy_data_to_backend_tensor(); - if (ggml_backend_is_cpu(runtime_backend)) { - ggml_backend_cpu_set_n_threads(runtime_backend, n_threads); + if (sd_ggml_backend_is_cpu(runtime_backend)) { + if (auto reg = ggml_backend_reg_by_name("CPU")) { + if (auto fn = (ggml_backend_set_n_threads_t)ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads")) { + fn(runtime_backend, n_threads); + } else { + LOG_ERROR("ggml_backend_reg_get_proc_address(\"ggml_backend_set_n_threads\") == nullptr"); + } + } else { + LOG_ERROR("ggml_backend_reg_by_name(\"CPU\") == nullptr"); + } } ggml_status status = ggml_backend_graph_compute(runtime_backend, gf); diff --git a/src/llm.hpp b/src/llm.hpp index 4afaa3ba6..7228e6ddc 100644 --- a/src/llm.hpp +++ b/src/llm.hpp @@ -1214,7 +1214,7 @@ namespace LLM { static void load_from_file_and_test(const std::string& file_path) { // cpu f16: pass // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); ggml_type model_data_type = GGML_TYPE_COUNT; ModelLoader model_loader; diff --git a/src/lora.hpp b/src/lora.hpp index d4a749ef9..f278bba98 100644 --- a/src/lora.hpp +++ b/src/lora.hpp @@ -767,7 +767,7 @@ struct LoraModel : public GGMLRunner { } ggml_tensor* original_tensor = model_tensor; - if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { + if (!sd_ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { model_tensor = ggml_dup_tensor(compute_ctx, model_tensor); set_backend_tensor_data(model_tensor, original_tensor->data); } @@ -781,7 +781,7 @@ struct LoraModel : public GGMLRunner { final_tensor = ggml_add_inplace(compute_ctx, model_tensor, diff); } ggml_build_forward_expand(gf, final_tensor); - if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { + if (!sd_ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { original_tensor_to_final_tensor[original_tensor] = final_tensor; } } diff --git a/src/mmdit.hpp b/src/mmdit.hpp index e75736c5d..1ffe89d74 100644 --- a/src/mmdit.hpp +++ b/src/mmdit.hpp @@ -925,7 +925,7 @@ struct MMDiTRunner : public GGMLRunner { static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); ggml_type model_data_type = GGML_TYPE_F16; std::shared_ptr mmdit = std::make_shared(backend, false); { diff --git a/src/qwen_image.hpp b/src/qwen_image.hpp index 83c8cec66..533cecbb4 100644 --- a/src/qwen_image.hpp +++ b/src/qwen_image.hpp @@ -662,7 +662,7 @@ namespace Qwen { // cuda q8: pass // cuda q8 fa: pass // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); ggml_type model_data_type = GGML_TYPE_Q8_0; ModelLoader model_loader; diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index b9d3e9af1..e5243b72d 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -222,8 +222,17 @@ class StableDiffusionGGML { #endif if (!backend) { + static bool need_load = true; + if (need_load) { + ggml_backend_load_all(); + need_load = false; + } LOG_DEBUG("Using CPU backend"); - backend = ggml_backend_cpu_init(); + backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (!backend) { + LOG_ERROR("CPU backend is nullptr!"); + std::terminate(); + } } } @@ -429,9 +438,9 @@ class StableDiffusionGGML { { clip_backend = backend; - if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { + if (clip_on_cpu && !sd_ggml_backend_is_cpu(backend)) { LOG_INFO("CLIP: Using CPU backend"); - clip_backend = ggml_backend_cpu_init(); + clip_backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); } if (sd_version_is_sd3(version)) { cond_stage_model = std::make_shared(clip_backend, @@ -607,9 +616,9 @@ class StableDiffusionGGML { high_noise_diffusion_model->get_param_tensors(tensors); } - if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { + if (sd_ctx_params->keep_vae_on_cpu && !sd_ggml_backend_is_cpu(backend)) { LOG_INFO("VAE Autoencoder: Using CPU backend"); - vae_backend = ggml_backend_cpu_init(); + vae_backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); } else { vae_backend = backend; } @@ -700,9 +709,9 @@ class StableDiffusionGGML { if (strlen(SAFE_STR(sd_ctx_params->control_net_path)) > 0) { ggml_backend_t controlnet_backend = nullptr; - if (sd_ctx_params->keep_control_net_on_cpu && !ggml_backend_is_cpu(backend)) { + if (sd_ctx_params->keep_control_net_on_cpu && !sd_ggml_backend_is_cpu(backend)) { LOG_DEBUG("ControlNet: Using CPU backend"); - controlnet_backend = ggml_backend_cpu_init(); + controlnet_backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); } else { controlnet_backend = backend; } @@ -869,25 +878,25 @@ class StableDiffusionGGML { size_t total_params_ram_size = 0; size_t total_params_vram_size = 0; - if (ggml_backend_is_cpu(clip_backend)) { + if (sd_ggml_backend_is_cpu(clip_backend)) { total_params_ram_size += clip_params_mem_size + pmid_params_mem_size; } else { total_params_vram_size += clip_params_mem_size + pmid_params_mem_size; } - if (ggml_backend_is_cpu(backend)) { + if (sd_ggml_backend_is_cpu(backend)) { total_params_ram_size += unet_params_mem_size; } else { total_params_vram_size += unet_params_mem_size; } - if (ggml_backend_is_cpu(vae_backend)) { + if (sd_ggml_backend_is_cpu(vae_backend)) { total_params_ram_size += vae_params_mem_size; } else { total_params_vram_size += vae_params_mem_size; } - if (ggml_backend_is_cpu(control_net_backend)) { + if (sd_ggml_backend_is_cpu(control_net_backend)) { total_params_ram_size += control_net_params_mem_size; } else { total_params_vram_size += control_net_params_mem_size; @@ -901,15 +910,15 @@ class StableDiffusionGGML { total_params_vram_size / 1024.0 / 1024.0, total_params_ram_size / 1024.0 / 1024.0, clip_params_mem_size / 1024.0 / 1024.0, - ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM", + sd_ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM", unet_params_mem_size / 1024.0 / 1024.0, - ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", + sd_ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", vae_params_mem_size / 1024.0 / 1024.0, - ggml_backend_is_cpu(vae_backend) ? "RAM" : "VRAM", + sd_ggml_backend_is_cpu(vae_backend) ? "RAM" : "VRAM", control_net_params_mem_size / 1024.0 / 1024.0, - ggml_backend_is_cpu(control_net_backend) ? "RAM" : "VRAM", + sd_ggml_backend_is_cpu(control_net_backend) ? "RAM" : "VRAM", pmid_params_mem_size / 1024.0 / 1024.0, - ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM"); + sd_ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM"); } // init denoiser diff --git a/src/t5.hpp b/src/t5.hpp index bbd13e498..ff95a803f 100644 --- a/src/t5.hpp +++ b/src/t5.hpp @@ -555,7 +555,7 @@ struct T5Embedder { // cuda f32: pass // cuda q8_0: pass // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); ggml_type model_data_type = GGML_TYPE_F16; ModelLoader model_loader; diff --git a/src/upscaler.cpp b/src/upscaler.cpp index 03f7714e5..532c78298 100644 --- a/src/upscaler.cpp +++ b/src/upscaler.cpp @@ -52,7 +52,7 @@ struct UpscalerGGML { model_loader.set_wtype_override(model_data_type); if (!backend) { LOG_DEBUG("Using CPU backend"); - backend = ggml_backend_cpu_init(); + backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); } LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map()); diff --git a/src/util.cpp b/src/util.cpp index e01876268..a19bb92fe 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -499,19 +499,23 @@ const char* sd_get_system_info() { static char buffer[1024]; std::stringstream ss; ss << "System Info: \n"; - ss << " SSE3 = " << ggml_cpu_has_sse3() << " | "; - ss << " AVX = " << ggml_cpu_has_avx() << " | "; - ss << " AVX2 = " << ggml_cpu_has_avx2() << " | "; - ss << " AVX512 = " << ggml_cpu_has_avx512() << " | "; - ss << " AVX512_VBMI = " << ggml_cpu_has_avx512_vbmi() << " | "; - ss << " AVX512_VNNI = " << ggml_cpu_has_avx512_vnni() << " | "; - ss << " FMA = " << ggml_cpu_has_fma() << " | "; - ss << " NEON = " << ggml_cpu_has_neon() << " | "; - ss << " ARM_FMA = " << ggml_cpu_has_arm_fma() << " | "; - ss << " F16C = " << ggml_cpu_has_f16c() << " | "; - ss << " FP16_VA = " << ggml_cpu_has_fp16_va() << " | "; - ss << " WASM_SIMD = " << ggml_cpu_has_wasm_simd() << " | "; - ss << " VSX = " << ggml_cpu_has_vsx() << " | "; + if (auto reg = ggml_backend_reg_by_name("CPU")) { + ggml_backend_get_features_t fn = (ggml_backend_get_features_t)ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); + if (fn) { + auto ptr = fn(reg); + if ( !ptr || !ptr->name ) { + ss << " [None]"; + } else { + for ( ; ptr->name; ++ptr ) { + ss << ptr->name << " = " << ptr->value << " | "; + } + } + } else { + LOG_ERROR("ggml_backend_reg_get_proc_address() failed on \"ggml_backend_get_features\""); + } + } else { + LOG_ERROR("ggml_backend_reg_by_name(\"CPU\") == nullptr"); + } snprintf(buffer, sizeof(buffer), "%s", ss.str().c_str()); return buffer; } diff --git a/src/wan.hpp b/src/wan.hpp index 6860262c5..1e8e9aed8 100644 --- a/src/wan.hpp +++ b/src/wan.hpp @@ -1315,7 +1315,7 @@ namespace WAN { static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); ggml_type model_data_type = GGML_TYPE_F16; std::shared_ptr vae = std::make_shared(backend, false, String2TensorStorage{}, "", false, VERSION_WAN2_2_TI2V); { @@ -2305,7 +2305,7 @@ namespace WAN { static void load_from_file_and_test(const std::string& file_path) { // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); ggml_type model_data_type = GGML_TYPE_F16; LOG_INFO("loading from '%s'", file_path.c_str()); diff --git a/src/z_image.hpp b/src/z_image.hpp index 363ce5f4f..da340b082 100644 --- a/src/z_image.hpp +++ b/src/z_image.hpp @@ -592,7 +592,7 @@ namespace ZImage { // cuda q8: pass // cuda q8 fa: pass // ggml_backend_t backend = ggml_backend_cuda_init(0); - ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_backend_t backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); ggml_type model_data_type = GGML_TYPE_Q8_0; ModelLoader model_loader;