Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions backend/app/api/wecom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@

import base64
import hashlib
import re
import socket
import struct
import time
import uuid
import xml.etree.ElementTree as ET

from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from loguru import logger
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -138,6 +135,7 @@ async def configure_wecom_channel(
existing.verification_token = token
existing.extra_config = extra_config
existing.is_configured = True
existing.is_connected = False
await db.flush()
config_out = ChannelConfigOut.model_validate(existing)
else:
Expand All @@ -150,22 +148,26 @@ async def configure_wecom_channel(
verification_token=token,
extra_config=extra_config,
is_configured=True,
is_connected=False,
)
db.add(config)
await db.flush()
config_out = ChannelConfigOut.model_validate(config)

# Auto-start WebSocket client if bot credentials provided
if has_ws_mode:
try:
from app.services.wecom_stream import wecom_stream_manager
import asyncio
try:
from app.services.wecom_stream import wecom_stream_manager
import asyncio

if has_ws_mode:
asyncio.create_task(
wecom_stream_manager.start_client(agent_id, bot_id, bot_secret)
)
logger.info(f"[WeCom] WebSocket client start triggered for agent {agent_id}")
except Exception as e:
logger.error(f"[WeCom] Failed to start WebSocket client: {e}")
else:
asyncio.create_task(wecom_stream_manager.stop_client(agent_id))
logger.info(f"[WeCom] WebSocket client stop triggered for agent {agent_id}")
except Exception as e:
logger.error(f"[WeCom] Failed to update WebSocket client state: {e}")

return config_out

Expand All @@ -186,7 +188,15 @@ async def get_wecom_channel(
config = result.scalar_one_or_none()
if not config:
raise HTTPException(status_code=404, detail="WeCom not configured")
return ChannelConfigOut.model_validate(config)

config_out = ChannelConfigOut.model_validate(config)
if (config.extra_config or {}).get("connection_mode") == "websocket":
from app.services.wecom_stream import wecom_stream_manager

config_out.is_connected = wecom_stream_manager.status().get(str(agent_id), False)
else:
config_out.is_connected = False
return config_out


@router.get("/agents/{agent_id}/wecom-channel/webhook-url")
Expand Down Expand Up @@ -227,6 +237,9 @@ async def delete_wecom_channel(
config = result.scalar_one_or_none()
if not config:
raise HTTPException(status_code=404, detail="WeCom not configured")
from app.services.wecom_stream import wecom_stream_manager

await wecom_stream_manager.stop_client(agent_id)
await db.delete(config)


Expand Down Expand Up @@ -300,8 +313,6 @@ async def wecom_event_webhook(

token = config.verification_token or ""
encoding_aes_key = config.encrypt_key or ""
corp_id = config.app_id or ""

# Parse encrypted XML body
try:
root = ET.fromstring(body_bytes)
Expand Down Expand Up @@ -457,7 +468,6 @@ async def _process_wecom_text(
kf_msg_id: str = None,
):
"""Process an incoming WeCom text message and reply."""
import json
import httpx
from datetime import datetime, timezone
from sqlalchemy import select as _select
Expand All @@ -476,7 +486,6 @@ async def _process_wecom_text(
if not agent_obj:
logger.warning(f"[WeCom] Agent {agent_id} not found")
return
creator_id = agent_obj.creator_id
ctx_size = agent_obj.context_window_size if agent_obj else 20

conv_id = f"wecom_p2p_{from_user}"
Expand Down
78 changes: 63 additions & 15 deletions backend/app/services/wecom_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,47 @@
from app.models.channel_config import ChannelConfig


def _disable_wecom_sdk_proxy() -> None:
"""Force the WeCom SDK websocket path to bypass system proxies."""
import wecom_aibot_sdk.ws as sdk_ws

if getattr(sdk_ws.websockets.connect, "__clawith_no_proxy_patch__", False):
return

original_connect = sdk_ws.websockets.connect

def connect_no_proxy(*args, **kwargs):
kwargs.setdefault("proxy", None)
return original_connect(*args, **kwargs)

connect_no_proxy.__clawith_no_proxy_patch__ = True
sdk_ws.websockets.connect = connect_no_proxy


def _extract_wecom_sender_id(body: dict) -> str:
sender = body.get("from")
if isinstance(sender, dict):
sender_id = sender.get("user_id") or sender.get("userid")
if sender_id:
return str(sender_id).strip()
return str(body.get("from_userid") or body.get("userid") or "").strip()


def _extract_wecom_chat_type(body: dict) -> str:
return str(body.get("chattype") or body.get("chat_type") or "single").strip().lower()


def _extract_wecom_chat_id(body: dict) -> str:
return str(body.get("chatid") or body.get("chat_id") or "").strip()


def _build_wecom_conv_id(sender_id: str, chat_id: str, chat_type: str) -> str:
normalized_type = (chat_type or "single").strip().lower()
if normalized_type in {"group", "groupchat", "group_chat"} and chat_id:
return f"wecom_group_{chat_id}"
return f"wecom_p2p_{sender_id}"


class WeComStreamManager:
"""Manages WeCom AI Bot WebSocket clients for all agents."""

Expand Down Expand Up @@ -63,6 +104,7 @@ async def _run_client(
return

try:
_disable_wecom_sdk_proxy()
client = WSClient({
"bot_id": bot_id,
"secret": bot_secret,
Expand All @@ -80,13 +122,25 @@ async def on_text(frame):
if not user_text:
return

sender = body.get("from", {})
sender_id = sender.get("user_id", "") or sender.get("userid", "")
chat_id = body.get("chatid", "")
chat_type = body.get("chat_type", "single")
sender_id = _extract_wecom_sender_id(body)
chat_id = _extract_wecom_chat_id(body)
chat_type = _extract_wecom_chat_type(body)
msg_id = str(body.get("msgid", "")).strip()

if not sender_id:
logger.error(
"[WeCom Stream] Missing sender_id, skip message",
extra={
"msgid": msg_id,
"chat_type": chat_type,
"chat_id": chat_id,
"body_keys": sorted(body.keys()),
},
)
return

logger.info(
f"[WeCom Stream] Text from {sender_id}: {user_text[:80]}"
f"[WeCom Stream] Text from {sender_id} ({chat_type}, chat_id={chat_id or '-'})"
)

# Process message and get reply
Expand Down Expand Up @@ -121,8 +175,7 @@ async def on_text(frame):
async def on_image(frame):
try:
body = frame.body or {}
sender = body.get("from", {})
sender_id = sender.get("user_id", "") or sender.get("userid", "")
sender_id = _extract_wecom_sender_id(body)
logger.info(f"[WeCom Stream] Image message from {sender_id} (not yet handled)")
stream_id = generate_req_id("stream")
await client.reply_stream(
Expand All @@ -137,8 +190,7 @@ async def on_image(frame):
async def on_file(frame):
try:
body = frame.body or {}
sender = body.get("from", {})
sender_id = sender.get("user_id", "") or sender.get("userid", "")
sender_id = _extract_wecom_sender_id(body)
logger.info(f"[WeCom Stream] File message from {sender_id} (not yet handled)")
stream_id = generate_req_id("stream")
await client.reply_stream(
Expand Down Expand Up @@ -216,7 +268,7 @@ async def start_all(self):
async with async_session() as db:
result = await db.execute(
select(ChannelConfig).where(
ChannelConfig.is_configured == True,
ChannelConfig.is_configured,
ChannelConfig.channel_type == "wecom",
)
)
Expand Down Expand Up @@ -274,11 +326,7 @@ async def _process_wecom_stream_message(
return "Agent not found"
ctx_size = agent_obj.context_window_size or 20

# Conversation ID: differentiate single chat vs group chat
if chat_type == "group" and chat_id:
conv_id = f"wecom_group_{chat_id}"
else:
conv_id = f"wecom_p2p_{sender_id}"
conv_id = _build_wecom_conv_id(sender_id, chat_id, chat_type)

# Find or create platform user
wc_username = f"wecom_{sender_id}"
Expand Down
139 changes: 139 additions & 0 deletions backend/tests/test_wecom_channel_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import uuid
from datetime import UTC, datetime
from types import SimpleNamespace

import pytest

from app.api import wecom as wecom_api
from app.models.channel_config import ChannelConfig
from app.models.user import User


class DummyResult:
def __init__(self, value=None):
self._value = value

def scalar_one_or_none(self):
return self._value


class RecordingDB:
def __init__(self, responses=None):
self.responses = list(responses or [])
self.deleted = []
self.flushed = False

async def execute(self, statement):
if self.responses:
return self.responses.pop(0)
return DummyResult()

def add(self, _obj):
return None

async def flush(self):
self.flushed = True

async def delete(self, obj):
self.deleted.append(obj)


def make_user(**overrides):
values = {
"id": uuid.uuid4(),
"username": "alice",
"email": "alice@example.com",
"password_hash": "old-hash",
"display_name": "Alice",
"role": "member",
"tenant_id": uuid.uuid4(),
"is_active": True,
}
values.update(overrides)
return User(**values)


def make_channel(agent_id: uuid.UUID, *, connection_mode: str = "websocket") -> ChannelConfig:
return ChannelConfig(
id=uuid.uuid4(),
agent_id=agent_id,
channel_type="wecom",
app_id="corp_id",
app_secret="secret",
is_configured=True,
is_connected=False,
extra_config={"connection_mode": connection_mode, "bot_id": "bot_123", "bot_secret": "secret_123"},
created_at=datetime.now(UTC),
)


@pytest.mark.asyncio
async def test_get_wecom_channel_reports_runtime_websocket_status(monkeypatch):
agent_id = uuid.uuid4()
config = make_channel(agent_id, connection_mode="websocket")
db = RecordingDB([DummyResult(config)])

async def fake_check_agent_access(_db, _user, _agent_id):
return object(), None

class FakeManager:
def status(self):
return {str(agent_id): True}

monkeypatch.setattr(wecom_api, "check_agent_access", fake_check_agent_access)
monkeypatch.setattr("app.services.wecom_stream.wecom_stream_manager", FakeManager())

result = await wecom_api.get_wecom_channel(
agent_id=agent_id,
current_user=make_user(),
db=db,
)

assert result.is_connected is True


@pytest.mark.asyncio
async def test_get_wecom_channel_marks_webhook_mode_disconnected(monkeypatch):
agent_id = uuid.uuid4()
config = make_channel(agent_id, connection_mode="webhook")
db = RecordingDB([DummyResult(config)])

async def fake_check_agent_access(_db, _user, _agent_id):
return object(), None

monkeypatch.setattr(wecom_api, "check_agent_access", fake_check_agent_access)

result = await wecom_api.get_wecom_channel(
agent_id=agent_id,
current_user=make_user(),
db=db,
)

assert result.is_connected is False


@pytest.mark.asyncio
async def test_delete_wecom_channel_stops_runtime_client(monkeypatch):
agent_id = uuid.uuid4()
config = make_channel(agent_id)
db = RecordingDB([DummyResult(config)])
stop_calls = []

async def fake_check_agent_access(_db, _user, _agent_id):
return SimpleNamespace(creator_id=creator.id), None

async def fake_stop_client(aid):
stop_calls.append(aid)

creator = make_user()
monkeypatch.setattr(wecom_api, "check_agent_access", fake_check_agent_access)
monkeypatch.setattr("app.services.wecom_stream.wecom_stream_manager.stop_client", fake_stop_client)

await wecom_api.delete_wecom_channel(
agent_id=agent_id,
current_user=creator,
db=db,
)

assert stop_calls == [agent_id]
assert db.deleted == [config]
Loading