Skip to content

Commit f5e5f34

Browse files
authored
[modular] add tests for qwen modular (#12585)
* add tests for qwenimage modular. * qwenimage edit. * qwenimage edit plus. * empty * align with the latest structure * up * up * reason * up * fix multiple issues. * up * up * fix * up * make it similar to the original pipeline.
1 parent 093cd3f commit f5e5f34

File tree

9 files changed

+194
-123
lines changed

9 files changed

+194
-123
lines changed

src/diffusers/modular_pipelines/qwenimage/before_denoise.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def expected_components(self) -> List[ComponentSpec]:
132132
@property
133133
def inputs(self) -> List[InputParam]:
134134
return [
135+
InputParam("latents"),
135136
InputParam(name="height"),
136137
InputParam(name="width"),
137138
InputParam(name="num_images_per_prompt", default=1),
@@ -196,11 +197,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
196197
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
197198
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
198199
)
199-
200-
block_state.latents = randn_tensor(
201-
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
202-
)
203-
block_state.latents = components.pachifier.pack_latents(block_state.latents)
200+
if block_state.latents is None:
201+
block_state.latents = randn_tensor(
202+
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
203+
)
204+
block_state.latents = components.pachifier.pack_latents(block_state.latents)
204205

205206
self.set_block_state(state, block_state)
206207
return components, state
@@ -549,8 +550,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
549550
block_state.width // components.vae_scale_factor // 2,
550551
)
551552
]
552-
* block_state.batch_size
553-
]
553+
] * block_state.batch_size
554554
block_state.txt_seq_lens = (
555555
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
556556
)

src/diffusers/modular_pipelines/qwenimage/decoders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
7474
block_state = self.get_block_state(state)
7575

7676
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
77+
vae_scale_factor = components.vae_scale_factor
7778
block_state.latents = components.pachifier.unpack_latents(
78-
block_state.latents, block_state.height, block_state.width
79+
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
7980
)
8081
block_state.latents = block_state.latents.to(components.vae.dtype)
8182

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
503503
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
504504
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
505505

506+
block_state.negative_prompt_embeds = None
507+
block_state.negative_prompt_embeds_mask = None
506508
if components.requires_unconditional_embeds:
507509
negative_prompt = block_state.negative_prompt or ""
508510
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
@@ -627,6 +629,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
627629
device=device,
628630
)
629631

632+
block_state.negative_prompt_embeds = None
633+
block_state.negative_prompt_embeds_mask = None
630634
if components.requires_unconditional_embeds:
631635
negative_prompt = block_state.negative_prompt or " "
632636
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
@@ -679,6 +683,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
679683
device=device,
680684
)
681685

686+
block_state.negative_prompt_embeds = None
687+
block_state.negative_prompt_embeds_mask = None
682688
if components.requires_unconditional_embeds:
683689
negative_prompt = block_state.negative_prompt or " "
684690
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (

src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin):
2626
config_name = "config.json"
2727

2828
@register_to_config
29-
def __init__(
30-
self,
31-
patch_size: int = 2,
32-
):
29+
def __init__(self, patch_size: int = 2):
3330
super().__init__()
3431

3532
def pack_latents(self, latents):

tests/modular_pipelines/flux/test_modular_pipeline_flux.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def get_dummy_inputs(self, seed=0):
5555
}
5656
return inputs
5757

58+
def test_float16_inference(self):
59+
super().test_float16_inference(9e-2)
60+
5861

5962
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
6063
pipeline_class = FluxModularPipeline
@@ -118,6 +121,9 @@ def test_save_from_pretrained(self):
118121

119122
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
120123

124+
def test_float16_inference(self):
125+
super().test_float16_inference(8e-2)
126+
121127

122128
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
123129
pipeline_class = FluxKontextModularPipeline
@@ -170,3 +176,6 @@ def test_save_from_pretrained(self):
170176
image_slices.append(image[0, -3:, -3:, -1].flatten())
171177

172178
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
179+
180+
def test_float16_inference(self):
181+
super().test_float16_inference(9e-2)

tests/modular_pipelines/qwen/__init__.py

