Skip to content

Commit 8ecdf05

Browse files
authored
feat: add image preview support (#522)
1 parent ee89afc commit 8ecdf05

File tree

9 files changed

+563
-10
lines changed

9 files changed

+563
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ test/
1212
output*.png
1313
models*
1414
*.log
15+
preview.png

examples/cli/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Options:
3232
-o, --output <string> path to write result image to (default: ./output.png)
3333
-p, --prompt <string> the prompt to render
3434
-n, --negative-prompt <string> the negative prompt (default: "")
35+
--preview-path <string> path to write preview image to (default: ./preview.png)
3536
--upscale-model <string> path to esrgan model.
3637
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of
3738
CPU physical cores
@@ -48,6 +49,8 @@ Options:
4849
--fps <int> fps (default: 24)
4950
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for
5051
NitroSD-Vibrant
52+
--preview-interval <int> interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at
53+
every step)
5154
--cfg-scale <float> unconditional guidance scale: (default: 7.0)
5255
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)
5356
--guidance <float> distilled guidance scale for models with guidance input (default: 3.5)
@@ -86,6 +89,8 @@ Options:
8689
--chroma-enable-t5-mask enable t5 mask for chroma
8790
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
8891
--disable-auto-resize-ref-image disable auto resize of ref images
92+
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)
93+
--preview-noisy enables previewing noisy inputs of the models rather than the denoised outputs
8994
-M, --mode run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen
9095
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
9196
type of the weight file
@@ -107,4 +112,5 @@ Options:
107112
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
108113
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1
109114
(overrides --vae-tile-size)
115+
--preview preview method. must be one of the following [none, proj, tae, vae] (default is none)
110116
```

examples/cli/main.cpp

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ const char* modes_str[] = {
4646
};
4747
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale"
4848

49+
const char* previews_str[] = {
50+
"none",
51+
"proj",
52+
"tae",
53+
"vae",
54+
};
55+
4956
enum SDMode {
5057
IMG_GEN,
5158
VID_GEN,
@@ -135,6 +142,12 @@ struct SDParams {
135142
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
136143
bool force_sdxl_vae_conv_scale = false;
137144

145+
preview_t preview_method = PREVIEW_NONE;
146+
int preview_interval = 1;
147+
std::string preview_path = "preview.png";
148+
bool taesd_preview = false;
149+
bool preview_noisy = false;
150+
138151
SDParams() {
139152
sd_sample_params_init(&sample_params);
140153
sd_sample_params_init(&high_noise_sample_params);
@@ -210,6 +223,8 @@ void print_params(SDParams params) {
210223
printf(" video_frames: %d\n", params.video_frames);
211224
printf(" vace_strength: %.2f\n", params.vace_strength);
212225
printf(" fps: %d\n", params.fps);
226+
printf(" preview_mode: %s (%s)\n", previews_str[params.preview_method], params.preview_noisy ? "noisy" : "denoised");
227+
printf(" preview_interval: %d\n", params.preview_interval);
213228
free(sample_params_str);
214229
free(high_noise_sample_params_str);
215230
}
@@ -589,6 +604,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
589604
"--negative-prompt",
590605
"the negative prompt (default: \"\")",
591606
&params.negative_prompt},
607+
{"",
608+
"--preview-path",
609+
"path to write preview image to (default: ./preview.png)",
610+
&params.preview_path},
592611
{"",
593612
"--upscale-model",
594613
"path to esrgan model.",
@@ -647,6 +666,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
647666
"shift timestep for NitroFusion models (default: 0). "
648667
"recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant",
649668
&params.sample_params.shifted_timestep},
669+
{"",
670+
"--preview-interval",
671+
"interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)",
672+
&params.preview_interval},
650673
};
651674

652675
options.float_options = {
@@ -801,7 +824,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
801824
"--disable-auto-resize-ref-image",
802825
"disable auto resize of ref images",
803826
false, &params.auto_resize_ref_image},
804-
};
827+
{"",
828+
"--taesd-preview-only",
829+
std::string("prevents usage of taesd for decoding the final image. (for use with --preview ") + previews_str[PREVIEW_TAE] + ")",
830+
true, &params.taesd_preview},
831+
{"",
832+
"--preview-noisy",
833+
"enables previewing noisy inputs of the models rather than the denoised outputs",
834+
true, &params.preview_noisy}};
805835

806836
auto on_mode_arg = [&](int argc, const char** argv, int index) {
807837
if (++index >= argc) {
@@ -1046,6 +1076,26 @@ void parse_args(int argc, const char** argv, SDParams& params) {
10461076
return 1;
10471077
};
10481078

1079+
auto on_preview_arg = [&](int argc, const char** argv, int index) {
1080+
if (++index >= argc) {
1081+
return -1;
1082+
}
1083+
const char* preview = argv[index];
1084+
int preview_method = -1;
1085+
for (int m = 0; m < PREVIEW_COUNT; m++) {
1086+
if (!strcmp(preview, previews_str[m])) {
1087+
preview_method = m;
1088+
}
1089+
}
1090+
if (preview_method == -1) {
1091+
fprintf(stderr, "error: preview method %s\n",
1092+
preview);
1093+
return -1;
1094+
}
1095+
params.preview_method = (preview_t)preview_method;
1096+
return 1;
1097+
};
1098+
10491099
options.manual_options = {
10501100
{"-M",
10511101
"--mode",
@@ -1110,6 +1160,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
11101160
"--vae-relative-tile-size",
11111161
"relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)",
11121162
on_relative_tile_size_arg},
1163+
{"",
1164+
"--preview",
1165+
std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")\n",
1166+
on_preview_arg},
11131167
};
11141168

11151169
if (!parse_options(argc, argv, options)) {
@@ -1452,15 +1506,50 @@ bool load_images_from_dir(const std::string dir,
14521506
return true;
14531507
}
14541508

1509+
const char* preview_path;
1510+
float preview_fps;
1511+
1512+
void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy) {
1513+
(void)step;
1514+
(void)is_noisy;
1515+
// is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents
1516+
// unused in this app, it will either be always noisy or always denoised here
1517+
if (frame_count == 1) {
1518+
stbi_write_png(preview_path, image->width, image->height, image->channel, image->data, 0);
1519+
} else {
1520+
create_mjpg_avi_from_sd_images(preview_path, image, frame_count, preview_fps);
1521+
}
1522+
}
1523+
14551524
int main(int argc, const char* argv[]) {
14561525
SDParams params;
14571526
parse_args(argc, argv, params);
1527+
preview_path = params.preview_path.c_str();
1528+
if (params.video_frames > 4) {
1529+
size_t last_dot_pos = params.preview_path.find_last_of(".");
1530+
std::string base_path = params.preview_path;
1531+
std::string file_ext = "";
1532+
if (last_dot_pos != std::string::npos) { // filename has extension
1533+
base_path = params.preview_path.substr(0, last_dot_pos);
1534+
file_ext = params.preview_path.substr(last_dot_pos);
1535+
std::transform(file_ext.begin(), file_ext.end(), file_ext.begin(), ::tolower);
1536+
}
1537+
if (file_ext == ".png") {
1538+
base_path = base_path + ".avi";
1539+
preview_path = base_path.c_str();
1540+
}
1541+
}
1542+
preview_fps = params.fps;
1543+
if (params.preview_method == PREVIEW_PROJ)
1544+
preview_fps /= 4.0f;
1545+
14581546
params.sample_params.guidance.slg.layers = params.skip_layers.data();
14591547
params.sample_params.guidance.slg.layer_count = params.skip_layers.size();
14601548
params.high_noise_sample_params.guidance.slg.layers = params.high_noise_skip_layers.data();
14611549
params.high_noise_sample_params.guidance.slg.layer_count = params.high_noise_skip_layers.size();
14621550

14631551
sd_set_log_callback(sd_log_cb, (void*)&params);
1552+
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval, !params.preview_noisy, params.preview_noisy);
14641553

14651554
if (params.verbose) {
14661555
print_params(params);
@@ -1654,6 +1743,7 @@ int main(int argc, const char* argv[]) {
16541743
params.control_net_cpu,
16551744
params.vae_on_cpu,
16561745
params.diffusion_flash_attn,
1746+
params.taesd_preview,
16571747
params.diffusion_conv_direct,
16581748
params.vae_conv_direct,
16591749
params.force_sdxl_vae_conv_scale,

ggml_extend.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
875875
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
876876
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
877877
int num_tiles = num_tiles_x * num_tiles_y;
878-
LOG_INFO("processing %i tiles", num_tiles);
878+
LOG_DEBUG("processing %i tiles", num_tiles);
879879
pretty_progress(0, num_tiles, 0.0f);
880880
int tile_count = 1;
881881
bool last_y = false, last_x = false;

latent-preview.h

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#include <cstddef>
2+
#include <cstdint>
3+
#include "ggml.h"
4+
5+
const float wan_21_latent_rgb_proj[16][3] = {
6+
{0.015123f, -0.148418f, 0.479828f},
7+
{0.003652f, -0.010680f, -0.037142f},
8+
{0.212264f, 0.063033f, 0.016779f},
9+
{0.232999f, 0.406476f, 0.220125f},
10+
{-0.051864f, -0.082384f, -0.069396f},
11+
{0.085005f, -0.161492f, 0.010689f},
12+
{-0.245369f, -0.506846f, -0.117010f},
13+
{-0.151145f, 0.017721f, 0.007207f},
14+
{-0.293239f, -0.207936f, -0.421135f},
15+
{-0.187721f, 0.050783f, 0.177649f},
16+
{-0.013067f, 0.265964f, 0.166578f},
17+
{0.028327f, 0.109329f, 0.108642f},
18+
{-0.205343f, 0.043991f, 0.148914f},
19+
{0.014307f, -0.048647f, -0.007219f},
20+
{0.217150f, 0.053074f, 0.319923f},
21+
{0.155357f, 0.083156f, 0.064780f}};
22+
float wan_21_latent_rgb_bias[3] = {-0.270270f, -0.234976f, -0.456853f};
23+
24+
const float wan_22_latent_rgb_proj[48][3] = {
25+
{0.017126f, -0.027230f, -0.019257f},
26+
{-0.113739f, -0.028715f, -0.022885f},
27+
{-0.000106f, 0.021494f, 0.004629f},
28+
{-0.013273f, -0.107137f, -0.033638f},
29+
{-0.000381f, 0.000279f, 0.025877f},
30+
{-0.014216f, -0.003975f, 0.040528f},
31+
{0.001638f, -0.000748f, 0.011022f},
32+
{0.029238f, -0.006697f, 0.035933f},
33+
{0.021641f, -0.015874f, 0.040531f},
34+
{-0.101984f, -0.070160f, -0.028855f},
35+
{0.033207f, -0.021068f, 0.002663f},
36+
{-0.104711f, 0.121673f, 0.102981f},
37+
{0.082647f, -0.004991f, 0.057237f},
38+
{-0.027375f, 0.031581f, 0.006868f},
39+
{-0.045434f, 0.029444f, 0.019287f},
40+
{-0.046572f, -0.012537f, 0.006675f},
41+
{0.074709f, 0.033690f, 0.025289f},
42+
{-0.008251f, -0.002745f, -0.006999f},
43+
{0.012685f, -0.061856f, -0.048658f},
44+
{0.042304f, -0.007039f, 0.000295f},
45+
{-0.007644f, -0.060843f, -0.033142f},
46+
{0.159909f, 0.045628f, 0.367541f},
47+
{0.095171f, 0.086438f, 0.010271f},
48+
{0.006812f, 0.019643f, 0.029637f},
49+
{0.003467f, -0.010705f, 0.014252f},
50+
{-0.099681f, -0.066272f, -0.006243f},
51+
{0.047357f, 0.037040f, 0.000185f},
52+
{-0.041797f, -0.089225f, -0.032257f},
53+
{0.008928f, 0.017028f, 0.018684f},
54+
{-0.042255f, 0.016045f, 0.006849f},
55+
{0.011268f, 0.036462f, 0.037387f},
56+
{0.011553f, -0.016375f, -0.048589f},
57+
{0.046266f, -0.027189f, 0.056979f},
58+
{0.009640f, -0.017576f, 0.030324f},
59+
{-0.045794f, -0.036083f, -0.010616f},
60+
{0.022418f, 0.039783f, -0.032939f},
61+
{-0.052714f, -0.015525f, 0.007438f},
62+
{0.193004f, 0.223541f, 0.264175f},
63+
{-0.059406f, -0.008188f, 0.022867f},
64+
{-0.156742f, -0.263791f, -0.007385f},
65+
{-0.015717f, 0.016570f, 0.033969f},
66+
{0.037969f, 0.109835f, 0.200449f},
67+
{-0.000782f, -0.009566f, -0.008058f},
68+
{0.010709f, 0.052960f, -0.044195f},
69+
{0.017271f, 0.045839f, 0.034569f},
70+
{0.009424f, 0.013088f, -0.001714f},
71+
{-0.024805f, -0.059378f, -0.033756f},
72+
{-0.078293f, 0.029070f, 0.026129f}};
73+
float wan_22_latent_rgb_bias[3] = {0.013160f, -0.096492f, -0.071323f};
74+
75+
const float flux_latent_rgb_proj[16][3] = {
76+
{-0.041168f, 0.019917f, 0.097253f},
77+
{0.028096f, 0.026730f, 0.129576f},
78+
{0.065618f, -0.067950f, -0.014651f},
79+
{-0.012998f, -0.014762f, 0.081251f},
80+
{0.078567f, 0.059296f, -0.024687f},
81+
{-0.015987f, -0.003697f, 0.005012f},
82+
{0.033605f, 0.138999f, 0.068517f},
83+
{-0.024450f, -0.063567f, -0.030101f},
84+
{-0.040194f, -0.016710f, 0.127185f},
85+
{0.112681f, 0.088764f, -0.041940f},
86+
{-0.023498f, 0.093664f, 0.025543f},
87+
{0.082899f, 0.048320f, 0.007491f},
88+
{0.075712f, 0.074139f, 0.081965f},
89+
{-0.143501f, 0.018263f, -0.136138f},
90+
{-0.025767f, -0.082035f, -0.040023f},
91+
{-0.111849f, -0.055589f, -0.032361f}};
92+
float flux_latent_rgb_bias[3] = {0.024600f, -0.006937f, -0.008089f};
93+
94+
// This one was taken straight from
95+
// https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303
96+
// (MiT Licence)
97+
const float sd3_latent_rgb_proj[16][3] = {
98+
{-0.0645f, 0.0177f, 0.1052f},
99+
{0.0028f, 0.0312f, 0.0650f},
100+
{0.1848f, 0.0762f, 0.0360f},
101+
{0.0944f, 0.0360f, 0.0889f},
102+
{0.0897f, 0.0506f, -0.0364f},
103+
{-0.0020f, 0.1203f, 0.0284f},
104+
{0.0855f, 0.0118f, 0.0283f},
105+
{-0.0539f, 0.0658f, 0.1047f},
106+
{-0.0057f, 0.0116f, 0.0700f},
107+
{-0.0412f, 0.0281f, -0.0039f},
108+
{0.1106f, 0.1171f, 0.1220f},
109+
{-0.0248f, 0.0682f, -0.0481f},
110+
{0.0815f, 0.0846f, 0.1207f},
111+
{-0.0120f, -0.0055f, -0.0867f},
112+
{-0.0749f, -0.0634f, -0.0456f},
113+
{-0.1418f, -0.1457f, -0.1259f},
114+
};
115+
float sd3_latent_rgb_bias[3] = {0, 0, 0};
116+
117+
const float sdxl_latent_rgb_proj[4][3] = {
118+
{0.258303f, 0.277640f, 0.329699f},
119+
{-0.299701f, 0.105446f, 0.014194f},
120+
{0.050522f, 0.186163f, -0.143257f},
121+
{-0.211938f, -0.149892f, -0.080036f}};
122+
float sdxl_latent_rgb_bias[3] = {0.144381f, -0.033313f, 0.007061f};
123+
124+
const float sd_latent_rgb_proj[4][3] = {
125+
{0.337366f, 0.216344f, 0.257386f},
126+
{0.165636f, 0.386828f, 0.046994f},
127+
{-0.267803f, 0.237036f, 0.223517f},
128+
{-0.178022f, -0.200862f, -0.678514f}};
129+
float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};
130+
131+
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) {
132+
size_t buffer_head = 0;
133+
for (int k = 0; k < frames; k++) {
134+
for (int j = 0; j < height; j++) {
135+
for (int i = 0; i < width; i++) {
136+
size_t latent_id = (i * latents->nb[0] + j * latents->nb[1] + k * latents->nb[2]);
137+
float r = 0, g = 0, b = 0;
138+
if (latent_rgb_proj != nullptr) {
139+
for (int d = 0; d < dim; d++) {
140+
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[ggml_n_dims(latents) - 1]);
141+
r += value * latent_rgb_proj[d][0];
142+
g += value * latent_rgb_proj[d][1];
143+
b += value * latent_rgb_proj[d][2];
144+
}
145+
} else {
146+
// interpret first 3 channels as RGB
147+
r = *(float*)((char*)latents->data + latent_id + 0 * latents->nb[ggml_n_dims(latents) - 1]);
148+
g = *(float*)((char*)latents->data + latent_id + 1 * latents->nb[ggml_n_dims(latents) - 1]);
149+
b = *(float*)((char*)latents->data + latent_id + 2 * latents->nb[ggml_n_dims(latents) - 1]);
150+
}
151+
if (latent_rgb_bias != nullptr) {
152+
// bias
153+
r += latent_rgb_bias[0];
154+
g += latent_rgb_bias[1];
155+
b += latent_rgb_bias[2];
156+
}
157+
// change range
158+
r = r * .5f + .5f;
159+
g = g * .5f + .5f;
160+
b = b * .5f + .5f;
161+
162+
// clamp rgb values to [0,1] range
163+
r = r >= 0 ? r <= 1 ? r : 1 : 0;
164+
g = g >= 0 ? g <= 1 ? g : 1 : 0;
165+
b = b >= 0 ? b <= 1 ? b : 1 : 0;
166+
167+
buffer[buffer_head++] = (uint8_t)(r * 255);
168+
buffer[buffer_head++] = (uint8_t)(g * 255);
169+
buffer[buffer_head++] = (uint8_t)(b * 255);
170+
}
171+
}
172+
}
173+
}

0 commit comments

Comments
 (0)