diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index e2b9dadb6140..2e2d7a435fb8 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -107,8 +107,8 @@ def teardown_method(self): backend_empty_cache(torch_device) def test_single_file_model_config(self): - pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs} - single_file_kwargs = {"device": torch_device} + pretrained_kwargs = {"device_map": "auto", **self.pretrained_model_kwargs} + single_file_kwargs = {"device_map": "auto"} if self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype @@ -127,8 +127,8 @@ def test_single_file_model_config(self): ) def test_single_file_model_parameters(self): - pretrained_kwargs = {"device_map": str(torch_device), **self.pretrained_model_kwargs} - single_file_kwargs = {"device": torch_device} + pretrained_kwargs = {"device_map": "auto", **self.pretrained_model_kwargs} + single_file_kwargs = {"device_map": "auto"} if self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype @@ -259,7 +259,7 @@ def test_checkpoint_variant_loading(self): backend_empty_cache(torch_device) def test_single_file_loading_with_device_map(self): - single_file_kwargs = {"device_map": torch_device} + single_file_kwargs = {"device_map": "auto"} if self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index b5e65f6e0dea..4acc62899bc2 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -344,6 +344,10 @@ def alternate_ckpt_paths(self): def pretrained_model_name_or_path(self): return "black-forest-labs/FLUX.1-dev" + @property + def torch_dtype(self): + return torch.bfloat16 + class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): """BitsAndBytes quantization tests for Flux Transformer.""" diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py index 0642a71c5756..bec886a02e45 100644 --- a/tests/single_file/test_model_flux_transformer_single_file.py +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -15,6 +15,8 @@ import gc +import torch + from diffusers import ( FluxTransformer2DModel, ) @@ -38,9 +40,9 @@ class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin): repo_id = "black-forest-labs/FLUX.1-dev" subfolder = "transformer" - def test_device_map_cuda(self): + def test_device_map_auto(self): backend_empty_cache(torch_device) - model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda") + model = self.model_class.from_single_file(self.ckpt_path, device_map="auto", torch_dtype=torch.bfloat16) del model gc.collect()