[feat] JoyAI-JoyImage-Edit support#13444
Conversation
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks for the PR! I left some initial feedbacks
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from einops import rearrange |
There was a problem hiding this comment.
can we refactor our einops stuff? it is not a diffusers dependency
There was a problem hiding this comment.
ok, I will refactor and remove einops
| return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) | ||
|
|
||
|
|
||
| class JoyImageEditTransformer3DModel(JoyImageTransformer3DModel): |
There was a problem hiding this comment.
ohh what's going on here? is this some legancy code? can we remove?
There was a problem hiding this comment.
We first developed JoyImage, and then trained JoyImage-Edit based on it. This Transformer 3D model belongs to JoyImage, and JoyImage-Edit is inherited from JoyImage. We will also open-source JoyImage in the future.
They essentially share similar Transformer 3D models. I understand that each pipeline requires a specific Transformer model, which is why we implemented inheritance in this way.
| img_qkv = self.img_attn_qkv(img_modulated) | ||
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| img_q = self.img_attn_q_norm(img_q).to(img_v) | ||
| img_k = self.img_attn_k_norm(img_k).to(img_v) | ||
| if vis_freqs_cis is not None: | ||
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | ||
|
|
||
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | ||
| txt_qkv = self.txt_attn_qkv(txt_modulated) | ||
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | ||
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | ||
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | ||
| if txt_freqs_cis is not None: | ||
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | ||
|
|
||
| q = torch.cat((img_q, txt_q), dim=1) | ||
| k = torch.cat((img_k, txt_k), dim=1) | ||
| v = torch.cat((img_v, txt_v), dim=1) | ||
|
|
||
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | ||
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] |
There was a problem hiding this comment.
| img_qkv = self.img_attn_qkv(img_modulated) | |
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| img_q = self.img_attn_q_norm(img_q).to(img_v) | |
| img_k = self.img_attn_k_norm(img_k).to(img_v) | |
| if vis_freqs_cis is not None: | |
| img_q, img_k = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False) | |
| txt_modulated = modulate(self.txt_norm1(txt), shift=txt_mod1_shift, scale=txt_mod1_scale) | |
| txt_qkv = self.txt_attn_qkv(txt_modulated) | |
| txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) | |
| txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) | |
| txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) | |
| if txt_freqs_cis is not None: | |
| txt_q, txt_k = apply_rotary_emb(txt_q, txt_k, txt_freqs_cis, head_first=False) | |
| q = torch.cat((img_q, txt_q), dim=1) | |
| k = torch.cat((img_k, txt_k), dim=1) | |
| v = torch.cat((img_v, txt_v), dim=1) | |
| attn = attention(q, k, v, attn_kwargs=attn_kwargs).flatten(2, 3) | |
| img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] | |
| attn_output, text_attn_output = self.attn(...) |
can we refactor the attention implementation to follow diffusers style?
basically you need to move all the layers used in attention calculation here into a JoyImageAttention (similar to FluxAttention https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L275)
also create a JoyImageAttnProcessor (see FluxAttnProcessor as example, I think it is same) https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L75 )
There was a problem hiding this comment.
Thanks for the reminder. I'll clean up this messy code.
| def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | ||
| factory_kwargs = {"dtype": dtype, "device": device} | ||
| if modulate_type == "wanx": | ||
| return ModulateWan(hidden_size, factor, **factory_kwargs) | ||
| if modulate_type == "adaLN": | ||
| return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs) | ||
| if modulate_type == "jdx": | ||
| return ModulateX(hidden_size, factor, **factory_kwargs) | ||
| raise ValueError(f"Unknown modulation type: {modulate_type}.") |
There was a problem hiding this comment.
| def load_modulation(modulate_type: str, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| if modulate_type == "wanx": | |
| return ModulateWan(hidden_size, factor, **factory_kwargs) | |
| if modulate_type == "adaLN": | |
| return ModulateDiT(hidden_size, factor, act_layer, **factory_kwargs) | |
| if modulate_type == "jdx": | |
| return ModulateX(hidden_size, factor, **factory_kwargs) | |
| raise ValueError(f"Unknown modulation type: {modulate_type}.") |
| class ModulateX(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | ||
| super().__init__() | ||
| self.factor = factor | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| if len(x.shape) != 3: | ||
| x = x.unsqueeze(1) | ||
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
There was a problem hiding this comment.
| class ModulateX(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, dtype=None, device=None): | |
| super().__init__() | |
| self.factor = factor | |
| def forward(self, x: torch.Tensor): | |
| if len(x.shape) != 3: | |
| x = x.unsqueeze(1) | |
| return [o.squeeze(1) for o in x.chunk(self.factor, dim=1)] |
| class ModulateDiT(nn.Module): | ||
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | ||
| factory_kwargs = {"dtype": dtype, "device": device} | ||
| super().__init__() | ||
| self.factor = factor | ||
| self.act = act_layer() | ||
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | ||
| nn.init.zeros_(self.linear.weight) | ||
| nn.init.zeros_(self.linear.bias) | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
There was a problem hiding this comment.
| class ModulateDiT(nn.Module): | |
| def __init__(self, hidden_size: int, factor: int, act_layer=nn.SiLU, dtype=None, device=None): | |
| factory_kwargs = {"dtype": dtype, "device": device} | |
| super().__init__() | |
| self.factor = factor | |
| self.act = act_layer() | |
| self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs) | |
| nn.init.zeros_(self.linear.weight) | |
| nn.init.zeros_(self.linear.bias) | |
| def forward(self, x: torch.Tensor): | |
| return self.linear(self.act(x)).chunk(self.factor, dim=-1) |
is ModulateWan is one used in the model? if so let's remove the ModulateDit and ModulateX
| head_dim = hidden_size // heads_num | ||
| mlp_hidden_dim = int(hidden_size * mlp_width_ratio) | ||
|
|
||
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) |
There was a problem hiding this comment.
| self.img_mod = load_modulation(self.dit_modulation_type, hidden_size, 6, **factory_kwargs) | |
| self.img_mod = JoyImageModulate(...) |
let's remove the load_modulation function and use the layer directly, better to rename to JoyImageModulate too
There was a problem hiding this comment.
Ok, I will refactor modulation and use ModulateWan
Description
We are the JoyAI Team, and this is the Diffusers implementation for the JoyAI-Image-Edit model.
GitHub Repository: [https://github.com/jd-opensource/JoyAI-Image]
Hugging Face Model: [https://huggingface.co/jdopensource/JoyAI-Image-Edit-Diffusers]
Original opensource weights [https://huggingface.co/jdopensource/JoyAI-Image-Edit]
Fixes #13430
Model Overview
JoyAI-Image is a unified multimodal foundation model for image understanding, text-to-image generation, and instruction-guided image editing. It combines an 8B Multimodal Large Language Model (MLLM) with a 16B Multimodal Diffusion Transformer (MMDiT).
Kye Features
Image edit examples