Skip to content

Commit 88316e3

Browse files
committed
fix(stablediffusion-ggml): fixup schedulers and samplers arrays, use default getters
Signed-off-by: Richard Palethorpe <io@richiejp.com>
1 parent 6f4df82 commit 88316e3

File tree

1 file changed

+32
-27
lines changed

1 file changed

+32
-27
lines changed

backend/go/stablediffusion-ggml/gosd.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "stable-diffusion.h"
12
#include <cstdint>
23
#define GGML_MAX_NAME 128
34

@@ -23,8 +24,8 @@
2324

2425
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
2526
const char* sample_method_str[] = {
26-
"default",
2727
"euler",
28+
"euler_a",
2829
"heun",
2930
"dpm2",
3031
"dpm++2s_a",
@@ -35,29 +36,29 @@ const char* sample_method_str[] = {
3536
"lcm",
3637
"ddim_trailing",
3738
"tcd",
38-
"euler_a",
3939
};
4040

4141
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
4242

4343
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
4444
const char* schedulers[] = {
45-
"default",
4645
"discrete",
4746
"karras",
4847
"exponential",
4948
"ays",
5049
"gits",
50+
"sgm_uniform",
51+
"simple",
5152
"smoothstep",
53+
"lcm",
5254
};
5355

54-
static_assert(std::size(schedulers) == SCHEDULE_COUNT, "schedulers mismatch");
56+
static_assert(std::size(schedulers) == SCHEDULER_COUNT, "schedulers mismatch");
5557

5658
sd_ctx_t* sd_c;
5759
// Moved from the context (load time) to generation time params
58-
scheduler_t scheduler = scheduler_t::DEFAULT;
59-
60-
sample_method_t sample_method;
60+
scheduler_t scheduler = SCHEDULER_COUNT;
61+
sample_method_t sample_method = SAMPLE_METHOD_COUNT;
6162

6263
// Copied from the upstream CLI
6364
static void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
@@ -159,26 +160,6 @@ int load_model(const char *model, char *model_path, char* options[], int threads
159160

160161
fprintf(stderr, "parsed options\n");
161162

162-
int sample_method_found = -1;
163-
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
164-
if (!strcmp(sampler, sample_method_str[m])) {
165-
sample_method_found = m;
166-
fprintf(stderr, "Found sampler: %s\n", sampler);
167-
}
168-
}
169-
if (sample_method_found == -1) {
170-
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
171-
sample_method_found = sample_method_t::SAMPLE_METHOD_DEFAULT;
172-
}
173-
sample_method = (sample_method_t)sample_method_found;
174-
175-
for (int d = 0; d < SCHEDULE_COUNT; d++) {
176-
if (!strcmp(scheduler_str, schedulers[d])) {
177-
scheduler = (scheduler_t)d;
178-
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
179-
}
180-
}
181-
182163
fprintf (stderr, "Creating context\n");
183164
sd_ctx_params_t ctx_params;
184165
sd_ctx_params_init(&ctx_params);
@@ -208,6 +189,30 @@ int load_model(const char *model, char *model_path, char* options[], int threads
208189
}
209190
fprintf (stderr, "Created context: OK\n");
210191

192+
int sample_method_found = -1;
193+
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
194+
if (!strcmp(sampler, sample_method_str[m])) {
195+
sample_method_found = m;
196+
fprintf(stderr, "Found sampler: %s\n", sampler);
197+
}
198+
}
199+
if (sample_method_found == -1) {
200+
sample_method_found = sd_get_default_sample_method(sd_ctx);
201+
fprintf(stderr, "Invalid sample method, using default: %s\n", sample_method_str[sample_method_found]);
202+
}
203+
sample_method = (sample_method_t)sample_method_found;
204+
205+
for (int d = 0; d < SCHEDULER_COUNT; d++) {
206+
if (!strcmp(scheduler_str, schedulers[d])) {
207+
scheduler = (scheduler_t)d;
208+
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
209+
}
210+
}
211+
if (scheduler == SCHEDULER_COUNT) {
212+
scheduler = sd_get_default_scheduler(sd_ctx);
213+
fprintf(stderr, "Invalid scheduler, using default: %s\n", schedulers[scheduler]);
214+
}
215+
211216
sd_c = sd_ctx;
212217

213218
// Clean up allocated memory

0 commit comments

Comments
 (0)