diff --git a/metrics/mean_iou/mean_iou.py b/metrics/mean_iou/mean_iou.py index 4c19864d..8fcb2b29 100644 --- a/metrics/mean_iou/mean_iou.py +++ b/metrics/mean_iou/mean_iou.py @@ -256,8 +256,18 @@ def mean_iou( metrics = dict() all_acc = total_area_intersect.sum() / total_area_label.sum() - iou = total_area_intersect / total_area_union - acc = total_area_intersect / total_area_label + iou = np.divide( + total_area_intersect, + total_area_union, + out=np.zeros_like(total_area_intersect), + where=total_area_union != 0, + ) + acc = np.divide( + total_area_intersect, + total_area_label, + out=np.zeros_like(total_area_intersect), + where=total_area_label != 0, + ) metrics["mean_iou"] = np.nanmean(iou) metrics["mean_accuracy"] = np.nanmean(acc)