From 04d43d017518b21d9b5e06578be25815e4248380 Mon Sep 17 00:00:00 2001 From: "George E. Dahl" Date: Fri, 28 Mar 2025 12:08:19 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 741606926 --- .../notebook_utils/parquet_util.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py b/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py index 7baca161..c5c1d56b 100644 --- a/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py +++ b/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py @@ -27,7 +27,8 @@ def load_parquet_file( file_name: Optional[str] = None, *, sort_by: str = 'score', - ascending: bool = True + ascending: bool = True, + include_provenance: bool = False, ) -> pd.DataFrame: """Load a single parquet file and return it as a sorted DataFrame. @@ -36,6 +37,9 @@ def load_parquet_file( file_name (optional): File name string (default: 'results.parquet') sort_by: Column to sort by (default: 'score') ascending: Sort order (default: True) + include_provenance: Whether to include the provenance of the data in the + DataFrame (default: False). If set, adds a column 'provenance' to the + DataFrame with the path of the file. Returns: pandas DataFrame @@ -55,6 +59,8 @@ def load_parquet_file( if sort_by in df.columns: df.sort_values(by=sort_by, ascending=ascending, inplace=True) + if include_provenance: + df['provenance'] = str(path) return df @@ -63,7 +69,8 @@ def load_all_parquet_files( file_name: Optional[str] = None, *, sort_by: str = 'score', - ascending: bool = True + ascending: bool = True, + include_provenance: bool = False, ) -> pd.DataFrame: """Load and merge all parquet files from different paths. @@ -72,6 +79,9 @@ def load_all_parquet_files( file_name (optional): File name string (default: 'results.parquet') sort_by: Column to sort by (default: 'score'). ascending: Sort order (default: True). + include_provenance: Whether to include the provenance of the data in the + DataFrame (default: False). If set, adds a column 'provenance' to the + DataFrame with the path of the file each row came from. Returns: Merged pandas DataFrame. @@ -80,7 +90,11 @@ def load_all_parquet_files( for path in paths: df = load_parquet_file( - path, file_name, sort_by=sort_by, ascending=ascending + path, + file_name, + sort_by=sort_by, + ascending=ascending, + include_provenance=include_provenance, ) if not df.empty: dfs.append(df)