From 54afa6501af00560070419d16151de0b8a451614 Mon Sep 17 00:00:00 2001 From: "zhangmaoquan.1" Date: Thu, 2 Apr 2026 01:14:57 +0000 Subject: [PATCH] [feat] JoyAI-JoyImage-Edit support --- scripts/convert_joyimage_edit_to_diffusers.py | 306 +++++ setup.py | 1 + src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_joyimage.py | 658 +++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/joyimage/__init__.py | 49 + .../joyimage/pipeline_joyimage_edit.py | 1187 +++++++++++++++++ .../pipelines/joyimage/pipeline_output.py | 16 + 10 files changed, 2226 insertions(+) create mode 100644 scripts/convert_joyimage_edit_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_joyimage.py create mode 100644 src/diffusers/pipelines/joyimage/__init__.py create mode 100644 src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py create mode 100644 src/diffusers/pipelines/joyimage/pipeline_output.py diff --git a/scripts/convert_joyimage_edit_to_diffusers.py b/scripts/convert_joyimage_edit_to_diffusers.py new file mode 100644 index 000000000000..37506ea05d17 --- /dev/null +++ b/scripts/convert_joyimage_edit_to_diffusers.py @@ -0,0 +1,306 @@ +import argparse +import pathlib +from typing import Any, Dict, Tuple +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors.torch import load_file +from transformers import AutoProcessor, AutoTokenizer, Qwen3VLForConditionalGeneration +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from safetensors.torch import load_file +from diffusers import ( + AutoencoderKLWan, + JoyImageEditTransformer3DModel, + JoyImageEditPipeline, +) +# This code is modified from convert_wan_to_diffusers.py to support input ckpt path +def convert_vae(vae_ckpt_path): + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in old_state_dict.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + new_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + new_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + new_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + new_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Convert to down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Convert residual block naming but keep the original structure + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + + new_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Convert to up_blocks + parts = key.split(".") + block_idx = int(parts[2]) + + # Group residual blocks + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + # Keep as is for other blocks + new_state_dict[key] = value + continue + + # Convert residual block naming + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + + new_state_dict[new_key] = value + + # Handle shortcut connections + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + + new_state_dict[new_key] = value + + # Handle upsamplers + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + new_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + # Keep other keys unchanged + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan() + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + +def get_transformer_config() -> Tuple[Dict[str, Any], ...]: + config = { + "diffusers_config": { + "hidden_size": 4096, + "in_channels": 16, + "heads_num": 32, + "mm_double_blocks_depth": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "rope_dim_list": [16, 56, 56], + "text_states_dim": 4096, + "rope_type": "rope", + "dit_modulation_type": "wanx", + "unpatchify_new": True, + "rope_theta": 10000, + }, + } + return config +def convert_transformer(ckpt_path: str): + checkpoint = torch.load(ckpt_path, weights_only=True) + if "model" in checkpoint: + original_state_dict = checkpoint["model"] + else: + original_state_dict = checkpoint + config = get_transformer_config() + with init_empty_weights(): + transformer = JoyImageEditTransformer3DModel(**config['diffusers_config']) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint") + parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint") + parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer") + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + parser.add_argument("--flow_shift", type=float, default=7.0) + return parser.parse_args() + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} +if __name__ == "__main__": + args = get_args() + transformer = None + vae = None + dtype = DTYPE_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + assert args.text_encoder_path is not None + # assert args.tokenizer_path is not None + if args.transformer_ckpt_path is not None: + transformer = convert_transformer(args.transformer_ckpt_path) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.vae_ckpt_path is not None: + vae = convert_vae(args.vae_ckpt_path) + vae = vae.to(dtype=dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.save_pipeline: + processor = AutoProcessor.from_pretrained(args.text_encoder_path) + text_encoder = Qwen3VLForConditionalGeneration.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16).to("cuda") + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_path) + flow_shift = 1.5 + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, shift=flow_shift + ) + transformer = transformer.to("cuda") + vae = vae.to("cuda") + pipe = JoyImageEditPipeline( + processor=processor, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ).to("cuda") + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + processor.save_pretrained(f"{args.output_path}/processor") \ No newline at end of file diff --git a/setup.py b/setup.py index d42da57920a0..e16a2b792e25 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,7 @@ "accelerate>=0.31.0", "compel==0.1.8", "datasets", + "einops", "filelock", "flax>=0.4.1", "ftfy", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7d966452d1a2..0caedd8d7b81 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -237,6 +237,7 @@ "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", "HunyuanImageTransformer2DModel", + "JoyImageEditTransformer3DModel", "HunyuanVideo15Transformer3DModel", "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", @@ -596,6 +597,7 @@ "LTXLatentUpsamplePipeline", "LTXPipeline", "LucyEditPipeline", + "JoyImageEditPipeline", "Lumina2Pipeline", "Lumina2Text2ImgPipeline", "LuminaPipeline", @@ -1025,6 +1027,7 @@ HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, HunyuanImageTransformer2DModel, + JoyImageEditTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, @@ -1359,6 +1362,7 @@ LTXLatentUpsamplePipeline, LTXPipeline, LucyEditPipeline, + JoyImageEditPipeline, Lumina2Pipeline, Lumina2Text2ImgPipeline, LuminaPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..6364ce0ab78f 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -110,6 +110,7 @@ _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] + _import_structure["transformers.transformer_joyimage"] = ["JoyImageEditTransformer3DModel", "JoyImageTransformer3DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] @@ -225,6 +226,7 @@ HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, + JoyImageEditTransformer3DModel, HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..1fc78c124618 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -34,6 +34,7 @@ from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel + from .transformer_joyimage import JoyImageEditTransformer3DModel, JoyImageTransformer3DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_joyimage.py b/src/diffusers/models/transformers/transformer_joyimage.py new file mode 100644 index 000000000000..c29284b3d8f7 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_joyimage.py @@ -0,0 +1,658 @@ +import math +from types import SimpleNamespace + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ..attention import FeedForward +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin + +ATTN_BACKEND = 'sdpa' +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func + ATTN_BACKEND = 'flash_attn' +except: + pass + +def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + if len(x) == dim: + return tuple(x) + raise ValueError(f"Expected length {dim} or int, but got {x}") + +def get_meshgrid_nd(start, *args, dim=2): + if len(args) == 0: + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = _to_tuple(args[1], dim=dim) + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") + return torch.stack(grid, dim=0) + +def reshape_for_broadcast(freqs_cis, x: torch.Tensor, head_first: bool = False): + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + if head_first: + assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]) + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + if head_first: + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x: torch.Tensor): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis, head_first: bool = False): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + return xq_out, xk_out + + +def get_1d_rotary_pos_embed( + dim: int, + pos, + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, +): + if isinstance(pos, int): + pos = torch.arange(pos).float() + + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + freqs = torch.outer(pos.float() * interpolation_factor, freqs) + + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) + return freqs_cos, freqs_sin + + return torch.polar(torch.ones_like(freqs), freqs) + + +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + txt_rope_size=None, + theta_rescale_factor=1.0, + interpolation_factor=1.0, +): + rope_dim_list = list(rope_dim_list) + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) + + if isinstance(theta_rescale_factor, (int, float)): + theta_rescale_factor = [float(theta_rescale_factor)] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [float(theta_rescale_factor[0])] * len(rope_dim_list) + + if isinstance(interpolation_factor, (int, float)): + interpolation_factor = [float(interpolation_factor)] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [float(interpolation_factor[0])] * len(rope_dim_list) + + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs.append(emb) + + if use_real: + vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1)) + else: + vis_emb = torch.cat(embs, dim=1) + + if txt_rope_size is None: + return vis_emb, None + + embs_txt = [] + vis_max_ids = grid.view(-1).max().item() + grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1 + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid_txt, + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) + embs_txt.append(emb) + + if use_real: + txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1)) + else: + txt_emb = torch.cat(embs_txt, dim=1) + + return vis_emb, txt_emb + +def get_cu_seqlens(text_mask, img_len): + """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len + + Args: + text_mask (torch.Tensor): the mask of text + img_len (int): the length of image + + Returns: + torch.Tensor: the calculated cu_seqlens for flash attention + """ + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + + cu_seqlens = torch.zeros([2 * batch_size + 1], + dtype=torch.int32, device="cuda") + + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + + return cu_seqlens + +def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + if modulate_type == "wanx": + return ModulateWan(hidden_size, factor, **factory_kwargs) + if modulate_type == "adaLN": + return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs) + if modulate_type == "jdx": + return ModulateX(hidden_size, factor, **factory_kwargs) + raise ValueError(f"Unknown modulation type: {modulate_type}.") + + +class ModulateDiT(nn.Module): + def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.factor = factor + self.act = act_layer() + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + def forward(self, x: torch.Tensor): + return self.linear(self.act(x)).chunk(self.factor, dim=-1) + + +class ModulateWan(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): + super().__init__() + self.factor = factor + self.modulate_table = nn.Parameter( + torch.zeros(1, factor, hidden_size, dtype=dtype, device=device) / hidden_size**0.5, requires_grad=True + ) + + def forward(self, x: torch.Tensor): + if len(x.shape) != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)] + + +class ModulateX(nn.Module): + def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): + super().__init__() + self.factor = factor + + def forward(self, x: torch.Tensor): + if len(x.shape) != 3: + x = x.unsqueeze(1) + return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] + + +def modulate(x, shift=None, scale=None): + if scale is None and shift is None: + return x + if shift is None: + return x * (1 + scale.unsqueeze(1)) + if scale is None: + return x + shift.unsqueeze(1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False): + if gate is None: + return x + return x * (gate.unsqueeze(1).tanh() if tanh else gate.unsqueeze(1)) + +def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_kwargs=None): + batch_size = q.shape[0] + if ATTN_BACKEND == 'sdpa': + q = rearrange(q, "b l h c -> b h l c") + k = rearrange(k, "b l h c -> b h l c") + v = rearrange(v, "b l h c -> b h l c") + output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_kwargs['attn_mask']) + output = rearrange(output, "b h l c -> b l h c") + elif ATTN_BACKEND == 'flash_attn': + cu_seqlens_q = attn_kwargs['cu_seqlens_q'] + cu_seqlens_kv = attn_kwargs['cu_seqlens_kv'] + max_seqlen_q = attn_kwargs['max_seqlen_q'] + max_seqlen_kv = attn_kwargs['max_seqlen_kv'] + x = flash_attn_varlen_func( + q.view(q.shape[0] * q.shape[1], *q.shape[2:]), + k.view(k.shape[0] * k.shape[1], *k.shape[2:]), + v.view(v.shape[0] * v.shape[1], *v.shape[2:]), + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ) + # x with shape [(bxs), a, d] + output = x.view( + batch_size, max_seqlen_q, x.shape[-2], x.shape[-1] + ) # reshape + return output + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +class MMDoubleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + dtype=None, + device=None, + dit_modulation_type: str = "wanx", + attn_backend: str = "torch_spda", + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.attn_backend = attn_backend + self.dit_modulation_type = dit_modulation_type + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) + self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + self.txt_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs) + self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs) + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate") + + def forward(self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, vis_freqs_cis=None, txt_freqs_cis=None, attn_kwargs=None): + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = self.img_mod(vec) + ( + txt_mod1_shift, + txt_mod1_scale, + txt_mod1_gate, + txt_mod2_shift, + txt_mod2_scale, + txt_mod2_gate, + ) = self.txt_mod(vec) + + img_modulated = modulate(self.img_norm1(img), shift=img_mod1_shift, scale=img_mod1_scale) + img_qkv = self.img_attn_qkv(img_modulated) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + if vis_freqs_cis is not None: + img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) + + txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + if txt_freqs_cis is not None: + txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) + + q = torch.cat((img_q, txt_q), dim=1) + k = torch.cat((img_k, txt_k), dim=1) + v = torch.cat((img_v, txt_v), dim=1) + + attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) + img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] + + img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + img = img + apply_gate( + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), + gate=img_mod2_gate, + ) + + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + txt = txt + apply_gate( + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), + gate=txt_mod2_gate, + ) + + return img, txt + + +class WanTimeTextImageEmbedding(nn.Module): + def __init__(self, dim: int, time_freq_dim: int, time_proj_dim: int, text_embed_dim: int): + super().__init__() + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): + timestep = self.timesteps_proj(timestep) + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + return temb, timestep_proj, encoder_hidden_states + + +class JoyImageTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + patch_size: tuple[int, int, int] = (1, 2, 2), + in_channels: int = 16, + out_channels: int = 16, + hidden_size: int = 4096, + heads_num: int = 32, + text_states_dim: int = 4096, + mlp_width_ratio: float = 4.0, + mm_double_blocks_depth: int = 40, + rope_dim_list: tuple[int, int, int] = (16, 56, 56), + rope_type: str = "rope", + dit_modulation_type: str = "wanx", + attn_backend: str = "torch_spda", + unpatchify_new: bool = True, + rope_theta: int = 256, + enable_activation_checkpointing: bool = False, + is_repa: bool = False, + repa_layer: int = 13, + ): + super().__init__() + + self.args = SimpleNamespace( + enable_activation_checkpointing=enable_activation_checkpointing, + is_repa=is_repa, + repa_layer=repa_layer, + ) + + self.out_channels = out_channels or in_channels + self.patch_size = tuple(patch_size) + self.hidden_size = hidden_size + self.heads_num = heads_num + self.rope_dim_list = tuple(rope_dim_list) + self.dit_modulation_type = dit_modulation_type + self.mm_double_blocks_depth = mm_double_blocks_depth + self.attn_backend = attn_backend + self.rope_type = rope_type + self.unpatchify_new = unpatchify_new + self.theta = rope_theta + + if hidden_size % heads_num != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}") + + self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + + self.condition_embedder = WanTimeTextImageEmbedding( + dim=hidden_size, + time_freq_dim=256, + time_proj_dim=hidden_size * 6 if dit_modulation_type != "adaLN" else hidden_size, + text_embed_dim=text_states_dim, + ) + + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + hidden_size=self.hidden_size, + heads_num=self.heads_num, + mlp_width_ratio=mlp_width_ratio, + dit_modulation_type=self.dit_modulation_type, + attn_backend=attn_backend, + ) + for _ in range(mm_double_blocks_depth) + ] + ) + + self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(hidden_size, out_channels * math.prod(self.patch_size)) + + if self.args.is_repa: + self.repa_proj = nn.Linear(hidden_size, text_states_dim) + if self.args.repa_layer > mm_double_blocks_depth: + raise ValueError("repa_layer should be smaller than total depth") + + self.gradient_checkpointing = enable_activation_checkpointing + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None): + target_ndim = 3 + if len(vis_rope_size) != target_ndim: + vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + list(vis_rope_size) + + head_dim = self.hidden_size // self.heads_num + rope_dim_list = list(self.rope_dim_list) + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + if sum(rope_dim_list) != head_dim: + raise ValueError("sum(rope_dim_list) should equal head_dim") + + return get_nd_rotary_pos_embed( + rope_dim_list, + vis_rope_size, + txt_rope_size=txt_rope_size, + theta=self.theta, + use_real=True, + theta_rescale_factor=1, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + return_dict: bool = True, + ): + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states is required.") + + is_multi_item = len(hidden_states.shape) == 6 + num_items = 0 + if is_multi_item: + num_items = hidden_states.shape[1] + if num_items > 1: + if self.patch_size[0] != 1: + raise ValueError("For multi-item input, patch_size[0] must be 1") + hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1) + hidden_states = rearrange(hidden_states, "b n c t h w -> b c (n t) h w") + + _, _, ot, oh, ow = hidden_states.shape + tt, th, tw = ( + ot // self.patch_size[0], + oh // self.patch_size[1], + ow // self.patch_size[2], + ) + + if encoder_hidden_states_mask is None: + encoder_hidden_states_mask = torch.ones( + (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), + dtype=torch.bool, + device=encoder_hidden_states.device, + ) + + img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) + if vec.shape[-1] > self.hidden_size: + vec = vec.unflatten(1, (6, -1)) + + txt_seq_len = txt.shape[1] + vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed( + vis_rope_size=(tt, th, tw), + txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, + ) + + txt_seq_len = txt.shape[1] + img_seq_len = img.shape[1] + + cu_seqlens_q = get_cu_seqlens( + encoder_hidden_states_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q + + + attn_kwargs = {"encoder_hidden_states_mask": encoder_hidden_states_mask} + attn_kwargs.update({ + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_kv': cu_seqlens_kv, + 'max_seqlen_q': max_seqlen_q, + 'max_seqlen_kv': max_seqlen_kv, + }) + + max_seqlen_q = img_seq_len + txt_seq_len + seq_lens = encoder_hidden_states_mask.sum(dim=1) + img_seq_len + max_len = encoder_hidden_states_mask.shape[1] + img_seq_len + assert max_seqlen_q == max_len + positions = torch.arange(max_seqlen_q, device=img.device).unsqueeze(0) + seq_lens_expanded = seq_lens.unsqueeze(1) + mask = positions < seq_lens_expanded + mask = mask.unsqueeze(1).unsqueeze(2) + attn_mask = mask & mask.transpose(-1, -2) + + attn_kwargs.update({'attn_mask': attn_mask, 'max_seqlen_q': max_seqlen_q}) + + img_hidden_states = [] + for block in self.double_blocks: + img, txt = block(img, txt, vec, vis_freqs_cis, txt_freqs_cis, attn_kwargs) + img_hidden_states.append(img) + + img_len = img.shape[1] + x = torch.cat((img, txt), 1) + img = x[:, :img_len, ...] + + img = self.proj_out(self.norm_out(img)) + img = self.unpatchify(img, tt, th, tw) + + repa_hidden_state = None + if self.args.is_repa: + repa_hidden_state = self.repa_proj(img_hidden_states[self.args.repa_layer]) + repa_hidden_state = repa_hidden_state.view(img.shape[0], tt, th, tw, -1) + + if is_multi_item: + img = rearrange(img, "b c (n t) h w -> b n c t h w", n=num_items) + if num_items > 1: + img = torch.cat([img[:, 1:], img[:, :1]], dim=1) + if repa_hidden_state is not None: + repa_hidden_state = rearrange(repa_hidden_state, "b (n t) h w c -> b n t h w c", n=num_items) + + if not return_dict: + return (img, txt, repa_hidden_state) + + return Transformer2DModelOutput(sample=img) + + def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int): + c = self.out_channels + pt, ph, pw = self.patch_size + if t * h * w != x.shape[1]: + raise ValueError("Invalid token length for unpatchify.") + + if self.unpatchify_new: + x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c)) + x = torch.einsum("nthwopqc->nctohpwq", x) + else: + x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) + x = torch.einsum("nthwcopq->nctohpwq", x) + + return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + + +class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): + """ + Backward-compatible alias of JoyImageTransformer3DModel. + """ + + pass \ No newline at end of file diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3dafb56fdd65..d9159be6156b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -299,6 +299,7 @@ "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", ] + _import_structure["joyimage"] = ["JoyImageEditPipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] @@ -740,6 +741,7 @@ ) from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline + from .joyimage import JoyImageEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .marigold import ( diff --git a/src/diffusers/pipelines/joyimage/__init__.py b/src/diffusers/pipelines/joyimage/__init__.py new file mode 100644 index 000000000000..a6d5f31fe63c --- /dev/null +++ b/src/diffusers/pipelines/joyimage/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["JoyImageEditPipelineOutput"] + _import_structure["pipeline_joyimage_edit"] = ["JoyImageEditPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_joyimage_edit import JoyImageEditPipeline + from .pipeline_output import JoyImageEditPipelineOutput +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) \ No newline at end of file diff --git a/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py new file mode 100644 index 000000000000..190cb3d90c31 --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py @@ -0,0 +1,1187 @@ +import inspect +import math +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torchvision.transforms.functional as TF +from einops import rearrange +from PIL import Image +from transformers import AutoProcessor, Qwen2Tokenizer, Qwen3VLForConditionalGeneration, Qwen3VLProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKLWan, JoyImageEditTransformer3DModel +from ..pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import BaseOutput, deprecate, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import JoyImageEditPipelineOutput + + +EXAMPLE_DOC_STRING = """""" + +# Mapping from precision string to torch dtype. +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +class BucketGroup: + """Manages dynamic batch grouping buckets for image inference.""" + + def __init__( + self, + bucket_configs: list[tuple[int, int, int, int, int]], + prioritize_frame_matching: bool = True, + ): + """ + Initialize bucket group with predefined configurations. + + Args: + bucket_configs: List of (batch_size, num_items, num_frames, height, width) tuples. + prioritize_frame_matching: Unused, kept for API compatibility. + """ + self.bucket_configs = [tuple(b) for b in bucket_configs] + + def find_best_bucket(self, media_shape: tuple[int, int, int, int]) -> tuple[int, int, int, int, int]: + """ + Find the best matching bucket for given media dimensions. + + Selects the bucket whose aspect ratio (height/width) is closest to that of + the input media. Only image inference (num_frames=1) is supported. + + Args: + media_shape: (num_items, num_frames, height, width) of the input media. + + Returns: + Best matching bucket as (batch_size, num_items, num_frames, height, width). + + Raises: + ValueError: If num_frames != 1 or no valid bucket is found. + """ + num_items, num_frames, height, width = media_shape + target_aspect_ratio = height / width + + if num_frames != 1: + raise ValueError( + f"Only image inference (num_frames=1) is supported, got num_frames={num_frames}" + ) + + valid_buckets = [ + b for b in self.bucket_configs + if b[1] == num_items and b[2] == 1 + ] + if not valid_buckets: + raise ValueError(f"No image buckets found for shape {media_shape}") + + return min( + valid_buckets, + key=lambda bucket: abs((bucket[3] / bucket[4]) - target_aspect_ratio), + ) + + +def _get_text_encoder_ckpt( + text_encoder: Qwen3VLForConditionalGeneration, + fallback: str = "Qwen/Qwen3-VL-8B-Instruct", +) -> str: + """ + Retrieve the checkpoint identifier from the text encoder. + + Args: + text_encoder: The text encoder model instance. + fallback: Default checkpoint name if none can be resolved. + + Returns: + A non-empty string identifying the checkpoint. + """ + candidates = [ + getattr(text_encoder, "name_or_path", None), + getattr(getattr(text_encoder, "config", None), "_name_or_path", None), + ] + for c in candidates: + if isinstance(c, str) and len(c) > 0: + return c + return fallback + + +def _generate_hw_buckets( + base_height: int = 256, + base_width: int = 256, + step_width: int = 16, + step_height: int = 16, + max_ratio: float = 4.0, +) -> list[tuple[int, int, int, int, int]]: + """ + Generate (batch_size=1, num_items=1, num_frames=1, height, width) bucket tuples + covering a range of aspect ratios while keeping total pixels close to + base_height * base_width. + + Args: + base_height: Reference height in pixels. + base_width: Reference width in pixels. + step_width: Width increment per step. + step_height: Height decrement per step. + max_ratio: Maximum allowed aspect ratio (long side / short side). + + Returns: + List of bucket tuples (1, 1, 1, height, width). + """ + buckets = [] + target_pixels = base_height * base_width + + height = target_pixels // step_width + width = step_width + + while height >= step_height: + if max(height, width) / min(height, width) <= max_ratio: + buckets.append((1, 1, 1, height, width)) + if height * (width + step_width) <= target_pixels: + width += step_width + else: + height -= step_height + + return buckets + + +def generate_video_image_bucket( + basesize: int = 256, + min_temporal: int = 65, + max_temporal: int = 129, + bs_img: int = 8, + bs_vid: int = 1, + bs_mimg: int = 4, + min_items: int = 1, + max_items: int = 1, +) -> list[list[int]]: + """ + Generate bucket configurations for image inference. + + Each bucket is represented as [batch_size, num_items, num_frames, height, width]. + Spatial dimensions are scaled by (basesize // 256) when basesize > 256. + + Args: + basesize: Base spatial resolution. Must be one of {256, 512, 768, 1024}. + min_temporal: Unused; kept for API compatibility. + max_temporal: Unused; kept for API compatibility. + bs_img: Batch size for single-image buckets. + bs_vid: Unused; kept for API compatibility. + bs_mimg: Batch size for multi-image buckets. + min_items: Minimum number of items in multi-image buckets. + max_items: Maximum number of items in multi-image buckets. + + Returns: + List of bucket configs as [batch_size, num_items, num_frames, height, width]. + + Raises: + AssertionError: If basesize is not in {256, 512, 768, 1024}. + """ + assert basesize in [256, 512, 768, 1024], ( + f"[generate_video_image_bucket] unsupported basesize {basesize}" + ) + bucket_list = [] + base_bucket_list = _generate_hw_buckets() + + # Single-image buckets. + for _bucket in base_bucket_list: + bucket = list(_bucket) + bucket[0] = bs_img + bucket_list.append(bucket) + + # Multi-image buckets. + for num_items in range(min_items, max_items + 1): + for _bucket in base_bucket_list: + bucket = list(_bucket) + bucket[0] = bs_mimg + bucket[1] = num_items + bucket_list.append(bucket) + + # Scale spatial dimensions when basesize exceeds 256. + if basesize > 256: + ratio = basesize // 256 + + def _scale(bucket: list[int], r: int) -> list[int]: + bucket[-2] *= r + bucket[-1] *= r + return bucket + + bucket_list = [_scale(bucket, ratio) for bucket in bucket_list] + + return bucket_list + +def _resize_center_crop(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: + """Scale to cover target_size, then center-crop.""" + w, h = img.size # PIL uses (width, height). + bh, bw = target_size + scale = max(bh / h, bw / w) + resize_h = math.ceil(h * scale) + resize_w = math.ceil(w * scale) + img = TF.resize(img, (resize_h, resize_w), interpolation=TF.InterpolationMode.BILINEAR, antialias=True) + img = TF.center_crop(img, target_size) + return img + +def _dynamic_resize_from_bucket(image_size: Tuple[int, int], basesize: int = 512) -> Tuple[int, int]: + """ + Resize and center-crop an image to the nearest bucket dimensions. + + The best-matching bucket is selected based on the image's aspect ratio. + The image is first scaled so that neither dimension is smaller than the + target, then center-cropped to the exact target size. + + Args: + image: Input PIL image. Returns None if None is passed. + basesize: Base resolution used to generate candidate buckets. + + Returns: + Resized and cropped PIL image, or None if input is None. + """ + bucket_config = generate_video_image_bucket( + basesize=basesize, + min_temporal=56, + max_temporal=56, + bs_img=4, + bs_vid=4, + bs_mimg=8, + min_items=2, + max_items=2, + ) + bucket_group = BucketGroup(bucket_config) + src_w, src_h = image_size + bucket = bucket_group.find_best_bucket((1, 1, src_h, src_w)) + target_height, target_width = bucket[-2], bucket[-1] + return target_height, target_width + + + +def _build_args( + args: Any, + text_encoder: Qwen3VLForConditionalGeneration, +) -> Any: + """ + Return args unchanged if provided, otherwise construct a default namespace. + + Args: + args: Existing args object, or None. + text_encoder: Text encoder used to resolve the checkpoint path when args is None. + + Returns: + The original args object, or a SimpleNamespace with sensible defaults. + """ + if args is not None: + return args + + text_encoder_ckpt = _get_text_encoder_ckpt(text_encoder) + return SimpleNamespace( + enable_multi_task_training=False, + text_token_max_length=2048, + dit_precision="bf16", + vae_precision="bf16", + text_encoder_arch_config={"params": {"text_encoder_ckpt": text_encoder_ckpt}}, + ) + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Configure the scheduler and return its timestep sequence. + + Exactly one of ``timesteps``, ``sigmas``, or ``num_inference_steps`` should be + provided to control the denoising schedule. + + Args: + scheduler: The diffusion scheduler. + num_inference_steps: Number of denoising steps (used when neither + ``timesteps`` nor ``sigmas`` is given). + device: Target device for the timestep tensor. + timesteps: Custom discrete timesteps. + sigmas: Custom sigma values (alternative to ``timesteps``). + **kwargs: Additional keyword arguments forwarded to ``set_timesteps``. + + Returns: + Tuple of (timesteps tensor, num_inference_steps int). + + Raises: + ValueError: If both ``timesteps`` and ``sigmas`` are provided, or if the + scheduler does not support the requested schedule parameterisation. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + + if timesteps is not None: + if "timesteps" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom timesteps.") + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + if "sigmas" not in set(inspect.signature(scheduler.set_timesteps).parameters.keys()): + raise ValueError(f"{scheduler.__class__} does not support custom sigmas.") + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + + return timesteps, num_inference_steps + + +@dataclass +class _LegacyPipelineOutput(BaseOutput): + """Legacy output dataclass retained for backward compatibility.""" + + videos: Union[torch.Tensor, np.ndarray] + + +class JoyImageEditPipeline(DiffusionPipeline): + """ + Diffusion pipeline for image editing using the JoyImage architecture. + + The pipeline encodes text and image conditioning via a Qwen3-VL text encoder, + denoises latents with a 3-D transformer, and decodes the result with a WAN VAE. + + Model offloading order: text_encoder -> transformer -> vae. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLWan, + text_encoder: Qwen3VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: JoyImageEditTransformer3DModel, + processor: Qwen3VLProcessor, + args: Any = None, + ): + """ + Initialise the pipeline and register all sub-modules. + + Args: + scheduler: Noise scheduler for the denoising process. + vae: Variational autoencoder used for encoding / decoding latents. + text_encoder: Qwen3-VL multimodal language model for prompt encoding. + tokenizer: Tokenizer paired with the text encoder. + transformer: 3-D transformer denoising network. + processor: Qwen3-VL processor for multi-image prompt preparation. + args: Optional configuration namespace. Defaults are inferred when None. + """ + super().__init__() + self.args = _build_args(args=args, text_encoder=text_encoder) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + processor=processor, + ) + + self.vae_scale_factor_temporal = ( + self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + ) + self.vae_scale_factor_spatial = ( + self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + text_encoder_ckpt = dict(self.args.text_encoder_arch_config.get("params", {})).get( + "text_encoder_ckpt", _get_text_encoder_ckpt(self.text_encoder) + ) + self.qwen_processor = ( + processor if processor is not None else AutoProcessor.from_pretrained(text_encoder_ckpt) + ) + + self.text_token_max_length = self.args.text_token_max_length + + # Prompt templates used when encoding text with / without image tokens. + self.prompt_template_encode = { + "image": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + ), + "multiple_images": ( + "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" + "{}<|im_start|>assistant\n" + ), + } + # Number of system-prompt tokens to drop from the beginning of hidden states. + self.prompt_template_encode_start_idx = { + "image": 34, + "multiple_images": 34, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _extract_masked_hidden( + self, hidden_states: torch.Tensor, mask: torch.Tensor + ) -> tuple[torch.Tensor, ...]: + """ + Extract valid (non-padded) hidden states for each sequence in the batch. + + Args: + hidden_states: Shape (B, T, D). + mask: Binary attention mask of shape (B, T). + + Returns: + Tuple of tensors, one per batch element, each of shape (valid_T, D). + """ + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + return torch.split(selected, valid_lengths.tolist(), dim=0) + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + template_type: str = "image", + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode text prompts using the Qwen tokenizer (text-only path). + + Args: + prompt: A single prompt string or a list of prompt strings. + template_type: Key into ``prompt_template_encode`` / ``prompt_template_encode_start_idx``. + device: Target device. + dtype: Target floating-point dtype. + + Returns: + Tuple of (prompt_embeds, encoder_attention_mask) where both tensors + have shape (B, max_seq_len, D) and (B, max_seq_len) respectively, + zero-padded to the same length. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + max_length=self.text_token_max_length + drop_idx, + padding=True, + truncation=True, + return_tensors="pt", + ).to(device) + + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + + # Drop system-prompt prefix tokens and re-pack into a padded batch. + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [ + torch.ones(e.size(0), dtype=torch.long, device=e.device) + for e in split_hidden_states + ] + + max_seq_len = min( + self.text_token_max_length, + max(u.size(0) for u in split_hidden_states), + max(u.size(0) for u in attn_mask_list), + ) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds, encoder_attention_mask + + def encode_prompt_multiple_images( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + images: Optional[torch.Tensor] = None, + template_type: Optional[str] = "multiple_images", + max_sequence_length: Optional[int] = None, + drop_vit_feature: Optional[float] = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode prompts that contain inline image tokens via the Qwen processor. + + ``\\n`` placeholders in each prompt string are replaced by the + Qwen vision special tokens before being fed to the multimodal encoder. + + Args: + prompt: Prompt string(s), optionally containing ``\\n`` tokens. + device: Target device. + images: Pixel tensors corresponding to the inline image tokens. + template_type: Must be ``"multiple_images"``. + max_sequence_length: If set, truncate the output to this length + (keeping the last ``max_sequence_length`` tokens). + drop_vit_feature: When True, drop all tokens up to and including the + last vision-end token so that only the text portion is returned. + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask). + """ + assert template_type == "multiple_images" + device = device or self._execution_device + template = self.prompt_template_encode[template_type] + drop_idx = self.prompt_template_encode_start_idx[template_type] + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # If no image tokens are present, discard the image tensors. + if not any("\n" in p for p in prompt): + images = None + + prompt = [p.replace("\n", "<|vision_start|><|image_pad|><|vision_end|>") for p in prompt] + prompt = [template.format(p) for p in prompt] + + inputs = self.qwen_processor( + text=prompt, + images=images, + padding=True, + return_tensors="pt", + ).to(device) + + encoder_hidden_states = self.text_encoder(**inputs, output_hidden_states=True) + last_hidden_states = encoder_hidden_states.hidden_states[-1] + + if drop_vit_feature: + # Find the last vision-end token and drop everything before it. + input_ids = inputs["input_ids"] + vlm_image_end_idx = torch.where(input_ids[0] == 151653)[0][-1] + drop_idx = vlm_image_end_idx + 1 + + prompt_embeds = last_hidden_states[:, drop_idx:] + prompt_embeds_mask = inputs["attention_mask"][:, drop_idx:] + + if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length: + prompt_embeds = prompt_embeds[:, -max_sequence_length:, :] + prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:] + + return prompt_embeds, prompt_embeds_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + images: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + template_type: str = "image", + drop_vit_feature: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode a text prompt (and optional inline images) into embeddings. + + When ``images`` is provided the multi-image encoding path is used; + otherwise the text-only Qwen tokenizer path is used. Pre-computed + ``prompt_embeds`` bypass encoding entirely. + + Args: + prompt: Prompt string or list of prompt strings. + images: Optional image tensors for multi-image conditioning. + device: Target device. + num_images_per_prompt: Number of outputs to generate per prompt. + prompt_embeds: Pre-computed prompt embeddings. + prompt_embeds_mask: Attention mask for pre-computed embeddings. + max_sequence_length: Maximum output sequence length. + template_type: Prompt template key (``"image"`` or ``"multiple_images"``). + drop_vit_feature: Drop vision tokens in the multi-image path. + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask). + """ + if images is not None: + return self.encode_prompt_multiple_images( + prompt=prompt, + images=images, + device=device, + max_sequence_length=max_sequence_length, + drop_vit_feature=drop_vit_feature, + ) + + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt, template_type, device + ) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def decode_latents(self, latents: torch.Tensor, enable_tiling: bool = True) -> torch.Tensor: + """ + Decode latents to pixel values. + + .. deprecated:: 1.0.0 + Use the VAE directly instead of calling this method. + + Args: + latents: Latent tensor to decode. + enable_tiling: Whether to enable tiled decoding to save memory. + + Returns: + Float tensor of shape (..., H, W, C) with values in [0, 1]. + """ + deprecation_message = "The decode_latents method is deprecated." + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + if image.ndim == 4: + image = image.cpu().permute(0, 2, 3, 1).float() + else: + image = image.cpu().float() + return image + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + ): + """ + Validate pipeline inputs before the forward pass. + + Raises: + ValueError: On any invalid combination of arguments. + """ + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError("`callback_on_step_end_tensor_inputs` has invalid keys.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot forward both `prompt` and `prompt_embeds`.") + elif prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`.") + elif prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError("`prompt` has to be of type `str` or `list`.") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError("Cannot forward both `negative_prompt` and `negative_prompt_embeds`.") + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError("If `prompt_embeds` are provided, `prompt_embeds_mask` is required.") + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` is required." + ) + + def normalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + """ + Normalise latents using per-channel statistics from the VAE config. + + Uses (latent - mean) / std when the VAE exposes ``latents_mean`` and + ``latents_std``; otherwise falls back to scaling by ``scaling_factor``. + + Args: + latent: Raw latent tensor from ``vae.encode``. + + Returns: + Normalised latent tensor. + """ + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latent = (latent - latents_mean) / latents_std + else: + latent = latent * self.vae.config.scaling_factor + return latent + + def denormalize_latents(self, latent: torch.Tensor) -> torch.Tensor: + """ + Invert :meth:`normalize_latents` to recover the original latent scale. + + Args: + latent: Normalised latent tensor. + + Returns: + Latent tensor in the scale expected by ``vae.decode``. + """ + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to( + device=latent.device, dtype=latent.dtype + ) + latent = latent * latents_std + latents_mean + else: + latent = latent / self.vae.config.scaling_factor + return latent + + def prepare_latents( + self, + batch_size: int, + num_items: int, + num_channels_latents: int, + height: int, + width: int, + video_length: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + latents: Optional[torch.Tensor] = None, + reference_images: Optional[List[Image.Image]] = None, + enable_denormalization: bool = True, + ) -> torch.Tensor: + """ + Prepare the initial noisy latent tensor for the denoising loop. + + When ``reference_images`` is provided the first (num_items - 1) slots are + filled with VAE-encoded reference image latents; the last slot is random noise. + When ``latents`` is provided it is moved to ``device`` without modification. + Otherwise pure random noise is returned. + + Args: + batch_size: Number of samples in the batch. + num_items: Number of image slots (reference + target). + num_channels_latents: Latent channel dimension from the transformer config. + height: Spatial height in pixels. + width: Spatial width in pixels. + video_length: Number of frames (1 for image inference). + dtype: Floating-point dtype for the latent tensor. + device: Target device. + generator: RNG generator(s) for reproducible sampling. + latents: Optional pre-allocated latent tensor. + reference_images: Optional list of PIL images to encode as conditioning. + enable_denormalization: Whether to normalise encoded reference latents. + + Returns: + Latent tensor of shape (B, num_items, C, T, H', W'). + + Raises: + ValueError: If ``generator`` is a list whose length differs from ``batch_size``. + """ + shape = ( + batch_size, + num_items, + num_channels_latents, + (video_length - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError("Generator list length must match batch size.") + + if latents is None: + if reference_images is not None: + # Encode reference images and concatenate with a noise slot. + ref_img = [torch.from_numpy(np.array(x.convert("RGB"))) for x in reference_images] + ref_img = torch.stack(ref_img).to(device=device, dtype=dtype) + ref_img = ref_img / 127.5 - 1.0 + ref_img = rearrange(ref_img, "x h w c -> x c 1 h w") + ref_vae = self.vae.encode(ref_img).latent_dist.sample() + if enable_denormalization: + ref_vae = self.normalize_latents(ref_vae) + ref_vae = rearrange(ref_vae, "(b n) c 1 h w -> b n c 1 h w", n=(num_items - 1)) + noise = randn_tensor((shape[0], 1, *shape[2:]), generator=generator, device=device, dtype=dtype) + latents = torch.cat([ref_vae, noise], dim=1) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + # ------------------------------------------------------------------ + # Pipeline properties + # ------------------------------------------------------------------ + + @property + def guidance_scale(self) -> float: + """Classifier-free guidance scale used in the current forward pass.""" + return self._guidance_scale + + @property + def do_classifier_free_guidance(self) -> bool: + """True when guidance_scale > 1, enabling classifier-free guidance.""" + return self._guidance_scale > 1 + + @property + def num_timesteps(self) -> int: + """Total number of denoising timesteps in the current forward pass.""" + return self._num_timesteps + + @property + def interrupt(self) -> bool: + """When True, the denoising loop is interrupted at the next step.""" + return self._interrupt + + # ------------------------------------------------------------------ + # Utility + # ------------------------------------------------------------------ + + def pad_sequence(self, x: torch.Tensor, target_length: int) -> torch.Tensor: + """ + Truncate or zero-pad a sequence tensor along dimension 1. + + If the sequence is longer than ``target_length`` the last + ``target_length`` elements are kept. If it is shorter, zero-padding + is appended on the right. + + Args: + x: Input tensor of shape (B, T, ...) or (B, T). + target_length: Desired sequence length. + + Returns: + Tensor of shape (B, target_length, ...) or (B, target_length). + """ + current_length = x.shape[1] + if current_length >= target_length: + return x[:, -target_length:] + padding_length = target_length - current_length + if x.ndim >= 3: + padding = torch.zeros((x.shape[0], padding_length, *x.shape[2:]), dtype=x.dtype, device=x.device) + else: + padding = torch.zeros((x.shape[0], padding_length), dtype=x.dtype, device=x.device) + return torch.cat([x, padding], dim=1) + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + prompt: str | list[str] = None, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + enable_tiling: bool = False, + max_sequence_length: int = 4096, + drop_vit_feature: bool = False, + enable_denormalization: bool = True, + **kwargs, + ): + r""" + Generate an edited image conditioned on a reference image and a text prompt. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide generation. + height (`int`): + Height of the generated output in pixels. + width (`int`): + Width of the generated output in pixels. + image (`PipelineImageInput`, *optional*): + Reference image used for conditioning. When provided the pipeline + operates in image-editing mode with ``num_items=2``. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps. More steps generally improve quality at + the cost of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps for the denoising process. When provided, + ``num_inference_steps`` is inferred from the list length. + sigmas (`List[float]`, *optional*): + Custom sigmas for the denoising process. Mutually exclusive with + ``timesteps``. + guidance_scale (`float`, *optional*, defaults to 4.0): + Classifier-free guidance scale. + negative_prompt (`str` or `List[str]`, *optional*): + Negative prompt(s) used to suppress undesired content. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of generated samples per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + RNG generator(s) for deterministic sampling. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. Sampled from a Gaussian distribution + when not provided. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed prompt embeddings. When provided ``prompt`` can be omitted. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for ``prompt_embeds``. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative prompt embeddings. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Attention mask for ``negative_prompt_embeds``. + output_type (`str`, *optional*, defaults to ``"pil"``): + Output format. Pass ``"latent"`` to return raw latents. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a :class:`JoyImageEditPipelineOutput` or a plain tensor. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + Callback invoked at the end of each denoising step with signature + ``(self, step: int, timestep: int, callback_kwargs: Dict)``. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*, defaults to ``["latents"]``): + Tensor keys included in ``callback_kwargs`` for ``callback_on_step_end``. + enable_tiling (`bool`, *optional*, defaults to `False`): + Enable tiled VAE decoding to reduce peak memory usage. + max_sequence_length (`int`, *optional*, defaults to 4096): + Maximum sequence length for prompt encoding. + drop_vit_feature (`bool`, *optional*, defaults to `False`): + Drop vision tokens in the multi-image encoding path. + enable_denormalization (`bool`, *optional*, defaults to `True`): + Denormalise latents before VAE decoding. + **kwargs: + Additional keyword arguments for forward compatibility. + + Examples: + + Returns: + [`~pipelines.joyimage.JoyImageEditPipelineOutput`] or `torch.Tensor`: + If ``return_dict`` is ``True``, returns a pipeline output object + containing the generated image(s). Otherwise returns the image + tensor directly. + """ + # Resize the input image to the nearest bucket resolution. + # Or resize the specified height and width to the nearest bucket resolution. + image_size = image[0].size if isinstance(image, list) else image.size + if height is not None and width is not None: + # Override the image size if specified. + image_size = (width, height) + + height, width = _dynamic_resize_from_bucket(image_size, basesize=1024) + processed_image = _resize_center_crop(image, (height, width)) + + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # num_items: 1 for unconditional generation, 2 for reference-image editing. + num_items = 1 if image is None else 2 + + # Encode the conditioning prompt (and reference image when present). + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + images=processed_image, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + template_type="image", + drop_vit_feature=drop_vit_feature, + ) + + if self.do_classifier_free_guidance: + # Build default negative prompts when none are provided. + if negative_prompt is None and negative_prompt_embeds is None: + if num_items <= 1: + negative_prompt = ["<|im_start|>user\n<|im_end|>\n"] * batch_size + else: + negative_prompt = ["<|im_start|>user\n\n<|im_end|>\n"] * batch_size + + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + images=processed_image, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + template_type="image", + ) + + # Pad both embeddings to the same sequence length and concatenate + # in (unconditional, conditional) order for a single forward pass. + max_seq_len = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1]) + prompt_embeds = torch.cat( + [ + self.pad_sequence(negative_prompt_embeds, max_seq_len), + self.pad_sequence(prompt_embeds, max_seq_len), + ] + ) + if prompt_embeds_mask is not None: + prompt_embeds_mask = torch.cat( + [ + self.pad_sequence(negative_prompt_embeds_mask, max_seq_len), + self.pad_sequence(prompt_embeds_mask, max_seq_len), + ] + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas, + ) + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_items, + num_channels_latents, + height, + width, + 1, # video_length = 1 for image inference + prompt_embeds.dtype, + device, + generator, + latents, + reference_images=[processed_image], + enable_denormalization=enable_denormalization, + ) + + target_dtype = PRECISION_TO_TYPE[self.args.dit_precision] + autocast_enabled = target_dtype != torch.float32 + vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision] + vae_autocast_enabled = vae_dtype != torch.float32 + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + # Cache reference latents to restore them at each denoising step. + if num_items > 1: + ref_latents = latents[:, :(num_items - 1)].clone() + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # Restore reference latents so they are never overwritten by the scheduler. + if num_items > 1: + latents[:, :(num_items - 1)] = ref_latents.clone() + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + t_expand = t.repeat(latent_model_input.shape[0]) + + with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_expand, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # Rescale to match the conditional prediction norm (guidance rescaling). + cond_norm = torch.norm(noise_pred_text, dim=2, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=2, keepdim=True) + noise_pred = noise_pred * (cond_norm / noise_norm.clamp_min(1e-6)) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + if progress_bar is not None: + progress_bar.update() + + if output_type != "latent": + latents = rearrange(latents, "b n c f h w -> (b n) c f h w") + if enable_denormalization: + latents = self.denormalize_latents(latents) + + with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled): + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False)[0] + image = rearrange(image, "(b n) c f h w -> b n c f h w", b=batch_size) + else: + image = latents + + # Extract the last item (target slot) from the batch, shape: (F, C, H, W). + image = image.float().permute(0, 1, 3, 2, 4, 5)[0, -1] + + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return JoyImageEditPipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/pipelines/joyimage/pipeline_output.py b/src/diffusers/pipelines/joyimage/pipeline_output.py new file mode 100644 index 000000000000..a98b9066c69a --- /dev/null +++ b/src/diffusers/pipelines/joyimage/pipeline_output.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +import torch + +from ...utils import BaseOutput +import PIL.Image +from typing import Union, List, Tuple, Optional +import numpy as np + +@dataclass +class JoyImageEditPipelineOutput(BaseOutput): + """ + Output class for JoyImageEdit generation pipelines. + """ + + images: Union[List[PIL.Image.Image], np.ndarray]