|
4 | 4 |
|
5 | 5 | """ |
6 | 6 |
|
| 7 | +import asyncio |
7 | 8 | import json |
8 | 9 | import os |
9 | 10 | from typing import ( |
@@ -76,6 +77,7 @@ def __init__( |
76 | 77 | # 延迟初始化 |
77 | 78 | self._memory_store = None |
78 | 79 | self._ots_client = None |
| 80 | + self._init_lock = asyncio.Lock() |
79 | 81 |
|
80 | 82 | @staticmethod |
81 | 83 | def _default_user_id_extractor(req: Any) -> str: |
@@ -153,10 +155,18 @@ def _default_agent_id_extractor(req: Any) -> str: |
153 | 155 | return "default_agent" |
154 | 156 |
|
155 | 157 | async def _get_memory_store(self): |
156 | | - """获取或创建 AsyncMemoryStore 实例""" |
| 158 | + """获取或创建 AsyncMemoryStore 实例(双检锁,并发安全)""" |
157 | 159 | if self._memory_store is not None: |
158 | 160 | return self._memory_store |
159 | 161 |
|
| 162 | + async with self._init_lock: |
| 163 | + # 拿到锁后再检查一次,防止并发请求重复初始化 |
| 164 | + if self._memory_store is not None: |
| 165 | + return self._memory_store |
| 166 | + return await self._init_memory_store() |
| 167 | + |
| 168 | + async def _init_memory_store(self): |
| 169 | + """内部初始化方法,由 _get_memory_store 在持锁状态下调用""" |
160 | 170 | try: |
161 | 171 | # 导入依赖 |
162 | 172 | from tablestore_for_agent_memory.base.base_memory_store import ( |
@@ -228,7 +238,7 @@ async def _get_memory_store(self): |
228 | 238 | ) |
229 | 239 | await self._memory_store.init_table() |
230 | 240 | await self._memory_store.init_search_index() |
231 | | - logger.info(f"Tables and indexes initialized successfully") |
| 241 | + logger.info("Tables and indexes initialized successfully") |
232 | 242 | except Exception as e: |
233 | 243 | # 如果表已存在,会抛出异常,这是正常的 |
234 | 244 | logger.info( |
@@ -384,10 +394,13 @@ async def wrap_invoke_agent( |
384 | 394 | metadata={"agent_id": agent_id}, |
385 | 395 | ) |
386 | 396 |
|
387 | | - try: |
388 | | - await memory_store.put_session(session) |
389 | | - except Exception as e: |
390 | | - logger.error(f"Failed to save session: {e}", exc_info=True) |
| 397 | + async def _put_session_bg(): |
| 398 | + try: |
| 399 | + await memory_store.put_session(session) |
| 400 | + except Exception as e: |
| 401 | + logger.error(f"Failed to save session: {e}", exc_info=True) |
| 402 | + |
| 403 | + asyncio.create_task(_put_session_bg()) |
391 | 404 |
|
392 | 405 | # 构建输入消息列表(包含所有历史消息) |
393 | 406 | input_messages = [] |
@@ -465,57 +478,62 @@ async def wrap_invoke_agent( |
465 | 478 | yield event |
466 | 479 |
|
467 | 480 | # 保存完整的对话轮次(输入 + 输出) |
468 | | - # 只有当有文本内容或工具调用时才保存 |
| 481 | + # 使用 fire-and-forget 避免阻塞流式响应关闭 |
469 | 482 | if agent_response_content or tool_calls or tool_results: |
470 | | - try: |
471 | | - # 构建助手响应消息 |
472 | | - assistant_message: Dict[str, Any] = { |
473 | | - "role": "assistant", |
474 | | - } |
475 | | - |
476 | | - # 添加文本内容(如果有) |
477 | | - if agent_response_content: |
478 | | - assistant_message["content"] = agent_response_content |
479 | | - else: |
480 | | - # OpenAI 格式要求:如果有 tool_calls,content 可以为 null |
481 | | - assistant_message["content"] = None |
482 | | - |
483 | | - # 添加工具调用(如果有) |
484 | | - if tool_calls: |
485 | | - assistant_message["tool_calls"] = list( |
486 | | - tool_calls.values() |
| 483 | + # 构建助手响应消息 |
| 484 | + assistant_message: Dict[str, Any] = { |
| 485 | + "role": "assistant", |
| 486 | + } |
| 487 | + |
| 488 | + if agent_response_content: |
| 489 | + assistant_message["content"] = agent_response_content |
| 490 | + else: |
| 491 | + assistant_message["content"] = None |
| 492 | + |
| 493 | + if tool_calls: |
| 494 | + assistant_message["tool_calls"] = list(tool_calls.values()) |
| 495 | + |
| 496 | + output_messages = input_messages + [assistant_message] |
| 497 | + |
| 498 | + if tool_results: |
| 499 | + output_messages.extend(tool_results) |
| 500 | + |
| 501 | + conversation_message = Message( |
| 502 | + session_id=session_id, |
| 503 | + message_id=f"msg_{uuid.uuid4().hex[:16]}", |
| 504 | + content=json.dumps(output_messages, ensure_ascii=False), |
| 505 | + ) |
| 506 | + |
| 507 | + async def _save_conversation_bg( |
| 508 | + ms=memory_store, |
| 509 | + msg=conversation_message, |
| 510 | + sess=session, |
| 511 | + n_msgs=len(output_messages), |
| 512 | + text_len=len(agent_response_content), |
| 513 | + n_tc=len(tool_calls), |
| 514 | + n_tr=len(tool_results), |
| 515 | + ): |
| 516 | + try: |
| 517 | + await ms.put_message(msg) |
| 518 | + sess.update_time = microseconds_timestamp() |
| 519 | + await ms.update_session(sess) |
| 520 | + logger.debug( |
| 521 | + "Saved conversation: %d messages," |
| 522 | + " text length: %d chars," |
| 523 | + " tool_calls: %d, tool_results: %d", |
| 524 | + n_msgs, |
| 525 | + text_len, |
| 526 | + n_tc, |
| 527 | + n_tr, |
| 528 | + ) |
| 529 | + except Exception as e: |
| 530 | + logger.error( |
| 531 | + "Failed to save conversation: %s", |
| 532 | + e, |
| 533 | + exc_info=True, |
487 | 534 | ) |
488 | 535 |
|
489 | | - # 构建完整的消息列表 |
490 | | - output_messages = input_messages + [assistant_message] |
491 | | - |
492 | | - # 添加工具执行结果(如果有) |
493 | | - if tool_results: |
494 | | - output_messages.extend(tool_results) |
495 | | - |
496 | | - # 将完整的对话历史存储为一条消息 |
497 | | - # content 字段存储 JSON 格式的消息列表 |
498 | | - conversation_message = Message( |
499 | | - session_id=session_id, |
500 | | - message_id=f"msg_{uuid.uuid4().hex[:16]}", |
501 | | - content=json.dumps(output_messages, ensure_ascii=False), |
502 | | - ) |
503 | | - await memory_store.put_message(conversation_message) |
504 | | - |
505 | | - # 更新 Session 时间 |
506 | | - session.update_time = microseconds_timestamp() |
507 | | - await memory_store.update_session(session) |
508 | | - |
509 | | - logger.debug( |
510 | | - f"Saved conversation: {len(output_messages)} messages," |
511 | | - f" text length: {len(agent_response_content)} chars," |
512 | | - f" tool_calls: {len(tool_calls)}, tool_results:" |
513 | | - f" {len(tool_results)}" |
514 | | - ) |
515 | | - except Exception as e: |
516 | | - logger.error( |
517 | | - f"Failed to save conversation: {e}", exc_info=True |
518 | | - ) |
| 536 | + asyncio.create_task(_save_conversation_bg()) |
519 | 537 |
|
520 | 538 | except Exception as e: |
521 | 539 | logger.error(f"Error in agent handler: {e}", exc_info=True) |
|
0 commit comments