Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions src/denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1523,12 +1523,10 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {

int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {

float sigma = sigmas[i];
float sigma_to = sigmas[i + 1];
float sigma = sigmas[i];
float sigma_to = sigmas[i + 1];

auto model_output_opt = model(x, sigma, i + 1);
if (model_output_opt.empty()) {
Expand All @@ -1551,12 +1549,11 @@ static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
float std_dev_t = eta * std::sqrt(variance);

x = pred_original_sample +
std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2))/ alpha_prod_t_prev) * model_output;
std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output;

if (eta > 0) {
x+= std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor<float>::randn_like(x, rng);
}

}
return x;
}
Expand Down Expand Up @@ -1584,8 +1581,10 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,

auto get_timestep_from_sigma = [&](float s) -> int {
auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s);
if (it == compvis_sigmas.begin()) return 0;
if (it == compvis_sigmas.end()) return TIMESTEPS - 1;
if (it == compvis_sigmas.begin())
return 0;
if (it == compvis_sigmas.end())
return TIMESTEPS - 1;
int idx_high = static_cast<int>(std::distance(compvis_sigmas.begin(), it));
int idx_low = idx_high - 1;
if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) {
Expand All @@ -1596,7 +1595,6 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,

int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {

float sigma_to = sigmas[i + 1];
int prev_timestep = get_timestep_from_sigma(sigma_to);
int timestep_s = (int)floor((1 - eta) * prev_timestep);
Expand Down Expand Up @@ -1626,7 +1624,6 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
}

}
return x;
}
Expand Down
89 changes: 63 additions & 26 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <atomic>
#include <chrono>
#include <cstdarg>
#include <cstdlib>
#include <fstream>
#include <functional>
#include <mutex>
Expand All @@ -13,9 +14,10 @@
#include <vector>

#include "model.h"
#include "model_io/ckpt_io.h"
#include "model_io/gguf_io.h"
#include "model_io/safetensors_io.h"
#include "model_io/torch_legacy_io.h"
#include "model_io/torch_zip_io.h"
#include "stable-diffusion.h"
#include "util.h"

Expand Down Expand Up @@ -229,9 +231,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string
} else if (is_safetensors_file(file_path)) {
LOG_INFO("load %s using safetensors format", file_path.c_str());
return init_from_safetensors_file(file_path, prefix);
} else if (is_ckpt_file(file_path)) {
LOG_INFO("load %s using checkpoint format", file_path.c_str());
return init_from_ckpt_file(file_path, prefix);
} else if (is_torch_zip_file(file_path)) {
LOG_INFO("load %s using torch zip format", file_path.c_str());
return init_from_torch_zip_file(file_path, prefix);
} else if (init_from_torch_legacy_file(file_path, prefix)) {
LOG_INFO("load %s using torch legacy format", file_path.c_str());
return true;
} else {
if (file_exists(file_path)) {
LOG_WARN("unknown format %s", file_path.c_str());
Expand Down Expand Up @@ -329,40 +334,47 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
return true;
}

/*================================================= DiffusersModelLoader ==================================================*/
/*================================================= TorchLegacyModelLoader ==================================================*/

bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix) {
LOG_DEBUG("init from torch legacy '%s'", file_path.c_str());

if (!init_from_safetensors_file(unet_path, "unet.")) {
std::vector<TensorStorage> tensor_storages;
std::string error;
if (!read_torch_legacy_file(file_path, tensor_storages, &error)) {
if ((!error.empty()) && (ends_with(file_path, ".pt") || ends_with(file_path, ".pth"))) {
LOG_WARN("%s", error.c_str());
}
return false;
}

if (!init_from_safetensors_file(vae_path, "vae.")) {
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_path, "te.")) {
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
file_paths_.push_back(file_path);
size_t file_index = file_paths_.size() - 1;

for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}

if (!starts_with(tensor_storage.name, prefix)) {
tensor_storage.name = prefix + tensor_storage.name;
}
tensor_storage.file_index = file_index;

add_tensor_storage(tensor_storage);
}

return true;
}

/*================================================= CkptModelLoader ==================================================*/
/*================================================= TorchZipModelLoader ==================================================*/

bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) {
bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) {
LOG_DEBUG("init from '%s'", file_path.c_str());

std::vector<TensorStorage> tensor_storages;
std::string error;
if (!read_ckpt_file(file_path, tensor_storages, &error)) {
if (!read_torch_zip_file(file_path, tensor_storages, &error)) {
LOG_ERROR("%s", error.c_str());
return false;
}
Expand All @@ -384,6 +396,32 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
return true;
}

/*================================================= DiffusersModelLoader ==================================================*/

bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");

if (!init_from_safetensors_file(unet_path, "unet.")) {
return false;
}

if (!init_from_safetensors_file(vae_path, "vae.")) {
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_path, "te.")) {
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
// return false;
}
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
}
return true;
}

SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;

Expand Down Expand Up @@ -1210,6 +1248,5 @@ bool convert(const char* input_path,
if (convert_name) {
model_loader.convert_tensors_name();
}
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
return success;
return model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
}
3 changes: 2 additions & 1 deletion src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ class ModelLoader {

bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = "");
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");

public:
Expand Down
57 changes: 57 additions & 0 deletions src/model_io/binary_io.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef __SD_MODEL_IO_BINARY_IO_H__
#define __SD_MODEL_IO_BINARY_IO_H__

#include <cstdint>
#include <ostream>

namespace model_io {

inline int32_t read_int(const uint8_t* buffer) {
uint32_t value = 0;
value |= static_cast<uint32_t>(buffer[3]) << 24;
value |= static_cast<uint32_t>(buffer[2]) << 16;
value |= static_cast<uint32_t>(buffer[1]) << 8;
value |= static_cast<uint32_t>(buffer[0]);
return static_cast<int32_t>(value);
}

inline uint16_t read_short(const uint8_t* buffer) {
uint16_t value = 0;
value |= static_cast<uint16_t>(buffer[1]) << 8;
value |= static_cast<uint16_t>(buffer[0]);
return value;
}

inline uint64_t read_u64(const uint8_t* buffer) {
uint64_t value = 0;
value |= static_cast<uint64_t>(buffer[7]) << 56;
value |= static_cast<uint64_t>(buffer[6]) << 48;
value |= static_cast<uint64_t>(buffer[5]) << 40;
value |= static_cast<uint64_t>(buffer[4]) << 32;
value |= static_cast<uint64_t>(buffer[3]) << 24;
value |= static_cast<uint64_t>(buffer[2]) << 16;
value |= static_cast<uint64_t>(buffer[1]) << 8;
value |= static_cast<uint64_t>(buffer[0]);
return value;
}

inline void write_u64(std::ostream& stream, uint64_t value) {
uint8_t buffer[8];
for (int i = 0; i < 8; ++i) {
buffer[i] = static_cast<uint8_t>((value >> (8 * i)) & 0xFF);
}
stream.write((const char*)buffer, sizeof(buffer));
}

inline int find_char(const uint8_t* buffer, int len, char c) {
for (int pos = 0; pos < len; pos++) {
if (buffer[pos] == (uint8_t)c) {
return pos;
}
}
return -1;
}

} // namespace model_io

#endif // __SD_MODEL_IO_BINARY_IO_H__
Loading
Loading