diff --git a/README.md b/README.md index cc2bf5f..6cf2540 100644 --- a/README.md +++ b/README.md @@ -1,129 +1,195 @@

FISBe: A real-world benchmark dataset for instance segmentation of long-range thin filamentous structures

- ![Alt text](assets/Cover_Image.png) ## About ⚠️ *Currently under construction.* -This is the official implementation of the **FISBe (FlyLight Instance Segmentation Benchmark)** -evaluation pipeline, the first publicly available multi-neuron light microscopy dataset with -pixel-wise annotations. +This is the official implementation of the **FISBe (FlyLight Instance Segmentation Benchmark)** evaluation pipeline. It is the first publicly available multi-neuron light microscopy dataset with pixel-wise annotations. -You can download the dataset on the official project page: +**Download the dataset:** [https://kainmueller-lab.github.io/fisbe/](https://kainmueller-lab.github.io/fisbe/) -👉 https://kainmueller-lab.github.io/fisbe/ +The benchmark supports 2D and 3D segmentations and computes a wide range of commonly used evaluation metrics (e.g., AP, F1, coverage). Crucially, it provides specialized error attribution for topological errors (False Merges, False Splits) relevant to filamentous structures. -The benchmark supports 2D and 3D segmentations and computes a wide range of commonly used evaluation -metrics (e.g., AP, F1, coverage, precision, recall). Additionally, it provides a visualization -of segmentation errors. +### Features +- **Standard Metrics:** AP, F1, Precision, Recall. +- **FISBe Metrics:** Greedy many-to-many matching for False Merges (FM) and False Splits (FS). +- **Flexibility:** Supports HDF5 (`.hdf`, `.h5`) and Zarr (`.zarr`) files. +- **Modes:** Run on single files, entire folders, or in stability analysis mode. +- **Partly Labeled Data:** Robust evaluation ignoring background conflicts for sparse Ground Truth. -Overview: -------------- -This toolkit provides: +--- -- Standard and FlyLight-specific evaluation metrics -- Error attribution (false merges, splits, FP/FN instances) -- Visualizations for neurons and nuclei -- Support for partially annotated datasets -- Coverage metrics (skeleton, dimension-based, overlap-based) -- Command-line and Python API usage +## Installation -------------- +The recommended way to install is using `uv` (fastest) or `micromamba`. -Installation: -------------- -The recommended way is to install it into your micromamba/python virtual environment. +### Option 1: Using `uv` (Fastest) ```bash -git clone https://github.com/Kainmueller-Lab/evaluate-instance-segmentation +# 1. Install uv (if not installed) +pip install uv + +# 2. Clone and install +git clone https://github.com/Kainmueller-Lab/evaluate-instance-segmentation.git cd evaluate-instance-segmentation +uv venv +uv pip install -e . +``` -micromamba create -n evalinstseg -f environment.yml +### Option 2: Using `micromamba` or `conda` + +```bash +micromamba create -n evalinstseg python=3.10 micromamba activate evalinstseg +git clone https://github.com/Kainmueller-Lab/evaluate-instance-segmentation.git +cd evaluate-instance-segmentation pip install -e . ``` -## Run Benchmark +## Usage: Command Line (CLI) +The `evalinstseg` command is automatically available after installation. -You can use this repository in two ways: +### 1. Evaluate a Single File +```bash +evalinstseg \ + --res_file tests/pred/R14A02-20180905.hdf \ + --res_key volumes/labels \ + --gt_file tests/gt/R14A02-20180905.zarr \ + --gt_key volumes/gt_instances \ + --out_dir tests/results \ + --app flylight +``` -1. As a Python package (via `evaluate_file` / `evaluate_volume`) -2. From the command line +### 2. Evaluate an Entire Folder +If you provide a directory path to `--res_file`, the tool will look for matching Ground Truth files in the `--gt_file` folder. Files are matched by name. -Example: ```bash evalinstseg \ - --res_file tests/pred/R14A02-20180905_65_A6.hdf \ - --res_key volumes/gmm_label_cleaned \ - --gt_file tests/gt/R14A02-20180905_65_A6.zarr \ + --res_file /path/to/predictions_folder \ + --res_key volumes/labels \ + --gt_file /path/to/ground_truth_folder \ --gt_key volumes/gt_instances \ - --out_dir tests/results \ + --out_dir /path/to/output_folder \ + --app flylight +``` + +### 3. Stability & Robustness Mode +Compute the **Mean ± Std** of metrics across exactly 3 different training runs (e.g., different random seeds). + +```bash +evalinstseg \ + --stability_mode \ + --run_dirs experiments/seed1 experiments/seed2 experiments/seed3 \ + --gt_file data/ground_truth_folder \ + --out_dir results/stability_report \ --app flylight ``` -By setting `--app flylight`, the pipeline automatically uses the default FlyLight benchmark configuration. - -You can also define custom configurations, including: -- localization criteria -- assignment strategy -- metric subsets - -Output: - -- evaluation metrics are written to toml-file and returned as dict - - -## Metrics Overview: -The evaluation computes metrics at multiple levels: per-threshold instance metrics, aggregated AP/F-scores, and global statistics - -### Instance-Level Metrics (per threshold confusion_matrix.th_*) -| Metric | Description | -| --------------------------- | ---------------------------------------- | -| **AP_TP** | True positives at threshold | -| **AP_FP** | False positives at threshold | -| **AP_FN** | False negatives at threshold | -| **precision** | TP / (TP + FP) | -| **recall** | TP / (TP + FN) | -| **AP** | Approximate AP proxy: precision × recall | -| **fscore** | Harmonic mean of precision and recall | -| *optional:* **false_split** | Number of false splits | -| *optional:* **false_merge** | Number of false merges | - -Metrics are computed for thresholds: -0.1, 0.2, ..., 0.9, 0.55, 0.65, 0.75, 0.85, 0.95. - -### Aggregate Metrics -| Metric | Description | -| -------------- | ----------------------------------- | -| **avAP** | Mean AP for thresholds ≥ 0.5 | -| **avAP59** | AP averaged over thresholds 0.5–0.9 | -| **avAP19** | AP averaged over thresholds 0.1–0.9 | -| **avFscore** | Mean F-score for thresholds 0.1–0.9 | -| **avFscore59** | Mean F-score for thresholds 0.5–0.9 | -| **avFscore19** | Mean F-score for thresholds 0.1–0.9 | - -### General Metrics -| Metric | Description | -| ---------------------- | ------------------------------------------------ | -| **Num GT** | Number of ground-truth instances | -| **Num Pred** | Number of predicted instances | -| **TP_05** | True positives at threshold 0.5 | -| **TP_05_rel** | TP_05 / Num GT | -| **TP_05_cldice** | clDice scores of matched pairs at threshold 0.5 | -| **avg_TP_05_cldice** | Mean clDice over matched pairs at threshold 0.5 | - -### Optional General Metrics -| Metric | Description | -| ----------------------- | ----------- | -| **avg_gt_skel_coverage** | Mean skeleton coverage over all GT instances | -| **avg_tp_skel_coverage** | Mean skeleton coverage over TP GT instances (> 0.5) | -| **avg_f1_cov_score** | 0.5 × avFscore19 + 0.5 × avg_gt_skel_coverage | -| **FM** | Many-to-many false merge score (threshold `fm_thresh`) | -| **FS** | Many-to-many false split score (threshold `fs_thresh`) | -| **avg_gt_cov_dim** | Mean GT coverage for “dim” instances | -| **avg_gt_cov_overlap** | Mean GT coverage for overlapping-instance regions | +**Requirements:** +- `--run_dirs`: Provide exactly 3 folders. +- `--gt_file`: The folder containing Ground Truth files (filenames must match predictions). + +### 4. Partly Labeled Data +If your ground truth is sparse (not fully dense), use the `--partly` flag. T + +## Usage: Python Package +You can integrate the benchmark directly into your Python scripts or notebooks. + +### Evaluate a File +```python +from evalinstseg import evaluate_file + +# Run evaluation +metrics = evaluate_file( + res_file="tests/pred/sample_01.hdf", + gt_file="tests/gt/sample_01.zarr", + res_key="volumes/labels", + gt_key="volumes/gt_instances", + out_dir="output_folder", + ndim=3, + app="flylight", # Applies default FISBe config + partly=False # Set True for sparse GT +) + +# Access metrics directly +print("AP:", metrics['confusion_matrix']['avAP']) +print("False Merges:", metrics['general']['FM']) +``` +### Evaluate Raw Numpy Arrays +If you already have the arrays loaded in memory: + +```python +import numpy as np +from evalinstseg import evaluate_volume + +pred_array = np.load(...) # Shape: (Z, Y, X) +gt_array = np.load(...) + +metrics = evaluate_volume( + gt_labels=gt_array, + pred_labels=pred_array, + ndim=3, + outFn="output_path_prefix", + localization_criterion="cldice", # or 'iou' + assignment_strategy="greedy", + add_general_metrics=["false_merge", "false_split"] +) +``` +### 4. Partly Labeled Data (`--partly`) +Some samples contain sparse / incomplete GT annotations. In this setting, counting all unmatched predictions as false positives is not meaningful. + +When `--partly` is enabled, we approximate FP by counting only **unmatched predictions whose best match is a foreground GT instance** (based on the localization matrix used for evaluation, e.g. clPrecision for `cldice`). +Unmatched predictions whose best match is **background** are ignored. + +Concretely, we compute for each unmatched prediction the index of the GT label with maximal overlap score; it is counted as FP only if that index is > 0 (foreground), not 0 (background). + +--- + +## Metrics Explanation + +### 1. Standard Instance Metrics (TP/FP/FN, F-score, AP proxy) +These metrics are computed from a **one-to-one matching** between GT and prediction instances (Hungarian or greedy), using a chosen localization criterion (default for FlyLight is `cldice`). + +- **TP**: matched pairs above threshold +- **FP**: unmatched predictions (or, in `--partly`, only those whose best match is foreground) +- **FN**: unmatched GT instances +- **precision** = TP / (TP + FP) +- **recall** = TP / (TP + FN) +- **fscore** = 2 * precision * recall / (precision + recall) +- **AP**: we report a simple AP proxy `precision × recall` at each threshold and average it across thresholds (this is not COCO-style AP). + +### 2. FISBe Error Attribution (False Splits / False Merges) +False splits (FS) and false merges (FM) aim to quantify **instance topology errors** for long-range thin filamentous structures. + +We compute FS/FM using **greedy many-to-many matching with consumption**: +- Candidate GT–Pred pairs above threshold are processed in descending score order. +- After selecting a match, we update “available” pixels so that already explained structure is not matched again. +- FS counts when one GT is explained by multiple preds (excess preds per GT). +- FM counts when one pred explains multiple GTs (excess GTs per pred). + +This produces an explicit attribution of split/merge errors rather than only TP/FP/FN. + +### Metric Definitions + +#### Instance-Level (per threshold) +| Metric | Description | +| :--- | :--- | +| **AP_TP** | True Positives (1-to-1 match) | +| **AP_FP** | False Positives (unmatched preds; in `--partly`: only unmatched preds whose best match is foreground) | +| **AP_FN** | False Negatives (unmatched GT) | +| **precision** | TP / (TP + FP) | +| **recall** | TP / (TP + FN) | +| **fscore** | Harmonic mean of precision and recall | + +#### Global / FISBe +| Metric | Description | +| :--- | :--- | +| **avAP** | Mean AP proxy across thresholds ≥ 0.5 | +| **FM** | False Merges (many-to-many matching with consumption) | +| **FS** | False Splits (many-to-many matching with consumption) | +| **avg_gt_skel_coverage** | Mean skeleton coverage of GT instances by associated predictions (association via best-match mapping) | diff --git a/evalinstseg/compute.py b/evalinstseg/compute.py index a344195..689d325 100644 --- a/evalinstseg/compute.py +++ b/evalinstseg/compute.py @@ -7,7 +7,7 @@ get_centerline_overlap_single, get_centerline_overlap, ) -from .match import assign_labels, greedy_many_to_many_matching +from .match import assign_labels, greedy_many_to_many_matching, get_m2m_matches logger = logging.getLogger(__name__) @@ -160,34 +160,45 @@ def get_gt_coverage_overlap( return gt_ovlp, tp_05_ovlp, tp_05_rel_ovlp, gt_covs_ovlp, avg_cov_ovlp -def get_m2m_fm( - gt_labels, pred_labels, num_pred_labels, recallMat, fm_thresh, matches=None -): - # get false merges - if matches is None: - # call many-to-many matching based on clRecall - matches = greedy_many_to_many_matching( - gt_labels, pred_labels, recallMat, fm_thresh - ) +def compute_m2m_stats(matches, num_pred_labels): + '''Helper to compute false merges and false splits from many-to-many matches.''' + fm = 0 - if matches is not None: - fms = np.zeros(num_pred_labels) # without 0 background - for k, v in matches.items(): - for cv in v: - fms[cv - 1] += 1 - fms = np.maximum(fms - 1, np.zeros(num_pred_labels)) - fm = int(np.sum(fms)) - return fm, matches - - -def get_m2m_fs(gt_labels, pred_labels, recallMat, fs_thresh, matches=None): - # get false splits - if matches is None: - matches = greedy_many_to_many_matching( - gt_labels, pred_labels, recallMat, fs_thresh - ) fs = 0 if matches is not None: + # FS calculation for k, v in matches.items(): fs += max(0, len(v) - 1) - return fs, matches + + # FM calculation + if num_pred_labels > 0: + fms = np.zeros(num_pred_labels) # without 0 background + for k, v in matches.items(): + for cv in v: + fms[cv - 1] += 1 + fms = np.maximum(fms - 1, np.zeros(num_pred_labels)) + fm = int(np.sum(fms)) + + return fm, fs + + +def get_m2m_metrics(gt_labels, pred_labels, num_pred_labels, matchMat, thresh, overlaps=True): + """ + Compute false merge and false split metrics for any localization criterion using many-to-many matching. + + Args: + gt_labels: Ground truth labels + pred_labels: Predicted labels + num_pred_labels: Number of predicted labels + matchMat: Recall matrix for clDice (for IoU matrix appropriate m2m matching is needed) + thresh: Threshold for matching + overlaps: Whether to allow overlapping instances + + Returns: + Tuple of (false_merge, false_split, matches) + """ + matches = get_m2m_matches( + matchMat, thresh, gt_labels, pred_labels, overlaps + ) + fm, fs = compute_m2m_stats(matches, num_pred_labels) + return fm, fs, matches diff --git a/evalinstseg/evaluate.py b/evalinstseg/evaluate.py index 20af541..46b0fc7 100644 --- a/evalinstseg/evaluate.py +++ b/evalinstseg/evaluate.py @@ -21,8 +21,7 @@ get_gt_coverage, get_gt_coverage_dim, get_gt_coverage_overlap, - get_m2m_fm, - get_m2m_fs, + get_m2m_metrics, ) from .visualize import visualize_neurons, visualize_nuclei from .summarize import ( @@ -42,7 +41,7 @@ def evaluate_file( res_key=None, gt_key=None, suffix="", - localization_criterion="iou", # "iou", "cldice" + localization_criterion="cldice", # "iou", "cldice" assignment_strategy="greedy", # "hungarian", "greedy", "gt_0_5" add_general_metrics=[], visualize=False, @@ -99,10 +98,6 @@ def evaluate_file( remove_small_components, ) - # if from_scratch is set, overwrite existing evaluation files - # otherwise try to load precomputed metrics - # if check_for_metric is None, just check if matching file exists - # otherwise check if check_for_metric is contained within file if not from_scratch and len(glob.glob(outFn + ".toml")) > 0: with open(outFn + ".toml", "r") as tomlFl: metricsDict = toml.load(tomlFl) @@ -175,7 +170,7 @@ def evaluate_volume( pred_labels, ndim, outFn, - localization_criterion="iou", + localization_criterion="cldice", assignment_strategy="hungarian", evaluate_false_labels=False, add_general_metrics=[], @@ -221,6 +216,11 @@ def evaluate_volume( gt_labels, pred_labels, remove_small_components, foreground_only ) + # Check for overlapping instances + gt_overlaps = np.any(np.sum(gt_labels_rel > 0, axis=0) > 1) + pred_overlaps = np.any(np.sum(pred_labels_rel > 0, axis=0) > 1) + overlaps = gt_overlaps or pred_overlaps + logger.debug( "are there pixels with multiple instances?: " f"{np.sum(np.sum(gt_labels_rel > 0, axis=0) > 1)}" @@ -231,7 +231,7 @@ def evaluate_volume( num_gt_labels = int(np.max(gt_labels_rel)) num_matches = min(num_gt_labels, num_pred_labels) - # get localization criterion -> TODO: check: do we still need recallMat_wo_overlap? + # get localization criterion locMat, recallMat, precMat, recallMat_wo_overlap = compute_localization_criterion( pred_labels_rel, gt_labels_rel, @@ -273,7 +273,6 @@ def evaluate_volume( gt_ind, num_pred_labels, num_gt_labels, - locMat, precMat, recallMat, th, @@ -391,21 +390,43 @@ def evaluate_volume( avg_f1_cov_score = 0.5 * avFscore19 + 0.5 * gt_skel_coverage metrics.addMetric(tblNameGen, "avg_f1_cov_score", avg_f1_cov_score) - # TODO: rename "false_merge" and "false_splits" to sth with many-to-many? - m2m_matches = None - if "false_merge" in add_general_metrics: - fm, m2m_matches = get_m2m_fm( - gt_labels_rel, pred_labels_rel, num_pred_labels, recallMat, fm_thresh - ) - metrics.addMetric("general", "FM", fm) - if "false_split" in add_general_metrics: - # if fm and fs thresh are different, reset matches, reuse otherwise - if fm_thresh != fs_thresh: - m2m_matches = None - fs, _ = get_m2m_fs( - gt_labels_rel, pred_labels_rel, recallMat, fs_thresh, m2m_matches - ) - metrics.addMetric("general", "FS", fs) + if "false_merge" in add_general_metrics or "false_split" in add_general_metrics: + if fm_thresh == fs_thresh: + # Optimized path: Same threshold, compute both at once + fm, fs, _ = get_m2m_metrics( + gt_labels_rel, + pred_labels_rel, + num_pred_labels, + recallMat, + fm_thresh, + overlaps=overlaps, + ) + if "false_merge" in add_general_metrics: + metrics.addMetric("general", "FM", fm) + if "false_split" in add_general_metrics: + metrics.addMetric("general", "FS", fs) + else: + # Different thresholds, compute separately + if "false_merge" in add_general_metrics: + fm, _, _ = get_m2m_metrics( + gt_labels_rel, + pred_labels_rel, + num_pred_labels, + recallMat, + fm_thresh, + overlaps=overlaps, + ) + metrics.addMetric("general", "FM", fm) + if "false_split" in add_general_metrics: + _, fs, _ = get_m2m_metrics( + gt_labels_rel, + pred_labels_rel, + num_pred_labels, + recallMat, + fs_thresh, + overlaps=overlaps, + ) + metrics.addMetric("general", "FS", fs) if "avg_gt_cov_dim" in add_general_metrics: gt_dim, tp_05_dim, tp_05_rel_dim, gt_covs_dim, avg_cov_dim = ( @@ -463,7 +484,13 @@ def main(): parser = argparse.ArgumentParser() # input output parser.add_argument( - "--res_file", nargs="+", type=str, help="path to result file", required=True + "--stability_mode", action="store_true", help="Run 3x stability evaluation" + ) + parser.add_argument( + "--run_dirs", nargs="+", type=str, help="List of 3 experiment directories" + ) + parser.add_argument( + "--res_file", nargs="+", type=str, help="path to result file" ) parser.add_argument( "--gt_file", @@ -603,193 +630,230 @@ def main(): logger.debug("arguments %s", tuple(sys.argv)) args = parser.parse_args() - # shortcut if res_file and gt_file contain folders - if len(args.res_file) == 1 and len(args.gt_file) == 1: - res_file = args.res_file[0] - gt_file = args.gt_file[0] - if (os.path.isdir(res_file) and not res_file.endswith(".zarr")) and ( - os.path.isdir(gt_file) and not gt_file.endswith(".zarr") + def get_gt_file(in_fn, gt_folder): + """Helper to get gt file corresponding to input result file.""" + out_fn = os.path.join( + gt_folder, os.path.basename(in_fn).split(".")[0] + ".zarr" + ) + return out_fn + + def _run_loop(res_files, gt_files, out_dirs, partly_list_loc): + """Core evaluation loop used in normal and stability mode.""" + + loop_samples = [] + loop_metrics = [] + for res_file, gt_file, partly, out_dir in zip( + res_files, gt_files, partly_list_loc, out_dirs ): - args.res_file = natsorted(glob.glob(res_file + "/*.hdf")) + if not os.path.exists(out_dir): os.makedirs(out_dir, exist_ok=True) + + sample_name = os.path.basename(res_file).split(".")[0] + logger.info("sample_name: %s", sample_name) + + metric_dict = evaluate_file( + res_file, + gt_file, + args.ndim, + out_dir, + res_key=args.res_key, + gt_key=args.gt_key, + suffix=args.suffix, + localization_criterion=args.localization_criterion, + assignment_strategy=args.assignment_strategy, + add_general_metrics=args.add_general_metrics, + visualize=args.visualize, + visualize_type=args.visualize_type, + partly=partly, + foreground_only=args.foreground_only, + remove_small_components=args.remove_small_components, + evaluate_false_labels=args.evaluate_false_labels, + fm_thresh=args.fm_thresh, + fs_thresh=args.fs_thresh, + from_scratch=args.from_scratch, + eval_dim=args.eval_dim, + debug=args.debug, + ) + loop_metrics.append(metric_dict) + loop_samples.append(sample_name) + print(f"Evaluated {sample_name}: {metric_dict}") + + return loop_metrics, loop_samples + + # Stability Mode (Wraps logic 3 times) + if args.stability_mode: + if not args.run_dirs or len(args.run_dirs) != 3: + raise ValueError("Stability mode requires exactly 3 directories passed to --run_dirs") + + stability_scores = [] + print("--- EVALUTE USING STABILITY MODE ---") + + for run_idx, run_dir in enumerate(args.run_dirs): + print(f"Processing Run {run_idx+1}: {run_dir}") + + # Auto-detect files for this run + run_res_files = natsorted(glob.glob(run_dir + "/*.hdf")) + if not run_res_files: + run_res_files = natsorted(glob.glob(run_dir + "/*.zarr")) + + # Assume gt_file is the PARENT folder + run_gt_files = [get_gt_file(fn, args.gt_file[0]) for fn in run_res_files] + run_out_dirs = [os.path.join(args.out_dir[0], f"seed_{run_idx+1}")] * len(run_res_files) + + # Run the inner loop + m_dicts, s_names = _run_loop(run_res_files, run_gt_files, run_out_dirs, [args.partly]*len(run_res_files)) + + # Aggregate just this run + metrics_full = {s: m for m, s in zip(m_dicts, s_names) if m is not None} + acc, _ = average_flylight_score_over_instances(s_names, metrics_full) + stability_scores.append(acc) + + # Print Average and Std Dev across runs + print("\n=== FISBe BENCHMARK RESULTS (Mean ± Std) ===") + if stability_scores: + for key in stability_scores[0].keys(): + values = [s[key] for s in stability_scores if key in s] + if len(values) == 3: + print(f"{key:<30}: {np.mean(values):.4f} ± {np.std(values):.4f}") + + # Normal Mode + else: + print("--- EVALUTE USING SINGLE DIR ---") + # shortcut if res_file and gt_file contain folders + if len(args.res_file) == 1 and len(args.gt_file) == 1: + res_file = args.res_file[0] + gt_file = args.gt_file[0] + if (os.path.isdir(res_file) and not res_file.endswith(".zarr")) and ( + os.path.isdir(gt_file) and not gt_file.endswith(".zarr") + ): + args.res_file = natsorted(glob.glob(res_file + "/*.hdf")) + args.gt_file = [get_gt_file(fn, gt_file) for fn in args.res_file] + + # check same length for result and gt files + assert len(args.res_file) == len(args.gt_file), ( + "Please check, not the same number of result and gt files" + ) + # set partly parameter for all samples if not done already + if len(args.res_file) > 1: + if args.partly_list is not None: + assert len(args.partly_list) == len(args.res_file), ( + "Please check, not the same number of result files " + "and partly_list values" + ) + partly_list = np.array(args.partly_list, dtype=bool) + else: + partly_list = [args.partly] * len(args.res_file) + else: + partly_list = [args.partly] - def get_gt_file(in_fn, gt_folder): - out_fn = os.path.join( - gt_folder, os.path.basename(in_fn).split(".")[0] + ".zarr" + # check out_dir + if len(args.res_file) > 1: + if len(args.out_dir) > 1: + assert len(args.res_file) == len(args.out_dir), ( + "Please check, number of input files and output folders should correspond" ) - assert ( - os.path.basename(in_fn).split(".")[0] - == os.path.basename(out_fn).split(".")[0] + outdir_list = args.out_dir + else: + outdir_list = args.out_dir * len(args.res_file) + else: + assert len(args.out_dir) == 1, "Please check number of output directories" + outdir_list = args.out_dir + # check output dir for summary + if args.summary_out_dir is None: + args.summary_out_dir = args.out_dir[0] + + if args.app is not None: + if args.app == "flylight": + print( + "Warning: parameter app is set and will overwrite parameters. " + "This might not be what you want." ) - return out_fn - - args.gt_file = [get_gt_file(fn, gt_file) for fn in args.res_file] - - # check same length for result and gt files - assert len(args.res_file) == len(args.gt_file), ( - "Please check, not the same number of result and gt files" - ) - # set partly parameter for all samples if not done already - if len(args.res_file) > 1: - if args.partly_list is not None: - assert len(args.partly_list) == len(args.res_file), ( - "Please check, not the same number of result files " - "and partly_list values" + args.ndim = 3 + args.localization_criterion = "cldice" + args.assignment_strategy = "greedy" + args.remove_small_components = 800 + # args.evaluate_false_labels = True + args.metric = "general.avg_f1_cov_score" + args.add_general_metrics = [ + "avg_gt_skel_coverage", + "avg_f1_cov_score", + "false_merge", + "false_split", + "avg_gt_cov_dim", + "avg_gt_cov_overlap", + ] + args.summary = [ + "general.Num GT", + "general.Num Pred", + "general.avg_f1_cov_score", + "confusion_matrix.avFscore", + "general.avg_gt_skel_coverage", + "confusion_matrix.th_0_5.AP_TP", + "confusion_matrix.th_0_5.AP_FP", + "confusion_matrix.th_0_5.AP_FN", + "general.FM", + "general.FS", + "general.TP_05", + "general.TP_05_rel", + "general.avg_TP_05_cldice", + "general.GT_dim", + "general.TP_05_dim", + "general.TP_05_rel_dim", + "general.avg_gt_cov_dim", + "general.GT_overlap", + "general.TP_05_overlap", + "general.TP_05_rel_overlap", + "general.avg_gt_cov_overlap", + "confusion_matrix.th_0_1.fscore", + "confusion_matrix.th_0_2.fscore", + "confusion_matrix.th_0_3.fscore", + "confusion_matrix.th_0_4.fscore", + "confusion_matrix.th_0_5.fscore", + "confusion_matrix.th_0_6.fscore", + "confusion_matrix.th_0_7.fscore", + "confusion_matrix.th_0_8.fscore", + "confusion_matrix.th_0_9.fscore", + ] + args.visualize_type = "neuron" + args.fm_thresh = 0.1 + args.fs_thresh = 0.05 + args.eval_dim = True + + metric_dicts, samples = _run_loop(args.res_file, args.gt_file, outdir_list, partly_list) + + # aggregate over instances + metrics_full = {} + acc_all_instances = None + for metric_dict, sample in zip(metric_dicts, samples): + if metric_dict is None: + continue + metrics_full[sample] = metric_dict + if len(np.unique(partly_list)) > 1: + print("averaging for combined") + # get average over instances for completely + samples = np.array(samples) + acc_cpt, acc_inst_cpt = average_flylight_score_over_instances( + samples[partly_list == False], metrics_full ) - partly_list = np.array(args.partly_list, dtype=bool) - else: - partly_list = [args.partly] * len(args.res_file) - else: - partly_list = [args.partly] - - # check out_dir - if len(args.res_file) > 1: - if len(args.out_dir) > 1: - assert len(args.res_file) == len(args.out_dir), ( - "Please check, number of input files and output folders should correspond" + acc_prt, acc_inst_prt = average_flylight_score_over_instances( + samples[partly_list == True], metrics_full ) - outdir_list = args.out_dir - else: - outdir_list = args.out_dir * len(args.res_file) - else: - assert len(args.out_dir) == 1, "Please check number of output directories" - outdir_list = args.out_dir - # check output dir for summary - if args.summary_out_dir is None: - args.summary_out_dir = args.out_dir[0] - - if args.app is not None: - if args.app == "flylight": - print( - "Warning: parameter app is set and will overwrite parameters. " - "This might not be what you want." + acc, acc_all_instances = average_sets( + acc_cpt, acc_inst_cpt, acc_prt, acc_inst_prt ) - args.ndim = 3 - args.localization_criterion = "cldice" - args.assignment_strategy = "greedy" - args.remove_small_components = 800 - # args.evaluate_false_labels = True - args.metric = "general.avg_f1_cov_score" - args.add_general_metrics = [ - "avg_gt_skel_coverage", - "avg_f1_cov_score", - "false_merge", - "false_split", - "avg_gt_cov_dim", - "avg_gt_cov_overlap", - ] - args.summary = [ - "general.Num GT", - "general.Num Pred", - "general.avg_f1_cov_score", - "confusion_matrix.avFscore", - "general.avg_gt_skel_coverage", - "confusion_matrix.th_0_5.AP_TP", - "confusion_matrix.th_0_5.AP_FP", - "confusion_matrix.th_0_5.AP_FN", - "general.FM", - "general.FS", - "general.TP_05", - "general.TP_05_rel", - "general.avg_TP_05_cldice", - "general.GT_dim", - "general.TP_05_dim", - "general.TP_05_rel_dim", - "general.avg_gt_cov_dim", - "general.GT_overlap", - "general.TP_05_overlap", - "general.TP_05_rel_overlap", - "general.avg_gt_cov_overlap", - "confusion_matrix.th_0_1.fscore", - "confusion_matrix.th_0_2.fscore", - "confusion_matrix.th_0_3.fscore", - "confusion_matrix.th_0_4.fscore", - "confusion_matrix.th_0_5.fscore", - "confusion_matrix.th_0_6.fscore", - "confusion_matrix.th_0_7.fscore", - "confusion_matrix.th_0_8.fscore", - "confusion_matrix.th_0_9.fscore", - ] - args.visualize_type = "neuron" - args.fm_thresh = 0.1 - args.fs_thresh = 0.05 - args.eval_dim = True - - samples = [] - metric_dicts = [] - for res_file, gt_file, partly, out_dir in zip( - args.res_file, args.gt_file, partly_list, outdir_list - ): - sample_name = os.path.basename(res_file).split(".")[0] - logger.info("sample_name: %s", sample_name) - logger.info("res_file: %s", res_file) - logger.info("gt_file: %s", gt_file) - logger.info("partly: %s", partly) - logger.info("localization: %s", args.localization_criterion) - logger.info("assignment: %s", args.assignment_strategy) - logger.info("from scratch: %s", args.from_scratch) - logger.info("add general metrics: %s", args.add_general_metrics) - - samples.append(os.path.basename(res_file).split(".")[0]) - metric_dict = evaluate_file( - res_file, - gt_file, - args.ndim, - out_dir, - res_key=args.res_key, - gt_key=args.gt_key, - suffix=args.suffix, - localization_criterion=args.localization_criterion, - assignment_strategy=args.assignment_strategy, - add_general_metrics=args.add_general_metrics, - visualize=args.visualize, - visualize_type=args.visualize_type, - partly=partly, - foreground_only=args.foreground_only, - remove_small_components=args.remove_small_components, - evaluate_false_labels=args.evaluate_false_labels, - fm_thresh=args.fm_thresh, - fs_thresh=args.fs_thresh, - from_scratch=args.from_scratch, - eval_dim=args.eval_dim, - debug=args.debug, - ) - metric_dicts.append(metric_dict) - print(metric_dict) - - # aggregate over instances - metrics_full = {} - acc_all_instances = None - for metric_dict, sample in zip(metric_dicts, samples): - if metric_dict is None: - continue - metrics_full[sample] = metric_dict - if len(np.unique(partly_list)) > 1: - print("averaging for combined") - # get average over instances for completely - samples = np.array(samples) - acc_cpt, acc_inst_cpt = average_flylight_score_over_instances( - samples[partly_list == False], metrics_full - ) - acc_prt, acc_inst_prt = average_flylight_score_over_instances( - samples[partly_list == True], metrics_full - ) - acc, acc_all_instances = average_sets( - acc_cpt, acc_inst_cpt, acc_prt, acc_inst_prt - ) - else: - acc, acc_all_instances = average_flylight_score_over_instances( - samples, metrics_full - ) - if args.summary: - summarize_metric_dict( - metric_dicts, - samples, - args.summary, - os.path.join(args.summary_out_dir, "summary.csv"), - agg_inst_dict=acc_all_instances, - ) + else: + acc, acc_all_instances = average_flylight_score_over_instances( + samples, metrics_full + ) + if args.summary: + summarize_metric_dict( + metric_dicts, + samples, + args.summary, + os.path.join(args.summary_out_dir, "summary.csv"), + agg_inst_dict=acc_all_instances, + ) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/evalinstseg/match.py b/evalinstseg/match.py index 2a4fa5a..798eb50 100644 --- a/evalinstseg/match.py +++ b/evalinstseg/match.py @@ -164,7 +164,7 @@ def greedy_many_to_many_matching(gt_labels, pred_labels, locMat, thresh, def get_false_labels( - tp_pred_ind, tp_gt_ind, num_pred_labels, num_gt_labels, locMat, + tp_pred_ind, tp_gt_ind, num_pred_labels, num_gt_labels, precMat, recallMat, thresh, recallMat_wo_overlap): # get false positive indices @@ -213,3 +213,31 @@ def get_false_labels( fp_ind, fn_ind, fs_ind, fm_pred_ind, fm_gt_ind, fm_count, fp_ind_only_bg) + +def get_m2m_matches(locMat, thresh, gt_labels=None, pred_labels=None, overlaps=True): + """Get many-to-many matches between gt and predicted labels. + If we have no overlaps, we can do easy matching based on thresholding the locMat.""" + + # If we have overlapping instances, we need to do expensive greedy many-to-many matching + if overlaps: + if gt_labels is None or pred_labels is None: + raise ValueError("gt_labels and pred_labels required when overlaps=True") + matches = greedy_many_to_many_matching(gt_labels, pred_labels, locMat, thresh) + if matches is not None: + # key and values are 0-based, convert to 1-based + matches = {k + 1: [v + 1 for v in val] for k, val in matches.items()} + return matches + else: + # Simple matching based on threshold + matches = {} + locFgMat = locMat[1:, 1:] # excluding background + rows, cols = np.nonzero(locFgMat > thresh) + for gt_idx, pred_idx in zip(rows, cols): + gt_id = gt_idx + 1 # 1-based IDs + pred_id = pred_idx + 1 + if gt_id not in matches: + matches[gt_id] = [pred_id] + else: + matches[gt_id].append(pred_id) + return matches +