Whitespace-only changes.
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import PIL
18+
import pytest
19+
20+
from diffusers.modular_pipelines import (
21+
QwenImageAutoBlocks,
22+
QwenImageEditAutoBlocks,
23+
QwenImageEditModularPipeline,
24+
QwenImageEditPlusAutoBlocks,
25+
QwenImageEditPlusModularPipeline,
26+
QwenImageModularPipeline,
27+
)
28+
29+
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
30+
31+
32+
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
33+
pipeline_class = QwenImageModularPipeline
34+
pipeline_blocks_class = QwenImageAutoBlocks
35+
repo = "hf-internal-testing/tiny-qwenimage-modular"
36+
37+
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
38+
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
39+
40+
def get_dummy_inputs(self):
41+
generator = self.get_generator()
42+
inputs = {
43+
"prompt": "dance monkey",
44+
"negative_prompt": "bad quality",
45+
"generator": generator,
46+
"num_inference_steps": 2,
47+
"height": 32,
48+
"width": 32,
49+
"max_sequence_length": 16,
50+
"output_type": "pt",
51+
}
52+
return inputs
53+
54+
def test_inference_batch_single_identical(self):
55+
super().test_inference_batch_single_identical(expected_max_diff=5e-4)
56+
57+
58+
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
59+
pipeline_class = QwenImageEditModularPipeline
60+
pipeline_blocks_class = QwenImageEditAutoBlocks
61+
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"
62+
63+
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
64+
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
65+
66+
def get_dummy_inputs(self):
67+
generator = self.get_generator()
68+
inputs = {
69+
"prompt": "dance monkey",
70+
"negative_prompt": "bad quality",
71+
"generator": generator,
72+
"num_inference_steps": 2,
73+
"height": 32,
74+
"width": 32,
75+
"output_type": "pt",
76+
}
77+
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
78+
return inputs
79+
80+
def test_guider_cfg(self):
81+
super().test_guider_cfg(7e-5)
82+
83+
84+
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
85+
pipeline_class = QwenImageEditPlusModularPipeline
86+
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
87+
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
88+
89+
# No `mask_image` yet.
90+
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
91+
batch_params = frozenset(["prompt", "negative_prompt", "image"])
92+
93+
def get_dummy_inputs(self):
94+
generator = self.get_generator()
95+
inputs = {
96+
"prompt": "dance monkey",
97+
"negative_prompt": "bad quality",
98+
"generator": generator,
99+
"num_inference_steps": 2,
100+
"height": 32,
101+
"width": 32,
102+
"output_type": "pt",
103+
}
104+
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
105+
return inputs
106+
107+
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
108+
def test_num_images_per_prompt(self):
109+
super().test_num_images_per_prompt()
110+
111+
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
112+
def test_inference_batch_consistent():
113+
super().test_inference_batch_consistent()
114+
115+
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
116+
def test_inference_batch_single_identical():
117+
super().test_inference_batch_single_identical()
118+
119+
def test_guider_cfg(self):
120+
super().test_guider_cfg(1e-3)

tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py

Lines changed: 11 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
2727
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
28-
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
28+
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
2929

3030

