From 5ff736fd569ac3754242331e5a03ebd72b3585ab Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 19 Apr 2026 22:48:21 +0800 Subject: [PATCH 1/2] feat: add restricted torch legacy checkpoint loading --- src/denoiser.hpp | 19 +- src/model.cpp | 55 +- src/model.h | 3 +- src/model_io/binary_io.h | 57 ++ src/model_io/ckpt_io.cpp | 403 ----------- src/model_io/ckpt_io.h | 14 - src/model_io/pickle_io.cpp | 1064 ++++++++++++++++++++++++++++++ src/model_io/pickle_io.h | 21 + src/model_io/safetensors_io.cpp | 19 +- src/model_io/tensor_storage.h | 1 + src/model_io/torch_legacy_io.cpp | 252 +++++++ src/model_io/torch_legacy_io.h | 13 + src/model_io/torch_zip_io.cpp | 140 ++++ src/model_io/torch_zip_io.h | 14 + 14 files changed, 1620 insertions(+), 455 deletions(-) create mode 100644 src/model_io/binary_io.h delete mode 100644 src/model_io/ckpt_io.cpp delete mode 100644 src/model_io/ckpt_io.h create mode 100644 src/model_io/pickle_io.cpp create mode 100644 src/model_io/pickle_io.h create mode 100644 src/model_io/torch_legacy_io.cpp create mode 100644 src/model_io/torch_legacy_io.h create mode 100644 src/model_io/torch_zip_io.cpp create mode 100644 src/model_io/torch_zip_io.h diff --git a/src/denoiser.hpp b/src/denoiser.hpp index 14b6d3beb..a6e81d597 100644 --- a/src/denoiser.hpp +++ b/src/denoiser.hpp @@ -1523,12 +1523,10 @@ static sd::Tensor sample_ddim_trailing(denoise_cb_t model, const std::vector& sigmas, std::shared_ptr rng, float eta) { - int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { - - float sigma = sigmas[i]; - float sigma_to = sigmas[i + 1]; + float sigma = sigmas[i]; + float sigma_to = sigmas[i + 1]; auto model_output_opt = model(x, sigma, i + 1); if (model_output_opt.empty()) { @@ -1551,12 +1549,11 @@ static sd::Tensor sample_ddim_trailing(denoise_cb_t model, float std_dev_t = eta * std::sqrt(variance); x = pred_original_sample + - std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2))/ alpha_prod_t_prev) * model_output; + std::sqrt((1.0f - alpha_prod_t_prev - std::pow(std_dev_t, 2)) / alpha_prod_t_prev) * model_output; if (eta > 0) { - x+= std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor::randn_like(x, rng); + x += std_dev_t / std::sqrt(alpha_prod_t_prev) * sd::Tensor::randn_like(x, rng); } - } return x; } @@ -1584,8 +1581,10 @@ static sd::Tensor sample_tcd(denoise_cb_t model, auto get_timestep_from_sigma = [&](float s) -> int { auto it = std::lower_bound(compvis_sigmas.begin(), compvis_sigmas.end(), s); - if (it == compvis_sigmas.begin()) return 0; - if (it == compvis_sigmas.end()) return TIMESTEPS - 1; + if (it == compvis_sigmas.begin()) + return 0; + if (it == compvis_sigmas.end()) + return TIMESTEPS - 1; int idx_high = static_cast(std::distance(compvis_sigmas.begin(), it)); int idx_low = idx_high - 1; if (std::abs(compvis_sigmas[idx_high] - s) < std::abs(compvis_sigmas[idx_low] - s)) { @@ -1596,7 +1595,6 @@ static sd::Tensor sample_tcd(denoise_cb_t model, int steps = static_cast(sigmas.size()) - 1; for (int i = 0; i < steps; i++) { - float sigma_to = sigmas[i + 1]; int prev_timestep = get_timestep_from_sigma(sigma_to); int timestep_s = (int)floor((1 - eta) * prev_timestep); @@ -1626,7 +1624,6 @@ static sd::Tensor sample_tcd(denoise_cb_t model, x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x + std::sqrt(1.0f / alpha_prod_t_prev - 1.0f / alpha_prod_s) * sd::Tensor::randn_like(x, rng); } - } return x; } diff --git a/src/model.cpp b/src/model.cpp index 2594267fb..5cf577acd 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -13,9 +14,10 @@ #include #include "model.h" -#include "model_io/ckpt_io.h" #include "model_io/gguf_io.h" #include "model_io/safetensors_io.h" +#include "model_io/torch_legacy_io.h" +#include "model_io/torch_zip_io.h" #include "stable-diffusion.h" #include "util.h" @@ -229,9 +231,12 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string } else if (is_safetensors_file(file_path)) { LOG_INFO("load %s using safetensors format", file_path.c_str()); return init_from_safetensors_file(file_path, prefix); - } else if (is_ckpt_file(file_path)) { - LOG_INFO("load %s using checkpoint format", file_path.c_str()); - return init_from_ckpt_file(file_path, prefix); + } else if (is_torch_zip_file(file_path)) { + LOG_INFO("load %s using torch zip format", file_path.c_str()); + return init_from_torch_zip_file(file_path, prefix); + } else if (init_from_torch_legacy_file(file_path, prefix)) { + LOG_INFO("load %s using torch legacy format", file_path.c_str()); + return true; } else { if (file_exists(file_path)) { LOG_WARN("unknown format %s", file_path.c_str()); @@ -329,6 +334,39 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const return true; } +/*================================================= TorchLegacyModelLoader ==================================================*/ + +bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix) { + LOG_DEBUG("init from torch legacy '%s'", file_path.c_str()); + + std::vector tensor_storages; + std::string error; + if (!read_torch_legacy_file(file_path, tensor_storages, &error)) { + if ((!error.empty()) && (ends_with(file_path, ".pt") || ends_with(file_path, ".pth"))) { + LOG_WARN("%s", error.c_str()); + } + return false; + } + + file_paths_.push_back(file_path); + size_t file_index = file_paths_.size() - 1; + + for (auto& tensor_storage : tensor_storages) { + if (is_unused_tensor(tensor_storage.name)) { + continue; + } + + if (!starts_with(tensor_storage.name, prefix)) { + tensor_storage.name = prefix + tensor_storage.name; + } + tensor_storage.file_index = file_index; + + add_tensor_storage(tensor_storage); + } + + return true; +} + /*================================================= DiffusersModelLoader ==================================================*/ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) { @@ -355,14 +393,12 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s return true; } -/*================================================= CkptModelLoader ==================================================*/ - -bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::string& prefix) { +bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) { LOG_DEBUG("init from '%s'", file_path.c_str()); std::vector tensor_storages; std::string error; - if (!read_ckpt_file(file_path, tensor_storages, &error)) { + if (!read_torch_zip_file(file_path, tensor_storages, &error)) { LOG_ERROR("%s", error.c_str()); return false; } @@ -1210,6 +1246,5 @@ bool convert(const char* input_path, if (convert_name) { model_loader.convert_tensors_name(); } - bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); - return success; + return model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); } diff --git a/src/model.h b/src/model.h index de15431f4..90b8f6ca7 100644 --- a/src/model.h +++ b/src/model.h @@ -200,7 +200,8 @@ class ModelLoader { bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); - bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_zip_file(const std::string& file_path, const std::string& prefix = ""); + bool init_from_torch_legacy_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: diff --git a/src/model_io/binary_io.h b/src/model_io/binary_io.h new file mode 100644 index 000000000..9093eeaf9 --- /dev/null +++ b/src/model_io/binary_io.h @@ -0,0 +1,57 @@ +#ifndef __SD_MODEL_IO_BINARY_IO_H__ +#define __SD_MODEL_IO_BINARY_IO_H__ + +#include +#include + +namespace model_io { + + inline int32_t read_int(const uint8_t* buffer) { + uint32_t value = 0; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return static_cast(value); + } + + inline uint16_t read_short(const uint8_t* buffer) { + uint16_t value = 0; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; + } + + inline uint64_t read_u64(const uint8_t* buffer) { + uint64_t value = 0; + value |= static_cast(buffer[7]) << 56; + value |= static_cast(buffer[6]) << 48; + value |= static_cast(buffer[5]) << 40; + value |= static_cast(buffer[4]) << 32; + value |= static_cast(buffer[3]) << 24; + value |= static_cast(buffer[2]) << 16; + value |= static_cast(buffer[1]) << 8; + value |= static_cast(buffer[0]); + return value; + } + + inline void write_u64(std::ostream& stream, uint64_t value) { + uint8_t buffer[8]; + for (int i = 0; i < 8; ++i) { + buffer[i] = static_cast((value >> (8 * i)) & 0xFF); + } + stream.write((const char*)buffer, sizeof(buffer)); + } + + inline int find_char(const uint8_t* buffer, int len, char c) { + for (int pos = 0; pos < len; pos++) { + if (buffer[pos] == (uint8_t)c) { + return pos; + } + } + return -1; + } + +} // namespace model_io + +#endif // __SD_MODEL_IO_BINARY_IO_H__ diff --git a/src/model_io/ckpt_io.cpp b/src/model_io/ckpt_io.cpp deleted file mode 100644 index 63fd262d5..000000000 --- a/src/model_io/ckpt_io.cpp +++ /dev/null @@ -1,403 +0,0 @@ -#include "ckpt_io.h" - -#include -#include -#include -#include -#include -#include - -#include "zip.h" - -static constexpr int MAX_STRING_BUFFER = 512; - -static void set_error(std::string* error, const std::string& message) { - if (error != nullptr) { - *error = message; - } -} - -static int32_t read_int(const uint8_t* buffer) { - // little endian - uint32_t value = 0; - value |= static_cast(buffer[3]) << 24; - value |= static_cast(buffer[2]) << 16; - value |= static_cast(buffer[1]) << 8; - value |= static_cast(buffer[0]); - return static_cast(value); -} - -static uint16_t read_short(const uint8_t* buffer) { - // little endian - uint16_t value = 0; - value |= static_cast(buffer[1]) << 8; - value |= static_cast(buffer[0]); - return value; -} - -bool is_ckpt_file(const std::string& file_path) { - zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == nullptr) { - return false; - } - zip_close(zip); - return true; -} - -/*================================================= CkptModelLoader ==================================================*/ - -// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 -// 0: \x80 PROTO 2 -// 2: } EMPTY_DICT -// 3: q BINPUT 0 -// 5: ( MARK -// 6: X BINUNICODE 'epoch' -// 16: q BINPUT 1 -// 18: K BININT1 6 -// 20: X BINUNICODE 'global_step' -// 36: q BINPUT 2 -// 38: J BININT 470000 -// 43: X BINUNICODE 'pytorch-lightning_version' -// 73: q BINPUT 3 -// 75: X BINUNICODE '1.4.2' -// 85: q BINPUT 4 -// 87: X BINUNICODE 'state_dict' -// 102: q BINPUT 5 -// 104: } EMPTY_DICT -// 105: q BINPUT 6 -// 107: ( MARK -// 108: X BINUNICODE 'betas' -// 118: q BINPUT 7 -// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' -// 153: q BINPUT 8 -// 155: ( MARK -// 156: ( MARK -// 157: X BINUNICODE 'storage' -// 169: q BINPUT 9 -// 171: c GLOBAL 'torch FloatStorage' -// 191: q BINPUT 10 -// 193: X BINUNICODE '0' -// 199: q BINPUT 11 -// 201: X BINUNICODE 'cpu' -// 209: q BINPUT 12 -// 211: M BININT2 1000 -// 214: t TUPLE (MARK at 156) -// 215: q BINPUT 13 -// 217: Q BINPERSID -// 218: K BININT1 0 -// 220: M BININT2 1000 -// ............................... -// 3201: q BINPUT 250 -// 3203: R REDUCE -// 3204: q BINPUT 251 -// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' -// 3264: q BINPUT 252 -// 3266: h BINGET 8 -// 3268: ( MARK -// 3269: ( MARK -// 3270: h BINGET 9 -// 3272: h BINGET 10 -// 3274: X BINUNICODE '30' -// 3281: q BINPUT 253 -// 3283: h BINGET 12 -// 3285: J BININT 102400 -// 3290: t TUPLE (MARK at 3269) -// 3291: q BINPUT 254 -// 3293: Q BINPERSID -// 3294: K BININT1 0 -// 3296: ( MARK -// 3297: M BININT2 320 -// 3300: M BININT2 320 -// 3303: K BININT1 1 -// 3305: K BININT1 1 -// 3307: t TUPLE (MARK at 3296) -// 3308: q BINPUT 255 -// 3310: ( MARK -// 3311: M BININT2 320 -// 3314: K BININT1 1 -// 3316: K BININT1 1 -// 3318: K BININT1 1 -// 3320: t TUPLE (MARK at 3310) -// 3321: r LONG_BINPUT 256 -// 3326: \x89 NEWFALSE -// 3327: h BINGET 16 -// 3329: ) EMPTY_TUPLE -// 3330: R REDUCE -// 3331: r LONG_BINPUT 257 -// 3336: t TUPLE (MARK at 3268) -// 3337: r LONG_BINPUT 258 -// 3342: R REDUCE -// 3343: r LONG_BINPUT 259 -// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' -// 3404: r LONG_BINPUT 260 -// 3409: h BINGET 8 -// 3411: ( MARK -// 3412: ( MARK -// 3413: h BINGET 9 -// 3415: h BINGET 10 -// 3417: X BINUNICODE '31' - -struct PickleTensorReader { - enum ReadPhase { - READ_NAME, - READ_DATA, - CHECK_SIZE, - READ_DIMENS - }; - ReadPhase phase = READ_NAME; - size_t entry_size = 0; - int32_t nelements = 0; - - TensorStorage tensor_storage; - - static ggml_type global_type; // all pickle_tensors data type - static bool read_global_type; - - bool read_int_value(uint32_t value) { - if (phase == CHECK_SIZE) { - if (entry_size == value * ggml_type_size(tensor_storage.type)) { - nelements = value; - phase = READ_DIMENS; - return true; - } else { - phase = READ_NAME; - } - } else if (phase == READ_DIMENS) { - if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) { // too many dimens - phase = READ_NAME; - tensor_storage.n_dims = 0; - } - if (nelements % value == 0) { - tensor_storage.ne[tensor_storage.n_dims] = value; - tensor_storage.n_dims++; - } - } - return false; - } - - void read_global(const std::string& str) { - if (str == "FloatStorage") { - if (read_global_type) { - global_type = GGML_TYPE_F32; - read_global_type = false; - } - tensor_storage.type = GGML_TYPE_F32; - } else if (str == "HalfStorage") { - if (read_global_type) { - global_type = GGML_TYPE_F16; - read_global_type = false; - } - tensor_storage.type = GGML_TYPE_F16; - } - } - - void read_string(const std::string& str, zip_t* zip, std::string dir) { - if (str == "storage") { - read_global_type = true; - } else if (str != "state_dict") { - if (phase == READ_DATA) { - std::string entry_name = dir + "data/" + std::string(str); - - size_t i, n = zip_entries_total(zip); - for (i = 0; i < n; ++i) { - zip_entry_openbyindex(zip, i); - { - std::string name = zip_entry_name(zip); - if (name == entry_name) { - tensor_storage.index_in_zip = (int)i; - entry_size = zip_entry_size(zip); - zip_entry_close(zip); - break; - } - } - zip_entry_close(zip); - } - - phase = entry_size > 0 ? CHECK_SIZE : READ_NAME; - } - if (!read_global_type && phase == READ_NAME) { - tensor_storage.name = str; - phase = READ_DATA; - tensor_storage.type = global_type; - } - } - } -}; - -ggml_type PickleTensorReader::global_type = GGML_TYPE_F32; // all pickle_tensors data type -bool PickleTensorReader::read_global_type = false; - -static int find_char(uint8_t* buffer, int len, char c) { - for (int pos = 0; pos < len; pos++) { - if (buffer[pos] == c) { - return pos; - } - } - return -1; -} - -static bool parse_data_pkl(uint8_t* buffer, - size_t buffer_size, - zip_t* zip, - std::string dir, - std::vector& tensor_storages, - std::string* error) { - uint8_t* buffer_end = buffer + buffer_size; - if (buffer[0] == 0x80) { // proto - if (buffer[1] != 2) { - set_error(error, "unsupported pickle protocol"); - return false; - } - buffer += 2; // 0x80 and version - char string_buffer[MAX_STRING_BUFFER]; - bool finish = false; - PickleTensorReader reader; - // read pickle binary file - while (!finish && buffer < buffer_end) { - uint8_t opcode = *buffer; - buffer++; - // https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 - // https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 - switch (opcode) { - case '}': // EMPTY_DICT = b'}' # push empty dict - break; - case ']': // EMPTY_LIST = b']' # push empty list - break; - // skip unused sections - case 'h': // BINGET = b'h' # " " " " " " ; " " 1-byte arg - case 'q': // BINPUT = b'q' # " " " " " ; " " 1-byte arg - case 'Q': // BINPERSID = b'Q' # " " " ; " " " " stack - buffer++; - break; - case 'r': // LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg - buffer += 4; - break; - case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame - buffer += 8; - break; - case 0x94: // MEMOIZE = b'\x94' # store top of the stack in memo - break; - case '(': // MARK = b'(' # push special markobject on stack - break; - case 'K': // BININT1 = b'K' # push 1-byte unsigned int - { - uint8_t value = *buffer; - if (reader.read_int_value(value)) { - buffer++; - } - buffer++; - } break; - case 'M': // BININT2 = b'M' # push 2-byte unsigned int - { - uint16_t value = read_short(buffer); - if (reader.read_int_value(value)) { - buffer++; - } - buffer += 2; - } break; - case 'J': // BININT = b'J' # push four-byte signed int - { - const int32_t value = read_int(buffer); - if (reader.read_int_value(value)) { - buffer++; // skip tuple after read num_elements - } - buffer += 4; - } break; - case 'X': // BINUNICODE = b'X' # " " " ; counted UTF-8 string argument - { - const int32_t len = read_int(buffer); - buffer += 4; - memset(string_buffer, 0, MAX_STRING_BUFFER); - if (len > MAX_STRING_BUFFER) { - // keep truncated names null-terminated, matching the old parser behavior - } - memcpy(string_buffer, buffer, len < MAX_STRING_BUFFER ? len : (MAX_STRING_BUFFER - 1)); - buffer += len; - reader.read_string(string_buffer, zip, dir); - } break; - case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes - { - const int8_t len = *buffer; - buffer++; - memset(string_buffer, 0, MAX_STRING_BUFFER); - memcpy(string_buffer, buffer, len); - buffer += len; - // printf("String: '%s'\n", string_buffer); - } break; - case 'c': // GLOBAL = b'c' # push self.find_class(modname, name); 2 string args - { - int len = find_char(buffer, MAX_STRING_BUFFER, '\n'); - - buffer += len + 1; - len = find_char(buffer, MAX_STRING_BUFFER, '\n'); - - memset(string_buffer, 0, MAX_STRING_BUFFER); - memcpy(string_buffer, buffer, len); - buffer += len + 1; - reader.read_global(string_buffer); - } break; - case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items - case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack top - case 't': // TUPLE = b't' # build tuple from topmost stack items - if (reader.phase == PickleTensorReader::READ_DIMENS) { - reader.tensor_storage.reverse_ne(); - tensor_storages.push_back(reader.tensor_storage); - - // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); - // reset - reader = PickleTensorReader(); - } - break; - case '.': // STOP = b'.' # every pickle ends with STOP - finish = true; - break; - default: - break; - } - } - } - return true; -} - -bool read_ckpt_file(const std::string& file_path, - std::vector& tensor_storages, - std::string* error) { - zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == nullptr) { - set_error(error, "failed to open '" + file_path + "'"); - return false; - } - - tensor_storages.clear(); - bool success = true; - int n = (int)zip_entries_total(zip); - for (int i = 0; i < n; ++i) { - zip_entry_openbyindex(zip, i); - { - std::string name = zip_entry_name(zip); - size_t pos = name.find("data.pkl"); - if (pos != std::string::npos) { - std::string dir = name.substr(0, pos); - printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str()); - void* pkl_data = nullptr; - size_t pkl_size; - zip_entry_read(zip, &pkl_data, &pkl_size); - - // LOG_DEBUG("%lld", pkl_size); - - if (!parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) { - success = false; - } - - free(pkl_data); - } - } - zip_entry_close(zip); - - if (!success) { - break; - } - } - zip_close(zip); - return success; -} diff --git a/src/model_io/ckpt_io.h b/src/model_io/ckpt_io.h deleted file mode 100644 index 72667ce22..000000000 --- a/src/model_io/ckpt_io.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef __SD_MODEL_IO_CKPT_IO_H__ -#define __SD_MODEL_IO_CKPT_IO_H__ - -#include -#include - -#include "tensor_storage.h" - -bool is_ckpt_file(const std::string& file_path); -bool read_ckpt_file(const std::string& file_path, - std::vector& tensor_storages, - std::string* error = nullptr); - -#endif // __SD_MODEL_IO_CKPT_IO_H__ diff --git a/src/model_io/pickle_io.cpp b/src/model_io/pickle_io.cpp new file mode 100644 index 000000000..3a978178a --- /dev/null +++ b/src/model_io/pickle_io.cpp @@ -0,0 +1,1064 @@ +#include "pickle_io.h" + +#include +#include +#include +#include +#include +#include + +#include "binary_io.h" +#include "util.h" + +// $ python -m pickletools sd-v1-4/archive/data.pkl | head -n 100 +// 0: \x80 PROTO 2 +// 2: } EMPTY_DICT +// 3: q BINPUT 0 +// 5: ( MARK +// 6: X BINUNICODE 'epoch' +// 16: q BINPUT 1 +// 18: K BININT1 6 +// 20: X BINUNICODE 'global_step' +// 36: q BINPUT 2 +// 38: J BININT 470000 +// 43: X BINUNICODE 'pytorch-lightning_version' +// 73: q BINPUT 3 +// 75: X BINUNICODE '1.4.2' +// 85: q BINPUT 4 +// 87: X BINUNICODE 'state_dict' +// 102: q BINPUT 5 +// 104: } EMPTY_DICT +// 105: q BINPUT 6 +// 107: ( MARK +// 108: X BINUNICODE 'betas' +// 118: q BINPUT 7 +// 120: c GLOBAL 'torch._utils _rebuild_tensor_v2' +// 153: q BINPUT 8 +// 155: ( MARK +// 156: ( MARK +// 157: X BINUNICODE 'storage' +// 169: q BINPUT 9 +// 171: c GLOBAL 'torch FloatStorage' +// 191: q BINPUT 10 +// 193: X BINUNICODE '0' +// 199: q BINPUT 11 +// 201: X BINUNICODE 'cpu' +// 209: q BINPUT 12 +// 211: M BININT2 1000 +// 214: t TUPLE (MARK at 156) +// 215: q BINPUT 13 +// 217: Q BINPERSID +// 218: K BININT1 0 +// 220: M BININT2 1000 +// ............................... +// 3201: q BINPUT 250 +// 3203: R REDUCE +// 3204: q BINPUT 251 +// 3206: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.weight' +// 3264: q BINPUT 252 +// 3266: h BINGET 8 +// 3268: ( MARK +// 3269: ( MARK +// 3270: h BINGET 9 +// 3272: h BINGET 10 +// 3274: X BINUNICODE '30' +// 3281: q BINPUT 253 +// 3283: h BINGET 12 +// 3285: J BININT 102400 +// 3290: t TUPLE (MARK at 3269) +// 3291: q BINPUT 254 +// 3293: Q BINPERSID +// 3294: K BININT1 0 +// 3296: ( MARK +// 3297: M BININT2 320 +// 3300: M BININT2 320 +// 3303: K BININT1 1 +// 3305: K BININT1 1 +// 3307: t TUPLE (MARK at 3296) +// 3308: q BINPUT 255 +// 3310: ( MARK +// 3311: M BININT2 320 +// 3314: K BININT1 1 +// 3316: K BININT1 1 +// 3318: K BININT1 1 +// 3320: t TUPLE (MARK at 3310) +// 3321: r LONG_BINPUT 256 +// 3326: \x89 NEWFALSE +// 3327: h BINGET 16 +// 3329: ) EMPTY_TUPLE +// 3330: R REDUCE +// 3331: r LONG_BINPUT 257 +// 3336: t TUPLE (MARK at 3268) +// 3337: r LONG_BINPUT 258 +// 3342: R REDUCE +// 3343: r LONG_BINPUT 259 +// 3348: X BINUNICODE 'model.diffusion_model.input_blocks.1.1.proj_in.bias' +// 3404: r LONG_BINPUT 260 +// 3409: h BINGET 8 +// 3411: ( MARK +// 3412: ( MARK +// 3413: h BINGET 9 +// 3415: h BINGET 10 +// 3417: X BINUNICODE '31' +// https://github.com/python/cpython/blob/3.7/Lib/pickletools.py#L1048 +// https://github.com/python/cpython/blob/main/Lib/pickle.py#L105 + +using model_io::find_char; +using model_io::read_int; +using model_io::read_short; +using model_io::read_u64; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size) { + const uint8_t* p = buffer; + const uint8_t* end = buffer + buffer_size; + + while (p < end) { + uint8_t opcode = *p++; + switch (opcode) { + case '.': // STOP = b'.' # every pickle ends with STOP + *object_size = (size_t)(p - buffer); + return true; + case 0x80: // PROTO = b'\x80' # protocol version indicator + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + case 'h': // BINGET = b'h' # read memo index, 1-byte arg + case 'q': // BINPUT = b'q' # write memo index, 1-byte arg + case 'C': // SHORT_BINBYTES = b'C' # push bytes; length < 256 + case 0x82: // EXT1 = b'\x82' # extension code, 1-byte arg + p += 1; + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + case 0x83: // EXT2 = b'\x83' # extension code, 2-byte arg + p += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + case 'j': // LONG_BINGET = b'j' # read memo index, 4-byte arg + case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg + case 0x84: // EXT4 = b'\x84' # extension code, 4-byte arg + p += 4; + break; + case 'I': // INT = b'I' # push decimal integer line + case 'L': // LONG = b'L' # push decimal long integer line + case 'F': // FLOAT = b'F' # push decimal float line + case 'S': // STRING = b'S' # push quoted string line + case 'V': { // UNICODE = b'V' # push raw-unicode string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case 'G': // BINFLOAT = b'G' # push 8-byte binary float + p += 8; + break; + case 0x8A: // LONG1 = b'\x8a' # push long integer; 1-byte length + if (p >= end) { + return false; + } + p += 1 + p[0]; + break; + case 0x8B: { // LONG4 = b'\x8b' # push long integer; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 'T': // BINSTRING = b'T' # push string; 4-byte length + case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length + if (p + 4 > end) { + return false; + } + uint32_t n = (uint32_t)read_int(p); + p += 4 + n; + } break; + case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length + case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length + case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length + if (p + 8 > end) { + return false; + } + uint64_t n = read_u64(p); + p += 8; + if (n > (uint64_t)(end - p)) { + return false; + } + p += n; + } break; + case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 + case 0x8C: // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 + if (p >= end) { + return false; + } + p += 1 + p[0]; + break; + case 'P': { // PERSID = b'P' # persistent id, newline-terminated + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame + p += 8; + break; + case 'c': { // GLOBAL = b'c' # push module/name global reference + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + case '}': // EMPTY_DICT = b'}' # push empty dict + case ']': // EMPTY_LIST = b']' # push empty list + case '(': // MARK = b'(' # push markobject + case 't': // TUPLE = b't' # build tuple from mark + case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack + case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack + case 0x87: // TUPLE3 = b'\x87' # build 3-tuple from stack + case ')': // EMPTY_TUPLE = b')' # push empty tuple + case 'l': // LIST = b'l' # build list from mark + case 'Q': // BINPERSID = b'Q' # persistent id from stack + case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo + case 0x88: // NEWTRUE = b'\x88' # push True + case 0x89: // NEWFALSE = b'\x89' # push False + case 'R': // REDUCE = b'R' # apply callable to args + case 'u': // SETITEMS = b'u' # add mark-delimited items to dict + case 's': // SETITEM = b's' # add key/value to dict + case 'e': // APPENDS = b'e' # extend list with mark-delimited items + case 'a': // APPEND = b'a' # append item to list + case 'b': // BUILD = b'b' # build object state + case 0x81: // NEWOBJ = b'\x81' # build object via __new__ + case 0x8F: // EMPTY_SET = b'\x8f' # push empty set + case 0x90: // ADDITEMS = b'\x90' # add mark-delimited items to set + case 0x91: // FROZENSET = b'\x91' # build frozenset from mark + case 0x92: // NEWOBJ_EX = b'\x92' # build object with kwargs + case 0x93: // STACK_GLOBAL = b'\x93' # build global from module/name strings + case 0x97: // NEXT_BUFFER = b'\x97' # out-of-band buffer marker + case 0x98: // READONLY_BUFFER = b'\x98' # mark buffer readonly + case 'N': // NONE = b'N' # push None + case '0': // POP = b'0' # discard top stack item + case '1': // POP_MARK = b'1' # discard stack through topmost mark + case '2': // DUP = b'2' # duplicate top stack item + case 'o': // OBJ = b'o' # build class instance from mark + break; + case 'i': { // INST = b'i' # build class instance from module/name + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + } break; + default: + return false; + } + if (p > end) { + return false; + } + } + + return false; +} + +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size) { + static const uint8_t torch_magic_bytes[] = {0x6C, 0xFC, 0x9C, 0x46, 0xF9, 0x20, 0x6A, 0xA8, 0x50, 0x19}; + + if (buffer_size < 5 || buffer[0] != 0x80) { + return false; + } + + size_t pos = 2; + if (pos >= buffer_size) { + return false; + } + + uint8_t opcode = buffer[pos++]; + if (opcode != 0x8A || pos >= buffer_size) { + return false; + } + + uint8_t len = buffer[pos++]; + if (len != sizeof(torch_magic_bytes) || pos + len >= buffer_size) { + return false; + } + + if (memcmp(buffer + pos, torch_magic_bytes, sizeof(torch_magic_bytes)) != 0) { + return false; + } + pos += len; + + return pos < buffer_size && buffer[pos] == '.'; +} + +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value) { + if (buffer_size < 4 || buffer[0] != 0x80) { + return false; + } + + size_t pos = 2; + if (pos >= buffer_size) { + return false; + } + + uint8_t opcode = buffer[pos++]; + switch (opcode) { + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + if (pos + 1 >= buffer_size) { + return false; + } + *value = buffer[pos]; + pos += 1; + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + if (pos + 2 >= buffer_size) { + return false; + } + *value = read_short(buffer + pos); + pos += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + if (pos + 4 >= buffer_size) { + return false; + } + *value = (uint32_t)read_int(buffer + pos); + pos += 4; + break; + default: + return false; + } + + return pos < buffer_size && buffer[pos] == '.'; +} + +struct PickleStorageInfo { + std::string key; + ggml_type type = GGML_TYPE_COUNT; + bool is_f64 = false; + bool is_i64 = false; + uint64_t raw_element_nbytes = 0; + uint64_t nbytes = 0; +}; + +struct PickleTensorInfo { + TensorStorage tensor_storage; + int stride_n_dims = 0; + int64_t stride[SD_MAX_DIMS]{1, 1, 1, 1, 1}; +}; + +struct PickleValue { + enum Kind { + MARK, + NONE, + BOOL, + INT, + STRING, + GLOBAL, + TUPLE, + LIST, + DICT, + ORDERED_DICT, + STORAGE, + TENSOR, + }; + + Kind kind = NONE; + int64_t int_value = 0; + bool bool_value = false; + std::string str_value; + std::vector items; + std::vector> dict_items; + PickleStorageInfo storage; + PickleTensorInfo tensor; +}; + +static PickleValue make_mark_value() { + PickleValue value; + value.kind = PickleValue::MARK; + return value; +} + +static PickleValue make_none_value() { + PickleValue value; + value.kind = PickleValue::NONE; + return value; +} + +static PickleValue make_bool_value(bool b) { + PickleValue value; + value.kind = PickleValue::BOOL; + value.bool_value = b; + return value; +} + +static PickleValue make_int_value(int64_t x) { + PickleValue value; + value.kind = PickleValue::INT; + value.int_value = x; + return value; +} + +static PickleValue make_string_value(const std::string& s) { + PickleValue value; + value.kind = PickleValue::STRING; + value.str_value = s; + return value; +} + +static PickleValue make_global_value(const std::string& s) { + PickleValue value; + value.kind = PickleValue::GLOBAL; + value.str_value = s; + return value; +} + +static PickleValue make_tuple_value(std::vector items) { + PickleValue value; + value.kind = PickleValue::TUPLE; + value.items = std::move(items); + return value; +} + +static PickleValue make_list_value() { + PickleValue value; + value.kind = PickleValue::LIST; + return value; +} + +static PickleValue make_dict_value(bool ordered) { + PickleValue value; + value.kind = ordered ? PickleValue::ORDERED_DICT : PickleValue::DICT; + return value; +} + +static PickleValue make_storage_value(const PickleStorageInfo& storage) { + PickleValue value; + value.kind = PickleValue::STORAGE; + value.storage = storage; + return value; +} + +static PickleValue make_tensor_value(const PickleTensorInfo& tensor) { + PickleValue value; + value.kind = PickleValue::TENSOR; + value.tensor = tensor; + return value; +} + +static std::string pickle_value_to_string(const PickleValue& value) { + if (value.kind == PickleValue::STRING) { + return value.str_value; + } + if (value.kind == PickleValue::INT) { + return std::to_string(value.int_value); + } + return ""; +} + +static bool parse_storage_type(const std::string& global_name, PickleStorageInfo* storage) { + if (global_name == "torch.FloatStorage") { + storage->type = GGML_TYPE_F32; + storage->raw_element_nbytes = 4; + return true; + } + if (global_name == "torch.DoubleStorage") { + storage->type = GGML_TYPE_F32; + storage->is_f64 = true; + storage->raw_element_nbytes = 8; + return true; + } + if (global_name == "torch.HalfStorage") { + storage->type = GGML_TYPE_F16; + storage->raw_element_nbytes = 2; + return true; + } + if (global_name == "torch.BFloat16Storage") { + storage->type = GGML_TYPE_BF16; + storage->raw_element_nbytes = 2; + return true; + } + if (global_name == "torch.IntStorage") { + storage->type = GGML_TYPE_I32; + storage->raw_element_nbytes = 4; + return true; + } + if (global_name == "torch.LongStorage") { + storage->type = GGML_TYPE_I32; + storage->is_i64 = true; + storage->raw_element_nbytes = 8; + return true; + } + return false; +} + +static bool tensor_is_contiguous(const PickleTensorInfo& tensor) { + if (tensor.tensor_storage.nelements() == 0) { + return true; + } + if (tensor.stride_n_dims != tensor.tensor_storage.n_dims) { + return false; + } + + int64_t expected_stride = 1; + for (int i = tensor.tensor_storage.n_dims - 1; i >= 0; --i) { + if (tensor.stride[i] != expected_stride) { + return false; + } + expected_stride *= tensor.tensor_storage.ne[i]; + } + return true; +} + +static void collect_tensors_from_pickle_value(const PickleValue& value, + std::vector& tensor_storages) { + if (value.kind != PickleValue::DICT && value.kind != PickleValue::ORDERED_DICT) { + return; + } + + for (const auto& item : value.dict_items) { + if (item.first.kind == PickleValue::STRING && item.second.kind == PickleValue::TENSOR) { + TensorStorage tensor_storage = item.second.tensor.tensor_storage; + tensor_storage.name = item.first.str_value; + tensor_storage.reverse_ne(); + tensor_storages.push_back(tensor_storage); + } else if (item.second.kind == PickleValue::DICT || item.second.kind == PickleValue::ORDERED_DICT) { + collect_tensors_from_pickle_value(item.second, tensor_storages); + } + } +} + +bool parse_torch_state_dict_pickle(const uint8_t* buffer, + size_t buffer_size, + std::vector& tensor_storages, + std::unordered_map& storage_nbytes, + std::string* error) { + if (buffer_size < 2 || buffer[0] != 0x80 || buffer[1] < 2 || buffer[1] > 5) { + set_error(error, "unsupported torch pickle protocol"); + return false; + } + + const uint8_t* p = buffer + 2; + const uint8_t* end = buffer + buffer_size; + std::vector stack; + std::unordered_map memo; + + while (p < end) { + uint8_t opcode = *p++; + switch (opcode) { + case '.': { // STOP = b'.' # every pickle ends with STOP + if (stack.empty()) { + set_error(error, "empty torch pickle stack"); + return false; + } + size_t old_tensor_count = tensor_storages.size(); + collect_tensors_from_pickle_value(stack.back(), tensor_storages); + if (tensor_storages.size() == old_tensor_count) { + set_error(error, "torch pickle does not contain a supported state_dict"); + return false; + } + return true; + } + case '}': // EMPTY_DICT = b'}' # push empty dict + stack.push_back(make_dict_value(false)); + break; + case ']': // EMPTY_LIST = b']' # push empty list + stack.push_back(make_list_value()); + break; + case 'l': { // LIST = b'l' # build list from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + set_error(error, "torch pickle list without mark"); + return false; + } + std::vector items(stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + PickleValue list_value = make_list_value(); + list_value.items = std::move(items); + stack.push_back(std::move(list_value)); + } break; + case '(': // MARK = b'(' # push markobject + stack.push_back(make_mark_value()); + break; + case ')': // EMPTY_TUPLE = b')' # push empty tuple + stack.push_back(make_tuple_value({})); + break; + case 'N': // NONE = b'N' # push None + stack.push_back(make_none_value()); + break; + case 0x88: // NEWTRUE = b'\x88' # push True + stack.push_back(make_bool_value(true)); + break; + case 0x89: // NEWFALSE = b'\x89' # push False + stack.push_back(make_bool_value(false)); + break; + case 'K': // BININT1 = b'K' # push 1-byte unsigned int + if (p >= end) { + return false; + } + stack.push_back(make_int_value(*p++)); + break; + case 'M': // BININT2 = b'M' # push 2-byte unsigned int + if (p + 2 > end) { + return false; + } + stack.push_back(make_int_value(read_short(p))); + p += 2; + break; + case 'J': // BININT = b'J' # push 4-byte signed int + if (p + 4 > end) { + return false; + } + stack.push_back(make_int_value(read_int(p))); + p += 4; + break; + case 'I': { // INT = b'I' # push decimal integer line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (s == "01") { + stack.push_back(make_bool_value(true)); + } else if (s == "00") { + stack.push_back(make_bool_value(false)); + } else { + stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); + } + } break; + case 'L': { // LONG = b'L' # push decimal long integer line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (!s.empty() && s.back() == 'L') { + s.pop_back(); + } + stack.push_back(make_int_value(std::strtoll(s.c_str(), nullptr, 10))); + } break; + case 'F': { // FLOAT = b'F' # push decimal float line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + p += len + 1; + stack.push_back(make_none_value()); + } break; + case 'G': // BINFLOAT = b'G' # push 8-byte binary float + if (p + 8 > end) { + return false; + } + p += 8; + stack.push_back(make_none_value()); + break; + case 0x8A: { // LONG1 = b'\x8a' # push long integer; 1-byte length + if (p >= end) { + return false; + } + uint8_t n = *p++; + if (p + n > end || n > 8) { + return false; + } + int64_t value = 0; + for (uint8_t i = 0; i < n; ++i) { + value |= (int64_t)p[i] << (i * 8); + } + p += n; + stack.push_back(make_int_value(value)); + } break; + case 'C': { // SHORT_BINBYTES = b'C' # push bytes; length < 256 + if (p >= end) { + return false; + } + uint8_t len = *p++; + if (p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'B': { // BINBYTES = b'B' # push bytes; 4-byte length + if (p + 4 > end) { + return false; + } + int32_t len = read_int(p); + p += 4; + if (len < 0 || p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'T': // BINSTRING = b'T' # push string; 4-byte length + case 'X': { // BINUNICODE = b'X' # push UTF-8 string; 4-byte length + if (p + 4 > end) { + return false; + } + int32_t len = read_int(p); + p += 4; + if (len < 0 || p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 0x8D: // BINUNICODE8 = b'\x8d' # push UTF-8 string; 8-byte length + case 0x8E: // BINBYTES8 = b'\x8e' # push bytes; 8-byte length + case 0x96: { // BYTEARRAY8 = b'\x96' # push bytearray; 8-byte length + if (p + 8 > end) { + return false; + } + uint64_t len = read_u64(p); + p += 8; + if (len > (uint64_t)(end - p)) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, (size_t)len))); + p += len; + } break; + case 'U': // SHORT_BINSTRING = b'U' # push string; length < 256 + case 0x8C: { // SHORT_BINUNICODE = b'\x8c' # push UTF-8 string; length < 256 + if (p >= end) { + return false; + } + uint8_t len = *p++; + if (p + len > end) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len; + } break; + case 'S': { // STRING = b'S' # push quoted string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string s((const char*)p, len); + p += len + 1; + if (s.size() >= 2 && (s[0] == '\'' || s[0] == '"') && s.back() == s[0]) { + s = s.substr(1, s.size() - 2); + } + stack.push_back(make_string_value(s)); + } break; + case 'V': { // UNICODE = b'V' # push raw-unicode string line + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + stack.push_back(make_string_value(std::string((const char*)p, len))); + p += len + 1; + } break; + case 'c': { // GLOBAL = b'c' # push module/name global reference + int len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string module((const char*)p, len); + p += len + 1; + len = find_char(p, (int)(end - p), '\n'); + if (len < 0) { + return false; + } + std::string name((const char*)p, len); + p += len + 1; + stack.push_back(make_global_value(module + "." + name)); + } break; + case 0x93: { // STACK_GLOBAL = b'\x93' # build global from module/name strings + if (stack.size() < 2 || stack[stack.size() - 2].kind != PickleValue::STRING || + stack.back().kind != PickleValue::STRING) { + return false; + } + std::string name = stack.back().str_value; + stack.pop_back(); + std::string module = stack.back().str_value; + stack.pop_back(); + stack.push_back(make_global_value(module + "." + name)); + } break; + case 'h': // BINGET = b'h' # read memo index, 1-byte arg + if (p >= end || !memo.count(*p)) { + return false; + } + stack.push_back(memo[*p++]); + break; + case 'j': { // LONG_BINGET = b'j' # read memo index, 4-byte arg + if (p + 4 > end) { + return false; + } + int32_t memo_idx = read_int(p); + if (!memo.count(memo_idx)) { + return false; + } + stack.push_back(memo[memo_idx]); + p += 4; + } break; + case 'q': // BINPUT = b'q' # write memo index, 1-byte arg + if (p >= end || stack.empty()) { + return false; + } + memo[*p++] = stack.back(); + break; + case 'r': // LONG_BINPUT = b'r' # write memo index, 4-byte arg + if (p + 4 > end || stack.empty()) { + return false; + } + memo[read_int(p)] = stack.back(); + p += 4; + break; + case 0x94: // MEMOIZE = b'\x94' # store top of stack in memo + if (stack.empty()) { + return false; + } + memo[(int32_t)memo.size()] = stack.back(); + break; + case 0x95: // FRAME = b'\x95' # indicate the beginning of a new frame + if (p + 8 > end) { + return false; + } + p += 8; + break; + case '0': // POP = b'0' # discard top stack item + if (stack.empty()) { + return false; + } + stack.pop_back(); + break; + case '1': { // POP_MARK = b'1' # discard stack through topmost mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case '2': // DUP = b'2' # duplicate top stack item + if (stack.empty()) { + return false; + } + stack.push_back(stack.back()); + break; + case 0x8F: // EMPTY_SET = b'\x8f' # push empty set + stack.push_back(make_list_value()); + break; + case 0x90: { // ADDITEMS = b'\x90' # add mark-delimited items to set + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { + return false; + } + PickleValue& set_value = stack[mark_idx - 1]; + set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 0x91: { // FROZENSET = b'\x91' # build frozenset from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + PickleValue set_value = make_list_value(); + set_value.items.insert(set_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + stack.push_back(std::move(set_value)); + } break; + case 0x85: // TUPLE1 = b'\x85' # build 1-tuple from stack + case 0x86: // TUPLE2 = b'\x86' # build 2-tuple from stack + case 0x87: { // TUPLE3 = b'\x87' # build 3-tuple from stack + int tuple_size = opcode == 0x85 ? 1 : (opcode == 0x86 ? 2 : 3); + if ((int)stack.size() < tuple_size) { + return false; + } + std::vector items(stack.end() - tuple_size, stack.end()); + stack.erase(stack.end() - tuple_size, stack.end()); + stack.push_back(make_tuple_value(std::move(items))); + } break; + case 't': { // TUPLE = b't' # build tuple from mark + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx < 0) { + return false; + } + std::vector items(stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + stack.push_back(make_tuple_value(std::move(items))); + } break; + case 'Q': { // BINPERSID = b'Q' # persistent id from stack + if (stack.empty()) { + return false; + } + PickleValue pid = stack.back(); + stack.pop_back(); + if (pid.kind != PickleValue::TUPLE || pid.items.size() < 5 || pid.items[0].kind != PickleValue::STRING || + pid.items[1].kind != PickleValue::GLOBAL || pid.items[4].kind != PickleValue::INT || + pid.items[0].str_value != "storage") { + return false; + } + + PickleStorageInfo storage; + storage.key = pickle_value_to_string(pid.items[2]); + if (storage.key.empty() || !parse_storage_type(pid.items[1].str_value, &storage)) { + return false; + } + storage.nbytes = (uint64_t)pid.items[4].int_value * storage.raw_element_nbytes; + storage_nbytes[storage.key] = storage.nbytes; + stack.push_back(make_storage_value(storage)); + } break; + case 'R': { // REDUCE = b'R' # apply callable to args + if (stack.size() < 2) { + return false; + } + PickleValue args = stack.back(); + stack.pop_back(); + PickleValue callable = stack.back(); + stack.pop_back(); + if (callable.kind != PickleValue::GLOBAL || args.kind != PickleValue::TUPLE) { + stack.push_back(make_none_value()); + break; + } + + if (callable.str_value == "collections.OrderedDict" && args.items.empty()) { + stack.push_back(make_dict_value(true)); + break; + } + + if ((callable.str_value == "torch._utils._rebuild_tensor_v2" || callable.str_value == "torch._utils._rebuild_tensor") && + args.items.size() >= 4 && args.items[0].kind == PickleValue::STORAGE && + args.items[1].kind == PickleValue::INT && args.items[2].kind == PickleValue::TUPLE && + args.items[3].kind == PickleValue::TUPLE) { + PickleTensorInfo tensor; + tensor.tensor_storage.type = args.items[0].storage.type; + tensor.tensor_storage.is_f64 = args.items[0].storage.is_f64; + tensor.tensor_storage.is_i64 = args.items[0].storage.is_i64; + tensor.tensor_storage.storage_key = args.items[0].storage.key; + tensor.tensor_storage.offset = (uint64_t)args.items[1].int_value * args.items[0].storage.raw_element_nbytes; + + for (const auto& item : args.items[2].items) { + if (item.kind != PickleValue::INT || tensor.tensor_storage.n_dims >= SD_MAX_DIMS) { + return false; + } + tensor.tensor_storage.ne[tensor.tensor_storage.n_dims++] = item.int_value; + } + + for (const auto& item : args.items[3].items) { + if (item.kind != PickleValue::INT || tensor.stride_n_dims >= SD_MAX_DIMS) { + return false; + } + tensor.stride[tensor.stride_n_dims++] = item.int_value; + } + + if (!tensor_is_contiguous(tensor)) { + return false; + } + stack.push_back(make_tensor_value(tensor)); + break; + } + + // Non-tensor checkpoint metadata can use REDUCE for arbitrary + // Python objects. Do not execute it; keep stack shape only. + stack.push_back(make_none_value()); + break; + } + case 'b': // BUILD = b'b' # build object state + if (stack.size() < 2) { + return false; + } + stack.pop_back(); + break; + case 'u': { // SETITEMS = b'u' # add mark-delimited items to dict + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0) { + return false; + } + PickleValue& dict = stack[mark_idx - 1]; + if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { + return false; + } + for (int i = mark_idx + 1; i + 1 < (int)stack.size(); i += 2) { + dict.dict_items.emplace_back(stack[i], stack[i + 1]); + } + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 's': { // SETITEM = b's' # add key/value to dict + if (stack.size() < 3) { + return false; + } + PickleValue value = stack.back(); + stack.pop_back(); + PickleValue key = stack.back(); + stack.pop_back(); + PickleValue& dict = stack.back(); + if (dict.kind != PickleValue::DICT && dict.kind != PickleValue::ORDERED_DICT) { + return false; + } + dict.dict_items.emplace_back(key, value); + } break; + case 'e': { // APPENDS = b'e' # extend list with mark-delimited items + int mark_idx = (int)stack.size() - 1; + while (mark_idx >= 0 && stack[mark_idx].kind != PickleValue::MARK) { + --mark_idx; + } + if (mark_idx <= 0 || stack[mark_idx - 1].kind != PickleValue::LIST) { + return false; + } + PickleValue& list_value = stack[mark_idx - 1]; + list_value.items.insert(list_value.items.end(), stack.begin() + mark_idx + 1, stack.end()); + stack.erase(stack.begin() + mark_idx, stack.end()); + } break; + case 'a': { // APPEND = b'a' # append item to list + if (stack.size() < 2) { + return false; + } + PickleValue item = stack.back(); + stack.pop_back(); + if (stack.back().kind != PickleValue::LIST) { + return false; + } + stack.back().items.push_back(item); + } break; + default: + set_error(error, + "unsupported torch pickle opcode 0x" + sd_format("%02X", opcode) + + " at offset " + std::to_string((p - buffer) - 1)); + return false; + } + } + + set_error(error, "unterminated torch state_dict pickle"); + return false; +} diff --git a/src/model_io/pickle_io.h b/src/model_io/pickle_io.h new file mode 100644 index 000000000..6a3db37b9 --- /dev/null +++ b/src/model_io/pickle_io.h @@ -0,0 +1,21 @@ +#ifndef __SD_MODEL_IO_PICKLE_IO_H__ +#define __SD_MODEL_IO_PICKLE_IO_H__ + +#include +#include +#include +#include +#include + +#include "tensor_storage.h" + +bool skip_pickle_object(const uint8_t* buffer, size_t buffer_size, size_t* object_size); +bool pickle_object_is_torch_magic_number(const uint8_t* buffer, size_t buffer_size); +bool parse_pickle_uint32_object(const uint8_t* buffer, size_t buffer_size, uint32_t* value); +bool parse_torch_state_dict_pickle(const uint8_t* buffer, + size_t buffer_size, + std::vector& tensor_storages, + std::unordered_map& storage_nbytes, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_PICKLE_IO_H__ diff --git a/src/model_io/safetensors_io.cpp b/src/model_io/safetensors_io.cpp index 1ae485214..786a83b7d 100644 --- a/src/model_io/safetensors_io.cpp +++ b/src/model_io/safetensors_io.cpp @@ -6,6 +6,7 @@ #include #include +#include "binary_io.h" #include "json.hpp" static constexpr size_t ST_HEADER_SIZE_LEN = 8; @@ -16,20 +17,6 @@ static void set_error(std::string* error, const std::string& message) { } } -static uint64_t read_u64(const uint8_t* buffer) { - // little endian - uint64_t value = 0; - value |= static_cast(buffer[7]) << 56; - value |= static_cast(buffer[6]) << 48; - value |= static_cast(buffer[5]) << 40; - value |= static_cast(buffer[4]) << 32; - value |= static_cast(buffer[3]) << 24; - value |= static_cast(buffer[2]) << 16; - value |= static_cast(buffer[1]) << 8; - value |= static_cast(buffer[0]); - return value; -} - bool is_safetensors_file(const std::string& file_path) { std::ifstream file(file_path, std::ios::binary); if (!file.is_open()) { @@ -52,7 +39,7 @@ bool is_safetensors_file(const std::string& file_path) { return false; } - size_t header_size_ = read_u64(header_size_buf); + size_t header_size_ = model_io::read_u64(header_size_buf); if (header_size_ >= file_size_ || header_size_ <= 2) { return false; } @@ -123,7 +110,7 @@ bool read_safetensors_file(const std::string& file_path, return false; } - size_t header_size_ = read_u64(header_size_buf); + size_t header_size_ = model_io::read_u64(header_size_buf); if (header_size_ >= file_size_) { set_error(error, "invalid safetensor file '" + file_path + "'"); return false; diff --git a/src/model_io/tensor_storage.h b/src/model_io/tensor_storage.h index 20b58a19d..4779bfc23 100644 --- a/src/model_io/tensor_storage.h +++ b/src/model_io/tensor_storage.h @@ -24,6 +24,7 @@ struct TensorStorage { int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; int n_dims = 0; + std::string storage_key; size_t file_index = 0; int index_in_zip = -1; // >= means stored in a zip file uint64_t offset = 0; // offset in file diff --git a/src/model_io/torch_legacy_io.cpp b/src/model_io/torch_legacy_io.cpp new file mode 100644 index 000000000..816547252 --- /dev/null +++ b/src/model_io/torch_legacy_io.cpp @@ -0,0 +1,252 @@ +#include "torch_legacy_io.h" + +#include +#include +#include +#include +#include +#include + +#include "pickle_io.h" +#include "util.h" + +// torch.save format background: +// +// - Before PyTorch 1.6.0, torch.save used this legacy non-zip format by +// default. +// - Since PyTorch 1.6.0, torch.save defaults to an uncompressed ZIP64 archive +// containing data.pkl, data/, version, and, since PyTorch 2.1.0, byteorder. +// - The old format can still be produced explicitly with: +// torch.save(obj, path, _use_new_zipfile_serialization=False) +// +// Whether obj is a state_dict or a whole nn.Module does not change the outer +// container format selected by torch.save. It changes the pickled object inside: +// +// - state_dict: usually an OrderedDict[str, Tensor]. pickle_io.cpp supports a +// restricted subset of this layout because tensor metadata and raw storages +// can be recovered without executing pickle callables. +// - whole module/checkpoint object: arbitrary Python object graph. This may +// require importing user classes and executing pickle GLOBAL/REDUCE rebuild +// logic, so it is intentionally not supported here. +// +// Legacy non-zip PyTorch files are not a single pickle object: +// +// 1. pickle object: PyTorch legacy magic number +// 2. pickle object: legacy protocol version, expected to be 1001 +// 3. pickle object: sys_info metadata, ignored by this reader +// 4. pickle object: state_dict metadata, parsed by pickle_io.cpp +// 5. pickle object: serialized storage key list, skipped here +// 6. raw storage data payloads +// - PyTorch writes storages after the pickles, ordered by storage key +// - each storage has an 8-byte legacy storage header followed by raw bytes +static constexpr size_t LEGACY_STORAGE_HEADER_SIZE = 8; + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +static std::string bytes_to_hex(const std::vector& bytes) { + static const char* hex = "0123456789ABCDEF"; + std::string result; + result.reserve(bytes.size() * 3); + for (size_t i = 0; i < bytes.size(); ++i) { + if (i > 0) { + result.push_back('-'); + } + result.push_back(hex[(bytes[i] >> 4) & 0x0F]); + result.push_back(hex[bytes[i] & 0x0F]); + } + return result; +} + +static bool is_probably_tar_file(const std::vector& header) { + return header.size() >= 262 && + header[257] == 'u' && + header[258] == 's' && + header[259] == 't' && + header[260] == 'a' && + header[261] == 'r'; +} + +static std::string torch_legacy_diagnostics(const std::string& file_path, const std::vector& buffer) { + if (!ends_with(file_path, ".pt") && !ends_with(file_path, ".pth")) { + return ""; + } + if (buffer.empty()) { + return "unsupported PyTorch file '" + file_path + "': empty file"; + } + + size_t short_len = std::min(buffer.size(), 32); + std::vector short_header(buffer.begin(), buffer.begin() + short_len); + const bool raw_pickle = buffer[0] == 0x80; + const bool tar_file = is_probably_tar_file(buffer); + + std::string message = "unsupported PyTorch file '" + file_path + "': first bytes " + + bytes_to_hex(short_header) + + ", raw_pickle=" + (raw_pickle ? "true" : "false") + + ", tar=" + (tar_file ? "true" : "false"); + if (raw_pickle) { + message += "; raw pickle did not match the restricted state_dict layouts currently supported"; + } else if (tar_file) { + message += "; legacy tar PyTorch checkpoints are not supported yet"; + } + return message; +} + +bool read_torch_legacy_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + std::ifstream file(file_path, std::ios::binary); + if (!file.is_open()) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + file.seekg(0, file.end); + size_t file_size = (size_t)file.tellg(); + file.seekg(0, file.beg); + if (file_size == 0) { + set_error(error, "empty file '" + file_path + "'"); + return false; + } + + std::vector buffer(file_size); + file.read((char*)buffer.data(), file_size); + if (!file) { + set_error(error, "failed to read '" + file_path + "'"); + return false; + } + + auto finalize_tensor_offsets = [&](size_t storage_data_offset, + const std::unordered_map& legacy_storage_map) -> bool { + if (storage_data_offset > file_size) { + return false; + } + + std::vector storage_keys; + storage_keys.reserve(legacy_storage_map.size()); + for (const auto& [storage_key, _] : legacy_storage_map) { + storage_keys.push_back(storage_key); + } + std::sort(storage_keys.begin(), storage_keys.end()); + + std::unordered_map storage_offsets; + uint64_t current_offset = storage_data_offset; + for (const auto& storage_key : storage_keys) { + auto it = legacy_storage_map.find(storage_key); + if (it == legacy_storage_map.end()) { + return false; + } + if (current_offset + LEGACY_STORAGE_HEADER_SIZE + it->second > file_size) { + return false; + } + storage_offsets[storage_key] = current_offset + LEGACY_STORAGE_HEADER_SIZE; + current_offset += LEGACY_STORAGE_HEADER_SIZE + it->second; + } + + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.storage_key.empty()) { + continue; + } + + auto it_offset = storage_offsets.find(tensor_storage.storage_key); + auto it_size = legacy_storage_map.find(tensor_storage.storage_key); + if (it_offset == storage_offsets.end() || it_size == legacy_storage_map.end()) { + return false; + } + + uint64_t base_offset = it_offset->second; + uint64_t storage_nbytes = it_size->second; + uint64_t tensor_nbytes = tensor_storage.nbytes_to_read(); + if (tensor_storage.offset + tensor_nbytes > storage_nbytes) { + return false; + } + + tensor_storage.offset = base_offset + tensor_storage.offset; + tensor_storage.storage_key.clear(); + } + + return true; + }; + + auto parse_state_dict_at = [&](size_t state_dict_offset, size_t state_dict_size, size_t* storage_data_offset) -> bool { + tensor_storages.clear(); + std::unordered_map legacy_storage_map; + if (!parse_torch_state_dict_pickle(buffer.data() + state_dict_offset, + state_dict_size, + tensor_storages, + legacy_storage_map, + error)) { + return false; + } + + size_t offset_after_state_dict = state_dict_offset + state_dict_size; + size_t storage_keys_size = 0; + if (!skip_pickle_object(buffer.data() + offset_after_state_dict, + buffer.size() - offset_after_state_dict, + &storage_keys_size)) { + return false; + } + + *storage_data_offset = offset_after_state_dict + storage_keys_size; + return finalize_tensor_offsets(*storage_data_offset, legacy_storage_map); + }; + + size_t object_size_1 = 0; + size_t offset = 0; + + if (skip_pickle_object(buffer.data(), buffer.size(), &object_size_1) && + pickle_object_is_torch_magic_number(buffer.data(), object_size_1)) { + offset += object_size_1; + + size_t object_size_2 = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_2)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + uint32_t protocol_version = 0; + if (!parse_pickle_uint32_object(buffer.data() + offset, object_size_2, &protocol_version) || protocol_version != 1001) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + offset += object_size_2; + + size_t object_size_3 = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &object_size_3)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + offset += object_size_3; + + size_t state_dict_size = 0; + if (!skip_pickle_object(buffer.data() + offset, buffer.size() - offset, &state_dict_size)) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + return false; + } + + size_t storage_data_offset = 0; + if (parse_state_dict_at(offset, state_dict_size, &storage_data_offset)) { + return true; + } + + if (error != nullptr && error->empty()) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + } + return false; + } + + size_t state_dict_size = 0; + if (skip_pickle_object(buffer.data(), buffer.size(), &state_dict_size)) { + size_t storage_data_offset = 0; + if (parse_state_dict_at(0, state_dict_size, &storage_data_offset)) { + return true; + } + } + + if (error != nullptr && error->empty()) { + set_error(error, torch_legacy_diagnostics(file_path, buffer)); + } + return false; +} diff --git a/src/model_io/torch_legacy_io.h b/src/model_io/torch_legacy_io.h new file mode 100644 index 000000000..6680e02a1 --- /dev/null +++ b/src/model_io/torch_legacy_io.h @@ -0,0 +1,13 @@ +#ifndef __SD_MODEL_IO_TORCH_LEGACY_IO_H__ +#define __SD_MODEL_IO_TORCH_LEGACY_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool read_torch_legacy_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_TORCH_LEGACY_IO_H__ diff --git a/src/model_io/torch_zip_io.cpp b/src/model_io/torch_zip_io.cpp new file mode 100644 index 000000000..9eaf6c53a --- /dev/null +++ b/src/model_io/torch_zip_io.cpp @@ -0,0 +1,140 @@ +#include "torch_zip_io.h" + +#include +#include +#include +#include +#include + +#include "pickle_io.h" + +#include "zip.h" + +static void set_error(std::string* error, const std::string& message) { + if (error != nullptr) { + *error = message; + } +} + +bool is_torch_zip_file(const std::string& file_path) { + zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == nullptr) { + return false; + } + zip_close(zip); + return true; +} + +static bool find_zip_entry(zip_t* zip, const std::string& entry_name, int* index, uint64_t* size) { + size_t n = zip_entries_total(zip); + for (size_t i = 0; i < n; ++i) { + zip_entry_openbyindex(zip, i); + std::string name = zip_entry_name(zip); + if (name == entry_name) { + *index = (int)i; + *size = zip_entry_size(zip); + zip_entry_close(zip); + return true; + } + zip_entry_close(zip); + } + return false; +} + +static bool parse_zip_data_pkl(const uint8_t* buffer, + size_t buffer_size, + zip_t* zip, + const std::string& dir, + std::vector& tensor_storages, + std::string* error) { + std::vector parsed_tensors; + std::unordered_map storage_nbytes; + if (!parse_torch_state_dict_pickle(buffer, buffer_size, parsed_tensors, storage_nbytes, error)) { + if (error != nullptr && error->empty()) { + *error = "failed to parse torch zip pickle metadata"; + } + return false; + } + + for (auto& tensor_storage : parsed_tensors) { + if (tensor_storage.storage_key.empty()) { + set_error(error, "tensor '" + tensor_storage.name + "' has no storage key"); + return false; + } + + const std::string entry_name = dir + "data/" + tensor_storage.storage_key; + int zip_index = -1; + uint64_t entry_size = 0; + if (!find_zip_entry(zip, entry_name, &zip_index, &entry_size)) { + set_error(error, "storage entry '" + entry_name + "' was not found"); + return false; + } + + auto it_storage_size = storage_nbytes.find(tensor_storage.storage_key); + if (it_storage_size != storage_nbytes.end() && entry_size < it_storage_size->second) { + set_error(error, "storage entry '" + entry_name + "' is smaller than pickle metadata"); + return false; + } + + uint64_t tensor_nbytes = tensor_storage.nbytes_to_read(); + if (tensor_storage.offset + tensor_nbytes > entry_size) { + set_error(error, "tensor '" + tensor_storage.name + "' exceeds storage entry '" + entry_name + "'"); + return false; + } + + tensor_storage.index_in_zip = zip_index; + tensor_storage.storage_key.clear(); + tensor_storages.push_back(tensor_storage); + } + + return true; +} + +bool read_torch_zip_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error) { + zip_t* zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == nullptr) { + set_error(error, "failed to open '" + file_path + "'"); + return false; + } + + tensor_storages.clear(); + bool success = true; + bool found_data_pkl = false; + int n = (int)zip_entries_total(zip); + for (int i = 0; i < n; ++i) { + zip_entry_openbyindex(zip, i); + std::string name = zip_entry_name(zip); + size_t pos = name.find("data.pkl"); + if (pos != std::string::npos) { + found_data_pkl = true; + std::string dir = name.substr(0, pos); + void* pkl_data = nullptr; + size_t pkl_size = 0; + zip_entry_read(zip, &pkl_data, &pkl_size); + + if (pkl_data == nullptr || pkl_size == 0) { + set_error(error, "failed to read '" + name + "' from '" + file_path + "'"); + success = false; + } else if (!parse_zip_data_pkl((const uint8_t*)pkl_data, pkl_size, zip, dir, tensor_storages, error)) { + success = false; + } + + free(pkl_data); + } + zip_entry_close(zip); + + if (!success) { + break; + } + } + + if (success && !found_data_pkl) { + set_error(error, "data.pkl was not found in '" + file_path + "'"); + success = false; + } + + zip_close(zip); + return success; +} diff --git a/src/model_io/torch_zip_io.h b/src/model_io/torch_zip_io.h new file mode 100644 index 000000000..54fb099a7 --- /dev/null +++ b/src/model_io/torch_zip_io.h @@ -0,0 +1,14 @@ +#ifndef __SD_MODEL_IO_TORCH_ZIP_IO_H__ +#define __SD_MODEL_IO_TORCH_ZIP_IO_H__ + +#include +#include + +#include "tensor_storage.h" + +bool is_torch_zip_file(const std::string& file_path); +bool read_torch_zip_file(const std::string& file_path, + std::vector& tensor_storages, + std::string* error = nullptr); + +#endif // __SD_MODEL_IO_TORCH_ZIP_IO_H__ From e8c6581a7b4de3f88ace6db90ee9a75c86206054 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 19 Apr 2026 22:52:23 +0800 Subject: [PATCH 2/2] update --- src/model.cpp | 52 ++++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/src/model.cpp b/src/model.cpp index 5cf577acd..9d9a357dc 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -367,31 +367,7 @@ bool ModelLoader::init_from_torch_legacy_file(const std::string& file_path, cons return true; } -/*================================================= DiffusersModelLoader ==================================================*/ - -bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) { - std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); - std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); - std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); - std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors"); - - if (!init_from_safetensors_file(unet_path, "unet.")) { - return false; - } - - if (!init_from_safetensors_file(vae_path, "vae.")) { - LOG_WARN("Couldn't find working VAE in %s", file_path.c_str()); - // return false; - } - if (!init_from_safetensors_file(clip_path, "te.")) { - LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str()); - // return false; - } - if (!init_from_safetensors_file(clip_g_path, "te.1.")) { - LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str()); - } - return true; -} +/*================================================= TorchZipModelLoader ==================================================*/ bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const std::string& prefix) { LOG_DEBUG("init from '%s'", file_path.c_str()); @@ -420,6 +396,32 @@ bool ModelLoader::init_from_torch_zip_file(const std::string& file_path, const s return true; } +/*================================================= DiffusersModelLoader ==================================================*/ + +bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) { + std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); + std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); + std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); + std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors"); + + if (!init_from_safetensors_file(unet_path, "unet.")) { + return false; + } + + if (!init_from_safetensors_file(vae_path, "vae.")) { + LOG_WARN("Couldn't find working VAE in %s", file_path.c_str()); + // return false; + } + if (!init_from_safetensors_file(clip_path, "te.")) { + LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str()); + // return false; + } + if (!init_from_safetensors_file(clip_g_path, "te.1.")) { + LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str()); + } + return true; +} + SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight;