Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,3 @@ REPORT_GENERATION_DB__QUERY__MODE="ro"

# Report Generation (all optional, defaults are in implementations/report_generation/env_vars.py)
REPORT_GENERATION_OUTPUT_PATH="..."
REPORT_GENERATION_LANGFUSE_PROJECT_NAME="..."
7 changes: 7 additions & 0 deletions aieng-eval-agents/aieng/agent_evals/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ class Configs(BaseSettings):
web_search_base_url: str | None = Field(default=None, description="Base URL for web search service.")
web_search_api_key: SecretStr | None = Field(default=None, description="API key for web search service.")

# === Report Generation ===
# Defaults are set in the implementations/report_generation/env_vars.py file
report_generation_output_path: str | None = Field(
default=None,
description="Path to the directory where the report generation agent will save the reports.",
)

# Validators for the SecretStr fields
@field_validator("langfuse_secret_key")
@classmethod
Expand Down
19 changes: 10 additions & 9 deletions aieng-eval-agents/aieng/agent_evals/report_generation/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
>>> agent = get_report_generation_agent(
>>> instructions=MAIN_AGENT_INSTRUCTIONS,
>>> reports_output_path=Path("reports/"),
>>> langfuse_project_name="Report Generation",
>>> )
"""

Expand All @@ -37,8 +36,8 @@
def get_report_generation_agent(
instructions: str,
reports_output_path: Path,
langfuse_project_name: str | None,
after_agent_callback: AfterAgentCallback | None = None,
langfuse_tracing: bool = True,
) -> Agent:
"""
Define the report generation agent.
Expand All @@ -49,20 +48,22 @@ def get_report_generation_agent(
The instructions for the agent.
reports_output_path : Path
The path to the reports output directory.
langfuse_project_name : str | None
The name of the Langfuse project to use for tracing.
after_agent_callback : AfterAgentCallback | None
after_agent_callback : AfterAgentCallback | None, optional
The callback function to be called after the agent has
finished executing.
finished executing. Default is None.
langfuse_tracing : bool, optional
Whether to enable Langfuse tracing. Default is True.

Returns
-------
agents.Agent
The report generation agent.
"""
agent_name = "ReportGenerationAgent"

# Setup langfuse tracing if project name is provided
if langfuse_project_name:
init_tracing(langfuse_project_name)
if langfuse_tracing:
init_tracing(service_name=agent_name)

# Get the client manager singleton instance
client_manager = AsyncClientManager.get_instance()
Expand All @@ -71,7 +72,7 @@ def get_report_generation_agent(

# Define an agent using Google ADK
return Agent(
name="ReportGenerationAgent",
name=agent_name,
model=client_manager.configs.default_worker_model,
instruction=instructions,
tools=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
>>> evaluate(
>>> dataset_name="OnlineRetailReportEval",
>>> reports_output_path=Path("reports/"),
>>> langfuse_project_name="Report Generation",
>>> )
"""

Expand Down Expand Up @@ -61,7 +60,6 @@ class EvaluatorResponse(BaseModel):
async def evaluate(
dataset_name: str,
reports_output_path: Path,
langfuse_project_name: str,
max_concurrency: int = 5,
) -> None:
"""Evaluate the report generation agent against a Langfuse dataset.
Expand Down Expand Up @@ -90,10 +88,7 @@ async def evaluate(
# Initialize the task for the report generation agent evaluation
# We need this task so we can pass parameters to the agent, since
# the agent has to be instantiated inside the task function
report_generation_task = ReportGenerationTask(
reports_output_path=reports_output_path,
langfuse_project_name=langfuse_project_name,
)
report_generation_task = ReportGenerationTask(reports_output_path=reports_output_path)

# Run the experiment with the agent task and evaluator
# against the dataset items
Expand All @@ -119,22 +114,15 @@ async def evaluate(
class ReportGenerationTask:
"""Define a task for the the report generation agent."""

def __init__(
self,
reports_output_path: Path,
langfuse_project_name: str,
):
def __init__(self, reports_output_path: Path):
"""Initialize the task for an report generation agent evaluation.

Parameters
----------
reports_output_path : Path
The path to the reports output directory.
langfuse_project_name : str
The name of the Langfuse project to use for tracing.
"""
self.reports_output_path = reports_output_path
self.langfuse_project_name = langfuse_project_name

async def run(self, *, item: LocalExperimentItem | DatasetItemClient, **kwargs: dict[str, Any]) -> EvaluationOutput:
"""Run the report generation agent against an item from a Langfuse dataset.
Expand All @@ -154,7 +142,6 @@ async def run(self, *, item: LocalExperimentItem | DatasetItemClient, **kwargs:
report_generation_agent = get_report_generation_agent(
instructions=MAIN_AGENT_INSTRUCTIONS,
reports_output_path=self.reports_output_path,
langfuse_project_name=self.langfuse_project_name,
)
# Handle both TypedDict and class access patterns
item_input = item["input"] if isinstance(item, dict) else item.input
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests for the offline evaluation of the report generation agent."""

from pathlib import Path
from unittest.mock import ANY, Mock, patch

import pytest
from aieng.agent_evals.report_generation.evaluation.offline import (
evaluate,
final_result_evaluator,
trajectory_evaluator,
)


@patch("aieng.agent_evals.report_generation.evaluation.offline.AsyncClientManager.get_instance")
@patch("aieng.agent_evals.report_generation.evaluation.offline.DbManager.get_instance")
@pytest.mark.asyncio
async def test_evaluate(mock_db_manager_instance, mock_async_client_manager_instance):
"""Test the evaluate function."""
test_dataset_name = "test_dataset"
test_reports_output_path = Path("reports/")
test_max_concurrency = 5

mock_result = Mock()
mock_dataset = Mock()
mock_dataset.run_experiment.return_value = mock_result
mock_langfuse_client = Mock()
mock_langfuse_client.get_dataset.return_value = mock_dataset
mock_async_client_manager_instance.return_value = Mock()
mock_async_client_manager_instance.return_value.langfuse_client = mock_langfuse_client

mock_db_manager_instance.return_value = Mock()

await evaluate(
dataset_name=test_dataset_name,
reports_output_path=test_reports_output_path,
max_concurrency=test_max_concurrency,
)

mock_dataset.run_experiment.assert_called_once_with(
name="Evaluate Report Generation Agent",
description="Evaluate the Report Generation Agent with data from Langfuse",
task=ANY,
evaluators=[final_result_evaluator, trajectory_evaluator],
max_concurrency=test_max_concurrency,
)

task = mock_dataset.run_experiment.call_args_list[0][1]["task"]
assert task.__name__ == "run"
assert task.__self__.__class__.__name__ == "ReportGenerationTask"
assert task.__self__.reports_output_path == test_reports_output_path

mock_db_manager_instance.return_value.close.assert_called_once()
mock_async_client_manager_instance.return_value.close.assert_called_once()
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Tests for the online evaluation of the report generation agent."""

from unittest.mock import Mock, patch

import pytest
from aieng.agent_evals.report_generation.evaluation.online import report_final_response_score


@patch("aieng.agent_evals.report_generation.evaluation.online.AsyncClientManager.get_instance")
def test_report_final_response_score_positive_score(mock_async_client_manager_instance):
"""Test the report_final_response_score function with a positive score."""
test_string_match = "string-to-match"
test_trace_id = "test_trace_id"

mock_langfuse_client = Mock()
mock_langfuse_client.get_current_trace_id.return_value = test_trace_id

mock_async_client_manager_instance.return_value = Mock()
mock_async_client_manager_instance.return_value.langfuse_client = mock_langfuse_client

mock_event = Mock()
mock_event.is_final_response.return_value = True
mock_event.content = Mock()
mock_event.content.parts = [
Mock(text=f"test_final_response_text {test_string_match} test_final_response_text"),
]

report_final_response_score(mock_event, string_match=test_string_match)

mock_langfuse_client.create_score.assert_called_once_with(
name="Valid Final Response",
value=1,
trace_id=test_trace_id,
comment="Final response contains the string match.",
metadata={
"final_response": mock_event.content.parts[0].text,
"string_match": test_string_match,
},
)
mock_langfuse_client.flush.assert_called_once()


@patch("aieng.agent_evals.report_generation.evaluation.online.AsyncClientManager.get_instance")
def test_report_final_response_score_negative_score(mock_async_client_manager_instance):
"""Test the report_final_response_score function with a negative score."""
test_string_match = "string-to-match"
test_trace_id = "test_trace_id"

mock_langfuse_client = Mock()
mock_langfuse_client.get_current_trace_id.return_value = test_trace_id

mock_async_client_manager_instance.return_value = Mock()
mock_async_client_manager_instance.return_value.langfuse_client = mock_langfuse_client

mock_event = Mock()
mock_event.is_final_response.return_value = True
mock_event.content = Mock()
mock_event.content.parts = [
Mock(text="test_final_response_text test_final_response_text"),
]

report_final_response_score(mock_event, string_match=test_string_match)

mock_langfuse_client.create_score.assert_called_once_with(
name="Valid Final Response",
value=0,
trace_id=test_trace_id,
comment="Final response does not contains the string match.",
metadata={
"final_response": mock_event.content.parts[0].text,
"string_match": test_string_match,
},
)
mock_langfuse_client.flush.assert_called_once()


@patch("aieng.agent_evals.report_generation.evaluation.online.AsyncClientManager.get_instance")
def test_report_final_response_invalid(mock_async_client_manager_instance):
"""Test the report_final_response_score function with a negative score."""
test_string_match = "string-to-match"
test_trace_id = "test_trace_id"

mock_langfuse_client = Mock()
mock_langfuse_client.get_current_trace_id.return_value = test_trace_id

mock_async_client_manager_instance.return_value = Mock()
mock_async_client_manager_instance.return_value.langfuse_client = mock_langfuse_client

mock_event = Mock()
mock_event.is_final_response.return_value = True
mock_event.content = Mock()
mock_event.content.parts = [Mock(text=None)]

report_final_response_score(mock_event, string_match=test_string_match)

mock_langfuse_client.create_score.assert_called_once_with(
name="Valid Final Response",
value=0,
trace_id=test_trace_id,
comment="Final response not found in the event",
metadata={
"string_match": test_string_match,
},
)
mock_langfuse_client.flush.assert_called_once()


def test_report_final_response_not_final_response():
"""Test raising an error when the event is not a final response."""
mock_event = Mock()
mock_event.is_final_response.return_value = False

with pytest.raises(ValueError, match="Event is not a final response"):
report_final_response_score(mock_event)


@patch("aieng.agent_evals.report_generation.evaluation.online.AsyncClientManager.get_instance")
def test_report_final_response_langfuse_trace_id_none(mock_async_client_manager_instance):
"""Test raising an error when the Langfuse trace ID is None."""
mock_langfuse_client = Mock()
mock_langfuse_client.get_current_trace_id.return_value = None

mock_async_client_manager_instance.return_value = Mock()
mock_async_client_manager_instance.return_value.langfuse_client = mock_langfuse_client

with pytest.raises(ValueError, match="Langfuse trace ID is None."):
report_final_response_score(Mock())
Loading