diff --git a/backend/go/stablediffusion-ggml/Makefile b/backend/go/stablediffusion-ggml/Makefile index 4cbf9dcf136e..c1c22680b001 100644 --- a/backend/go/stablediffusion-ggml/Makefile +++ b/backend/go/stablediffusion-ggml/Makefile @@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1) # stablediffusion.cpp (ggml) STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp -STABLEDIFFUSION_GGML_VERSION?=0ebe6fe118f125665939b27c89f34ed38716bff8 +STABLEDIFFUSION_GGML_VERSION?=710169df5c93d3756bf1fc547512f4724e89c745 CMAKE_ARGS+=-DGGML_MAX_NAME=128 diff --git a/backend/go/stablediffusion-ggml/gosd.cpp b/backend/go/stablediffusion-ggml/gosd.cpp index db51b0b0978a..76889447061f 100644 --- a/backend/go/stablediffusion-ggml/gosd.cpp +++ b/backend/go/stablediffusion-ggml/gosd.cpp @@ -1,3 +1,4 @@ +#include "stable-diffusion.h" #include #define GGML_MAX_NAME 128 @@ -23,8 +24,8 @@ // Names of the sampler method, same order as enum sample_method in stable-diffusion.h const char* sample_method_str[] = { - "default", "euler", + "euler_a", "heun", "dpm2", "dpm++2s_a", @@ -35,29 +36,29 @@ const char* sample_method_str[] = { "lcm", "ddim_trailing", "tcd", - "euler_a", }; static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch"); // Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h const char* schedulers[] = { - "default", "discrete", "karras", "exponential", "ays", "gits", + "sgm_uniform", + "simple", "smoothstep", + "lcm", }; -static_assert(std::size(schedulers) == SCHEDULE_COUNT, "schedulers mismatch"); +static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch"); sd_ctx_t* sd_c; // Moved from the context (load time) to generation time params -scheduler_t scheduler = scheduler_t::DEFAULT; - -sample_method_t sample_method; +scheduler_t scheduler = SCHEDULER_COUNT; +sample_method_t sample_method = SAMPLE_METHOD_COUNT; // Copied from the upstream CLI static void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { @@ -159,26 +160,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads fprintf(stderr, "parsed options\n"); - int sample_method_found = -1; - for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) { - if (!strcmp(sampler, sample_method_str[m])) { - sample_method_found = m; - fprintf(stderr, "Found sampler: %s\n", sampler); - } - } - if (sample_method_found == -1) { - fprintf(stderr, "Invalid sample method, default to EULER_A!\n"); - sample_method_found = sample_method_t::SAMPLE_METHOD_DEFAULT; - } - sample_method = (sample_method_t)sample_method_found; - - for (int d = 0; d < SCHEDULE_COUNT; d++) { - if (!strcmp(scheduler_str, schedulers[d])) { - scheduler = (scheduler_t)d; - fprintf (stderr, "Found scheduler: %s\n", scheduler_str); - } - } - fprintf (stderr, "Creating context\n"); sd_ctx_params_t ctx_params; sd_ctx_params_init(&ctx_params); @@ -208,6 +189,30 @@ int load_model(const char *model, char *model_path, char* options[], int threads } fprintf (stderr, "Created context: OK\n"); + int sample_method_found = -1; + for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) { + if (!strcmp(sampler, sample_method_str[m])) { + sample_method_found = m; + fprintf(stderr, "Found sampler: %s\n", sampler); + } + } + if (sample_method_found == -1) { + sample_method_found = sd_get_default_sample_method(sd_ctx); + fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]); + } + sample_method = (sample_method_t)sample_method_found; + + for (int d = 0; d < SCHEDULER_COUNT; d++) { + if (!strcmp(scheduler_str, schedulers[d])) { + scheduler = (scheduler_t)d; + fprintf (stderr, "Found scheduler: %s\n", scheduler_str); + } + } + if (scheduler == SCHEDULER_COUNT) { + scheduler = sd_get_default_scheduler(sd_ctx); + fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]); + } + sd_c = sd_ctx; // Clean up allocated memory