Skip to content

Commit bc76f73

Browse files
authored
Merge pull request #116 from Serverless-Devs/dev-0609
Enhance MemoryConversation with async initialization and fire-and-for…
2 parents cb4466c + a25a479 commit bc76f73

2 files changed

Lines changed: 90 additions & 54 deletions

File tree

agentrun/memory_collection/memory_conversation.py

Lines changed: 72 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
55
"""
66

7+
import asyncio
78
import json
89
import os
910
from typing import (
@@ -76,6 +77,7 @@ def __init__(
7677
# 延迟初始化
7778
self._memory_store = None
7879
self._ots_client = None
80+
self._init_lock = asyncio.Lock()
7981

8082
@staticmethod
8183
def _default_user_id_extractor(req: Any) -> str:
@@ -153,10 +155,18 @@ def _default_agent_id_extractor(req: Any) -> str:
153155
return "default_agent"
154156

155157
async def _get_memory_store(self):
156-
"""获取或创建 AsyncMemoryStore 实例"""
158+
"""获取或创建 AsyncMemoryStore 实例(双检锁,并发安全)"""
157159
if self._memory_store is not None:
158160
return self._memory_store
159161

162+
async with self._init_lock:
163+
# 拿到锁后再检查一次,防止并发请求重复初始化
164+
if self._memory_store is not None:
165+
return self._memory_store
166+
return await self._init_memory_store()
167+
168+
async def _init_memory_store(self):
169+
"""内部初始化方法,由 _get_memory_store 在持锁状态下调用"""
160170
try:
161171
# 导入依赖
162172
from tablestore_for_agent_memory.base.base_memory_store import (
@@ -228,7 +238,7 @@ async def _get_memory_store(self):
228238
)
229239
await self._memory_store.init_table()
230240
await self._memory_store.init_search_index()
231-
logger.info(f"Tables and indexes initialized successfully")
241+
logger.info("Tables and indexes initialized successfully")
232242
except Exception as e:
233243
# 如果表已存在,会抛出异常,这是正常的
234244
logger.info(
@@ -384,10 +394,13 @@ async def wrap_invoke_agent(
384394
metadata={"agent_id": agent_id},
385395
)
386396

387-
try:
388-
await memory_store.put_session(session)
389-
except Exception as e:
390-
logger.error(f"Failed to save session: {e}", exc_info=True)
397+
async def _put_session_bg():
398+
try:
399+
await memory_store.put_session(session)
400+
except Exception as e:
401+
logger.error(f"Failed to save session: {e}", exc_info=True)
402+
403+
asyncio.create_task(_put_session_bg())
391404

392405
# 构建输入消息列表(包含所有历史消息)
393406
input_messages = []
@@ -465,57 +478,62 @@ async def wrap_invoke_agent(
465478
yield event
466479

467480
# 保存完整的对话轮次(输入 + 输出)
468-
# 只有当有文本内容或工具调用时才保存
481+
# 使用 fire-and-forget 避免阻塞流式响应关闭
469482
if agent_response_content or tool_calls or tool_results:
470-
try:
471-
# 构建助手响应消息
472-
assistant_message: Dict[str, Any] = {
473-
"role": "assistant",
474-
}
475-
476-
# 添加文本内容(如果有)
477-
if agent_response_content:
478-
assistant_message["content"] = agent_response_content
479-
else:
480-
# OpenAI 格式要求:如果有 tool_calls,content 可以为 null
481-
assistant_message["content"] = None
482-
483-
# 添加工具调用(如果有)
484-
if tool_calls:
485-
assistant_message["tool_calls"] = list(
486-
tool_calls.values()
483+
# 构建助手响应消息
484+
assistant_message: Dict[str, Any] = {
485+
"role": "assistant",
486+
}
487+
488+
if agent_response_content:
489+
assistant_message["content"] = agent_response_content
490+
else:
491+
assistant_message["content"] = None
492+
493+
if tool_calls:
494+
assistant_message["tool_calls"] = list(tool_calls.values())
495+
496+
output_messages = input_messages + [assistant_message]
497+
498+
if tool_results:
499+
output_messages.extend(tool_results)
500+
501+
conversation_message = Message(
502+
session_id=session_id,
503+
message_id=f"msg_{uuid.uuid4().hex[:16]}",
504+
content=json.dumps(output_messages, ensure_ascii=False),
505+
)
506+
507+
async def _save_conversation_bg(
508+
ms=memory_store,
509+
msg=conversation_message,
510+
sess=session,
511+
n_msgs=len(output_messages),
512+
text_len=len(agent_response_content),
513+
n_tc=len(tool_calls),
514+
n_tr=len(tool_results),
515+
):
516+
try:
517+
await ms.put_message(msg)
518+
sess.update_time = microseconds_timestamp()
519+
await ms.update_session(sess)
520+
logger.debug(
521+
"Saved conversation: %d messages,"
522+
" text length: %d chars,"
523+
" tool_calls: %d, tool_results: %d",
524+
n_msgs,
525+
text_len,
526+
n_tc,
527+
n_tr,
528+
)
529+
except Exception as e:
530+
logger.error(
531+
"Failed to save conversation: %s",
532+
e,
533+
exc_info=True,
487534
)
488535

489-
# 构建完整的消息列表
490-
output_messages = input_messages + [assistant_message]
491-
492-
# 添加工具执行结果(如果有)
493-
if tool_results:
494-
output_messages.extend(tool_results)
495-
496-
# 将完整的对话历史存储为一条消息
497-
# content 字段存储 JSON 格式的消息列表
498-
conversation_message = Message(
499-
session_id=session_id,
500-
message_id=f"msg_{uuid.uuid4().hex[:16]}",
501-
content=json.dumps(output_messages, ensure_ascii=False),
502-
)
503-
await memory_store.put_message(conversation_message)
504-
505-
# 更新 Session 时间
506-
session.update_time = microseconds_timestamp()
507-
await memory_store.update_session(session)
508-
509-
logger.debug(
510-
f"Saved conversation: {len(output_messages)} messages,"
511-
f" text length: {len(agent_response_content)} chars,"
512-
f" tool_calls: {len(tool_calls)}, tool_results:"
513-
f" {len(tool_results)}"
514-
)
515-
except Exception as e:
516-
logger.error(
517-
f"Failed to save conversation: {e}", exc_info=True
518-
)
536+
asyncio.create_task(_save_conversation_bg())
519537

520538
except Exception as e:
521539
logger.error(f"Error in agent handler: {e}", exc_info=True)

tests/unittests/memory_collection/test_memory_conversation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for AgentRun Memory Conversation / AgentRun 记忆对话测试"""
22

