Skip to content

[feat] JoyAI-JoyImage-Edit support#13444

Open
Moran232 wants to merge 1 commit intohuggingface:mainfrom
Moran232:joyimage_edit
Open

[feat] JoyAI-JoyImage-Edit support#13444
Moran232 wants to merge 1 commit intohuggingface:mainfrom
Moran232:joyimage_edit

Conversation

@Moran232
Copy link
Copy Markdown

@Moran232 Moran232 commented Apr 10, 2026

Description

We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.

GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430

Model Overview

JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).

Kye Features

  • Advanced Text Rendering Showcase: JoyAI-Image is optimized for challenging text-heavy scenarios, including multi-panel comics, dense multi-line text, multilingual typography, long-form layouts, real-world scene text, and handwritten styles.
  • Multi-view Generation and Spatial Editing Showcase: JoyAI-Image showcases a spatially grounded generation and editing pipeline that supports multi-view generation, geometry-aware transformations, camera control, object rotation, and precise location-specific object editing. Across these settings, it preserves scene content, structure, and visual consistency while following viewpoint-sensitive instructions more accurately.
  • Spatial Editing for Spatial Reasoning Showcase: JoyAI-Image poses high-fidelity spatial editing, serving as a powerful catalyst for enhancing spatial reasoning. Compared with Qwen-Image-Edit and Nano Banana Pro, JoyAI-Image-Edit synthesizes the most diagnostic viewpoints by faithfully executing camera motions. These high-fidelity novel views effectively disambiguate complex spatial relations, providing clearer visual evidence for downstream reasoning.

Image edit examples

spatial-editing-showcase

@github-actions github-actions bot added models pipelines size/L PR with diff > 200 LOC labels Apr 10, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! I left some initial feedbacks

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we refactor our einops stuff? it is not a diffusers dependency

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will refactor and remove einops

return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))


class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh what's going on here? is this some legancy code? can we remove?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.

They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.

Comment on lines +371 to +391
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)

txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)

q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)

attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False)
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
attn_output, text_attn_output = self.attn(...)

can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)

also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I'll clean up this messy code.

Comment on lines +203 to +211
def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
if modulate_type == "wanx":
return ModulateWan(hidden_size, factor, **factory_kwargs)
if modulate_type == "adaLN":
return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs)
if modulate_type == "jdx":
return ModulateX(hidden_size, factor, **factory_kwargs)
raise ValueError(f"Unknown modulation type: {modulate_type}.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
if modulate_type == "wanx":
return ModulateWan(hidden_size, factor, **factory_kwargs)
if modulate_type == "adaLN":
return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs)
if modulate_type == "jdx":
return ModulateX(hidden_size, factor, **factory_kwargs)
raise ValueError(f"Unknown modulation type: {modulate_type}.")

Comment on lines +242 to +250
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor

def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateX(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
def forward(self, x: torch.Tensor):
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)]

Comment on lines +214 to +225
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)

def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ModulateDiT(nn.Module):
def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.factor = factor
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor):
return self.linear(self.act(x)).chunk(self.factor, dim=-1)

is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX

head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)

self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs)
self.img_mod = JoyImageModulate(...)

let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will refactor modulation and use ModulateWan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models pipelines size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for JoyAI-Image-Edit

2 participants