Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions maseval/core/callbacks/result_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ResultLogger(BenchmarkCallback, ABC):
include_traces: Whether to include execution traces in logged results
include_config: Whether to include configuration in logged results
include_eval: Whether to include evaluation results in logged results
include_usage: Whether to include API usage data in logged results
validate_on_completion: Whether to validate all iterations were logged

Example:
Expand Down Expand Up @@ -62,7 +63,8 @@ def __init__(
include_traces: bool = True,
include_config: bool = True,
include_eval: bool = True,
include_task: bool = False,
include_task: bool = True,
include_usage: bool = True,
validate_on_completion: bool = True,
):
"""Initialize the result logger.
Expand All @@ -73,13 +75,15 @@ def __init__(
include_eval: If True, include evaluation results in logged results
include_task: If True, include task data (query, metadata, protocol)
in logged results
include_usage: If True, include API usage data in logged results
validate_on_completion: If True, validate all iterations were logged at end
"""
super().__init__()
self.include_traces = include_traces
self.include_config = include_config
self.include_eval = include_eval
self.include_task = include_task
self.include_usage = include_usage
self.validate_on_completion = validate_on_completion

# Tracking for validation
Expand Down Expand Up @@ -177,6 +181,9 @@ def _filter_report(self, report: Dict) -> Dict:
if self.include_eval and "eval" in report:
filtered["eval"] = report["eval"]

if self.include_usage and "usage" in report:
filtered["usage"] = report["usage"]

if self.include_task and "task" in report:
filtered["task"] = report["task"]

Expand Down Expand Up @@ -313,7 +320,8 @@ def __init__(
include_traces: bool = True,
include_config: bool = True,
include_eval: bool = True,
include_task: bool = False,
include_task: bool = True,
include_usage: bool = True,
validate_on_completion: bool = True,
):
"""Initialize the file logger.
Expand All @@ -332,13 +340,15 @@ def __init__(
include_eval: If True, include evaluation results in logged results
include_task: If True, include task data (query, metadata, protocol)
in logged results
include_usage: If True, include API usage data in logged results
validate_on_completion: If True, validate all iterations were logged
"""
super().__init__(
include_traces=include_traces,
include_config=include_config,
include_eval=include_eval,
include_task=include_task,
include_usage=include_usage,
validate_on_completion=validate_on_completion,
)

Expand Down Expand Up @@ -530,6 +540,7 @@ def _write_metadata(self) -> None:
"include_config": self.include_config,
"include_eval": self.include_eval,
"include_task": self.include_task,
"include_usage": self.include_usage,
"validation_enabled": self.validate_on_completion,
}

Expand Down
19 changes: 17 additions & 2 deletions tests/test_core/test_callbacks/test_result_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ def test_filter_report_includes_task_when_enabled(self):
assert filtered["task"]["query"] == "What is 2+2?"
assert filtered["task"]["metadata"] == {"difficulty": "easy"}

def test_filter_report_excludes_task_by_default(self):
"""Test that task data is excluded from filtered report by default."""
def test_filter_report_includes_task_by_default(self):
"""Test that task data is included in filtered report by default."""
logger = MockResultLogger()

report = {
Expand All @@ -209,6 +209,21 @@ def test_filter_report_excludes_task_by_default(self):

filtered = logger._filter_report(report)

assert "task" in filtered
assert filtered["task"]["query"] == "What is 2+2?"

def test_filter_report_excludes_task_when_disabled(self):
"""Test that task data is excluded from filtered report when include_task is False."""
logger = MockResultLogger(include_task=False)

report = {
"task_id": "task_0",
"repeat_idx": 0,
"task": {"query": "What is 2+2?", "metadata": {}, "protocol": {}},
}

filtered = logger._filter_report(report)

assert "task" not in filtered

def test_filter_report_partial_included(self):
Expand Down
Loading