Skip to content

Commit 7b756a5

Browse files
committed
flux
1 parent 416ff5d commit 7b756a5

File tree

9 files changed

+1286
-13
lines changed

9 files changed

+1286
-13
lines changed

diffsynth/configs/model_configs.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,58 @@
312312
"model_hash": "0629116fce1472503a66992f96f3eb1a",
313313
"model_name": "flux_value_controller",
314314
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
315-
}
315+
},
316+
{
317+
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
318+
"model_hash": "52357cb26250681367488a8954c271e8",
319+
"model_name": "flux_controlnet",
320+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
321+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
322+
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
323+
},
324+
{
325+
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
326+
"model_hash": "78d18b9101345ff695f312e7e62538c0",
327+
"model_name": "flux_controlnet",
328+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
329+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
330+
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
331+
},
332+
{
333+
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
334+
"model_hash": "b001c89139b5f053c715fe772362dd2a",
335+
"model_name": "flux_controlnet",
336+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
337+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
338+
"extra_kwargs": {"num_single_blocks": 0},
339+
},
340+
{
341+
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
342+
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
343+
"model_name": "infiniteyou_image_projector",
344+
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
345+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
346+
},
347+
{
348+
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
349+
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
350+
"model_name": "flux_controlnet",
351+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
352+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
353+
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
354+
},
355+
{
356+
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
357+
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
358+
"model_name": "flux_lora_encoder",
359+
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
360+
},
361+
{
362+
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
363+
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
364+
"model_name": "flux_lora_patcher",
365+
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
366+
},
316367
]
317368

318369
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series

diffsynth/models/flux_controlnet.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,62 @@
11
import torch
22
from einops import rearrange, repeat
33
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
4-
from .utils import hash_state_dict_keys, init_weights_on_device
4+
# from .utils import hash_state_dict_keys, init_weights_on_device
5+
from contextlib import contextmanager
56

7+
def hash_state_dict_keys(state_dict, with_shape=True):
8+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
9+
keys_str = keys_str.encode(encoding="UTF-8")
10+
return hashlib.md5(keys_str).hexdigest()
611

12+
@contextmanager
13+
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
14+
15+
old_register_parameter = torch.nn.Module.register_parameter
16+
if include_buffers:
17+
old_register_buffer = torch.nn.Module.register_buffer
18+
19+
def register_empty_parameter(module, name, param):
20+
old_register_parameter(module, name, param)
21+
if param is not None:
22+
param_cls = type(module._parameters[name])
23+
kwargs = module._parameters[name].__dict__
24+
kwargs["requires_grad"] = param.requires_grad
25+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
26+
27+
def register_empty_buffer(module, name, buffer, persistent=True):
28+
old_register_buffer(module, name, buffer, persistent=persistent)
29+
if buffer is not None:
30+
module._buffers[name] = module._buffers[name].to(device)
31+
32+
def patch_tensor_constructor(fn):
33+
def wrapper(*args, **kwargs):
34+
kwargs["device"] = device
35+
return fn(*args, **kwargs)
36+
37+
return wrapper
38+
39+
if include_buffers:
40+
tensor_constructors_to_patch = {
41+
torch_function_name: getattr(torch, torch_function_name)
42+
for torch_function_name in ["empty", "zeros", "ones", "full"]
43+
}
44+
else:
45+
tensor_constructors_to_patch = {}
46+
47+
try:
48+
torch.nn.Module.register_parameter = register_empty_parameter
49+
if include_buffers:
50+
torch.nn.Module.register_buffer = register_empty_buffer
51+
for torch_function_name in tensor_constructors_to_patch.keys():
52+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
53+
yield
54+
finally:
55+
torch.nn.Module.register_parameter = old_register_parameter
56+
if include_buffers:
57+
torch.nn.Module.register_buffer = old_register_buffer
58+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
59+
setattr(torch, torch_function_name, old_torch_function)
760

861
class FluxControlNet(torch.nn.Module):
962
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
@@ -102,9 +155,9 @@ def forward(
102155
return controlnet_res_stack, controlnet_single_res_stack
103156

104157

105-
@staticmethod
106-
def state_dict_converter():
107-
return FluxControlNetStateDictConverter()
158+
# @staticmethod
159+
# def state_dict_converter():
160+
# return FluxControlNetStateDictConverter()
108161

109162
def quantize(self):
110163
def cast_to(weight, dtype=None, device=None, copy=False):

0 commit comments

Comments
 (0)