Skip to content

Commit d6b43a2

Browse files
authored
Merge pull request #320 from SubstraFoundation/fully-flexible-split
Create a fully flexible splitter
2 parents 3aa6dc0 + 53ebdfe commit d6b43a2

File tree

5 files changed

+165
-38
lines changed

5 files changed

+165
-38
lines changed

mplc/doc/documentation.md

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -238,23 +238,25 @@ There are 2 ways to select a dataset. You can either choose a pre-implemented da
238238
Example: `amounts_per_partner=[0.3, 0.3, 0.1, 0.3]`
239239

240240
<a id="sample_split_option"></a>
241-
- `samples_split_option`: Used to set the strategy of samples data split. You can either instantiate a Splitter before passing it to Scenario, as in the below example, or you can pass it by its string identifier. In the latter case, the default parameters for the Splitter selected will be used.
241+
- `samples_split_option`: Used to set the strategy of samples data split. You can either instantiate a `Splitter` before passing it to `Scenario`, or you can pass it by its string identifier. In the latter case, the default parameters for the `Splitter` selected will be used.
242242
How the original dataset data samples are split among partners:
243-
- `RandomSplitter`: the dataset is shuffled and partners receive data samples selected randomly
244-
String identifier: `'random'`
245-
- `StratifiedSplitter`: the dataset is stratified per class and each partner receives certain classes only (note: depending on the `amounts_per_partner` specified, there might be small overlaps of classes)
246-
String identifier: `'stratified'``[[nb of clusters (int), 'shared' or 'specific']]`
247-
- `'AdvancedSplitter'`: in certain cases it might be interesting to split the dataset among partners in a more elaborate way. For that we consider the data samples from the initial dataset as split in clusters per data labels. The advanced split is configured by indicating, for each partner in sequence, the following 2 elements: `[[nb of clusters (int), 'shared' or 'specific']]`. Practically, you can either instantiate your `AdvancedSplitter` object, and pass this list `[[nb of clusters (int), 'shared' or 'specific']]` to the keyword argument `description`, or use the string identifier and pass the list `[[nb of clusters (int), 'shared' or 'specific']]` to the scenario via the keyword argument `samples_split_configuration`.
248-
String identifier:`'advanced'`.
249-
Configuration:
243+
244+
- `RandomSplitter`: the dataset is shuffled and partners receive data samples selected randomly. String identifier: `'random'`
245+
246+
- `StratifiedSplitter`: the dataset is stratified per class and each partner receives certain classes only (note: depending on the `amounts_per_partner` specified, there might be some overlap of classes). String identifier: `'stratified'`
247+
248+
- `AdvancedSplitter`: in certain cases it might be interesting to split the dataset among partners in a more elaborate way. For that we consider the data samples from the initial dataset as split in clusters per data labels. The advanced split is configured by indicating, for each partner in sequence, the following 2 elements: `[[nb of clusters (int), 'shared' or 'specific']]`. Practically, you can either instantiate your `AdvancedSplitter` object, and pass this list `[[nb of clusters (int), 'shared' or 'specific']]` to the keyword argument `description`, or use the string identifier and pass the list `[[nb of clusters (int), 'shared' or 'specific']]` to the `Scenario` via the keyword argument `samples_split_configuration`. String identifier:`'advanced'`. Configuration:
250249
- `nb of clusters (int)`: the given partner will receive data samples from that many different clusters (clusters of data samples per labels/classes)
251250
- `'shared'` or `'specific'`:
252-
- `'shared'`: all partners with option `'shared'` receive data samples picked
251+
- `'shared'`: all partners with option `'shared'` receive data samples picked
253252
from clusters they all share data samples from
254-
- `'specific'`: each partner with option `'specific'` receives data samples picked
255-
from cluster(s) it is the only one to receive from
256-
257-
Example: `samples_split_option='advanced', samples_split_configuration=[[7, 'shared'], [6, 'shared'], [2, 'specific'], [1, 'specific']]]`
253+
- `'specific'`: each partner with option `'specific'` receives data samples picked
254+
from cluster(s) it is the only one to receive from
255+
Example: `samples_split_option='advanced', samples_split_configuration=[[7, 'shared'], [6, 'shared'], [2, 'specific'], [1, 'specific']]`
256+
257+
- `FlexibleSplitter`: in other cases one might want to specify in detail the split among partners (partner per partner and class per class). For that the `FlexibleSplitter` can be used. It is configured by indicating, for each partner in sequence, a list of the percentage of samples for each class: `[[% for class 1, ..., % for class n]]`. As above, it can be instantiated separately and then passed to the `Scenario` instance. Or the string identifier `'flexible'` can be used for the parameter `samples_split_option`, coupled with the split configuration passed to the keyword argument `samples_split_configuration`. String identified: `'flexible'`.
258+
Example: `samples_split_option='flexible', samples_split_configuration=[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 0.5, 0.5, 0.5, 0.0]]` (this corresponds to 50% of the last 3 classes for partner 1, and 50% or 100% of each of the first 9 classes for partner 2).
259+
Note: in the list of % for each class, one shouldn't interpret the order of its inputs as any human-readable order of the samples (e.g. alphabetical, numerical...). The implementation uses the order in which the samples appear in the dataset. As such, note that one can artificially enforce a certain order if desired, by sorting the dataset beforehand.
258260

