From 949201a2f3a84570a9719bec526c2e9632bff05d Mon Sep 17 00:00:00 2001 From: Aaryaman3 Date: Mon, 15 Dec 2025 01:25:42 -0500 Subject: [PATCH] Add zero_division parameter to F1 metric Adds zero_division parameter to F1._compute() to match sklearn.metrics.f1_score interface. Controls behavior when precision/recall denominators are zero. Default value is 0 for backward compatibility. --- metrics/f1/f1.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/metrics/f1/f1.py b/metrics/f1/f1.py index 05b0baad..46a6f9ac 100644 --- a/metrics/f1/f1.py +++ b/metrics/f1/f1.py @@ -39,6 +39,11 @@ - 'weighted': Calculate metrics for each label, and find their average weighted by support (the number of true instances for each label). This alters `'macro'` to account for label imbalance. This option can result in an F-score that is not between precision and recall. - 'samples': Calculate metrics for each instance, and find their average (only meaningful for multilabel classification). sample_weight (`list` of `float`): Sample weights Defaults to None. + zero_division (`int` or `str`): Sets the value to return when there is a zero division. Defaults to 0. + + - 0: Returns 0 when there is a zero division. + - 1: Returns 1 when there is a zero division. + - "warn": Raises a warning and returns 0 when there is a zero division. Returns: f1 (`float` or `array` of `float`): F1 score or list of f1 scores, depending on the value passed to `average`. Minimum possible value is 0. Maximum possible value is 1. Higher f1 scores are better. @@ -123,8 +128,8 @@ def _info(self): reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"], ) - def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None): + def _compute(self, predictions, references, labels=None, pos_label=1, average="binary", sample_weight=None, zero_division=0): score = f1_score( - references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight + references, predictions, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight, zero_division=zero_division ) return {"f1": score if getattr(score, "size", 1) > 1 else float(score)}