From 9a5d2264a0780fc2afc1dc98f82e7d958b5ef2fd Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 27 May 2026 14:48:31 +0800 Subject: [PATCH 1/2] fix flux tests OOM on 24G GPU Signed-off-by: jiqing-feng --- tests/models/testing_utils/single_file.py | 10 ++-- .../test_models_transformer_flux.py | 56 ++----------------- ...test_model_flux_transformer_single_file.py | 6 +- 3 files changed, 15 insertions(+), 57 deletions(-) diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index e2b9dadb6140..2e2d7a435fb8 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -107,8 +107,8 @@ def teardown_method(self): backend_empty_cache(torch_device) def test_single_file_model_config(self): - pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs} - single_file_kwargs = {"device": torch_device} + pretrained_kwargs = {"device_map": "auto", **self.pretrained_model_kwargs} + single_file_kwargs = {"device_map": "auto"} if self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype @@ -127,8 +127,8 @@ def test_single_file_model_config(self): ) def test_single_file_model_parameters(self): - pretrained_kwargs = {"device_map": str(torch_device), **self.pretrained_model_kwargs} - single_file_kwargs = {"device": torch_device} + pretrained_kwargs = {"device_map": "auto", **self.pretrained_model_kwargs} + single_file_kwargs = {"device_map": "auto"} if self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype @@ -259,7 +259,7 @@ def test_checkpoint_variant_loading(self): backend_empty_cache(torch_device) def test_single_file_loading_with_device_map(self): - single_file_kwargs = {"device_map": torch_device} + single_file_kwargs = {"device_map": "auto"} if self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index b5e65f6e0dea..119fd17a21a0 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -13,13 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tempfile from typing import Any import pytest import torch -from diffusers import BitsAndBytesConfig, FluxTransformer2DModel +from diffusers import FluxTransformer2DModel from diffusers.models.embeddings import ImageProjection from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor from diffusers.utils.torch_utils import randn_tensor @@ -344,6 +343,10 @@ def alternate_ckpt_paths(self): def pretrained_model_name_or_path(self): return "black-forest-labs/FLUX.1-dev" + @property + def torch_dtype(self): + return torch.bfloat16 + class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): """BitsAndBytes quantization tests for Flux Transformer.""" @@ -449,57 +452,10 @@ class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCo """ModelOpt + compile tests for Flux Transformer.""" +@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes") class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin): """BitsAndBytes + compile tests for Flux Transformer.""" - def get_init_dict(self) -> dict[str, int | list[int]]: - # Dims must be multiples of 64 (bnb 4bit blocksize) so single-token activations - # don't trigger the runtime `warn()` inside bnb.matmul_4bit that breaks fullgraph compile. - return { - "patch_size": 1, - "in_channels": 4, - "num_layers": 1, - "num_single_layers": 1, - "attention_head_dim": 32, - "num_attention_heads": 2, - "joint_attention_dim": 64, - "pooled_projection_dim": 64, - "axes_dims_rope": [8, 8, 16], - } - - def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: - inputs = super().get_dummy_inputs(batch_size=batch_size) - embedding_dim = 64 - sequence_length = inputs["encoder_hidden_states"].shape[1] - inputs["encoder_hidden_states"] = randn_tensor( - (batch_size, sequence_length, embedding_dim), - generator=self.generator, - device=torch_device, - dtype=self.torch_dtype, - ) - inputs["pooled_projections"] = randn_tensor( - (batch_size, embedding_dim), generator=self.generator, device=torch_device, dtype=self.torch_dtype - ) - return inputs - - def _create_quantized_model(self, config_kwargs, **extra_kwargs): - config_kwargs = {**config_kwargs, "bnb_4bit_compute_dtype": self.torch_dtype} - bnb_config = BitsAndBytesConfig(**config_kwargs) - base_model = self.model_class(**self.get_init_dict()).to(self.torch_dtype) - with tempfile.TemporaryDirectory() as tmp_dir: - base_model.save_pretrained(tmp_dir) - del base_model - return self.model_class.from_pretrained( - tmp_dir, quantization_config=bnb_config, torch_dtype=self.torch_dtype, **extra_kwargs - ) - - @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) - def test_bnb_torch_compile_with_group_offload(self, config_name): - # use_stream=True is required: bnb 4bit kernels read device pointers eagerly, so - # without an explicit prefetch-stream sync we hit "illegal memory access" in - # bnb/csrc/ops.cu. The pipeline-level Bnb4BitCompileTests override does the same. - self._test_torch_compile_with_group_offload(self.BNB_CONFIGS[config_name], use_stream=True) - class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin): """FirstBlockCache tests for Flux Transformer.""" diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py index 0642a71c5756..bec886a02e45 100644 --- a/tests/single_file/test_model_flux_transformer_single_file.py +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -15,6 +15,8 @@ import gc +import torch + from diffusers import ( FluxTransformer2DModel, ) @@ -38,9 +40,9 @@ class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin): repo_id = "black-forest-labs/FLUX.1-dev" subfolder = "transformer" - def test_device_map_cuda(self): + def test_device_map_auto(self): backend_empty_cache(torch_device) - model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda") + model = self.model_class.from_single_file(self.ckpt_path, device_map="auto", torch_dtype=torch.bfloat16) del model gc.collect() From 2db10d5d3a04b63d506681fa74b33def83b40180 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 27 May 2026 15:35:53 +0800 Subject: [PATCH 2/2] revert wrong change Signed-off-by: jiqing-feng --- .../test_models_transformer_flux.py | 52 ++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 119fd17a21a0..4acc62899bc2 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile from typing import Any import pytest import torch -from diffusers import FluxTransformer2DModel +from diffusers import BitsAndBytesConfig, FluxTransformer2DModel from diffusers.models.embeddings import ImageProjection from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor from diffusers.utils.torch_utils import randn_tensor @@ -452,10 +453,57 @@ class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCo """ModelOpt + compile tests for Flux Transformer.""" -@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes") class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin): """BitsAndBytes + compile tests for Flux Transformer.""" + def get_init_dict(self) -> dict[str, int | list[int]]: + # Dims must be multiples of 64 (bnb 4bit blocksize) so single-token activations + # don't trigger the runtime `warn()` inside bnb.matmul_4bit that breaks fullgraph compile. + return { + "patch_size": 1, + "in_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 32, + "num_attention_heads": 2, + "joint_attention_dim": 64, + "pooled_projection_dim": 64, + "axes_dims_rope": [8, 8, 16], + } + + def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]: + inputs = super().get_dummy_inputs(batch_size=batch_size) + embedding_dim = 64 + sequence_length = inputs["encoder_hidden_states"].shape[1] + inputs["encoder_hidden_states"] = randn_tensor( + (batch_size, sequence_length, embedding_dim), + generator=self.generator, + device=torch_device, + dtype=self.torch_dtype, + ) + inputs["pooled_projections"] = randn_tensor( + (batch_size, embedding_dim), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ) + return inputs + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config_kwargs = {**config_kwargs, "bnb_4bit_compute_dtype": self.torch_dtype} + bnb_config = BitsAndBytesConfig(**config_kwargs) + base_model = self.model_class(**self.get_init_dict()).to(self.torch_dtype) + with tempfile.TemporaryDirectory() as tmp_dir: + base_model.save_pretrained(tmp_dir) + del base_model + return self.model_class.from_pretrained( + tmp_dir, quantization_config=bnb_config, torch_dtype=self.torch_dtype, **extra_kwargs + ) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_torch_compile_with_group_offload(self, config_name): + # use_stream=True is required: bnb 4bit kernels read device pointers eagerly, so + # without an explicit prefetch-stream sync we hit "illegal memory access" in + # bnb/csrc/ops.cu. The pipeline-level Bnb4BitCompileTests override does the same. + self._test_torch_compile_with_group_offload(self.BNB_CONFIGS[config_name], use_stream=True) + class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin): """FirstBlockCache tests for Flux Transformer."""