259261
![Example of the advanced split option](../../img/advanced_split_example.png)
260262

mplc/scenario.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def __init__(
174174
# (% of samples of the dataset for each partner, ...
175175
# ... has to sum to 1, and number of items has to equal partners_count)
176176
self.amounts_per_partner = amounts_per_partner
177+
if np.sum(self.amounts_per_partner) != 1:
178+
raise ValueError("The sum of the amount per partners you provided isn't equal to 1")
179+
if len(self.amounts_per_partner) != self.partners_count:
180+
raise AttributeError(f"The amounts_per_partner list should have a size ({len(self.amounts_per_partner)}) "
181+
f"equals to partners_count ({self.partners_count})")
177182

178183
# To configure how validation set and test set will be organized.
179184
if test_set in ['local', 'global']:
@@ -341,9 +346,11 @@ def __init__(
341346
self.save_folder = Path(save_path) / self.scenario_name
342347
else:
343348
self.save_folder = None
344-
# ------------------------------------------------------------------
349+
350+
# -------------------------------------------------------------------
345351
# Select in the kwargs the parameters to be transferred to sub object
346-
# ------------------------------------------------------------------
352+
# -------------------------------------------------------------------
353+
347354
self.mpl_kwargs = {}
348355
for key, value in kwargs.items():
349356
if key.startswith('mpl_'):

mplc/splitter.py

Lines changed: 97 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ def __init__(self, amounts_per_partner, val_set='global', test_set='global', **k
2323
self.dataset = None
2424
self.partners_list = None
2525

26-
# Check the percentages of samples per partner and control its coherence
27-
if np.sum(self.amounts_per_partner) != 1:
28-
raise ValueError("The sum of the amount per partners you provided isn't equal to 1")
29-
3026
@property
3127
def partners_count(self):
3228
return len(self.partners_list)
@@ -37,29 +33,40 @@ def __str__(self):
3733
def split(self, partners_list, dataset):
3834
self.dataset = dataset
3935
self.partners_list = partners_list
40-
if len(self.amounts_per_partner) != self.partners_count:
41-
raise AttributeError(f"The amounts_per_partner list should have a size ({len(self.amounts_per_partner)}) "
42-
f"equals to partners_count ({self.partners_count})")
4336

44-
logger.info("### Splitting data among partners:")
45-
logger.info("Train data split:")
37+
logger.info("Splitting data among partners: starting now.")
38+
self._test_config_coherence()
39+
logger.info("Coherence of config parameters: OK.")
40+
41+
logger.info("Train data split: starting now.")
4642
self._split_train()
4743

4844
if self.val_set == 'local':
49-
logger.info("Validation data split:")
45+
logger.info("Validation data split: starting now.")
5046
self._split_val()
5147

5248
if self.test_set == 'local':
53-
logger.info("Test data split:")
49+
logger.info("Test data split: starting now.")
5450
self._split_test()
5551

5652
for partner in self.partners_list:
5753
logger.info(
58-
f" Partner #{partner.id}: "
59-
f"{partner.final_nb_samples} samples "
60-
f"with labels {partner.labels}"
54+
f"Partner #{partner.id}: {partner.final_nb_samples} samples with labels {partner.labels}"
6155
)
6256

57+
def _test_config_coherence(self):
58+
self._test_amounts_per_partner_total()
59+
self._test_amounts_per_partner_length()
60+
61+
def _test_amounts_per_partner_total(self):
62+
if np.sum(self.amounts_per_partner) != 1:
63+
raise ValueError("The sum of the amount per partners you provided isn't equal to 1; it has to.")
64+
65+
def _test_amounts_per_partner_length(self):
66+
if len(self.amounts_per_partner) != self.partners_count:
67+
raise AttributeError(f"The amounts_per_partner list should have a size ({len(self.amounts_per_partner)}) "
68+
f"equals to partners_count ({self.partners_count})")
69+
6370
def _split_train(self):
6471
subsets = self._generate_subset(self.dataset.x_train, self.dataset.y_train)
6572
for idx, p in enumerate(self.partners_list):
@@ -89,8 +96,79 @@ def copy(self):
8996
return self.__copy__()
9097

9198

99+
class FlexibleSplitter(Splitter):
100+
name = 'Fully Flexible Splitter'
101+
102+
def __init__(self, amounts_per_partner, configuration, **kwargs):
103+
104+
logger.info("Proceeding to a flexible split as requested. Please note that the flexible "
105+
"split currently discards the amounts_per_partner (if provided) and infers amounts of samples "
106+
"per partner from the samples_split_configuration provided.")
107+
108+
# First we re-assemble the split configuration per cluster
109+
self.configuration = configuration
110+
self.split_configuration = configuration
111+
self.samples_split_grouped_by_cluster = list(zip(*configuration))
112+
113+
# Init of the superclass to inherit its methods
114+
super().__init__(amounts_per_partner, **kwargs)
115+
116+
def _test_config_coherence(self):
117+
118+
# First, we test if the splitter configuration is coherent with the number of partners
119+
if len(self.split_configuration) != self.partners_count:
120+
raise AttributeError(f"The split configuration should have a size ({len(self.split_configuration)}) "
121+
f"equals to partners_count ({self.partners_count})")
122+
123+
# Second, we test for each class that the amount of samples split across partners is <= 100%
124+
for idx, cluster in enumerate(self.samples_split_grouped_by_cluster):
125+
if np.sum(cluster) > 1:
126+
raise ValueError(f"Amounts of samples of class {idx} split among partners exceed 100%, "
127+
f"the dataset split cannot be performed.")
128+
129+
def _generate_subset(self, x, y):
130+
131+
# Convert raw labels in y to simplify operations on the dataset
132+
lb = LabelEncoder()
133+
y_str = lb.fit_transform([str(label) for label in y])
134+
labels = list(set(y_str))
135+
136+
# Split the datasets (x and y) into subsets of samples of each label (called "clusters")
137+
x_for_cluster, y_for_cluster, nb_samples_per_cluster = {}, {}, {}
138+
for label in labels:
139+
idx_in_full_set = np.where(y_str == label)
140+
x_for_cluster[label] = x[idx_in_full_set]
141+
y_for_cluster[label] = y[idx_in_full_set]
142+
nb_samples_per_cluster[label] = len(y_for_cluster[label])
143+
144+
# Assemble datasets per partner by looping over partners and labels
145+
res = []
146+
nb_samples_split = []
147+
for p_idx, p in enumerate(self.partners_list):
148+
149+
list_arrays_x, list_arrays_y = [], []
150+
151+
for idx, label in enumerate(labels):
152+
nb_samples_to_pick = int(nb_samples_per_cluster[label] * self.samples_split_grouped_by_cluster[idx][
153+
p_idx])
154+
list_arrays_x.append(x_for_cluster[label][:nb_samples_to_pick])
155+
x_for_cluster[label] = x_for_cluster[label][nb_samples_to_pick:]
156+
list_arrays_y.append(y_for_cluster[label][:nb_samples_to_pick])
157+
y_for_cluster[label] = y_for_cluster[label][nb_samples_to_pick:]
158+
159+
res.append((np.concatenate(list_arrays_x), np.concatenate(list_arrays_y)))
160+
nb_samples_split.append(len(np.concatenate(list_arrays_y)))
161+
162+
# Log the relative amounts of samples split among partners
163+
total_nb_samples_split = np.sum(nb_samples_split)
164+
relative_nb_samples = [round(nb / total_nb_samples_split, 2) for nb in nb_samples_split]
165+
logger.info(f"Partners' relative number of samples: {relative_nb_samples}")
166+
167+
return res
168+
169+
92170
class RandomSplitter(Splitter):
93-
name = 'Random samples split'
171+
name = 'Random Splitter'
94172

95173
def _generate_subset(self, x, y):
96174
if self.partners_count == 1:
@@ -107,7 +185,7 @@ def _generate_subset(self, x, y):
107185

108186

109187
class StratifiedSplitter(Splitter):
110-
name = 'Stratified samples split'
188+
name = 'Stratified Splitter'
111189

112190
def _generate_subset(self, x, y):
113191
if self.partners_count == 1:
@@ -124,9 +202,10 @@ def _generate_subset(self, x, y):
124202

125203

126204
class AdvancedSplitter(Splitter):
127-
name = 'Advanced samples split'
205+
name = 'Advanced Splitter'
128206

129207
def __init__(self, amounts_per_partner, configuration, **kwargs):
208+
self.configuration = configuration
130209
self.num_clusters, self.specific_shared = list(zip(*configuration))
131210
super().__init__(amounts_per_partner, **kwargs)
132211

@@ -271,6 +350,7 @@ def _generate_subset(self, x, y):
271350

272351

273352
IMPLEMENTED_SPLITTERS = {
353+
'flexible': FlexibleSplitter,
274354
'random': RandomSplitter,
275355
'stratified': StratifiedSplitter,
276356
'advanced': AdvancedSplitter

tests/contrib_end_to_end_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_titanic_contrib(self):
2727

2828
df = test_utils.get_latest_dataframe("*end_to_end_test*")
2929

30-
# 2 contributivity methods X 2 parters x 2 repeats = 12
30+
# 2 contributivity methods X 2 partners x 2 repeats = 12
3131
assert len(df) == 12
3232

3333
def test_mnist_contrib(self):

tests/unit_tests.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
This enables to parameterize unit tests - the tests are run by Travis each time you commit to the github repo
3+
This enables to parameterize unit tests - the tests are run by GitHub Actions each time you commit to the github repo
44
"""
55

66
#########
@@ -55,7 +55,7 @@
5555
from mplc.partner import Partner
5656
from mplc.scenario import Scenario
5757
# create_Mpl uses create_Dataset and create_Contributivity uses create_Scenario
58-
from mplc.splitter import AdvancedSplitter, RandomSplitter, StratifiedSplitter
58+
from mplc.splitter import FlexibleSplitter, AdvancedSplitter, RandomSplitter, StratifiedSplitter
5959

6060

6161
######
@@ -91,7 +91,12 @@ def create_MultiPartnerLearning(create_all_datasets):
9191
@pytest.fixture(scope="class", params=(RandomSplitter([0.1, 0.2, 0.3, 0.4]),
9292
StratifiedSplitter([0.1, 0.2, 0.3, 0.4]),
9393
AdvancedSplitter([0.3, 0.5, 0.2],
94-
[[4, "specific"], [6, "shared"], [4, "shared"]])))
94+
[[4, "specific"], [6, "shared"], [4, "shared"]]),
95+
FlexibleSplitter([1.0, 0.0, 0.0], [
96+
[0.33, 0.33, 0.33, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
97+
[0.33, 0.33, 0.33, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
98+
[0.33, 0.33, 0.33, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
99+
])))
95100
def create_splitter(request):
96101
return request.param()
97102

@@ -113,13 +118,15 @@ def create_Partner(create_all_datasets):
113118
['not-corrupted'] * 3),
114119
(Cifar10, "random", ['not-corrupted'] * 3),
115120
(Cifar10,
116-
AdvancedSplitter([0.3, 0.5, 0.2], [[4, "specific"], [6, "shared"], [4, "shared"]]),
121+
FlexibleSplitter([0.3, 0.5, 0.2], [[0.33, 0.33, 0.33, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
122+
[0.33, 0.33, 0.33, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
123+
[0.33, 0.33, 0.33, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]]),
117124
['not-corrupted'] * 3)),
118125
ids=['Mnist - basic',
119126
'Mnist - basic - corrupted',
120127
'Mnist - advanced',
121128
'Cifar10 - basic',
122-
'Cifar10 - advanced'])
129+
'Cifar10 - flex'])
123130
def create_Scenario(request):
124131
dataset = request.param[0]()
125132
samples_split_option = request.param[1]
@@ -366,6 +373,37 @@ def test_advanced_splitter_local(self, create_all_datasets):
366373
with pytest.raises(Exception):
367374
splitter.split(partners_list, dataset)
368375

376+
def test_flexible_splitter_global(self, create_all_datasets):
377+
dataset = create_all_datasets
378+
splitter = FlexibleSplitter([0.3, 0.3, 0.4], configuration=[
379+
[0.33, 0.33, 0.33, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
380+
[0.33, 0.33, 0.33, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
381+
[0.33, 0.33, 0.33, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]])
382+
partners_list = [Partner(i) for i in range(len(splitter.amounts_per_partner))]
383+
if dataset.num_classes == 10:
384+
splitter.split(partners_list, dataset)
385+
for p in partners_list:
386+
assert len(p.y_val) == 0, "validation set is not empty in spite of the val_set == 'global'"
387+
assert len(p.y_test) == 0, "test set is not empty in spite of the val_set == 'global'"
388+
assert len(p.x_train) == len(p.y_train), 'labels and samples numbers mismatches'
389+
assert len(p.labels) < dataset.num_classes, f'Partner {p.id} has all labels.'
390+
391+
def test_flexible_splitter_local(self, create_all_datasets):
392+
dataset = create_all_datasets
393+
splitter = FlexibleSplitter([0.3, 0.3, 0.4], configuration=[
394+
[0.33, 0.33, 0.33, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
395+
[0.33, 0.33, 0.33, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
396+
[0.33, 0.33, 0.33, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0]],
397+
val_set='local', test_set='local')
398+
partners_list = [Partner(i) for i in range(len(splitter.amounts_per_partner))]
399+
if dataset.num_classes == 10:
400+
splitter.split(partners_list, dataset)
401+
for p in partners_list:
402+
assert len(p.y_val) > 0, "validation set is empty in spite of the val_set == 'local'"
403+
assert len(p.y_test) > 0, "test set is empty in spite of the val_set == 'local'"
404+
assert len(p.x_train) == len(p.y_train), 'labels and samples numbers mismatches'
405+
assert len(p.labels) < dataset.num_classes, f'Partner {p.id} has all labels.'
406+
369407

370408
######
371409
#

0 commit comments

Comments
 (0)