From 7b756a518ed61f37ac5e22049b14ae8f69dbdce0 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Tue, 18 Nov 2025 20:19:37 +0800 Subject: [PATCH 1/3] flux --- diffsynth/configs/model_configs.py | 53 +- diffsynth/models/flux_controlnet.py | 61 ++- diffsynth/models/flux_lora_encoder.py | 483 +++++++++++++++++- diffsynth/models/flux_lora_patcher.py | 60 +++ diffsynth/models/sd_text_encoder.py | 412 +++++++++++++++ diffsynth/pipelines/flux_image.py | 118 ++++- .../state_dict_converters/flux_controlnet.py | 104 ++++ .../state_dict_converters/flux_infiniteyou.py | 4 + .../FLUX.1-dev-LoRA-Encoder.py | 4 +- 9 files changed, 1286 insertions(+), 13 deletions(-) create mode 100644 diffsynth/models/flux_lora_patcher.py create mode 100644 diffsynth/models/sd_text_encoder.py create mode 100644 diffsynth/utils/state_dict_converters/flux_controlnet.py create mode 100644 diffsynth/utils/state_dict_converters/flux_infiniteyou.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 1fade416..3aab5c72 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -312,7 +312,58 @@ "model_hash": "0629116fce1472503a66992f96f3eb1a", "model_name": "flux_value_controller", "model_class": "diffsynth.models.flux_value_control.SingleValueEncoder", - } + }, + { + # Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "52357cb26250681367488a8954c271e8", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}, + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "78d18b9101345ff695f312e7e62538c0", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}, + }, + { + # Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "b001c89139b5f053c715fe772362dd2a", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_single_blocks": 0}, + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin") + "model_hash": "c07c0f04f5ff55e86b4e937c7a40d481", + "model_name": "infiniteyou_image_projector", + "model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors") + "model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors") + "model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab", + "model_name": "flux_lora_encoder", + "model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors") + "model_hash": "30143afb2dea73d1ac580e0787628f8c", + "model_name": "flux_lora_patcher", + "model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher", + }, ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series diff --git a/diffsynth/models/flux_controlnet.py b/diffsynth/models/flux_controlnet.py index 85fccd7d..7fb1138b 100644 --- a/diffsynth/models/flux_controlnet.py +++ b/diffsynth/models/flux_controlnet.py @@ -1,9 +1,62 @@ import torch from einops import rearrange, repeat from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm -from .utils import hash_state_dict_keys, init_weights_on_device +# from .utils import hash_state_dict_keys, init_weights_on_device +from contextlib import contextmanager +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() +@contextmanager +def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): + + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) class FluxControlNet(torch.nn.Module): def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0): @@ -102,9 +155,9 @@ def forward( return controlnet_res_stack, controlnet_single_res_stack - @staticmethod - def state_dict_converter(): - return FluxControlNetStateDictConverter() + # @staticmethod + # def state_dict_converter(): + # return FluxControlNetStateDictConverter() def quantize(self): def cast_to(weight, dtype=None, device=None, copy=False): diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py index 695640a8..2be5dbbe 100644 --- a/diffsynth/models/flux_lora_encoder.py +++ b/diffsynth/models/flux_lora_encoder.py @@ -1,5 +1,415 @@ import torch -from .sd_text_encoder import CLIPEncoderLayer +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ + class LoRALayerBlock(torch.nn.Module): @@ -58,13 +468,80 @@ def default_lora_patterns(self): "type": suffix, }) return lora_patterns + + def get_lora_param_pair(self, lora, name, dim, device, dtype): + key_A = name + ".lora_A.default.weight" + key_B = name + ".lora_B.default.weight" + if key_A in lora and key_B in lora: + return lora[key_A], lora[key_B] + + if "to_qkv" in name: + base_name = name.replace("to_qkv", "") + suffixes = ["to_q", "to_k", "to_v"] + + found_As = [] + found_Bs = [] + + all_found = True + for suffix in suffixes: + sub_name = base_name + suffix + k_A = sub_name + ".lora_A.default.weight" + k_B = sub_name + ".lora_B.default.weight" + + if k_A in lora and k_B in lora: + found_As.append(lora[k_A]) + found_Bs.append(lora[k_B]) + else: + all_found = False + break + if all_found: + pass + + rank = 16 + for k, v in lora.items(): + if "lora_A" in k: + rank = v.shape[0] + device = v.device + dtype = v.dtype + break + lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype) + lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype) + + return lora_A, lora_B + def forward(self, lora): lora_emb = [] + device = None + dtype = None + for v in lora.values(): + device = v.device + dtype = v.dtype + break + for lora_pattern in self.lora_patterns: name, layer_type = lora_pattern["name"], lora_pattern["type"] - lora_A = lora[name + ".lora_A.default.weight"] - lora_B = lora[name + ".lora_B.default.weight"] + dim = lora_pattern["dim"] + + lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype) + + if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))): + base_name = name.replace("to_qkv", "") + try: + q_name = base_name + "to_q" + k_name = base_name + "to_k" + v_name = base_name + "to_v" + + real_A = lora[q_name + ".lora_A.default.weight"] + B_q = lora[q_name + ".lora_B.default.weight"] + B_k = lora[k_name + ".lora_B.default.weight"] + B_v = lora[v_name + ".lora_B.default.weight"] + real_B = torch.cat([B_q, B_k, B_v], dim=0) + + lora_A, lora_B = real_A, real_B + except KeyError: + pass + lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) lora_emb.append(lora_out) diff --git a/diffsynth/models/flux_lora_patcher.py b/diffsynth/models/flux_lora_patcher.py new file mode 100644 index 00000000..a249feb1 --- /dev/null +++ b/diffsynth/models/flux_lora_patcher.py @@ -0,0 +1,60 @@ +import torch + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + +class FluxLoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) + diff --git a/diffsynth/models/sd_text_encoder.py b/diffsynth/models/sd_text_encoder.py new file mode 100644 index 00000000..b0a1171c --- /dev/null +++ b/diffsynth/models/sd_text_encoder.py @@ -0,0 +1,412 @@ +import torch +from .attention import Attention +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 3c0151d3..78f0cb1c 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -102,8 +102,122 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): ] self.model_fn = model_fn_flux_image self.lora_loader = FluxLoRALoader - - + + def enable_lora_magic(self): + pass + + # def load_lora(self, model, lora_config, alpha=1, hotload=False): + # if isinstance(lora_config, str): + # path = lora_config + # else: + # lora_config.download_if_necessary() + # path = lora_config.path + + # state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") + # loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) + # state_dict = loader.convert_state_dict(state_dict) + # loaded_count = 0 + # for key in tqdm(state_dict, desc="Applying LoRA"): + # if ".lora_A." in key: + # layer_name = key.split(".lora_A.")[0] + # module = model + # try: + # parts = layer_name.split(".") + # for part in parts: + # if part.isdigit(): + # module = module[int(part)] + # else: + # module = getattr(module, part) + # except AttributeError: + # continue + + # w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) + # w_b_key = key.replace("lora_A", "lora_B") + # if w_b_key not in state_dict: continue + # w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) + # delta_w = torch.mm(w_b, w_a) + # module.weight.data += delta_w * alpha + # loaded_count += 1 + + + def load_lora(self, model, lora_config, alpha=1.0, hotload=False): + if isinstance(lora_config, str): + path = lora_config + else: + lora_config.download_if_necessary() + path = lora_config.path + + state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") + loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) + state_dict = loader.convert_state_dict(state_dict) + + print(f"Merging LoRA weights from {path}...") + loaded_count = 0 + + # [新增] 键名映射表,处理 FW2 Loader 与 DiT 模型名称不一致的情况 + # 针对 Single Blocks 常见的命名差异进行修正 + key_mapping = { + ".linear1.": ".to_qkv_mlp.", # 常见差异点 1 + ".linear2.": ".proj_out.", # 常见差异点 2 + ".modulation.lin.": ".norm.linear." # 常见差异点 3 + } + + for key in tqdm(state_dict, desc="Applying LoRA"): + if ".lora_A." in key: + layer_name = key.split(".lora_A.")[0] + + # [新增] 尝试应用键名修正 + target_layer_name = layer_name + for src, dst in key_mapping.items(): + if src in target_layer_name: + target_layer_name = target_layer_name.replace(src, dst) + + # 在模型中查找层 + module = model + try: + parts = target_layer_name.split(".") + for part in parts: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + except AttributeError: + # 如果修正后还是找不到,尝试原始名称(作为保底) + try: + module = model + parts = layer_name.split(".") + for part in parts: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + except AttributeError: + # 确实找不到,跳过并打印警告(可选) + # print(f"Warning: Could not find layer for {layer_name}") + continue + + # 获取 LoRA 参数并计算增量 + try: + w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) + w_b_key = key.replace("lora_A", "lora_B") + if w_b_key not in state_dict: continue + w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) + + # 检查形状是否匹配 (非常重要,防止 broadcasting 错误掩盖问题) + # Linear weight: (out, in). B@A: (out, in) + delta_w = torch.mm(w_b, w_a) + if delta_w.shape != module.weight.shape: + # 形状不匹配通常意味着 QKV 融合/分离状态不一致 + # 简单跳过或尝试转置(视具体情况,这里保守跳过) + continue + + module.weight.data += delta_w * alpha + loaded_count += 1 + except Exception as e: + continue + + print(f"Applied LoRA to {loaded_count} layers.") + @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, diff --git a/diffsynth/utils/state_dict_converters/flux_controlnet.py b/diffsynth/utils/state_dict_converters/flux_controlnet.py new file mode 100644 index 00000000..926590ca --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_controlnet.py @@ -0,0 +1,104 @@ +import torch +import hashlib +import json + +def FluxControlNetStateDictConverter(state_dict): + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + + for name in state_dict: + param = state_dict[name] + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_infiniteyou.py b/diffsynth/utils/state_dict_converters/flux_infiniteyou.py new file mode 100644 index 00000000..826f9966 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_infiniteyou.py @@ -0,0 +1,4 @@ +import torch + +def FluxInfiniteYouImageProjectorStateDictConverter(state_dict): + return state_dict['image_proj'] \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py index 9e3d74bc..75f1bc80 100644 --- a/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py @@ -13,10 +13,8 @@ ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"), ], ) -pipe.enable_lora_magic() - lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") -pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA. +pipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA. # Empty prompt can automatically activate LoRA capabilities. image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) From 3f9e9cad9d6d8e2fb3b0ed1e4e1da518861f18ac Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Tue, 18 Nov 2025 20:37:14 +0800 Subject: [PATCH 2/3] fix:flux --- diffsynth/pipelines/flux_image.py | 100 +++--------------------------- 1 file changed, 10 insertions(+), 90 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 78f0cb1c..9ac03739 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -106,41 +106,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def enable_lora_magic(self): pass - # def load_lora(self, model, lora_config, alpha=1, hotload=False): - # if isinstance(lora_config, str): - # path = lora_config - # else: - # lora_config.download_if_necessary() - # path = lora_config.path - - # state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") - # loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) - # state_dict = loader.convert_state_dict(state_dict) - # loaded_count = 0 - # for key in tqdm(state_dict, desc="Applying LoRA"): - # if ".lora_A." in key: - # layer_name = key.split(".lora_A.")[0] - # module = model - # try: - # parts = layer_name.split(".") - # for part in parts: - # if part.isdigit(): - # module = module[int(part)] - # else: - # module = getattr(module, part) - # except AttributeError: - # continue - - # w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - # w_b_key = key.replace("lora_A", "lora_B") - # if w_b_key not in state_dict: continue - # w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - # delta_w = torch.mm(w_b, w_a) - # module.weight.data += delta_w * alpha - # loaded_count += 1 - - - def load_lora(self, model, lora_config, alpha=1.0, hotload=False): + def load_lora(self, model, lora_config, alpha=1, hotload=False): if isinstance(lora_config, str): path = lora_config else: @@ -150,74 +116,28 @@ def load_lora(self, model, lora_config, alpha=1.0, hotload=False): state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) state_dict = loader.convert_state_dict(state_dict) - - print(f"Merging LoRA weights from {path}...") loaded_count = 0 - - # [新增] 键名映射表,处理 FW2 Loader 与 DiT 模型名称不一致的情况 - # 针对 Single Blocks 常见的命名差异进行修正 - key_mapping = { - ".linear1.": ".to_qkv_mlp.", # 常见差异点 1 - ".linear2.": ".proj_out.", # 常见差异点 2 - ".modulation.lin.": ".norm.linear." # 常见差异点 3 - } - for key in tqdm(state_dict, desc="Applying LoRA"): if ".lora_A." in key: layer_name = key.split(".lora_A.")[0] - - # [新增] 尝试应用键名修正 - target_layer_name = layer_name - for src, dst in key_mapping.items(): - if src in target_layer_name: - target_layer_name = target_layer_name.replace(src, dst) - - # 在模型中查找层 module = model try: - parts = target_layer_name.split(".") + parts = layer_name.split(".") for part in parts: if part.isdigit(): module = module[int(part)] else: module = getattr(module, part) except AttributeError: - # 如果修正后还是找不到,尝试原始名称(作为保底) - try: - module = model - parts = layer_name.split(".") - for part in parts: - if part.isdigit(): - module = module[int(part)] - else: - module = getattr(module, part) - except AttributeError: - # 确实找不到,跳过并打印警告(可选) - # print(f"Warning: Could not find layer for {layer_name}") - continue - - # 获取 LoRA 参数并计算增量 - try: - w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - w_b_key = key.replace("lora_A", "lora_B") - if w_b_key not in state_dict: continue - w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - - # 检查形状是否匹配 (非常重要,防止 broadcasting 错误掩盖问题) - # Linear weight: (out, in). B@A: (out, in) - delta_w = torch.mm(w_b, w_a) - if delta_w.shape != module.weight.shape: - # 形状不匹配通常意味着 QKV 融合/分离状态不一致 - # 简单跳过或尝试转置(视具体情况,这里保守跳过) - continue - - module.weight.data += delta_w * alpha - loaded_count += 1 - except Exception as e: continue - - print(f"Applied LoRA to {loaded_count} layers.") - + + w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) + w_b_key = key.replace("lora_A", "lora_B") + if w_b_key not in state_dict: continue + w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) + delta_w = torch.mm(w_b, w_a) + module.weight.data += delta_w * alpha + loaded_count += 1 @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, From 2d23c897c20c60a877b834b689f80e786b5d7495 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Tue, 18 Nov 2025 21:29:35 +0800 Subject: [PATCH 3/3] add: LoRA Encoder --- diffsynth/models/flux_lora_encoder.py | 71 +------------------ diffsynth/pipelines/flux_image.py | 32 --------- .../model_inference/FLUX.1-dev-LoRA-Fusion.py | 2 - 3 files changed, 2 insertions(+), 103 deletions(-) diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py index 2be5dbbe..13589b06 100644 --- a/diffsynth/models/flux_lora_encoder.py +++ b/diffsynth/models/flux_lora_encoder.py @@ -468,80 +468,13 @@ def default_lora_patterns(self): "type": suffix, }) return lora_patterns - - def get_lora_param_pair(self, lora, name, dim, device, dtype): - key_A = name + ".lora_A.default.weight" - key_B = name + ".lora_B.default.weight" - if key_A in lora and key_B in lora: - return lora[key_A], lora[key_B] - - if "to_qkv" in name: - base_name = name.replace("to_qkv", "") - suffixes = ["to_q", "to_k", "to_v"] - - found_As = [] - found_Bs = [] - - all_found = True - for suffix in suffixes: - sub_name = base_name + suffix - k_A = sub_name + ".lora_A.default.weight" - k_B = sub_name + ".lora_B.default.weight" - - if k_A in lora and k_B in lora: - found_As.append(lora[k_A]) - found_Bs.append(lora[k_B]) - else: - all_found = False - break - if all_found: - pass - - rank = 16 - for k, v in lora.items(): - if "lora_A" in k: - rank = v.shape[0] - device = v.device - dtype = v.dtype - break - lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype) - lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype) - - return lora_A, lora_B - def forward(self, lora): lora_emb = [] - device = None - dtype = None - for v in lora.values(): - device = v.device - dtype = v.dtype - break - for lora_pattern in self.lora_patterns: name, layer_type = lora_pattern["name"], lora_pattern["type"] - dim = lora_pattern["dim"] - - lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype) - - if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))): - base_name = name.replace("to_qkv", "") - try: - q_name = base_name + "to_q" - k_name = base_name + "to_k" - v_name = base_name + "to_v" - - real_A = lora[q_name + ".lora_A.default.weight"] - B_q = lora[q_name + ".lora_B.default.weight"] - B_k = lora[k_name + ".lora_B.default.weight"] - B_v = lora[v_name + ".lora_B.default.weight"] - real_B = torch.cat([B_q, B_k, B_v], dim=0) - - lora_A, lora_B = real_A, real_B - except KeyError: - pass - + lora_A = lora[name + ".lora_A.weight"] + lora_B = lora[name + ".lora_B.weight"] lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) lora_emb.append(lora_out) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 9ac03739..dcd9e8ea 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -106,38 +106,6 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16): def enable_lora_magic(self): pass - def load_lora(self, model, lora_config, alpha=1, hotload=False): - if isinstance(lora_config, str): - path = lora_config - else: - lora_config.download_if_necessary() - path = lora_config.path - - state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") - loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) - state_dict = loader.convert_state_dict(state_dict) - loaded_count = 0 - for key in tqdm(state_dict, desc="Applying LoRA"): - if ".lora_A." in key: - layer_name = key.split(".lora_A.")[0] - module = model - try: - parts = layer_name.split(".") - for part in parts: - if part.isdigit(): - module = module[int(part)] - else: - module = getattr(module, part) - except AttributeError: - continue - - w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - w_b_key = key.replace("lora_A", "lora_B") - if w_b_key not in state_dict: continue - w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - delta_w = torch.mm(w_b, w_a) - module.weight.data += delta_w * alpha - loaded_count += 1 @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py index 5339230d..9d0b189d 100644 --- a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py @@ -18,12 +18,10 @@ pipe.load_lora( pipe.dit, ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"), - hotload=True, ) pipe.load_lora( pipe.dit, ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"), - hotload=True, ) image = pipe(prompt="a cat", seed=0) image.save("image_fused.jpg")