3131
enable_full_determinism()
@@ -37,13 +37,11 @@ class SDXLModularTesterMixin:
3737
"""
3838

3939
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
40-
sd_pipe = self.get_pipeline()
41-
sd_pipe = sd_pipe.to(torch_device)
42-
sd_pipe.set_progress_bar_config(disable=None)
40+
sd_pipe = self.get_pipeline().to(torch_device)
4341

4442
inputs = self.get_dummy_inputs()
4543
image = sd_pipe(**inputs, output="images")
46-
image_slice = image[0, -3:, -3:, -1]
44+
image_slice = image[0, -3:, -3:, -1].cpu()
4745

4846
assert image.shape == expected_image_shape
4947
max_diff = torch.abs(image_slice.flatten() - expected_slice).max()
@@ -110,7 +108,7 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
110108
pipe = blocks.init_pipeline(self.repo)
111109
pipe.load_components(torch_dtype=torch.float32)
112110
pipe = pipe.to(torch_device)
113-
pipe.set_progress_bar_config(disable=None)
111+
114112
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
115113

116114
# forward pass without ip adapter
@@ -219,9 +217,7 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
219217
# compare against static slices and that can be shaky (with a VVVV low probability).
220218
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
221219

222-
pipe = self.get_pipeline()
223-
pipe = pipe.to(torch_device)
224-
pipe.set_progress_bar_config(disable=None)
220+
pipe = self.get_pipeline().to(torch_device)
225221

226222
# forward pass without controlnet
227223
inputs = self.get_dummy_inputs()
@@ -251,9 +247,7 @@ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
251247
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
252248

253249
def test_controlnet_cfg(self):
254-
pipe = self.get_pipeline()
255-
pipe = pipe.to(torch_device)
256-
pipe.set_progress_bar_config(disable=None)
250+
pipe = self.get_pipeline().to(torch_device)
257251

258252
# forward pass with CFG not applied
259253
guider = ClassifierFreeGuidance(guidance_scale=1.0)
@@ -273,35 +267,11 @@ def test_controlnet_cfg(self):
273267
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
274268

275269

276-
class SDXLModularGuiderTesterMixin:
277-
def test_guider_cfg(self):
278-
pipe = self.get_pipeline()
279-
pipe = pipe.to(torch_device)
280-
pipe.set_progress_bar_config(disable=None)
281-
282-
# forward pass with CFG not applied
283-
guider = ClassifierFreeGuidance(guidance_scale=1.0)
284-
pipe.update_components(guider=guider)
285-
286-
inputs = self.get_dummy_inputs()
287-
out_no_cfg = pipe(**inputs, output="images")
288-
289-
# forward pass with CFG applied
290-
guider = ClassifierFreeGuidance(guidance_scale=7.5)
291-
pipe.update_components(guider=guider)
292-
inputs = self.get_dummy_inputs()
293-
out_cfg = pipe(**inputs, output="images")
294-
295-
assert out_cfg.shape == out_no_cfg.shape
296-
max_diff = np.abs(out_cfg - out_no_cfg).max()
297-
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
298-
299-
300270
class TestSDXLModularPipelineFast(
301271
SDXLModularTesterMixin,
302272
SDXLModularIPAdapterTesterMixin,
303273
SDXLModularControlNetTesterMixin,
304-
SDXLModularGuiderTesterMixin,
274+
ModularGuiderTesterMixin,
305275
ModularPipelineTesterMixin,
306276
):
307277
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
@@ -335,18 +305,7 @@ def test_stable_diffusion_xl_euler(self):
335305
self._test_stable_diffusion_xl_euler(
336306
expected_image_shape=self.expected_image_output_shape,
337307
expected_slice=torch.tensor(
338-
[
339-
0.5966781,
340-
0.62939394,
341-
0.48465094,
342-
0.51573336,
343-
0.57593524,
344-
0.47035995,
345-
0.53410417,
346-
0.51436996,
347-
0.47313565,
348-
],
349-
device=torch_device,
308+
[0.3886, 0.4685, 0.4953, 0.4217, 0.4317, 0.3945, 0.4847, 0.4704, 0.4731],
350309
),
351310
expected_max_diff=1e-2,
352311
)
@@ -359,7 +318,7 @@ class TestSDXLImg2ImgModularPipelineFast(
359318
SDXLModularTesterMixin,
360319
SDXLModularIPAdapterTesterMixin,
361320
SDXLModularControlNetTesterMixin,
362-
SDXLModularGuiderTesterMixin,
321+
ModularGuiderTesterMixin,
363322
ModularPipelineTesterMixin,
364323
):
365324
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
@@ -400,20 +359,7 @@ def get_dummy_inputs(self, seed=0):
400359
def test_stable_diffusion_xl_euler(self):
401360
self._test_stable_diffusion_xl_euler(
402361
expected_image_shape=self.expected_image_output_shape,
403-
expected_slice=torch.tensor(
404-
[
405-
0.56943184,
406-
0.4702148,
407-
0.48048905,
408-
0.6235963,
409-
0.551138,
410-
0.49629188,
411-
0.60031277,
412-
0.5688907,
413-
0.43996853,
414-
],
415-
device=torch_device,
416-
),
362+
expected_slice=torch.tensor([0.5246, 0.4466, 0.444, 0.3246, 0.4443, 0.5108, 0.5225, 0.559, 0.5147]),
417363
expected_max_diff=1e-2,
418364
)
419365

@@ -425,7 +371,7 @@ class SDXLInpaintingModularPipelineFastTests(
425371
SDXLModularTesterMixin,
426372
SDXLModularIPAdapterTesterMixin,
427373
SDXLModularControlNetTesterMixin,
428-
SDXLModularGuiderTesterMixin,
374+
ModularGuiderTesterMixin,
429375
ModularPipelineTesterMixin,
430376
):
431377
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""

0 commit comments

Comments
 (0)