diff --git a/qlib/data/ops.py b/qlib/data/ops.py index d9a2ffbb3e3..671948d98cf 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -1337,7 +1337,7 @@ def _load_internal(self, instrument, start_index, end_index, *args): def weighted_mean(x): w = np.arange(len(x)) + 1 w = w / w.sum() - return np.nanmean(w * x) + return np.nansum(w * x) if self.N == 0: series = series.expanding(min_periods=1).apply(weighted_mean, raw=True) diff --git a/tests/ops/test_rolling_ops.py b/tests/ops/test_rolling_ops.py new file mode 100644 index 00000000000..a6e46600b43 --- /dev/null +++ b/tests/ops/test_rolling_ops.py @@ -0,0 +1,20 @@ +import numpy as np +import pandas as pd + +from qlib.data.ops import WMA + + +class _Feature: + def __init__(self, series): + self.series = series + + def load(self, *args): + return self.series + + +def test_wma_uses_weighted_sum(): + series = pd.Series([1.0, 2.0, 3.0]) + + result = WMA(_Feature(series), 3)._load_internal("SH600000", 0, 2) + + np.testing.assert_allclose(result.to_numpy(), [1.0, 5.0 / 3.0, 14.0 / 6.0])