|
6 | 6 | #include "unet.hpp" |
7 | 7 | #include "wan.hpp" |
8 | 8 |
|
| 9 | +struct DiffusionParams { |
| 10 | + struct ggml_tensor* x = NULL; |
| 11 | + struct ggml_tensor* timesteps = NULL; |
| 12 | + struct ggml_tensor* context = NULL; |
| 13 | + struct ggml_tensor* c_concat = NULL; |
| 14 | + struct ggml_tensor* y = NULL; |
| 15 | + struct ggml_tensor* guidance = NULL; |
| 16 | + std::vector<ggml_tensor*> ref_latents = {}; |
| 17 | + bool increase_ref_index = false; |
| 18 | + int num_video_frames = -1; |
| 19 | + std::vector<struct ggml_tensor*> controls = {}; |
| 20 | + float control_strength = 0.f; |
| 21 | + struct ggml_tensor* vace_context = NULL; |
| 22 | + float vace_strength = 1.f; |
| 23 | + std::vector<int> skip_layers = {}; |
| 24 | +}; |
| 25 | + |
9 | 26 | struct DiffusionModel { |
10 | 27 | virtual std::string get_desc() = 0; |
11 | 28 | virtual void compute(int n_threads, |
12 | | - struct ggml_tensor* x, |
13 | | - struct ggml_tensor* timesteps, |
14 | | - struct ggml_tensor* context, |
15 | | - struct ggml_tensor* c_concat, |
16 | | - struct ggml_tensor* y, |
17 | | - struct ggml_tensor* guidance, |
18 | | - std::vector<ggml_tensor*> ref_latents = {}, |
19 | | - bool increase_ref_index = false, |
20 | | - int num_video_frames = -1, |
21 | | - std::vector<struct ggml_tensor*> controls = {}, |
22 | | - float control_strength = 0.f, |
23 | | - struct ggml_tensor** output = NULL, |
24 | | - struct ggml_context* output_ctx = NULL, |
25 | | - std::vector<int> skip_layers = std::vector<int>()) = 0; |
| 29 | + DiffusionParams diffusion_params, |
| 30 | + struct ggml_tensor** output = NULL, |
| 31 | + struct ggml_context* output_ctx = NULL) = 0; |
26 | 32 | virtual void alloc_params_buffer() = 0; |
27 | 33 | virtual void free_params_buffer() = 0; |
28 | 34 | virtual void free_compute_buffer() = 0; |
@@ -71,22 +77,18 @@ struct UNetModel : public DiffusionModel { |
71 | 77 | } |
72 | 78 |
|
73 | 79 | void compute(int n_threads, |
74 | | - struct ggml_tensor* x, |
75 | | - struct ggml_tensor* timesteps, |
76 | | - struct ggml_tensor* context, |
77 | | - struct ggml_tensor* c_concat, |
78 | | - struct ggml_tensor* y, |
79 | | - struct ggml_tensor* guidance, |
80 | | - std::vector<ggml_tensor*> ref_latents = {}, |
81 | | - bool increase_ref_index = false, |
82 | | - int num_video_frames = -1, |
83 | | - std::vector<struct ggml_tensor*> controls = {}, |
84 | | - float control_strength = 0.f, |
85 | | - struct ggml_tensor** output = NULL, |
86 | | - struct ggml_context* output_ctx = NULL, |
87 | | - std::vector<int> skip_layers = std::vector<int>()) { |
88 | | - (void)skip_layers; // SLG doesn't work with UNet models |
89 | | - return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); |
| 80 | + DiffusionParams diffusion_params, |
| 81 | + struct ggml_tensor** output = NULL, |
| 82 | + struct ggml_context* output_ctx = NULL) { |
| 83 | + return unet.compute(n_threads, |
| 84 | + diffusion_params.x, |
| 85 | + diffusion_params.timesteps, |
| 86 | + diffusion_params.context, |
| 87 | + diffusion_params.c_concat, |
| 88 | + diffusion_params.y, |
| 89 | + diffusion_params.num_video_frames, |
| 90 | + diffusion_params.controls, |
| 91 | + diffusion_params.control_strength, output, output_ctx); |
90 | 92 | } |
91 | 93 | }; |
92 | 94 |
|
@@ -129,21 +131,17 @@ struct MMDiTModel : public DiffusionModel { |
129 | 131 | } |
130 | 132 |
|
131 | 133 | void compute(int n_threads, |
132 | | - struct ggml_tensor* x, |
133 | | - struct ggml_tensor* timesteps, |
134 | | - struct ggml_tensor* context, |
135 | | - struct ggml_tensor* c_concat, |
136 | | - struct ggml_tensor* y, |
137 | | - struct ggml_tensor* guidance, |
138 | | - std::vector<ggml_tensor*> ref_latents = {}, |
139 | | - bool increase_ref_index = false, |
140 | | - int num_video_frames = -1, |
141 | | - std::vector<struct ggml_tensor*> controls = {}, |
142 | | - float control_strength = 0.f, |
143 | | - struct ggml_tensor** output = NULL, |
144 | | - struct ggml_context* output_ctx = NULL, |
145 | | - std::vector<int> skip_layers = std::vector<int>()) { |
146 | | - return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); |
| 134 | + DiffusionParams diffusion_params, |
| 135 | + struct ggml_tensor** output = NULL, |
| 136 | + struct ggml_context* output_ctx = NULL) { |
| 137 | + return mmdit.compute(n_threads, |
| 138 | + diffusion_params.x, |
| 139 | + diffusion_params.timesteps, |
| 140 | + diffusion_params.context, |
| 141 | + diffusion_params.y, |
| 142 | + output, |
| 143 | + output_ctx, |
| 144 | + diffusion_params.skip_layers); |
147 | 145 | } |
148 | 146 | }; |
149 | 147 |
|
@@ -188,21 +186,21 @@ struct FluxModel : public DiffusionModel { |
188 | 186 | } |
189 | 187 |
|
190 | 188 | void compute(int n_threads, |
191 | | - struct ggml_tensor* x, |
192 | | - struct ggml_tensor* timesteps, |
193 | | - struct ggml_tensor* context, |
194 | | - struct ggml_tensor* c_concat, |
195 | | - struct ggml_tensor* y, |
196 | | - struct ggml_tensor* guidance, |
197 | | - std::vector<ggml_tensor*> ref_latents = {}, |
198 | | - bool increase_ref_index = false, |
199 | | - int num_video_frames = -1, |
200 | | - std::vector<struct ggml_tensor*> controls = {}, |
201 | | - float control_strength = 0.f, |
202 | | - struct ggml_tensor** output = NULL, |
203 | | - struct ggml_context* output_ctx = NULL, |
204 | | - std::vector<int> skip_layers = std::vector<int>()) { |
205 | | - return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, output, output_ctx, skip_layers); |
| 189 | + DiffusionParams diffusion_params, |
| 190 | + struct ggml_tensor** output = NULL, |
| 191 | + struct ggml_context* output_ctx = NULL) { |
| 192 | + return flux.compute(n_threads, |
| 193 | + diffusion_params.x, |
| 194 | + diffusion_params.timesteps, |
| 195 | + diffusion_params.context, |
| 196 | + diffusion_params.c_concat, |
| 197 | + diffusion_params.y, |
| 198 | + diffusion_params.guidance, |
| 199 | + diffusion_params.ref_latents, |
| 200 | + diffusion_params.increase_ref_index, |
| 201 | + output, |
| 202 | + output_ctx, |
| 203 | + diffusion_params.skip_layers); |
206 | 204 | } |
207 | 205 | }; |
208 | 206 |
|
@@ -248,21 +246,20 @@ struct WanModel : public DiffusionModel { |
248 | 246 | } |
249 | 247 |
|
250 | 248 | void compute(int n_threads, |
251 | | - struct ggml_tensor* x, |
252 | | - struct ggml_tensor* timesteps, |
253 | | - struct ggml_tensor* context, |
254 | | - struct ggml_tensor* c_concat, |
255 | | - struct ggml_tensor* y, |
256 | | - struct ggml_tensor* guidance, |
257 | | - std::vector<ggml_tensor*> ref_latents = {}, |
258 | | - bool increase_ref_index = false, |
259 | | - int num_video_frames = -1, |
260 | | - std::vector<struct ggml_tensor*> controls = {}, |
261 | | - float control_strength = 0.f, |
262 | | - struct ggml_tensor** output = NULL, |
263 | | - struct ggml_context* output_ctx = NULL, |
264 | | - std::vector<int> skip_layers = std::vector<int>()) { |
265 | | - return wan.compute(n_threads, x, timesteps, context, y, c_concat, NULL, output, output_ctx); |
| 249 | + DiffusionParams diffusion_params, |
| 250 | + struct ggml_tensor** output = NULL, |
| 251 | + struct ggml_context* output_ctx = NULL) { |
| 252 | + return wan.compute(n_threads, |
| 253 | + diffusion_params.x, |
| 254 | + diffusion_params.timesteps, |
| 255 | + diffusion_params.context, |
| 256 | + diffusion_params.y, |
| 257 | + diffusion_params.c_concat, |
| 258 | + NULL, |
| 259 | + diffusion_params.vace_context, |
| 260 | + diffusion_params.vace_strength, |
| 261 | + output, |
| 262 | + output_ctx); |
266 | 263 | } |
267 | 264 | }; |
268 | 265 |
|
|
0 commit comments