From 93d3dcff10db588ea5a2c2176c3979c27fd5e1e6 Mon Sep 17 00:00:00 2001 From: Jaycee Date: Tue, 2 Sep 2025 14:07:31 -0400 Subject: [PATCH 1/7] adjusted for filtering gt AND pred --- src/jabs_postprocess/cli/main.py | 28 +++- src/jabs_postprocess/compare_gt.py | 250 +++++++++++++++++++++++++++-- 2 files changed, 262 insertions(+), 16 deletions(-) diff --git a/src/jabs_postprocess/cli/main.py b/src/jabs_postprocess/cli/main.py index bab92e3..8afb616 100644 --- a/src/jabs_postprocess/cli/main.py +++ b/src/jabs_postprocess/cli/main.py @@ -154,7 +154,22 @@ def evaluate_ground_truth( ), filter_ground_truth: bool = typer.Option( False, - help="Apply filters to ground truth data (default is only to filter predictions)", + help=( + "Enable extra filtered outputs and apply stitch/filter to BOTH GT and predictions. " + "Use together with --stitch-value-filter and --filter-value-filter." + ), + ), + stitch_value_filter: Optional[int] = typer.Option( + None, + "--stitch-value-filter", + "--stitch_value_filter", + help="Stitch (frames) to use for filtered outputs", + ), + filter_value_filter: Optional[int] = typer.Option( + None, + "--filter-value-filter", + "--filter_value_filter", + help="Minimum bout (frames) to use for filtered outputs", ), trim_time: Optional[int] = typer.Option( None, @@ -177,6 +192,15 @@ def evaluate_ground_truth( f"Prediction folder does not exist: {prediction_folder}" ) + # Convert CLI options into the dict expected by the underlying function + filter_gt_dict: Optional[dict] = None + if filter_ground_truth: + if stitch_value_filter is None or filter_value_filter is None: + raise typer.BadParameter( + "When using --filter-ground-truth, you must also provide --stitch-value-filter and --filter-value-filter." + ) + filter_gt_dict = {"stitch": int(stitch_value_filter), "filter": int(filter_value_filter)} + # Call the refactored function with individual parameters compare_gt.evaluate_ground_truth( behavior=behavior, @@ -186,7 +210,7 @@ def evaluate_ground_truth( stitch_scan=stitch_scan, filter_scan=filter_scan, iou_thresholds=iou_thresholds, - filter_ground_truth=filter_ground_truth, + filter_ground_truth=filter_gt_dict, trim_time=trim_time, ) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index b2322e1..9782198 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -27,7 +27,7 @@ def evaluate_ground_truth( stitch_scan: List[float] = None, filter_scan: List[float] = None, iou_thresholds: List[float] = None, - filter_ground_truth: bool = False, + filter_ground_truth: Optional[dict] = None, trim_time: Optional[int] = None, ): """Main function for evaluating ground truth annotations against classifier predictions. @@ -40,7 +40,9 @@ def evaluate_ground_truth( stitch_scan: List of stitching (time gaps in frames to merge bouts together) values to test filter_scan: List of filter (minimum duration in frames to consider real) values to test iou_thresholds: List of intersection over union thresholds to scan - filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions) + filter_ground_truth: Optional dict specifying stitch/filter to apply to BOTH GT and predictions + for additional filtered outputs. Shape: {"stitch": int, "filter": int}. If None, no + additional filtered outputs are produced. scan_output: Output file to save the filter scan performance plot bout_output: Output file to save the resulting bout performance plot trim_time: Limit the duration in frames of videos for performance @@ -139,7 +141,7 @@ def evaluate_ground_truth( stitch_scan, filter_scan, iou_thresholds, - filter_ground_truth, + False, ) if performance_df.empty: @@ -258,17 +260,7 @@ def evaluate_ground_truth( winning_bout_df.to_csv(ouput_paths["bout_csv"], index=False) logging.info(f"Bout performance data saved to {ouput_paths['bout_csv']}") - melted_winning = pd.melt(winning_bout_df, id_vars=["threshold", "stitch", "filter"]) - - ( - p9.ggplot( - melted_winning[melted_winning["variable"].isin(["pr", "re", "f1"])], - p9.aes(x="threshold", y="value", color="variable"), - ) - + p9.geom_line() - + p9.theme_bw() - + p9.scale_y_continuous(limits=(0, 1)) - ).save(ouput_paths["bout_plot"], height=6, width=12, dpi=300) + _save_bout_curve_performance(winning_bout_df, ouput_paths["bout_plot"]) if ouput_paths["ethogram"] is not None: # Prepare data for ethogram plot @@ -348,6 +340,46 @@ def evaluate_ground_truth( f"No annotations found for behavior {behavior} to generate ethogram plot." ) + # New: filtered outputs if user provided filter settings to apply to both GT and Pred + if filter_ground_truth is not None: + if not isinstance(filter_ground_truth, dict) or not { + "stitch", + "filter", + }.issubset(set(filter_ground_truth.keys())): + logger.warning( + "filter_ground_truth must be a dict with keys {'stitch','filter'}. Skipping filtered outputs." + ) + else: + stitch_val = int(filter_ground_truth["stitch"]) + filter_val = int(filter_ground_truth["filter"]) + # 1) Filtered ethogram with 4 tracks + _save_filtered_ethogram( + all_annotations, + behavior, + stitch_val, + filter_val, + ouput_paths.get("ethogram_filtered"), + ) + # 2) Filtered curves CSV + plot over IoU thresholds + filtered_curve_df = generate_filtered_iou_curve( + all_annotations, + stitch_val, + filter_val, + np.round(iou_thresholds, 2), + ) + if filtered_curve_df is not None and len(filtered_curve_df) > 0: + if ouput_paths.get("bout_filtered_csv") is not None: + filtered_curve_df.to_csv(ouput_paths["bout_filtered_csv"], index=False) + logging.info( + f"Filtered curve performance saved to {ouput_paths['bout_filtered_csv']}" + ) + # Reuse the same curve plotting by adding fixed stitch/filter columns + filtered_curve_df["stitch"] = stitch_val + filtered_curve_df["filter"] = filter_val + _save_bout_curve_performance(filtered_curve_df, ouput_paths.get("bout_filtered_plot")) + else: + logger.warning("No filtered performance data available to save plots/CSV.") + def generate_iou_scan( all_annotations, @@ -480,4 +512,194 @@ def generate_output_paths(results_folder: Path): "ethogram": results_folder / "ethogram.png", "scan_plot": results_folder / "scan_performance.png", "bout_plot": results_folder / "bout_performance.png", + # New filtered outputs + "ethogram_filtered": results_folder / "ethogram_filtered.png", + "bout_filtered_plot": results_folder / "bout_performance_filtered.png", + "bout_filtered_csv": results_folder / "bout_performance_filtered.csv", } + + +def _save_bout_curve_performance(curve_df: pd.DataFrame, output_path: Optional[Path]): + """ + Saves the curve iou performance plot. + + Args: + curve_df: Contains the curve performance data + output_path: Path to save the plot to + """ + if output_path is None: + return + melted_df = pd.melt(curve_df, id_vars=["threshold", "stitch", "filter"]) + ( + p9.ggplot( + melted_df[melted_df["variable"].isin(["pr", "re", "f1"])], + p9.aes(x="threshold", y="value", color="variable"), + ) + + p9.geom_line() + + p9.theme_bw() + + p9.scale_y_continuous(limits=(0, 1)) + ).save(output_path, height=6, width=12, dpi=300) + + +def _save_filtered_ethogram( + all_annotations: pd.DataFrame, + behavior: str, + stitch_val: int, + filter_val: int, + output_path: Optional[Path], +): + """ + Saves the filtered ethogram plot. + + Args: + all_annotations: Contains the ethogram data + behavior: The behavior to plot + stitch_val: The stitch value + filter_val: The filter value + output_path: Path to save the plot to + """ + + if output_path is None: + return + + # Grabbing raw and filtered bouts per animal/video for gt and pred + records = [] + for (cur_animal, cur_video), animal_df in all_annotations.groupby( + ["animal_idx", "video_name"] + ): + pr_df = animal_df[~animal_df["is_gt"]] + gt_df = animal_df[animal_df["is_gt"]] + if len(pr_df) == 0: + continue + pr_obj = Bouts(pr_df["start"], pr_df["duration"], pr_df["is_behavior"]) + gt_obj = Bouts(gt_df["start"], gt_df["duration"], gt_df["is_behavior"]) + + full_duration = int(pr_obj.starts[-1] + pr_obj.durations[-1]) + pr_obj.fill_to_size(full_duration, 0) + gt_obj.fill_to_size(full_duration, 0) + + settings = ClassifierSettings("", interpolate=0, stitch=stitch_val, min_bout=filter_val) + pr_fil = pr_obj.copy() + gt_fil = gt_obj.copy() + pr_fil.filter_by_settings(settings) + gt_fil.filter_by_settings(settings) + + # Helper to extend records from a Bouts object for is_behavior == 1 + def add_records_from_bouts(bouts_obj: Bouts, track_label: str): + starts = bouts_obj.starts + durations = bouts_obj.durations + values = bouts_obj.values + if starts is None or len(starts) == 0: + return + ends = starts + durations + for s, e, v in zip(starts, ends, values): + if v == 1: + records.append( + {"animal_idx": cur_animal, "video_name": cur_video, "start": int(s), "end": int(e), "track": track_label} + ) + + add_records_from_bouts(gt_obj, "GT Raw") + add_records_from_bouts(gt_fil, "GT Filtered") + add_records_from_bouts(pr_obj, "Pred Raw") + add_records_from_bouts(pr_fil, "Pred Filtered") + + if len(records) == 0: + logger.warning("No behavior bouts found to generate filtered ethogram.") + return + + df = pd.DataFrame.from_records(records) + df["animal_video_combo"] = df["animal_idx"].astype(str) + " | " + df["video_name"].astype(str) + + # Map track to vertical bands in the requested order: raw gt, filtered gt, raw pred, filtered pred + track_order = ["GT Raw", "GT Filtered", "Pred Raw", "Pred Filtered"] + track_to_idx = {label: idx for idx, label in enumerate(track_order[::-1])} # reverse so top is GT Raw + df["track_idx"] = df["track"].map(track_to_idx) + df["ymin"] = df["track_idx"].astype(float) + df["ymax"] = df["ymin"] + 0.9 + + num_unique_combos = len(df["animal_video_combo"].unique()) + + plot = ( + p9.ggplot(df) + + p9.geom_rect(p9.aes(xmin="start", xmax="end", ymin="ymin", ymax="ymax", fill="track")) + + p9.theme_bw() + + p9.facet_wrap("~animal_video_combo", ncol=1, scales="free_x") + + p9.scale_y_continuous( + breaks=[track_to_idx[t] + 0.45 for t in track_order[::-1]], + labels=track_order[::-1], + name="", + ) + + p9.scale_fill_brewer(type="qual", palette="Set1") + + p9.labs( + x="Frame", + fill="Track", + title=f"Ethogram (filtered) for behavior: {behavior}", + ) + + p9.expand_limits(x=0) + ) + + plot.save(output_path, height=2.0 * num_unique_combos + 2, width=12, dpi=300, limitsize=False, verbose=False) + logging.info(f"Filtered ethogram plot saved to {output_path}") + + +def generate_filtered_iou_curve( + all_annotations: pd.DataFrame, + stitch_val: int, + filter_val: int, + threshold_scan: np.ndarray, +) -> Optional[pd.DataFrame]: + """Compute PR/RE/F1 across IoU thresholds after applying a fixed stitch/filter to BOTH GT and predictions. + + Returns a DataFrame aggregated over animals/videos with columns: threshold,tp,fn,fp,pr,re,f1 + """ + threshold_scan = np.round(threshold_scan, 2) + settings = ClassifierSettings("", interpolate=0, stitch=stitch_val, min_bout=filter_val) + + perf_rows = [] + for (cur_animal, cur_video), animal_df in all_annotations.groupby(["animal_idx", "video_name"]): + pr_df = animal_df[~animal_df["is_gt"]] + if len(pr_df) == 0: + continue + gt_df = animal_df[animal_df["is_gt"]] + pr_obj = Bouts(pr_df["start"], pr_df["duration"], pr_df["is_behavior"]) + gt_obj = Bouts(gt_df["start"], gt_df["duration"], gt_df["is_behavior"]) + + full_duration = int(pr_obj.starts[-1] + pr_obj.durations[-1]) + pr_obj.fill_to_size(full_duration, 0) + gt_obj.fill_to_size(full_duration, 0) + + pr_fil = pr_obj.copy(); pr_fil.filter_by_settings(settings) + gt_fil = gt_obj.copy(); gt_fil.filter_by_settings(settings) + + # Handle empty-positive cases without calling compare_to (which expects non-empty arrays) + num_pr_pos = int(np.sum(pr_fil.values == 1)) + num_gt_pos = int(np.sum(gt_fil.values == 1)) + if num_pr_pos == 0 or num_gt_pos == 0: + for thr in threshold_scan: + if num_pr_pos == 0 and num_gt_pos == 0: + metrics = {"tp": 0, "fn": 0, "fp": 0, "pr": 0, "re": 0, "f1": 0} + elif num_pr_pos == 0 and num_gt_pos > 0: + metrics = {"tp": 0, "fn": num_gt_pos, "fp": 0, "pr": 0, "re": 0, "f1": 0} + else: # num_pr_pos > 0 and num_gt_pos == 0 + metrics = {"tp": 0, "fn": 0, "fp": num_pr_pos, "pr": 0, "re": 0, "f1": 0} + perf_rows.append( + {"animal": cur_animal, "video": cur_video, "threshold": thr, **metrics} + ) + continue + + int_mat, u_mat, iou_mat = gt_fil.compare_to(pr_fil) + for thr in threshold_scan: + metrics = Bouts.calculate_iou_metrics(iou_mat, thr) + perf_rows.append( + {"animal": cur_animal, "video": cur_video, "threshold": thr, **metrics} + ) + + if len(perf_rows) == 0: + return None + + df = pd.DataFrame(perf_rows) + df = df.groupby(["threshold"])[["tp", "fn", "fp"]].apply(np.sum).reset_index() + df["pr"] = df["tp"] / (df["tp"] + df["fp"]) if "tp" in df and "fp" in df else np.nan + df["re"] = df["tp"] / (df["tp"] + df["fn"]) if "tp" in df and "fn" in df else np.nan + df["f1"] = 2 * (df["pr"] * df["re"]) / (df["pr"] + df["re"]) + return df From 3d20cd9c11a06051cfed07d2e9996a0905dc54e0 Mon Sep 17 00:00:00 2001 From: Jaycee Date: Tue, 2 Sep 2025 15:20:37 -0400 Subject: [PATCH 2/7] linting error fix --- src/jabs_postprocess/compare_gt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 9782198..72013c9 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -668,8 +668,10 @@ def generate_filtered_iou_curve( pr_obj.fill_to_size(full_duration, 0) gt_obj.fill_to_size(full_duration, 0) - pr_fil = pr_obj.copy(); pr_fil.filter_by_settings(settings) - gt_fil = gt_obj.copy(); gt_fil.filter_by_settings(settings) + pr_fil = pr_obj.copy() + pr_fil.filter_by_settings(settings) + gt_fil = gt_obj.copy() + gt_fil.filter_by_settings(settings) # Handle empty-positive cases without calling compare_to (which expects non-empty arrays) num_pr_pos = int(np.sum(pr_fil.values == 1)) From ae5b7f260180c8aeb27c529c2ffbbb98c4b68eb6 Mon Sep 17 00:00:00 2001 From: Jaycee Date: Tue, 2 Sep 2025 15:29:15 -0400 Subject: [PATCH 3/7] ruff linting --- src/jabs_postprocess/compare_gt.py | 78 ++++++++++++++++++++++++------ 1 file changed, 64 insertions(+), 14 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 72013c9..2cc7bf2 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -369,16 +369,22 @@ def evaluate_ground_truth( ) if filtered_curve_df is not None and len(filtered_curve_df) > 0: if ouput_paths.get("bout_filtered_csv") is not None: - filtered_curve_df.to_csv(ouput_paths["bout_filtered_csv"], index=False) + filtered_curve_df.to_csv( + ouput_paths["bout_filtered_csv"], index=False + ) logging.info( f"Filtered curve performance saved to {ouput_paths['bout_filtered_csv']}" ) # Reuse the same curve plotting by adding fixed stitch/filter columns filtered_curve_df["stitch"] = stitch_val filtered_curve_df["filter"] = filter_val - _save_bout_curve_performance(filtered_curve_df, ouput_paths.get("bout_filtered_plot")) + _save_bout_curve_performance( + filtered_curve_df, ouput_paths.get("bout_filtered_plot") + ) else: - logger.warning("No filtered performance data available to save plots/CSV.") + logger.warning( + "No filtered performance data available to save plots/CSV." + ) def generate_iou_scan( @@ -578,7 +584,9 @@ def _save_filtered_ethogram( pr_obj.fill_to_size(full_duration, 0) gt_obj.fill_to_size(full_duration, 0) - settings = ClassifierSettings("", interpolate=0, stitch=stitch_val, min_bout=filter_val) + settings = ClassifierSettings( + "", interpolate=0, stitch=stitch_val, min_bout=filter_val + ) pr_fil = pr_obj.copy() gt_fil = gt_obj.copy() pr_fil.filter_by_settings(settings) @@ -595,7 +603,13 @@ def add_records_from_bouts(bouts_obj: Bouts, track_label: str): for s, e, v in zip(starts, ends, values): if v == 1: records.append( - {"animal_idx": cur_animal, "video_name": cur_video, "start": int(s), "end": int(e), "track": track_label} + { + "animal_idx": cur_animal, + "video_name": cur_video, + "start": int(s), + "end": int(e), + "track": track_label, + } ) add_records_from_bouts(gt_obj, "GT Raw") @@ -608,11 +622,15 @@ def add_records_from_bouts(bouts_obj: Bouts, track_label: str): return df = pd.DataFrame.from_records(records) - df["animal_video_combo"] = df["animal_idx"].astype(str) + " | " + df["video_name"].astype(str) + df["animal_video_combo"] = ( + df["animal_idx"].astype(str) + " | " + df["video_name"].astype(str) + ) # Map track to vertical bands in the requested order: raw gt, filtered gt, raw pred, filtered pred track_order = ["GT Raw", "GT Filtered", "Pred Raw", "Pred Filtered"] - track_to_idx = {label: idx for idx, label in enumerate(track_order[::-1])} # reverse so top is GT Raw + track_to_idx = { + label: idx for idx, label in enumerate(track_order[::-1]) + } # reverse so top is GT Raw df["track_idx"] = df["track"].map(track_to_idx) df["ymin"] = df["track_idx"].astype(float) df["ymax"] = df["ymin"] + 0.9 @@ -621,7 +639,9 @@ def add_records_from_bouts(bouts_obj: Bouts, track_label: str): plot = ( p9.ggplot(df) - + p9.geom_rect(p9.aes(xmin="start", xmax="end", ymin="ymin", ymax="ymax", fill="track")) + + p9.geom_rect( + p9.aes(xmin="start", xmax="end", ymin="ymin", ymax="ymax", fill="track") + ) + p9.theme_bw() + p9.facet_wrap("~animal_video_combo", ncol=1, scales="free_x") + p9.scale_y_continuous( @@ -638,7 +658,14 @@ def add_records_from_bouts(bouts_obj: Bouts, track_label: str): + p9.expand_limits(x=0) ) - plot.save(output_path, height=2.0 * num_unique_combos + 2, width=12, dpi=300, limitsize=False, verbose=False) + plot.save( + output_path, + height=2.0 * num_unique_combos + 2, + width=12, + dpi=300, + limitsize=False, + verbose=False, + ) logging.info(f"Filtered ethogram plot saved to {output_path}") @@ -653,10 +680,14 @@ def generate_filtered_iou_curve( Returns a DataFrame aggregated over animals/videos with columns: threshold,tp,fn,fp,pr,re,f1 """ threshold_scan = np.round(threshold_scan, 2) - settings = ClassifierSettings("", interpolate=0, stitch=stitch_val, min_bout=filter_val) + settings = ClassifierSettings( + "", interpolate=0, stitch=stitch_val, min_bout=filter_val + ) perf_rows = [] - for (cur_animal, cur_video), animal_df in all_annotations.groupby(["animal_idx", "video_name"]): + for (cur_animal, cur_video), animal_df in all_annotations.groupby( + ["animal_idx", "video_name"] + ): pr_df = animal_df[~animal_df["is_gt"]] if len(pr_df) == 0: continue @@ -681,11 +712,30 @@ def generate_filtered_iou_curve( if num_pr_pos == 0 and num_gt_pos == 0: metrics = {"tp": 0, "fn": 0, "fp": 0, "pr": 0, "re": 0, "f1": 0} elif num_pr_pos == 0 and num_gt_pos > 0: - metrics = {"tp": 0, "fn": num_gt_pos, "fp": 0, "pr": 0, "re": 0, "f1": 0} + metrics = { + "tp": 0, + "fn": num_gt_pos, + "fp": 0, + "pr": 0, + "re": 0, + "f1": 0, + } else: # num_pr_pos > 0 and num_gt_pos == 0 - metrics = {"tp": 0, "fn": 0, "fp": num_pr_pos, "pr": 0, "re": 0, "f1": 0} + metrics = { + "tp": 0, + "fn": 0, + "fp": num_pr_pos, + "pr": 0, + "re": 0, + "f1": 0, + } perf_rows.append( - {"animal": cur_animal, "video": cur_video, "threshold": thr, **metrics} + { + "animal": cur_animal, + "video": cur_video, + "threshold": thr, + **metrics, + } ) continue From fdff842b906728a9f61c7664163f8325f6e0e9e9 Mon Sep 17 00:00:00 2001 From: Jaycee Date: Tue, 2 Sep 2025 15:37:07 -0400 Subject: [PATCH 4/7] more ruff linting --- src/jabs_postprocess/cli/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/jabs_postprocess/cli/main.py b/src/jabs_postprocess/cli/main.py index 8afb616..0da040f 100644 --- a/src/jabs_postprocess/cli/main.py +++ b/src/jabs_postprocess/cli/main.py @@ -199,7 +199,10 @@ def evaluate_ground_truth( raise typer.BadParameter( "When using --filter-ground-truth, you must also provide --stitch-value-filter and --filter-value-filter." ) - filter_gt_dict = {"stitch": int(stitch_value_filter), "filter": int(filter_value_filter)} + filter_gt_dict = { + "stitch": int(stitch_value_filter), + "filter": int(filter_value_filter), + } # Call the refactored function with individual parameters compare_gt.evaluate_ground_truth( From ca9b6e34c163bacee2cdccbdae422c74b4fbc3a3 Mon Sep 17 00:00:00 2001 From: Jaycee Date: Tue, 2 Sep 2025 16:19:11 -0400 Subject: [PATCH 5/7] docstring update --- src/jabs_postprocess/compare_gt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 2cc7bf2..25fdb67 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -41,8 +41,10 @@ def evaluate_ground_truth( filter_scan: List of filter (minimum duration in frames to consider real) values to test iou_thresholds: List of intersection over union thresholds to scan filter_ground_truth: Optional dict specifying stitch/filter to apply to BOTH GT and predictions - for additional filtered outputs. Shape: {"stitch": int, "filter": int}. If None, no - additional filtered outputs are produced. + for additional filtered outputs. If provided, need to include two more arguments: stitch_value_filter + and filter_value_filter + stitch_value_filter: Stitch (frames) to use for filtered outputs (gt and pred) + filter_value_filter: Minimum bout (frames) to use for filtered outputs (gt and pred) scan_output: Output file to save the filter scan performance plot bout_output: Output file to save the resulting bout performance plot trim_time: Limit the duration in frames of videos for performance From 8125259f31ead665e48ed89ddafae4f2d6ea72e7 Mon Sep 17 00:00:00 2001 From: Jaycee Date: Wed, 3 Sep 2025 13:40:24 -0400 Subject: [PATCH 6/7] deleting redundant code --- src/jabs_postprocess/compare_gt.py | 148 +++++++++-------------------- 1 file changed, 47 insertions(+), 101 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 25fdb67..5408a9a 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -1,8 +1,8 @@ """Associated lines of code that deal with the comparison of predictions (from classify.py) and GT annotation (from a JABS project).""" import logging -from typing import List, Optional from pathlib import Path +from typing import List, Optional import numpy as np import pandas as pd @@ -15,7 +15,6 @@ JabsProject, ) - logger = logging.getLogger(__name__) @@ -41,7 +40,7 @@ def evaluate_ground_truth( filter_scan: List of filter (minimum duration in frames to consider real) values to test iou_thresholds: List of intersection over union thresholds to scan filter_ground_truth: Optional dict specifying stitch/filter to apply to BOTH GT and predictions - for additional filtered outputs. If provided, need to include two more arguments: stitch_value_filter + for additional filtered outputs. If provided, need to include two more arguments: stitch_value_filter and filter_value_filter stitch_value_filter: Stitch (frames) to use for filtered outputs (gt and pred) filter_value_filter: Minimum bout (frames) to use for filtered outputs (gt and pred) @@ -158,8 +157,6 @@ def evaluate_ground_truth( performance_df.to_csv(ouput_paths["scan_csv"], index=False) logging.info(f"Scan performance data saved to {ouput_paths['scan_csv']}") - _melted_df = pd.melt(performance_df, id_vars=["threshold", "stitch", "filter"]) - middle_threshold = np.sort(iou_thresholds)[int(np.floor(len(iou_thresholds) / 2))] # Create a copy to avoid SettingWithCopyWarning @@ -363,11 +360,12 @@ def evaluate_ground_truth( ouput_paths.get("ethogram_filtered"), ) # 2) Filtered curves CSV + plot over IoU thresholds - filtered_curve_df = generate_filtered_iou_curve( + filtered_curve_df = generate_iou_scan( all_annotations, - stitch_val, - filter_val, + [stitch_val], + [filter_val], np.round(iou_thresholds, 2), + True, # True for filter_ground_truth ) if filtered_curve_df is not None and len(filtered_curve_df) > 0: if ouput_paths.get("bout_filtered_csv") is not None: @@ -377,9 +375,6 @@ def evaluate_ground_truth( logging.info( f"Filtered curve performance saved to {ouput_paths['bout_filtered_csv']}" ) - # Reuse the same curve plotting by adding fixed stitch/filter columns - filtered_curve_df["stitch"] = stitch_val - filtered_curve_df["filter"] = filter_val _save_bout_curve_performance( filtered_curve_df, ouput_paths.get("bout_filtered_plot") ) @@ -445,8 +440,47 @@ def generate_iou_scan( # Always apply filters to predictions cur_pr.filter_by_settings(cur_filter_settings) + # Handle cases with zero positives on either side without calling compare_to. + # This is necessary when filtering both gt and pred because then there can be + # gt videos that have all annotations filtered out as well. + num_pr_pos = int(np.sum(cur_pr.values == 1)) + num_gt_pos = int(np.sum(cur_gt.values == 1)) + if num_pr_pos == 0 or num_gt_pos == 0: + for cur_threshold in threshold_scan: + if num_pr_pos == 0 and num_gt_pos == 0: + metrics = {"tp": 0, "fn": 0, "fp": 0, "pr": 0, "re": 0, "f1": 0} + elif num_pr_pos == 0 and num_gt_pos > 0: + metrics = { + "tp": 0, + "fn": num_gt_pos, + "fp": 0, + "pr": 0, + "re": 0, + "f1": 0, + } + else: # num_pr_pos > 0 and num_gt_pos == 0 + metrics = { + "tp": 0, + "fn": 0, + "fp": num_pr_pos, + "pr": 0, + "re": 0, + "f1": 0, + } + new_performance = { + "animal": [cur_animal], + "video": [cur_video], + "stitch": [cur_stitch], + "filter": [cur_filter], + "threshold": [cur_threshold], + } + for key, val in metrics.items(): + new_performance[key] = [val] + performance_df.append(pd.DataFrame(new_performance)) + continue + # Add iou metrics to the list - int_mat, u_mat, iou_mat = cur_gt.compare_to(cur_pr) + _, _, iou_mat = cur_gt.compare_to(cur_pr) for cur_threshold in threshold_scan: new_performance = { "animal": [cur_animal], @@ -508,6 +542,7 @@ def generate_output_paths(results_folder: Path): Args: results_folder: Path to the folder where results will be saved. + Returns: A dictionary with keys 'scan_csv', 'bout_csv', 'ethogram', 'scan_plot', and 'bout_plot' containing the respective output paths. """ @@ -566,7 +601,6 @@ def _save_filtered_ethogram( filter_val: The filter value output_path: Path to save the plot to """ - if output_path is None: return @@ -669,91 +703,3 @@ def add_records_from_bouts(bouts_obj: Bouts, track_label: str): verbose=False, ) logging.info(f"Filtered ethogram plot saved to {output_path}") - - -def generate_filtered_iou_curve( - all_annotations: pd.DataFrame, - stitch_val: int, - filter_val: int, - threshold_scan: np.ndarray, -) -> Optional[pd.DataFrame]: - """Compute PR/RE/F1 across IoU thresholds after applying a fixed stitch/filter to BOTH GT and predictions. - - Returns a DataFrame aggregated over animals/videos with columns: threshold,tp,fn,fp,pr,re,f1 - """ - threshold_scan = np.round(threshold_scan, 2) - settings = ClassifierSettings( - "", interpolate=0, stitch=stitch_val, min_bout=filter_val - ) - - perf_rows = [] - for (cur_animal, cur_video), animal_df in all_annotations.groupby( - ["animal_idx", "video_name"] - ): - pr_df = animal_df[~animal_df["is_gt"]] - if len(pr_df) == 0: - continue - gt_df = animal_df[animal_df["is_gt"]] - pr_obj = Bouts(pr_df["start"], pr_df["duration"], pr_df["is_behavior"]) - gt_obj = Bouts(gt_df["start"], gt_df["duration"], gt_df["is_behavior"]) - - full_duration = int(pr_obj.starts[-1] + pr_obj.durations[-1]) - pr_obj.fill_to_size(full_duration, 0) - gt_obj.fill_to_size(full_duration, 0) - - pr_fil = pr_obj.copy() - pr_fil.filter_by_settings(settings) - gt_fil = gt_obj.copy() - gt_fil.filter_by_settings(settings) - - # Handle empty-positive cases without calling compare_to (which expects non-empty arrays) - num_pr_pos = int(np.sum(pr_fil.values == 1)) - num_gt_pos = int(np.sum(gt_fil.values == 1)) - if num_pr_pos == 0 or num_gt_pos == 0: - for thr in threshold_scan: - if num_pr_pos == 0 and num_gt_pos == 0: - metrics = {"tp": 0, "fn": 0, "fp": 0, "pr": 0, "re": 0, "f1": 0} - elif num_pr_pos == 0 and num_gt_pos > 0: - metrics = { - "tp": 0, - "fn": num_gt_pos, - "fp": 0, - "pr": 0, - "re": 0, - "f1": 0, - } - else: # num_pr_pos > 0 and num_gt_pos == 0 - metrics = { - "tp": 0, - "fn": 0, - "fp": num_pr_pos, - "pr": 0, - "re": 0, - "f1": 0, - } - perf_rows.append( - { - "animal": cur_animal, - "video": cur_video, - "threshold": thr, - **metrics, - } - ) - continue - - int_mat, u_mat, iou_mat = gt_fil.compare_to(pr_fil) - for thr in threshold_scan: - metrics = Bouts.calculate_iou_metrics(iou_mat, thr) - perf_rows.append( - {"animal": cur_animal, "video": cur_video, "threshold": thr, **metrics} - ) - - if len(perf_rows) == 0: - return None - - df = pd.DataFrame(perf_rows) - df = df.groupby(["threshold"])[["tp", "fn", "fp"]].apply(np.sum).reset_index() - df["pr"] = df["tp"] / (df["tp"] + df["fp"]) if "tp" in df and "fp" in df else np.nan - df["re"] = df["tp"] / (df["tp"] + df["fn"]) if "tp" in df and "fn" in df else np.nan - df["f1"] = 2 * (df["pr"] * df["re"]) / (df["pr"] + df["re"]) - return df From 71de9e03220422809736084310ea381ea0171927 Mon Sep 17 00:00:00 2001 From: Jaycee Date: Wed, 3 Sep 2025 15:25:20 -0400 Subject: [PATCH 7/7] trying to fix test outcome --- src/jabs_postprocess/compare_gt.py | 31 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 5408a9a..457a24c 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -498,20 +498,23 @@ def generate_iou_scan( logger.warning( "No valid ground truth and prediction pairs found for behavior across all files. Cannot generate performance metrics." ) - # Return an empty DataFrame with expected columns to prevent downstream errors - return pd.DataFrame( - columns=[ - "stitch", - "filter", - "threshold", - "tp", - "fn", - "fp", - "pr", - "re", - "f1", - ] - ) + # Build stitch/filter/threshold with zeros to prevent downstream errors + rows = [] + for s in stitch_scan: + for f in filter_scan: + for t in threshold_scan: + rows.append({ + "stitch": s, + "filter": f, + "threshold": t, + "tp": 0, + "fn": 0, + "fp": 0, + "pr": np.nan, + "re": np.nan, + "f1": np.nan, + }) + return pd.DataFrame(rows) performance_df = pd.concat(performance_df) # Aggregate over animals