Skip to content

Commit c4f3b84

Browse files
authored
Fix custom metric miscalculation (#307)
* Fix custom metric miscalculation * add unit tests
1 parent 4a99a8a commit c4f3b84

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/sagemaker_xgboost_container/metrics/custom_metrics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,12 @@ def compute_multiclass_and_binary_metrics(metricfunc, preds, dtrain):
239239

240240

241241
def get_custom_metrics(eval_metrics):
242-
"""Get container defined metrics from metrics list."""
243-
return set(eval_metrics).intersection(CUSTOM_METRICS.keys())
242+
"""Get container defined metrics from metrics list.
243+
244+
The order of the returning custom metrics need to be consistent with the input for distributed training.
245+
Otherwise, metrics reported from each host will be miscalculated in the master host. (P70679777)
246+
"""
247+
return [eval_m for eval_m in eval_metrics if eval_m in CUSTOM_METRICS.keys()]
244248

245249

246250
def configure_feval(custom_metric_list):

test/unit/algorithm_mode/test_custom_metrics.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
import numpy as np
1414
import xgboost as xgb
1515
from math import log, sqrt
16+
import unittest
1617
from sagemaker_xgboost_container.metrics.custom_metrics import accuracy, f1, mse, r2, f1_binary, f1_macro, \
1718
precision_macro, precision_micro, recall_macro, recall_micro, mae, rmse, balanced_accuracy, \
18-
precision, recall
19+
precision, recall, get_custom_metrics
1920

2021

2122
binary_train_data = np.random.rand(10, 2)
@@ -190,3 +191,10 @@ def test_mae():
190191
mae_score_name, mae_score_result = mae(regression_preds, regression_dtrain)
191192
assert mae_score_name == 'mae'
192193
assert mae_score_result == .5
194+
195+
196+
class TestCustomMetric(unittest.TestCase):
197+
def test_get_custom_metrics(self):
198+
eval_metrics = ["mse", "rmse", "mae", "r2", "wrong_metric"]
199+
res = get_custom_metrics(eval_metrics)
200+
self.assertListEqual(res, ["mse", "rmse", "mae", "r2"])

0 commit comments

Comments
 (0)