diff --git a/src/jabs_postprocess/cli/main.py b/src/jabs_postprocess/cli/main.py index bab92e3..0da040f 100644 --- a/src/jabs_postprocess/cli/main.py +++ b/src/jabs_postprocess/cli/main.py @@ -154,7 +154,22 @@ def evaluate_ground_truth( ), filter_ground_truth: bool = typer.Option( False, - help="Apply filters to ground truth data (default is only to filter predictions)", + help=( + "Enable extra filtered outputs and apply stitch/filter to BOTH GT and predictions. " + "Use together with --stitch-value-filter and --filter-value-filter." + ), + ), + stitch_value_filter: Optional[int] = typer.Option( + None, + "--stitch-value-filter", + "--stitch_value_filter", + help="Stitch (frames) to use for filtered outputs", + ), + filter_value_filter: Optional[int] = typer.Option( + None, + "--filter-value-filter", + "--filter_value_filter", + help="Minimum bout (frames) to use for filtered outputs", ), trim_time: Optional[int] = typer.Option( None, @@ -177,6 +192,18 @@ def evaluate_ground_truth( f"Prediction folder does not exist: {prediction_folder}" ) + # Convert CLI options into the dict expected by the underlying function + filter_gt_dict: Optional[dict] = None + if filter_ground_truth: + if stitch_value_filter is None or filter_value_filter is None: + raise typer.BadParameter( + "When using --filter-ground-truth, you must also provide --stitch-value-filter and --filter-value-filter." + ) + filter_gt_dict = { + "stitch": int(stitch_value_filter), + "filter": int(filter_value_filter), + } + # Call the refactored function with individual parameters compare_gt.evaluate_ground_truth( behavior=behavior, @@ -186,7 +213,7 @@ def evaluate_ground_truth( stitch_scan=stitch_scan, filter_scan=filter_scan, iou_thresholds=iou_thresholds, - filter_ground_truth=filter_ground_truth, + filter_ground_truth=filter_gt_dict, trim_time=trim_time, ) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index b2322e1..457a24c 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -1,8 +1,8 @@ """Associated lines of code that deal with the comparison of predictions (from classify.py) and GT annotation (from a JABS project).""" import logging -from typing import List, Optional from pathlib import Path +from typing import List, Optional import numpy as np import pandas as pd @@ -15,7 +15,6 @@ JabsProject, ) - logger = logging.getLogger(__name__) @@ -27,7 +26,7 @@ def evaluate_ground_truth( stitch_scan: List[float] = None, filter_scan: List[float] = None, iou_thresholds: List[float] = None, - filter_ground_truth: bool = False, + filter_ground_truth: Optional[dict] = None, trim_time: Optional[int] = None, ): """Main function for evaluating ground truth annotations against classifier predictions. @@ -40,7 +39,11 @@ def evaluate_ground_truth( stitch_scan: List of stitching (time gaps in frames to merge bouts together) values to test filter_scan: List of filter (minimum duration in frames to consider real) values to test iou_thresholds: List of intersection over union thresholds to scan - filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions) + filter_ground_truth: Optional dict specifying stitch/filter to apply to BOTH GT and predictions + for additional filtered outputs. If provided, need to include two more arguments: stitch_value_filter + and filter_value_filter + stitch_value_filter: Stitch (frames) to use for filtered outputs (gt and pred) + filter_value_filter: Minimum bout (frames) to use for filtered outputs (gt and pred) scan_output: Output file to save the filter scan performance plot bout_output: Output file to save the resulting bout performance plot trim_time: Limit the duration in frames of videos for performance @@ -139,7 +142,7 @@ def evaluate_ground_truth( stitch_scan, filter_scan, iou_thresholds, - filter_ground_truth, + False, ) if performance_df.empty: @@ -154,8 +157,6 @@ def evaluate_ground_truth( performance_df.to_csv(ouput_paths["scan_csv"], index=False) logging.info(f"Scan performance data saved to {ouput_paths['scan_csv']}") - _melted_df = pd.melt(performance_df, id_vars=["threshold", "stitch", "filter"]) - middle_threshold = np.sort(iou_thresholds)[int(np.floor(len(iou_thresholds) / 2))] # Create a copy to avoid SettingWithCopyWarning @@ -258,17 +259,7 @@ def evaluate_ground_truth( winning_bout_df.to_csv(ouput_paths["bout_csv"], index=False) logging.info(f"Bout performance data saved to {ouput_paths['bout_csv']}") - melted_winning = pd.melt(winning_bout_df, id_vars=["threshold", "stitch", "filter"]) - - ( - p9.ggplot( - melted_winning[melted_winning["variable"].isin(["pr", "re", "f1"])], - p9.aes(x="threshold", y="value", color="variable"), - ) - + p9.geom_line() - + p9.theme_bw() - + p9.scale_y_continuous(limits=(0, 1)) - ).save(ouput_paths["bout_plot"], height=6, width=12, dpi=300) + _save_bout_curve_performance(winning_bout_df, ouput_paths["bout_plot"]) if ouput_paths["ethogram"] is not None: # Prepare data for ethogram plot @@ -348,6 +339,50 @@ def evaluate_ground_truth( f"No annotations found for behavior {behavior} to generate ethogram plot." ) + # New: filtered outputs if user provided filter settings to apply to both GT and Pred + if filter_ground_truth is not None: + if not isinstance(filter_ground_truth, dict) or not { + "stitch", + "filter", + }.issubset(set(filter_ground_truth.keys())): + logger.warning( + "filter_ground_truth must be a dict with keys {'stitch','filter'}. Skipping filtered outputs." + ) + else: + stitch_val = int(filter_ground_truth["stitch"]) + filter_val = int(filter_ground_truth["filter"]) + # 1) Filtered ethogram with 4 tracks + _save_filtered_ethogram( + all_annotations, + behavior, + stitch_val, + filter_val, + ouput_paths.get("ethogram_filtered"), + ) + # 2) Filtered curves CSV + plot over IoU thresholds + filtered_curve_df = generate_iou_scan( + all_annotations, + [stitch_val], + [filter_val], + np.round(iou_thresholds, 2), + True, # True for filter_ground_truth + ) + if filtered_curve_df is not None and len(filtered_curve_df) > 0: + if ouput_paths.get("bout_filtered_csv") is not None: + filtered_curve_df.to_csv( + ouput_paths["bout_filtered_csv"], index=False + ) + logging.info( + f"Filtered curve performance saved to {ouput_paths['bout_filtered_csv']}" + ) + _save_bout_curve_performance( + filtered_curve_df, ouput_paths.get("bout_filtered_plot") + ) + else: + logger.warning( + "No filtered performance data available to save plots/CSV." + ) + def generate_iou_scan( all_annotations, @@ -405,8 +440,47 @@ def generate_iou_scan( # Always apply filters to predictions cur_pr.filter_by_settings(cur_filter_settings) + # Handle cases with zero positives on either side without calling compare_to. + # This is necessary when filtering both gt and pred because then there can be + # gt videos that have all annotations filtered out as well. + num_pr_pos = int(np.sum(cur_pr.values == 1)) + num_gt_pos = int(np.sum(cur_gt.values == 1)) + if num_pr_pos == 0 or num_gt_pos == 0: + for cur_threshold in threshold_scan: + if num_pr_pos == 0 and num_gt_pos == 0: + metrics = {"tp": 0, "fn": 0, "fp": 0, "pr": 0, "re": 0, "f1": 0} + elif num_pr_pos == 0 and num_gt_pos > 0: + metrics = { + "tp": 0, + "fn": num_gt_pos, + "fp": 0, + "pr": 0, + "re": 0, + "f1": 0, + } + else: # num_pr_pos > 0 and num_gt_pos == 0 + metrics = { + "tp": 0, + "fn": 0, + "fp": num_pr_pos, + "pr": 0, + "re": 0, + "f1": 0, + } + new_performance = { + "animal": [cur_animal], + "video": [cur_video], + "stitch": [cur_stitch], + "filter": [cur_filter], + "threshold": [cur_threshold], + } + for key, val in metrics.items(): + new_performance[key] = [val] + performance_df.append(pd.DataFrame(new_performance)) + continue + # Add iou metrics to the list - int_mat, u_mat, iou_mat = cur_gt.compare_to(cur_pr) + _, _, iou_mat = cur_gt.compare_to(cur_pr) for cur_threshold in threshold_scan: new_performance = { "animal": [cur_animal], @@ -424,20 +498,23 @@ def generate_iou_scan( logger.warning( "No valid ground truth and prediction pairs found for behavior across all files. Cannot generate performance metrics." ) - # Return an empty DataFrame with expected columns to prevent downstream errors - return pd.DataFrame( - columns=[ - "stitch", - "filter", - "threshold", - "tp", - "fn", - "fp", - "pr", - "re", - "f1", - ] - ) + # Build stitch/filter/threshold with zeros to prevent downstream errors + rows = [] + for s in stitch_scan: + for f in filter_scan: + for t in threshold_scan: + rows.append({ + "stitch": s, + "filter": f, + "threshold": t, + "tp": 0, + "fn": 0, + "fp": 0, + "pr": np.nan, + "re": np.nan, + "f1": np.nan, + }) + return pd.DataFrame(rows) performance_df = pd.concat(performance_df) # Aggregate over animals @@ -468,6 +545,7 @@ def generate_output_paths(results_folder: Path): Args: results_folder: Path to the folder where results will be saved. + Returns: A dictionary with keys 'scan_csv', 'bout_csv', 'ethogram', 'scan_plot', and 'bout_plot' containing the respective output paths. """ @@ -480,4 +558,151 @@ def generate_output_paths(results_folder: Path): "ethogram": results_folder / "ethogram.png", "scan_plot": results_folder / "scan_performance.png", "bout_plot": results_folder / "bout_performance.png", + # New filtered outputs + "ethogram_filtered": results_folder / "ethogram_filtered.png", + "bout_filtered_plot": results_folder / "bout_performance_filtered.png", + "bout_filtered_csv": results_folder / "bout_performance_filtered.csv", } + + +def _save_bout_curve_performance(curve_df: pd.DataFrame, output_path: Optional[Path]): + """ + Saves the curve iou performance plot. + + Args: + curve_df: Contains the curve performance data + output_path: Path to save the plot to + """ + if output_path is None: + return + melted_df = pd.melt(curve_df, id_vars=["threshold", "stitch", "filter"]) + ( + p9.ggplot( + melted_df[melted_df["variable"].isin(["pr", "re", "f1"])], + p9.aes(x="threshold", y="value", color="variable"), + ) + + p9.geom_line() + + p9.theme_bw() + + p9.scale_y_continuous(limits=(0, 1)) + ).save(output_path, height=6, width=12, dpi=300) + + +def _save_filtered_ethogram( + all_annotations: pd.DataFrame, + behavior: str, + stitch_val: int, + filter_val: int, + output_path: Optional[Path], +): + """ + Saves the filtered ethogram plot. + + Args: + all_annotations: Contains the ethogram data + behavior: The behavior to plot + stitch_val: The stitch value + filter_val: The filter value + output_path: Path to save the plot to + """ + if output_path is None: + return + + # Grabbing raw and filtered bouts per animal/video for gt and pred + records = [] + for (cur_animal, cur_video), animal_df in all_annotations.groupby( + ["animal_idx", "video_name"] + ): + pr_df = animal_df[~animal_df["is_gt"]] + gt_df = animal_df[animal_df["is_gt"]] + if len(pr_df) == 0: + continue + pr_obj = Bouts(pr_df["start"], pr_df["duration"], pr_df["is_behavior"]) + gt_obj = Bouts(gt_df["start"], gt_df["duration"], gt_df["is_behavior"]) + + full_duration = int(pr_obj.starts[-1] + pr_obj.durations[-1]) + pr_obj.fill_to_size(full_duration, 0) + gt_obj.fill_to_size(full_duration, 0) + + settings = ClassifierSettings( + "", interpolate=0, stitch=stitch_val, min_bout=filter_val + ) + pr_fil = pr_obj.copy() + gt_fil = gt_obj.copy() + pr_fil.filter_by_settings(settings) + gt_fil.filter_by_settings(settings) + + # Helper to extend records from a Bouts object for is_behavior == 1 + def add_records_from_bouts(bouts_obj: Bouts, track_label: str): + starts = bouts_obj.starts + durations = bouts_obj.durations + values = bouts_obj.values + if starts is None or len(starts) == 0: + return + ends = starts + durations + for s, e, v in zip(starts, ends, values): + if v == 1: + records.append( + { + "animal_idx": cur_animal, + "video_name": cur_video, + "start": int(s), + "end": int(e), + "track": track_label, + } + ) + + add_records_from_bouts(gt_obj, "GT Raw") + add_records_from_bouts(gt_fil, "GT Filtered") + add_records_from_bouts(pr_obj, "Pred Raw") + add_records_from_bouts(pr_fil, "Pred Filtered") + + if len(records) == 0: + logger.warning("No behavior bouts found to generate filtered ethogram.") + return + + df = pd.DataFrame.from_records(records) + df["animal_video_combo"] = ( + df["animal_idx"].astype(str) + " | " + df["video_name"].astype(str) + ) + + # Map track to vertical bands in the requested order: raw gt, filtered gt, raw pred, filtered pred + track_order = ["GT Raw", "GT Filtered", "Pred Raw", "Pred Filtered"] + track_to_idx = { + label: idx for idx, label in enumerate(track_order[::-1]) + } # reverse so top is GT Raw + df["track_idx"] = df["track"].map(track_to_idx) + df["ymin"] = df["track_idx"].astype(float) + df["ymax"] = df["ymin"] + 0.9 + + num_unique_combos = len(df["animal_video_combo"].unique()) + + plot = ( + p9.ggplot(df) + + p9.geom_rect( + p9.aes(xmin="start", xmax="end", ymin="ymin", ymax="ymax", fill="track") + ) + + p9.theme_bw() + + p9.facet_wrap("~animal_video_combo", ncol=1, scales="free_x") + + p9.scale_y_continuous( + breaks=[track_to_idx[t] + 0.45 for t in track_order[::-1]], + labels=track_order[::-1], + name="", + ) + + p9.scale_fill_brewer(type="qual", palette="Set1") + + p9.labs( + x="Frame", + fill="Track", + title=f"Ethogram (filtered) for behavior: {behavior}", + ) + + p9.expand_limits(x=0) + ) + + plot.save( + output_path, + height=2.0 * num_unique_combos + 2, + width=12, + dpi=300, + limitsize=False, + verbose=False, + ) + logging.info(f"Filtered ethogram plot saved to {output_path}")