Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions init2winit/projects/optlrschedule/notebook_utils/parquet_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def load_parquet_file(
file_name: Optional[str] = None,
*,
sort_by: str = 'score',
ascending: bool = True
ascending: bool = True,
include_provenance: bool = False,
) -> pd.DataFrame:
"""Load a single parquet file and return it as a sorted DataFrame.

Expand All @@ -36,6 +37,9 @@ def load_parquet_file(
file_name (optional): File name string (default: 'results.parquet')
sort_by: Column to sort by (default: 'score')
ascending: Sort order (default: True)
include_provenance: Whether to include the provenance of the data in the
DataFrame (default: False). If set, adds a column 'provenance' to the
DataFrame with the path of the file.

Returns:
pandas DataFrame
Expand All @@ -55,6 +59,8 @@ def load_parquet_file(
if sort_by in df.columns:
df.sort_values(by=sort_by, ascending=ascending, inplace=True)

if include_provenance:
df['provenance'] = str(path)
return df


Expand All @@ -63,7 +69,8 @@ def load_all_parquet_files(
file_name: Optional[str] = None,
*,
sort_by: str = 'score',
ascending: bool = True
ascending: bool = True,
include_provenance: bool = False,
) -> pd.DataFrame:
"""Load and merge all parquet files from different paths.

Expand All @@ -72,6 +79,9 @@ def load_all_parquet_files(
file_name (optional): File name string (default: 'results.parquet')
sort_by: Column to sort by (default: 'score').
ascending: Sort order (default: True).
include_provenance: Whether to include the provenance of the data in the
DataFrame (default: False). If set, adds a column 'provenance' to the
DataFrame with the path of the file each row came from.

Returns:
Merged pandas DataFrame.
Expand All @@ -80,7 +90,11 @@ def load_all_parquet_files(

for path in paths:
df = load_parquet_file(
path, file_name, sort_by=sort_by, ascending=ascending
path,
file_name,
sort_by=sort_by,
ascending=ascending,
include_provenance=include_provenance,
)
if not df.empty:
dfs.append(df)
Expand Down