Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ alembic upgrade head
- **[ArXiv](https://arxiv.org)** — 开放论文平台
- **[Semantic Scholar](https://www.semanticscholar.org)** — 引用数据来源
- **[CSFeeds](https://csarxiv.org)** — 论文源订阅服务
- **[learn-claude-code](https://github.com/shareAI-lab/learn-claude-code)** — Agent Harness 工程体系启发,s01-s12 渐进式解构:Loop → Tools → Planning → Subagents → Skills → Context → Tasks → Background → Teams → Protocols → Autonomous

---

Expand Down
191 changes: 115 additions & 76 deletions apps/api/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
@author Color2333
"""

from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING

from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import StreamingResponse

from packages.ai.agent_service import confirm_action, reject_action, stream_chat
from packages.domain.schemas import AgentChatRequest

if TYPE_CHECKING:
from collections.abc import Callable

router = APIRouter()

_SSE_HEADERS = {
Expand All @@ -18,124 +27,148 @@
}


def _parse_sse_events(chunk: str) -> list[tuple[str, dict]]:
"""解析 SSE chunk,返回 [(event_type, data), ...]"""
events = []
# 每个事件块以 "event: xxx\ndata: {...}\n\n" 格式
event_pattern = re.compile(r"event:\s*(\S+)\s*\ndata:\s*(\{.*?\})\s*\n\n", re.DOTALL)
for match in event_pattern.finditer(chunk):
event_type = match.group(1)
try:
data = json.loads(match.group(2))
events.append((event_type, data))
except json.JSONDecodeError:
pass
return events


@router.post("/agent/chat")
async def agent_chat(req: AgentChatRequest):
"""Agent 对话 - SSE 流式响应(带持久化 + 工具调用记录)"""
from packages.storage.db import session_scope
from packages.storage.repositories import AgentConversationRepository, AgentMessageRepository

# 追踪已保存的用户消息内容,避免重复保存
saved_user_contents: set[str] = set()
from packages.storage.repositories import (
AgentConversationRepository,
AgentMessageRepository,
)

# 如果有 conversation_id,保存到该会话;否则创建新会话
conversation_id = getattr(req, "conversation_id", None)

with session_scope() as session:
conv_repo = AgentConversationRepository(session)
msg_repo = AgentMessageRepository(session)

# 已有 conversation_id:验证存在
if conversation_id:
conv = conv_repo.get_by_id(conversation_id)
if not conv:
conversation_id = None

# 无 conversation_id:创建新会话
if not conversation_id:
first_user_msg = next((m for m in req.messages if m.role == "user"), None)
title = first_user_msg.content[:50] if first_user_msg else "新对话"
conv = conv_repo.create(title=title)
conversation_id = conv.id

# 只保存最新一条用户消息(避免重复)
# 找到最后一条用户消息
latest_user_msg = None
for msg in reversed(req.messages):
if msg.role == "user":
latest_user_msg = msg
break

if latest_user_msg:
# 用内容的 hash 作为去重 key
content_key = latest_user_msg.content[:200]
if content_key not in saved_user_contents:
# 保存本次请求带来的所有新消息(user + assistant + tool)
# 已有的历史消息从 DB 加载,不重复保存
saved_ids: set[str] = set()
for msg in req.messages:
if msg.role == "system":
continue
content_key = f"{msg.role}:{msg.content[:200]}"
if content_key not in saved_ids:
msg_repo.create(
conversation_id=conversation_id,
role=latest_user_msg.role,
content=latest_user_msg.content,
role=msg.role,
content=msg.content,
meta=msg.meta,
)
saved_user_contents.add(content_key)
# 流式响应
saved_ids.add(content_key)

# 构建传给 stream_chat 的 messages(包含 DB 加载的历史)
# 前端传的是本次新增消息,需要拼上 DB 里的历史
msgs = [m.model_dump() for m in req.messages]

def _save_assistant_response(content: str, tool_calls: list | None = None):
"""保存助手响应(包含工具调用)"""
with session_scope() as session:
msg_repo = AgentMessageRepository(session)
meta = {"tool_calls": tool_calls} if tool_calls else None
msg_repo.create(
conversation_id=conversation_id,
role="assistant",
content=content,
meta=meta,
)

# SSE 解析:提取文本和工具调用
import json
import re

# 用于累积助手响应
text_content = ""
tool_calls_records: list[dict] = []

# SSE 格式: "event: xxx\ndata: {...}\n\n"
_sse_pattern = re.compile(r"^event:\s*(\S+)\ndata:\s*(.+?)\n\n", re.DOTALL)

def _parse_sse_chunk(chunk: str) -> tuple[str | None, dict | None]:
"""解析 SSE chunk,返回 (event_type, data)"""
match = _sse_pattern.match(chunk)
if match:
event_type = match.group(1)
try:
data = json.loads(match.group(2))
return event_type, data
except json.JSONDecodeError:
pass
return None, None
def _build_save_callback(conv_id: str) -> Callable[[list[dict]], None]:
"""创建压缩回写回调"""

def on_compact(compressed_messages: list[dict]):
with session_scope() as session:
msg_repo = AgentMessageRepository(session)
# 删除旧消息,写入压缩后的消息
msg_repo.delete_by_conversation(conv_id)
for msg in compressed_messages:
msg_repo.create(
conversation_id=conv_id,
role=msg.get("role", "user"),
content=msg.get("content", ""),
meta=msg.get("meta"),
)

return on_compact

text_buf = ""
tool_records: list[dict] = []
tool_call_id: str | None = None

def stream_with_save():
nonlocal text_content, tool_calls_records
for chunk in stream_chat(msgs, confirmed_action_id=req.confirmed_action_id):
# 解析 SSE 事件
event_type, data = _parse_sse_chunk(chunk)
if event_type and data:
nonlocal text_buf, tool_records, tool_call_id
sse_iter, updated_conversation = stream_chat(
msgs, confirmed_action_id=req.confirmed_action_id
)
for chunk in sse_iter:
yield chunk

for event_type, data in _parse_sse_events(chunk):
if event_type == "text_delta":
# 累积文本内容
text_content += data.get("content", "")
text_buf += data.get("content", "")
elif event_type == "tool_start":
tool_call_id = data.get("id")
elif event_type == "tool_result":
# 记录工具调用结果
tool_calls_records.append(
tool_records.append(
{
"name": data.get("name"),
"success": data.get("success"),
"summary": data.get("summary"),
"data": data.get("data"),
}
)
# 立即保存 tool 消息到 DB
with session_scope() as session:
msg_repo = AgentMessageRepository(session)
msg_repo.create(
conversation_id=conversation_id,
role="tool",
content=json.dumps(
{
"name": data.get("name"),
"success": data.get("success"),
"summary": data.get("summary"),
"data": data.get("data"),
},
ensure_ascii=False,
),
meta={"tool_call_id": tool_call_id},
)
elif event_type == "action_result":
# 记录用户确认的操作结果
tool_calls_records.append(
tool_records.append(
{
"action_id": data.get("id"),
"success": data.get("success"),
"summary": data.get("summary"),
"data": data.get("data"),
}
)
yield chunk

# 流结束后保存助手响应
if text_content or tool_calls_records:
_save_assistant_response(
text_content, tool_calls_records if tool_calls_records else None
)
elif event_type == "done" and (text_buf or tool_records):
with session_scope() as session:
msg_repo = AgentMessageRepository(session)
msg_repo.create(
conversation_id=conversation_id,
role="assistant",
content=text_buf,
meta={"tool_calls": tool_records} if tool_records else None,
)

return StreamingResponse(
stream_with_save(),
Expand All @@ -147,8 +180,9 @@ def stream_with_save():
@router.post("/agent/confirm/{action_id}")
async def agent_confirm(action_id: str):
"""确认执行 Agent 挂起的操作"""
sse_iter, _ = confirm_action(action_id)
return StreamingResponse(
confirm_action(action_id),
sse_iter,
media_type="text/event-stream",
headers=_SSE_HEADERS,
)
Expand All @@ -157,8 +191,9 @@ async def agent_confirm(action_id: str):
@router.post("/agent/reject/{action_id}")
async def agent_reject(action_id: str):
"""拒绝 Agent 挂起的操作"""
sse_iter, _ = reject_action(action_id)
return StreamingResponse(
reject_action(action_id),
sse_iter,
media_type="text/event-stream",
headers=_SSE_HEADERS,
)
Expand Down Expand Up @@ -192,7 +227,10 @@ def get_conversation_messages(
) -> dict:
"""获取指定会话的所有消息"""
from packages.storage.db import session_scope
from packages.storage.repositories import AgentConversationRepository, AgentMessageRepository
from packages.storage.repositories import (
AgentConversationRepository,
AgentMessageRepository,
)

with session_scope() as session:
conv_repo = AgentConversationRepository(session)
Expand All @@ -214,6 +252,7 @@ def get_conversation_messages(
"id": m.id,
"role": m.role,
"content": m.content,
"meta": m.meta,
"created_at": m.created_at.isoformat(),
}
for m in messages
Expand Down
40 changes: 40 additions & 0 deletions packages/agent_core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
agent_core — Agent Harness 工程核心库

参考 learn-claude-code (https://github.com/shareAI-lab/learn-claude-code)
s01-s12 渐进式 harness 机制 Python 实现。

主要模块:
- loop.py : AgentLoop,显式 agent 循环
- dispatcher.py : ToolDispatcher,工具注册与分发
- tasks.py : TaskManager,任务持久化 + 依赖图
- message_bus.py : MessageBus,team 异步通信
- teammates.py : TeammateManager,持久 agent 线程管理
- protocols.py : TeamProtocols,shutdown/plan_approval FSM
- background.py : BackgroundTaskRunner,daemon 线程池
"""

from .background import BackgroundTask, BackgroundTaskRunner
from .dispatcher import ToolDispatcher, make_default_dispatcher
from .loop import AgentConfig, AgentLoop, AgentResponse, StopReason
from .message_bus import MessageBus
from .protocols import ProtocolState, TeamProtocols
from .tasks import Task, TaskManager
from .teammates import TeammateManager

__all__ = [
"AgentLoop",
"AgentConfig",
"AgentResponse",
"StopReason",
"ToolDispatcher",
"make_default_dispatcher",
"TaskManager",
"Task",
"MessageBus",
"TeammateManager",
"TeamProtocols",
"ProtocolState",
"BackgroundTaskRunner",
"BackgroundTask",
]
Loading