From d083c48164f3008f58a4fb799fe40551f65637c1 Mon Sep 17 00:00:00 2001 From: warren618 Date: Mon, 23 Mar 2026 02:07:40 +0800 Subject: [PATCH] fix(model): use `in` instead of `==` for metric comparison in HIST and IGMTF The metric_fn() in pytorch_hist.py and pytorch_igmtf.py compares self.metric against the tuple ("", "loss") using `==`, which only matches when self.metric is exactly that tuple. Since self.metric is always a string (e.g. "" or "loss"), this condition is never true, causing the function to raise ValueError("unknown metric") instead of returning the loss-based metric. All other models (ALSTM, GRU, LSTM, etc.) correctly use `in ("", "loss")`. Fixes the default metric path for HIST and IGMTF models. --- qlib/contrib/model/pytorch_hist.py | 2 +- qlib/contrib/model/pytorch_igmtf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/model/pytorch_hist.py b/qlib/contrib/model/pytorch_hist.py index 779cde9c859..50ef64ec432 100644 --- a/qlib/contrib/model/pytorch_hist.py +++ b/qlib/contrib/model/pytorch_hist.py @@ -170,7 +170,7 @@ def metric_fn(self, pred, label): vy = y - torch.mean(y) return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2))) - if self.metric == ("", "loss"): + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric) diff --git a/qlib/contrib/model/pytorch_igmtf.py b/qlib/contrib/model/pytorch_igmtf.py index 0bddc5a0f5f..1e8be1c8f3f 100644 --- a/qlib/contrib/model/pytorch_igmtf.py +++ b/qlib/contrib/model/pytorch_igmtf.py @@ -163,7 +163,7 @@ def metric_fn(self, pred, label): vy = y - torch.mean(y) return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2))) - if self.metric == ("", "loss"): + if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) raise ValueError("unknown metric `%s`" % self.metric)