Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/go/stablediffusion-ggml/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)

# stablediffusion.cpp (ggml)
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
STABLEDIFFUSION_GGML_VERSION?=0ebe6fe118f125665939b27c89f34ed38716bff8
STABLEDIFFUSION_GGML_VERSION?=710169df5c93d3756bf1fc547512f4724e89c745

CMAKE_ARGS+=-DGGML_MAX_NAME=128

Expand Down
59 changes: 32 additions & 27 deletions backend/go/stablediffusion-ggml/gosd.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "stable-diffusion.h"
#include <cstdint>
#define GGML_MAX_NAME 128

Expand All @@ -23,8 +24,8 @@

// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
const char* sample_method_str[] = {
"default",
"euler",
"euler_a",
"heun",
"dpm2",
"dpm++2s_a",
Expand All @@ -35,29 +36,29 @@ const char* sample_method_str[] = {
"lcm",
"ddim_trailing",
"tcd",
"euler_a",
};

static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");

// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
const char* schedulers[] = {
"default",
"discrete",
"karras",
"exponential",
"ays",
"gits",
"sgm_uniform",
"simple",
"smoothstep",
"lcm",
};

static_assert(std::size(schedulers) == SCHEDULE_COUNT, "schedulers mismatch");
static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");

sd_ctx_t* sd_c;
// Moved from the context (load time) to generation time params
scheduler_t scheduler = scheduler_t::DEFAULT;

sample_method_t sample_method;
scheduler_t scheduler = SCHEDULER_COUNT;
sample_method_t sample_method = SAMPLE_METHOD_COUNT;

// Copied from the upstream CLI
static void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
Expand Down Expand Up @@ -159,26 +160,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads

fprintf(stderr, "parsed options\n");

int sample_method_found = -1;
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
if (!strcmp(sampler, sample_method_str[m])) {
sample_method_found = m;
fprintf(stderr, "Found sampler: %s\n", sampler);
}
}
if (sample_method_found == -1) {
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
sample_method_found = sample_method_t::SAMPLE_METHOD_DEFAULT;
}
sample_method = (sample_method_t)sample_method_found;

for (int d = 0; d < SCHEDULE_COUNT; d++) {
if (!strcmp(scheduler_str, schedulers[d])) {
scheduler = (scheduler_t)d;
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
}
}

fprintf (stderr, "Creating context\n");
sd_ctx_params_t ctx_params;
sd_ctx_params_init(&ctx_params);
Expand Down Expand Up @@ -208,6 +189,30 @@ int load_model(const char *model, char *model_path, char* options[], int threads
}
fprintf (stderr, "Created context: OK\n");

int sample_method_found = -1;
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
if (!strcmp(sampler, sample_method_str[m])) {
sample_method_found = m;
fprintf(stderr, "Found sampler: %s\n", sampler);
}
}
if (sample_method_found == -1) {
sample_method_found = sd_get_default_sample_method(sd_ctx);
fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]);
}
sample_method = (sample_method_t)sample_method_found;

for (int d = 0; d < SCHEDULER_COUNT; d++) {
if (!strcmp(scheduler_str, schedulers[d])) {
scheduler = (scheduler_t)d;
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
}
}
if (scheduler == SCHEDULER_COUNT) {
scheduler = sd_get_default_scheduler(sd_ctx);
fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]);
}

sd_c = sd_ctx;

// Clean up allocated memory
Expand Down
Loading