Skip to content

Commit 5ae8bc4

Browse files
committed
Update test of flexible splitter to be performed on 10-class datasets only
1 parent 2113eaa commit 5ae8bc4

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

tests/unit_tests.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -380,17 +380,13 @@ def test_flexible_splitter_global(self, create_all_datasets):
380380
[0.33, 0.33, 0.33, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
381381
[0.33, 0.33, 0.33, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]])
382382
partners_list = [Partner(i) for i in range(len(splitter.amounts_per_partner))]
383-
if dataset.num_classes >= 10:
383+
if dataset.num_classes == 10:
384384
splitter.split(partners_list, dataset)
385385
for p in partners_list:
386386
assert len(p.y_val) == 0, "validation set is not empty in spite of the val_set == 'global'"
387387
assert len(p.y_test) == 0, "test set is not empty in spite of the val_set == 'global'"
388388
assert len(p.x_train) == len(p.y_train), 'labels and samples numbers mismatches'
389-
if dataset.num_classes >= 3:
390-
assert len(p.labels) < dataset.num_classes, f'Partner {p.id} has all labels.'
391-
else:
392-
with pytest.raises(Exception):
393-
splitter.split(partners_list, dataset)
389+
assert len(p.labels) < dataset.num_classes, f'Partner {p.id} has all labels.'
394390

395391
def test_flexible_splitter_local(self, create_all_datasets):
396392
dataset = create_all_datasets
@@ -400,17 +396,13 @@ def test_flexible_splitter_local(self, create_all_datasets):
400396
[0.33, 0.33, 0.33, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]],
401397
val_set='local', test_set='local')
402398
partners_list = [Partner(i) for i in range(len(splitter.amounts_per_partner))]
403-
if dataset.num_classes >= 10:
399+
if dataset.num_classes == 10:
404400
splitter.split(partners_list, dataset)
405401
for p in partners_list:
406402
assert len(p.y_val) > 0, "validation set is empty in spite of the val_set == 'local'"
407403
assert len(p.y_test) > 0, "test set is empty in spite of the val_set == 'local'"
408404
assert len(p.x_train) == len(p.y_train), 'labels and samples numbers mismatches'
409-
if dataset.num_classes >= 3:
410-
assert len(p.labels) < dataset.num_classes, f'Partner {p.id} has all labels.'
411-
else:
412-
with pytest.raises(Exception):
413-
splitter.split(partners_list, dataset)
405+
assert len(p.labels) < dataset.num_classes, f'Partner {p.id} has all labels.'
414406

415407

416408
######

0 commit comments

Comments
 (0)