Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ def _accepts_norm_num_groups(model_class):
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()

init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
Expand All @@ -480,9 +481,9 @@ def test_forward_with_norm_groups(self):
if isinstance(output, dict):
output = output.to_tuple()[0]

self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"


class ModelTesterMixin:
Expand Down
38 changes: 24 additions & 14 deletions tests/models/testing_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)

image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
Comment on lines -295 to +297
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To ensure reproducibility.


assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")

Expand All @@ -313,8 +314,9 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):

new_model.to(torch_device)

image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]

assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")

Expand Down Expand Up @@ -342,8 +344,9 @@ def test_determinism(self, atol=1e-5, rtol=0):
model.to(torch_device)
model.eval()

first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
first = model(**inputs_dict, return_dict=False)[0]
second = model(**inputs_dict, return_dict=False)[0]

first_flat = first.flatten()
second_flat = second.flatten()
Expand Down Expand Up @@ -400,8 +403,9 @@ def recursive_check(tuple_object, dict_object):
model.to(torch_device)
model.eval()

outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
inputs_dict = self.get_dummy_inputs()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)

recursive_check(outputs_tuple, outputs_dict)

Expand Down Expand Up @@ -528,8 +532,10 @@ def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
new_model = new_model.to(torch_device)

torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]

assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
Expand Down Expand Up @@ -568,8 +574,10 @@ def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
new_model = new_model.to(torch_device)

torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]

assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
Expand Down Expand Up @@ -619,8 +627,10 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rt
model_parallel = model_parallel.to(torch_device)

torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]

assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
Expand Down
3 changes: 0 additions & 3 deletions tests/models/testing_utils/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def test_torch_compile_repeated_blocks(self, recompile_limit=1):
model.eval()
model.compile_repeated_blocks(fullgraph=True)

if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2

Comment on lines -95 to -97
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed as we pass recompile_limit explicitly now.

with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
Expand Down
Loading
Loading