3+
import asyncio
34
from unittest.mock import AsyncMock, MagicMock, Mock, patch
45

56
import pytest
@@ -8,6 +9,11 @@
89
from agentrun.server.model import AgentRequest, Message, MessageRole
910

1011

12+
async def _flush_bg_tasks():
13+
"""Let fire-and-forget background tasks complete before assertions."""
14+
await asyncio.sleep(0.05)
15+
16+
1117
@pytest.fixture
1218
def mock_memory_collection():
1319
"""Mock MemoryCollection"""
@@ -185,6 +191,9 @@ async def mock_agent(request: AgentRequest):
185191
# Verify results
186192
assert results == ["Hello", ", ", "world!"]
187193

194+
# Wait for fire-and-forget background tasks to complete
195+
await _flush_bg_tasks()
196+
188197
# Verify memory store calls
189198
assert mock_memory_store.put_session.called
190199
assert mock_memory_store.put_message.called
@@ -252,6 +261,9 @@ async def mock_agent(request: AgentRequest):
252261
async for event in memory.wrap_invoke_agent(request, mock_agent):
253262
results.append(event)
254263

264+
# Wait for fire-and-forget background tasks to complete
265+
await _flush_bg_tasks()
266+
255267
# Verify agent still responds
256268
assert results == ["Still works!"]
257269

@@ -339,6 +351,9 @@ async def mock_agent(request: AgentRequest):
339351
assert results[0] == "Let me search for that..."
340352
assert results[3] == "Based on the search, it's sunny today."
341353

354+
# Wait for fire-and-forget background tasks to complete
355+
await _flush_bg_tasks()
356+
342357
# Verify message was saved with tool calls
343358
assert mock_memory_store.put_message.called
344359
saved_message = mock_memory_store.put_message.call_args[0][0]
@@ -437,6 +452,9 @@ async def mock_agent(request: AgentRequest):
437452
# Verify all events were passed through
438453
assert len(results) == 4
439454

455+
# Wait for fire-and-forget background tasks to complete
456+
await _flush_bg_tasks()
457+
440458
# Verify message was saved with accumulated tool call
441459
assert mock_memory_store.put_message.called
442460
saved_message = mock_memory_store.put_message.call_args[0][0]

0 commit comments

Comments
 (0)