diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index decee0a94..1bacae158 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -815,12 +815,14 @@ int main(int argc, const char* argv[]) { SDImageOwner current_image(results[i]); results[i] = {0, 0, 0, nullptr}; for (int u = 0; u < gen_params.upscale_repeats; ++u) { - SDImageOwner upscaled_image(upscale(upscaler_ctx.get(), current_image.get(), upscale_factor)); - if (upscaled_image.get().data == nullptr) { + sd_image_t* upscaled_images = upscale(upscaler_ctx.get(), current_image.get(), upscale_factor); + if (upscaled_images == nullptr || upscaled_images[0].data == nullptr) { + free(upscaled_images); LOG_ERROR("upscale failed"); break; } - current_image = std::move(upscaled_image); + current_image.reset(upscaled_images[0]); + free(upscaled_images); } results[i] = current_image.release(); // Set the final upscaled image as the result } diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 17596f849..0ea36ca58 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -468,7 +468,7 @@ SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path, const char* params_backend); SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); -SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, +SD_API sd_image_t* upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor); diff --git a/src/upscaler.cpp b/src/upscaler.cpp index 8635f6778..90cd9012b 100644 --- a/src/upscaler.cpp +++ b/src/upscaler.cpp @@ -4,6 +4,7 @@ #include "model_loader.h" #include "stable-diffusion.h" +#include #include UpscalerGGML::UpscalerGGML(int n_threads, @@ -179,8 +180,23 @@ upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str, return upscaler_ctx; } -sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor) { - return upscaler_ctx->upscaler->upscale(input_image, upscale_factor); +sd_image_t* upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor) { + if (upscaler_ctx == nullptr || upscaler_ctx->upscaler == nullptr) { + return nullptr; + } + + sd_image_t* result_images = (sd_image_t*)calloc(1, sizeof(sd_image_t)); + if (result_images == nullptr) { + return nullptr; + } + + result_images[0] = upscaler_ctx->upscaler->upscale(input_image, upscale_factor); + if (result_images[0].data == nullptr) { + free(result_images); + return nullptr; + } + + return result_images; } int get_upscale_factor(upscaler_ctx_t* upscaler_ctx) {