Skip to content

Commit 56eb522

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - Add support to local agent run for agent eval
PiperOrigin-RevId: 829920374
1 parent 3eb38bf commit 56eb522

File tree

5 files changed

+369
-57
lines changed

5 files changed

+369
-57
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,144 @@ def test_run_inference_with_agent_engine_with_response_column_raises_error(
12891289
"'intermediate_events' or 'response' columns"
12901290
) in str(excinfo.value)
12911291

1292+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
1293+
@mock.patch("vertexai._genai._evals_common.InMemorySessionService")
1294+
@mock.patch("vertexai._genai._evals_common.Runner")
1295+
@mock.patch("vertexai._genai._evals_common.LlmAgent")
1296+
def test_run_inference_with_local_agent(
1297+
self,
1298+
mock_llm_agent,
1299+
mock_runner,
1300+
mock_session_service,
1301+
mock_eval_dataset_loader,
1302+
):
1303+
mock_df = pd.DataFrame(
1304+
{
1305+
"prompt": ["agent prompt", "agent prompt 2"],
1306+
"session_inputs": [
1307+
{
1308+
"user_id": "123",
1309+
"state": {"a": "1"},
1310+
},
1311+
{
1312+
"user_id": "456",
1313+
"state": {"b": "2"},
1314+
},
1315+
],
1316+
}
1317+
)
1318+
mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict(
1319+
orient="records"
1320+
)
1321+
1322+
mock_agent_instance = mock.Mock()
1323+
mock_llm_agent.return_value = mock_agent_instance
1324+
mock_session_service.return_value.create_session = mock.AsyncMock()
1325+
mock_runner_instance = mock_runner.return_value
1326+
stream_run_return_value_1 = [
1327+
mock.Mock(
1328+
model_dump=lambda: {
1329+
"id": "1",
1330+
"content": {"parts": [{"text": "intermediate1"}]},
1331+
"timestamp": 123,
1332+
"author": "model",
1333+
}
1334+
),
1335+
mock.Mock(
1336+
model_dump=lambda: {
1337+
"id": "2",
1338+
"content": {"parts": [{"text": "agent response"}]},
1339+
"timestamp": 124,
1340+
"author": "model",
1341+
}
1342+
),
1343+
]
1344+
stream_run_return_value_2 = [
1345+
mock.Mock(
1346+
model_dump=lambda: {
1347+
"id": "3",
1348+
"content": {"parts": [{"text": "intermediate2"}]},
1349+
"timestamp": 125,
1350+
"author": "model",
1351+
}
1352+
),
1353+
mock.Mock(
1354+
model_dump=lambda: {
1355+
"id": "4",
1356+
"content": {"parts": [{"text": "agent response 2"}]},
1357+
"timestamp": 126,
1358+
"author": "model",
1359+
}
1360+
),
1361+
]
1362+
1363+
async def async_iterator(items):
1364+
for item in items:
1365+
yield item
1366+
1367+
mock_runner_instance.run_async.side_effect = [
1368+
async_iterator(stream_run_return_value_1),
1369+
async_iterator(stream_run_return_value_2),
1370+
]
1371+
1372+
inference_result = self.client.evals.run_inference(
1373+
agent=mock_agent_instance,
1374+
src=mock_df,
1375+
)
1376+
1377+
mock_eval_dataset_loader.return_value.load.assert_called_once_with(mock_df)
1378+
assert mock_session_service.call_count == 2
1379+
mock_runner.assert_called_with(
1380+
agent=mock_agent_instance,
1381+
app_name="local agent run",
1382+
session_service=mock_session_service.return_value,
1383+
)
1384+
assert mock_runner.call_count == 2
1385+
assert mock_runner_instance.run_async.call_count == 2
1386+
1387+
expected_df = pd.DataFrame(
1388+
{
1389+
"prompt": ["agent prompt", "agent prompt 2"],
1390+
"session_inputs": [
1391+
{
1392+
"user_id": "123",
1393+
"state": {"a": "1"},
1394+
},
1395+
{
1396+
"user_id": "456",
1397+
"state": {"b": "2"},
1398+
},
1399+
],
1400+
"intermediate_events": [
1401+
[
1402+
{
1403+
"event_id": "1",
1404+
"content": {"parts": [{"text": "intermediate1"}]},
1405+
"creation_timestamp": 123,
1406+
"author": "model",
1407+
}
1408+
],
1409+
[
1410+
{
1411+
"event_id": "3",
1412+
"content": {"parts": [{"text": "intermediate2"}]},
1413+
"creation_timestamp": 125,
1414+
"author": "model",
1415+
}
1416+
],
1417+
],
1418+
"response": ["agent response", "agent response 2"],
1419+
}
1420+
)
1421+
pd.testing.assert_frame_equal(
1422+
inference_result.eval_dataset_df.sort_values(by="prompt").reset_index(
1423+
drop=True
1424+
),
1425+
expected_df.sort_values(by="prompt").reset_index(drop=True),
1426+
)
1427+
assert inference_result.candidate_name is None
1428+
assert inference_result.gcs_source is None
1429+
12921430
def test_run_inference_with_litellm_string_prompt_format(
12931431
self,
12941432
mock_api_client_fixture,
@@ -1641,6 +1779,7 @@ def test_run_agent_internal_success(self, mock_run_agent):
16411779
result_df = _evals_common._run_agent_internal(
16421780
api_client=mock_api_client,
16431781
agent_engine=mock_agent_engine,
1782+
agent=None,
16441783
prompt_dataset=prompt_dataset,
16451784
)
16461785

@@ -1671,6 +1810,7 @@ def test_run_agent_internal_error_response(self, mock_run_agent):
16711810
result_df = _evals_common._run_agent_internal(
16721811
api_client=mock_api_client,
16731812
agent_engine=mock_agent_engine,
1813+
agent=None,
16741814
prompt_dataset=prompt_dataset,
16751815
)
16761816

@@ -1697,6 +1837,7 @@ def test_run_agent_internal_malformed_event(self, mock_run_agent):
16971837
result_df = _evals_common._run_agent_internal(
16981838
api_client=mock_api_client,
16991839
agent_engine=mock_agent_engine,
1840+
agent=None,
17001841
prompt_dataset=prompt_dataset,
17011842
)
17021843
assert "response" in result_df.columns

0 commit comments

Comments
 (0)