@@ -380,17 +380,13 @@ def test_flexible_splitter_global(self, create_all_datasets):
380380 [0.33 , 0.33 , 0.33 , 0.0 , 1.0 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 ],
381381 [0.33 , 0.33 , 0.33 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 1.0 , 0.0 ]])
382382 partners_list = [Partner (i ) for i in range (len (splitter .amounts_per_partner ))]
383- if dataset .num_classes > = 10 :
383+ if dataset .num_classes = = 10 :
384384 splitter .split (partners_list , dataset )
385385 for p in partners_list :
386386 assert len (p .y_val ) == 0 , "validation set is not empty in spite of the val_set == 'global'"
387387 assert len (p .y_test ) == 0 , "test set is not empty in spite of the val_set == 'global'"
388388 assert len (p .x_train ) == len (p .y_train ), 'labels and samples numbers mismatches'
389- if dataset .num_classes >= 3 :
390- assert len (p .labels ) < dataset .num_classes , f'Partner { p .id } has all labels.'
391- else :
392- with pytest .raises (Exception ):
393- splitter .split (partners_list , dataset )
389+ assert len (p .labels ) < dataset .num_classes , f'Partner { p .id } has all labels.'
394390
395391 def test_flexible_splitter_local (self , create_all_datasets ):
396392 dataset = create_all_datasets
@@ -400,17 +396,13 @@ def test_flexible_splitter_local(self, create_all_datasets):
400396 [0.33 , 0.33 , 0.33 , 0.0 , 0.0 , 1.0 , 0.0 , 0.0 , 1.0 , 0.0 ]],
401397 val_set = 'local' , test_set = 'local' )
402398 partners_list = [Partner (i ) for i in range (len (splitter .amounts_per_partner ))]
403- if dataset .num_classes > = 10 :
399+ if dataset .num_classes = = 10 :
404400 splitter .split (partners_list , dataset )
405401 for p in partners_list :
406402 assert len (p .y_val ) > 0 , "validation set is empty in spite of the val_set == 'local'"
407403 assert len (p .y_test ) > 0 , "test set is empty in spite of the val_set == 'local'"
408404 assert len (p .x_train ) == len (p .y_train ), 'labels and samples numbers mismatches'
409- if dataset .num_classes >= 3 :
410- assert len (p .labels ) < dataset .num_classes , f'Partner { p .id } has all labels.'
411- else :
412- with pytest .raises (Exception ):
413- splitter .split (partners_list , dataset )
405+ assert len (p .labels ) < dataset .num_classes , f'Partner { p .id } has all labels.'
414406
415407
416408######
0 commit comments