Skip to content

Commit 836a1c8

Browse files
authored
Merge pull request #318 from SubstraFoundation/aggregation_fix
Homogenize aggregation and aggregation_weighting
2 parents 5392596 + a4d2cdb commit 836a1c8

File tree

4 files changed

+11
-11
lines changed

4 files changed

+11
-11
lines changed

mplc/contributivity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ def PVRL(self, learning_rate):
977977
logger.info(f"Partners selected for the next epoch: {[p.id for p in mpl.partners_list]}")
978978

979979
# apply one epoch with the selected partner to the previous model/ do the action
980-
mpl.aggregator = self.scenario._aggregation(mpl) # we have to reset the weight of _aggregation
980+
mpl.aggregator = self.scenario._aggregation_weighting(mpl) # we have to reset the weight of _aggregation
981981
mpl.fit_epoch()
982982
loss = mpl.history.history['mpl_model']['val_loss'][mpl.epoch_index, -1]
983983
mpl.epoch_index += 1

mplc/multi_partner_learning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def __init__(self, scenario, **kwargs):
5454
self.minibatch_count = scenario.minibatch_count
5555
self.is_early_stopping = scenario.is_early_stopping
5656

57-
# Attributes related to the _aggregation approach
58-
self.aggregation_method = scenario._aggregation
57+
# Attributes related to the _aggregation_weighting approach
58+
self.aggregation_method = scenario._aggregation_weighting
5959

6060
# Attributes to store results
6161
self.save_folder = scenario.save_folder

mplc/scenario.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
params_known += [
9595
"contributivity_methods",
9696
"multi_partner_learning_approach",
97-
"aggregation",
97+
"aggregation_weighting",
9898
] # federated learning related
9999
params_known += [
100100
"partners_count",
@@ -232,9 +232,9 @@ def __init__(
232232

233233
# Define how federated learning aggregation steps are weighted...
234234
# ... Toggle between 'uniform' (default) and 'data_volume'
235-
self.aggregation = aggregation_weighting
235+
self.aggregation_weighting = aggregation_weighting
236236
try:
237-
self._aggregation = AGGREGATORS[aggregation_weighting]
237+
self._aggregation_weighting = AGGREGATORS[aggregation_weighting]
238238
except KeyError:
239239
raise ValueError(f"aggregation approach '{aggregation_weighting}' is not a valid approach. ")
240240

@@ -380,7 +380,7 @@ def copy(self, **kwargs):
380380
for key in ['partners_list',
381381
'mpl',
382382
'_multi_partner_learning_approach',
383-
'_aggregation',
383+
'_aggregation_weighting',
384384
'use_saved_weights',
385385
'contributivity_list',
386386
'scenario_name',
@@ -408,7 +408,7 @@ def log_scenario_description(self):
408408
logger.info(f" Number of partners defined: {self.partners_count}")
409409
logger.info(f" Data distribution scenario chosen: {self.splitter}")
410410
logger.info(f" Multi-partner learning approach: {self.multi_partner_learning_approach}")
411-
logger.info(f" Weighting option: {self.aggregation}")
411+
logger.info(f" Weighting option: {self.aggregation_weighting}")
412412
logger.info(f" Iterations parameters: "
413413
f"{self.epoch_count} epochs > "
414414
f"{self.minibatch_count} mini-batches > "
@@ -515,7 +515,7 @@ def to_dataframe(self):
515515

516516
# Multi-partner learning approach parameters
517517
dict_results["multi_partner_learning_approach"] = self.multi_partner_learning_approach
518-
dict_results["aggregation"] = self.aggregation
518+
dict_results["aggregation_weighting"] = self.aggregation_weighting
519519
dict_results["epoch_count"] = self.epoch_count
520520
dict_results["minibatch_count"] = self.minibatch_count
521521
dict_results["gradient_updates_per_pass_count"] = self.gradient_updates_per_pass_count

tests/unit_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def create_MultiPartnerLearning(create_all_datasets):
8080
epoch_count=2,
8181
minibatch_count=2,
8282
dataset=data,
83-
aggregation=UniformAggregator,
83+
aggregation_weighting=UniformAggregator,
8484
is_early_stopping=True,
8585
is_save_data=False,
8686
)
@@ -137,7 +137,7 @@ def create_Scenario(request):
137137
{
138138
"contributivity_methods": ["Shapley values", "Independent scores"],
139139
"multi_partner_learning_approach": "fedavg",
140-
"aggregation": "uniform",
140+
"aggregation_weighting": "uniform",
141141
}
142142
)
143143
params.update(

0 commit comments

Comments
 (0)