diff --git a/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py b/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py index 7baca161..c5c1d56b 100644 --- a/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py +++ b/init2winit/projects/optlrschedule/notebook_utils/parquet_util.py @@ -27,7 +27,8 @@ def load_parquet_file( file_name: Optional[str] = None, *, sort_by: str = 'score', - ascending: bool = True + ascending: bool = True, + include_provenance: bool = False, ) -> pd.DataFrame: """Load a single parquet file and return it as a sorted DataFrame. @@ -36,6 +37,9 @@ def load_parquet_file( file_name (optional): File name string (default: 'results.parquet') sort_by: Column to sort by (default: 'score') ascending: Sort order (default: True) + include_provenance: Whether to include the provenance of the data in the + DataFrame (default: False). If set, adds a column 'provenance' to the + DataFrame with the path of the file. Returns: pandas DataFrame @@ -55,6 +59,8 @@ def load_parquet_file( if sort_by in df.columns: df.sort_values(by=sort_by, ascending=ascending, inplace=True) + if include_provenance: + df['provenance'] = str(path) return df @@ -63,7 +69,8 @@ def load_all_parquet_files( file_name: Optional[str] = None, *, sort_by: str = 'score', - ascending: bool = True + ascending: bool = True, + include_provenance: bool = False, ) -> pd.DataFrame: """Load and merge all parquet files from different paths. @@ -72,6 +79,9 @@ def load_all_parquet_files( file_name (optional): File name string (default: 'results.parquet') sort_by: Column to sort by (default: 'score'). ascending: Sort order (default: True). + include_provenance: Whether to include the provenance of the data in the + DataFrame (default: False). If set, adds a column 'provenance' to the + DataFrame with the path of the file each row came from. Returns: Merged pandas DataFrame. @@ -80,7 +90,11 @@ def load_all_parquet_files( for path in paths: df = load_parquet_file( - path, file_name, sort_by=sort_by, ascending=ascending + path, + file_name, + sort_by=sort_by, + ascending=ascending, + include_provenance=include_provenance, ) if not df.empty: dfs.append(df)