Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,19 @@ async def _append_new_message_to_session(
new_message.parts[i] = types.Part(
text=f'Uploaded file: {file_name}. It is saved into artifacts'
)

if self._has_duplicate_user_event_for_invocation(
session=session,
invocation_id=invocation_context.invocation_id,
new_message=new_message,
state_delta=state_delta,
):
logger.info(
'Skipping duplicate user event append for invocation_id=%s',
invocation_context.invocation_id,
)
return

# Appends only. We do not yield the event because it's not from the model.
if state_delta:
event = Event(
Expand All @@ -918,6 +931,25 @@ async def _append_new_message_to_session(

await self.session_service.append_event(session=session, event=event)

def _has_duplicate_user_event_for_invocation(
self,
*,
session: Session,
invocation_id: str,
new_message: types.Content,
state_delta: Optional[dict[str, Any]],
) -> bool:
expected_state_delta = state_delta or {}
for event in session.events:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better performance, especially in sessions with a long history, it's advisable to search for duplicates starting from the end of the event list. A duplicate event from a retry is likely to be recent, so iterating in reverse will find it more quickly.

Suggested change
for event in session.events:
for event in reversed(session.events):

if event.invocation_id != invocation_id or event.author != 'user':
continue
if (
event.content == new_message
and event.actions.state_delta == expected_state_delta
):
return True
return False

async def run_live(
self,
*,
Expand Down
145 changes: 145 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,151 @@ def _infer_agent_origin(
assert event.content.parts[0].text == "Test LLM response"


@pytest.mark.asyncio
async def test_append_new_message_to_session_skips_duplicate_retry_message():
session_service = InMemorySessionService()
runner = Runner(
app_name="test_app",
agent=MockLlmAgent("root_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
)
session = await session_service.create_session(
app_name="test_app",
user_id="test_user",
)
user_message = types.Content(
role="user",
parts=[types.Part(text="retry message")],
)
invocation_context = runner._new_invocation_context(
session,
invocation_id="inv-retry",
new_message=user_message,
run_config=RunConfig(),
)

await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
)
await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
)

matched_events = [
event
for event in session.events
if event.author == "user"
and event.invocation_id == "inv-retry"
and event.content == user_message
]
assert len(matched_events) == 1


@pytest.mark.asyncio
async def test_append_new_message_to_session_keeps_non_duplicate_messages():
session_service = InMemorySessionService()
runner = Runner(
app_name="test_app",
agent=MockLlmAgent("root_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
)
session = await session_service.create_session(
app_name="test_app",
user_id="test_user",
)
invocation_context = runner._new_invocation_context(
session,
invocation_id="inv-retry",
new_message=types.Content(role="user", parts=[types.Part(text="first")]),
run_config=RunConfig(),
)
first_message = types.Content(role="user", parts=[types.Part(text="first")])
second_message = types.Content(role="user", parts=[types.Part(text="second")])

await runner._append_new_message_to_session(
session=session,
new_message=first_message,
invocation_context=invocation_context,
)
await runner._append_new_message_to_session(
session=session,
new_message=second_message,
invocation_context=invocation_context,
)

matched_events = [
event
for event in session.events
if event.author == "user" and event.invocation_id == "inv-retry"
]
assert len(matched_events) == 2


@pytest.mark.asyncio
async def test_append_new_message_to_session_state_delta_deduping():
session_service = InMemorySessionService()
runner = Runner(
app_name="test_app",
agent=MockLlmAgent("root_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
)
session = await session_service.create_session(
app_name="test_app",
user_id="test_user",
)
user_message = types.Content(role="user", parts=[types.Part(text="same message")])
invocation_context = runner._new_invocation_context(
session,
invocation_id="inv-state-delta",
new_message=user_message,
run_config=RunConfig(),
)

await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
state_delta={"attempt": 1},
)
await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
state_delta={"attempt": 1},
)
await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
state_delta={"attempt": 2},
)
await runner._append_new_message_to_session(
session=session,
new_message=user_message,
invocation_context=invocation_context,
state_delta=None,
)

matched_events = [
event
for event in session.events
if event.author == "user"
and event.invocation_id == "inv-state-delta"
and event.content == user_message
]
assert len(matched_events) == 3
assert matched_events[0].actions.state_delta == {"attempt": 1}
assert matched_events[1].actions.state_delta == {"attempt": 2}
assert matched_events[2].actions.state_delta == {}


@pytest.mark.asyncio
async def test_rewind_auto_create_session_on_missing_session():
"""When auto_create_session=True, rewind should create session if missing.
Expand Down