Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions scripts/convert_joyimage_edit_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import argparse
import pathlib
from typing import Any, Dict, Tuple
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from safetensors.torch import load_file
from diffusers import (
AutoencoderKLWan,
JoyImageEditTransformer3DModel,
JoyImageEditPipeline,
)
# This code is modified from convert_wan_to_diffusers.py to support input ckpt path
def convert_vae(vae_ckpt_path):
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
new_state_dict = {}

# Create mappings for specific components
middle_key_mapping = {
# Encoder middle block
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle block
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}

# Create a mapping for attention blocks
attention_mapping = {
# Encoder middle attention
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
# Decoder middle attention
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}

# Create a mapping for the head components
head_mapping = {
# Encoder head
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
# Decoder head
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}

# Create a mapping for the quant components
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}

# Process each key in the state dict
for key, value in old_state_dict.items():
# Handle middle block keys using the mapping
if key in middle_key_mapping:
new_key = middle_key_mapping[key]
new_state_dict[new_key] = value
# Handle attention blocks using the mapping
elif key in attention_mapping:
new_key = attention_mapping[key]
new_state_dict[new_key] = value
# Handle head keys using the mapping
elif key in head_mapping:
new_key = head_mapping[key]
new_state_dict[new_key] = value
# Handle quant keys using the mapping
elif key in quant_mapping:
new_key = quant_mapping[key]
new_state_dict[new_key] = value
# Handle encoder conv1
elif key == "encoder.conv1.weight":
new_state_dict["encoder.conv_in.weight"] = value
elif key == "encoder.conv1.bias":
new_state_dict["encoder.conv_in.bias"] = value
# Handle decoder conv1
elif key == "decoder.conv1.weight":
new_state_dict["decoder.conv_in.weight"] = value
elif key == "decoder.conv1.bias":
new_state_dict["decoder.conv_in.bias"] = value
# Handle encoder downsamples
elif key.startswith("encoder.downsamples."):
# Convert to down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")

# Convert residual block naming but keep the original structure
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")

new_state_dict[new_key] = value

# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Convert to up_blocks
parts = key.split(".")
block_idx = int(parts[2])

# Group residual blocks
if "residual" in key:
if block_idx in [0, 1, 2]:
new_block_idx = 0
resnet_idx = block_idx
elif block_idx in [4, 5, 6]:
new_block_idx = 1
resnet_idx = block_idx - 4
elif block_idx in [8, 9, 10]:
new_block_idx = 2
resnet_idx = block_idx - 8
elif block_idx in [12, 13, 14]:
new_block_idx = 3
resnet_idx = block_idx - 12
else:
# Keep as is for other blocks
new_state_dict[key] = value
continue

# Convert residual block naming
if ".residual.0.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
elif ".residual.2.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
elif ".residual.2.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
elif ".residual.3.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
elif ".residual.6.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
elif ".residual.6.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
else:
new_key = key

new_state_dict[new_key] = value

# Handle shortcut connections
elif ".shortcut." in key:
if block_idx == 4:
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")

new_state_dict[new_key] = value

# Handle upsamplers
elif ".resample." in key or ".time_conv." in key:
if block_idx == 3:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
elif block_idx == 7:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
elif block_idx == 11:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")

new_state_dict[new_key] = value
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_state_dict[new_key] = value
else:
# Keep other keys unchanged
new_state_dict[key] = value

with init_empty_weights():
vae = AutoencoderKLWan()
vae.load_state_dict(new_state_dict, strict=True, assign=True)
return vae

def get_transformer_config() -> Tuple[Dict[str, Any], ...]:
config = {
"diffusers_config": {
"hidden_size": 4096,
"in_channels": 16,
"heads_num": 32,
"mm_double_blocks_depth": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"rope_dim_list": [16, 56, 56],
"text_states_dim": 4096,
"rope_type": "rope",
"dit_modulation_type": "wanx",
"unpatchify_new": True,
"rope_theta": 10000,
},
}
return config
def convert_transformer(ckpt_path: str):
checkpoint = torch.load(ckpt_path, weights_only=True)
if "model" in checkpoint:
original_state_dict = checkpoint["model"]
else:
original_state_dict = checkpoint
config = get_transformer_config()
with init_empty_weights():
transformer = JoyImageEditTransformer3DModel(**config['diffusers_config'])
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
parser.add_argument("--flow_shift", type=float, default=7.0)
return parser.parse_args()

DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
vae = None
dtype = DTYPE_MAPPING[args.dtype]

if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
assert args.text_encoder_path is not None
# assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_ckpt_path is not None:
vae = convert_vae(args.vae_ckpt_path)
vae = vae.to(dtype=dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
processor = AutoProcessor.from_pretrained(args.text_encoder_path)
text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path)
flow_shift = 1.5
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=flow_shift
)
transformer = transformer.to("cuda")
vae = vae.to("cuda")
pipe = JoyImageEditPipeline(
processor=processor,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
).to("cuda")
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
processor.save_pretrained(f"{args.output_path}/processor")
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
"accelerate>=0.31.0",
"compel==0.1.8",
"datasets",
"einops",
"filelock",
"flax>=0.4.1",
"ftfy",
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanImageTransformer2DModel",
"JoyImageEditTransformer3DModel",
"HunyuanVideo15Transformer3DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
Expand Down Expand Up @@ -596,6 +597,7 @@
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"LucyEditPipeline",
"JoyImageEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
Expand Down Expand Up @@ -1025,6 +1027,7 @@
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanImageTransformer2DModel,
JoyImageEditTransformer3DModel,
HunyuanVideo15Transformer3DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Expand Down Expand Up @@ -1359,6 +1362,7 @@
LTXLatentUpsamplePipeline,
LTXPipeline,
LucyEditPipeline,
JoyImageEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel", "JoyImageTransformer3DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
Expand Down Expand Up @@ -225,6 +226,7 @@
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
JoyImageEditTransformer3DModel,
HunyuanVideo15Transformer3DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
from .transformer_joyimage import JoyImageEditTransformer3DModel, JoyImageTransformer3DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
Expand Down
Loading
Loading