diff --git a/README.md b/README.md index 44e4057..67c0fc0 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ - [Screenshot](#screenshot) - [Action Recording & Playback](#action-recording--playback) - [JSON Action Scripting](#json-action-scripting) + - [MCP Server (Use AutoControl from Claude)](#mcp-server-use-autocontrol-from-claude) - [Scheduler (Interval & Cron)](#scheduler-interval--cron) - [Global Hotkey Daemon](#global-hotkey-daemon) - [Event Triggers](#event-triggers) @@ -66,6 +67,7 @@ - **Event Triggers** — fire scripts when an image appears, a window opens, a pixel changes, or a file is modified - **Run History** — SQLite-backed run log across scheduler / triggers / hotkeys / REST with auto error-screenshot artifacts - **Report Generation** — export test records as HTML, JSON, or XML reports with success/failure status +- **MCP Server** — JSON-RPC 2.0 Model Context Protocol server (stdio + HTTP/SSE) so Claude Desktop / Claude Code / custom tool-use loops can drive AutoControl. ~90 tools, full protocol coverage (resources, prompts, sampling, roots, logging, progress, cancellation, elicitation), bearer-token auth + TLS, audit log, rate limit, plugin hot-reload, CI fake backend - **Remote Automation** — TCP socket server **and** REST API server to receive automation commands - **Plugin Loader** — drop `.py` files exposing `AC_*` callables into a directory and register them as executor commands at runtime - **Shell Integration** — execute shell commands within automation workflows with async output capture @@ -81,6 +83,97 @@ ## Architecture +The runtime is layered: **client surfaces** (CLI, GUI, MCP/REST/socket +servers) sit on top of the **headless API** (`wrapper/` + `utils/`), +which resolves to a **per-OS backend** chosen at import time by +`wrapper/platform_wrapper.py`. The package façade +(`je_auto_control/__init__.py`) re-exports every public name so users +need only `import je_auto_control` regardless of which surface or +backend they hit. + +```mermaid +flowchart LR + subgraph Clients["Client Surfaces"] + direction TB + Claude[["Claude Desktop /
Claude Code"]] + APIUser[["Custom Anthropic /
OpenAI tool loops"]] + HTTPClient[["HTTP / SSE clients"]] + TCPClient[["Socket / REST clients"]] + GUIUser[["PySide6 GUI"]] + CLIUser[["python -m
je_auto_control[.cli]"]] + Library[["Library users
(import je_auto_control)"]] + end + + subgraph Transports["Transports & Servers"] + direction TB + Stdio["MCP stdio
JSON-RPC 2.0"] + HTTPMCP["MCP HTTP /
SSE + auth + TLS"] + REST["REST server
:9939"] + Socket["Socket server
:9938"] + end + + subgraph MCP["mcp_server/"] + direction TB + Dispatcher["MCPServer
(JSON-RPC dispatcher)"] + Tools["tools/
~90 ac_* + aliases"] + Resources["resources/
files · history ·
commands · screen-live"] + Prompts["prompts/
built-in templates"] + Context["context · audit ·
rate-limit · log-bridge"] + FakeBE["fake_backend
(CI smoke)"] + end + + subgraph Core["Headless Core (wrapper/ + utils/)"] + direction TB + Wrapper["wrapper/
mouse · keyboard · screen ·
image · record · window"] + Executor["executor/
AC_* JSON action engine"] + Vision["vision/ · ocr/ ·
accessibility/"] + Recorder["scheduler/ · triggers/ ·
hotkey/ · plugin_loader/
run_history/"] + IOUtils["clipboard/ · cv2_utils/ ·
shell_process/ · json/"] + end + + subgraph Backends["Per-OS Backends"] + direction TB + Win["windows/
Win32 ctypes"] + Mac["osx/
pyobjc · Quartz"] + X11["linux_with_x11/
python-Xlib"] + end + + Claude --> Stdio + APIUser --> Stdio + HTTPClient --> HTTPMCP + TCPClient --> Socket + TCPClient --> REST + + Stdio --> Dispatcher + HTTPMCP --> Dispatcher + Dispatcher --> Tools + Dispatcher --> Resources + Dispatcher --> Prompts + Dispatcher -.- Context + Tools -.optional.-> FakeBE + + Tools --> Wrapper + Tools --> Executor + Tools --> Vision + Tools --> Recorder + Tools --> IOUtils + Resources --> Recorder + Resources --> Wrapper + + REST --> Executor + Socket --> Executor + + GUIUser --> Wrapper + GUIUser --> Recorder + CLIUser --> Executor + Library --> Wrapper + Library --> Executor + + Wrapper --> Backends + Vision -.- Wrapper + Recorder -.- Executor +``` + ``` je_auto_control/ ├── wrapper/ # Platform-agnostic API layer @@ -89,19 +182,21 @@ je_auto_control/ │ ├── auto_control_keyboard.py# Keyboard operations │ ├── auto_control_image.py # Image recognition (OpenCV template matching) │ ├── auto_control_screen.py # Screenshot, screen size, pixel color +│ ├── auto_control_window.py # Cross-platform window manager facade │ └── auto_control_record.py # Action recording/playback ├── windows/ # Windows-specific backend (Win32 API / ctypes) ├── osx/ # macOS-specific backend (pyobjc / Quartz) ├── linux_with_x11/ # Linux-specific backend (python-Xlib) ├── gui/ # PySide6 GUI application └── utils/ + ├── mcp_server/ # MCP server (stdio + HTTP/SSE) — server, tools/, resources, prompts, audit, rate_limit, fake_backend, plugin_watcher ├── executor/ # JSON action executor engine ├── callback/ # Callback function executor ├── cv2_utils/ # OpenCV screenshot, template matching, video recording ├── accessibility/ # UIA (Windows) / AX (macOS) element finder ├── vision/ # VLM-based locator (Anthropic / OpenAI backends) ├── ocr/ # Tesseract-backed text locator - ├── clipboard/ # Cross-platform clipboard + ├── clipboard/ # Cross-platform clipboard (text + image) ├── scheduler/ # Interval + cron scheduler ├── hotkey/ # Global hotkey daemon ├── triggers/ # Image/window/pixel/file triggers @@ -408,6 +503,77 @@ je_auto_control.execute_action([ | Process | `AC_execute_process` | | Executor | `AC_execute_action`, `AC_execute_files` | +### MCP Server (Use AutoControl from Claude) + +Expose AutoControl as a Model Context Protocol server so any +MCP-compatible client (Claude Desktop, Claude Code, custom Anthropic +/ OpenAI tool-use loops) can drive the host machine. Stdlib-only — +JSON-RPC 2.0 over stdio or HTTP+SSE. + +**Register with Claude Code:** + +```bash +claude mcp add autocontrol -- python -m je_auto_control.utils.mcp_server +``` + +**Register with Claude Desktop** (`claude_desktop_config.json`): + +```json +{ + "mcpServers": { + "autocontrol": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"] + } + } +} +``` + +**Start programmatically:** + +```python +import je_auto_control as ac + +# Stdio (blocks until stdin closes) +ac.start_mcp_stdio_server() + +# Or HTTP / SSE with bearer-token auth + optional TLS +ac.start_mcp_http_server(host="127.0.0.1", port=9940, + auth_token="hunter2") +``` + +**Inspect the catalogue without starting the server:** + +```bash +je_auto_control_mcp --list-tools +je_auto_control_mcp --list-tools --read-only +je_auto_control_mcp --list-resources +je_auto_control_mcp --list-prompts +``` + +**What ships:** + +| Surface | Coverage | +|---|---| +| Tools (~90) | mouse · keyboard · drag · screen / multi-monitor · screenshot-as-image · diff · OCR · image · windows (move/min/max/restore/...) · clipboard text+image · process / shell · recording · screen recording · scheduler / triggers / hotkeys · accessibility tree · VLM locator · executor · history | +| Aliases | `click`, `type`, `screenshot`, `find_image`, `drag`, `shell`, `wait_image`, ... — toggle with `JE_AUTOCONTROL_MCP_ALIASES=0` | +| Resources | `autocontrol://files/`, `autocontrol://history`, `autocontrol://commands`, `autocontrol://screen/live` (with `resources/subscribe`) | +| Prompts | `automate_ui_task`, `record_and_generalize`, `compare_screenshots`, `find_widget`, `explain_action_file` | +| Protocol | tools / resources / prompts / sampling / roots / logging / progress / cancellation / list_changed / elicitation | +| Transports | stdio, HTTP `POST /mcp`, SSE streaming when `Accept: text/event-stream` | +| Safety | tool annotations · `JE_AUTOCONTROL_MCP_READONLY` · `JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE` · audit log · token-bucket rate limiter · auto-screenshot on error | +| Ops | bearer-token auth · TLS via `ssl_context` · `PluginWatcher` hot-reload · `JE_AUTOCONTROL_FAKE_BACKEND=1` for CI | + +See [docs/source/Eng/doc/mcp_server/mcp_server_doc.rst](docs/source/Eng/doc/mcp_server/mcp_server_doc.rst) +for the full reference (or the +[繁體中文](docs/source/Zh/doc/mcp_server/mcp_server_doc.rst) version). + +> ⚠️ The MCP server can move the mouse, send keystrokes, capture the +> screen, and execute arbitrary `AC_*` actions. Only register it with +> MCP clients you trust. HTTP defaults to `127.0.0.1`; binding to +> `0.0.0.0` requires explicit reason and **must** be paired with +> `auth_token` plus `ssl_context`. + ### Scheduler (Interval & Cron) ```python diff --git a/README/README_zh-CN.md b/README/README_zh-CN.md index 7ff22fc..5ce0fc1 100644 --- a/README/README_zh-CN.md +++ b/README/README_zh-CN.md @@ -27,6 +27,7 @@ - [截图](#截图) - [动作录制与回放](#动作录制与回放) - [JSON 脚本执行器](#json-脚本执行器) + - [MCP 服务器(让 Claude 使用 AutoControl)](#mcp-服务器让-claude-使用-autocontrol) - [调度器(Interval & Cron)](#调度器interval--cron) - [全局热键](#全局热键) - [事件触发器](#事件触发器) @@ -65,6 +66,7 @@ - **事件触发器** — 检测到图像出现、窗口出现、像素变化或文件变动时自动执行脚本 - **执行历史** — 使用 SQLite 记录 scheduler / triggers / hotkeys / REST 的执行结果;错误时自动附带截图 - **报告生成** — 将测试记录导出为 HTML、JSON 或 XML 报告,包含成功/失败状态 +- **MCP 服务器** — JSON-RPC 2.0 Model Context Protocol 服务(stdio + HTTP/SSE),让 Claude Desktop / Claude Code / 自定义 tool-use 循环直接驱动 AutoControl。约 90 个工具,完整协议支持(resources、prompts、sampling、roots、logging、progress、cancellation、elicitation),Bearer token 验证 + TLS、审计 log、rate limit、plugin 热加载、CI fake backend - **远程自动化** — 同时提供 TCP Socket 服务器与 REST API 服务器 - **插件加载器** — 将定义 `AC_*` 可调用对象的 `.py` 文件放入目录,运行时即可注册为 executor 命令 - **Shell 集成** — 在自动化流程中执行 Shell 命令,支持异步输出捕获 @@ -80,6 +82,96 @@ ## 架构 +运行时是分层的:**客户端接口**(CLI、GUI、MCP/REST/Socket 服务 +器)位于最上层,下面是**无头 API**(`wrapper/` + `utils/`),最后 +解析到 `wrapper/platform_wrapper.py` 在 import 时选定的**操作系统 +后端**。包 façade(`je_auto_control/__init__.py`)会 re-export 所 +有公开名称,使用者只需要 `import je_auto_control`,无论用哪个接口 +或后端都一样。 + +```mermaid +flowchart LR + subgraph Clients["客户端接口"] + direction TB + Claude[["Claude Desktop /
Claude Code"]] + APIUser[["自定义 Anthropic /
OpenAI tool-use 循环"]] + HTTPClient[["HTTP / SSE clients"]] + TCPClient[["Socket / REST clients"]] + GUIUser[["PySide6 GUI"]] + CLIUser[["python -m
je_auto_control[.cli]"]] + Library[["Library 使用者
(import je_auto_control)"]] + end + + subgraph Transports["传输与服务器"] + direction TB + Stdio["MCP stdio
JSON-RPC 2.0"] + HTTPMCP["MCP HTTP /
SSE + auth + TLS"] + REST["REST 服务器
:9939"] + Socket["Socket 服务器
:9938"] + end + + subgraph MCP["mcp_server/"] + direction TB + Dispatcher["MCPServer
(JSON-RPC dispatcher)"] + Tools["tools/
~90 ac_* + 别名"] + Resources["resources/
files · history ·
commands · screen-live"] + Prompts["prompts/
内置模板"] + Context["context · audit ·
rate-limit · log-bridge"] + FakeBE["fake_backend
(CI 烟雾测试)"] + end + + subgraph Core["无头核心 (wrapper/ + utils/)"] + direction TB + Wrapper["wrapper/
鼠标 · 键盘 · 屏幕 ·
图像 · 录制 · 窗口"] + Executor["executor/
AC_* JSON 动作引擎"] + Vision["vision/ · ocr/ ·
accessibility/"] + Recorder["scheduler/ · triggers/ ·
hotkey/ · plugin_loader/
run_history/"] + IOUtils["clipboard/ · cv2_utils/ ·
shell_process/ · json/"] + end + + subgraph Backends["操作系统后端"] + direction TB + Win["windows/
Win32 ctypes"] + Mac["osx/
pyobjc · Quartz"] + X11["linux_with_x11/
python-Xlib"] + end + + Claude --> Stdio + APIUser --> Stdio + HTTPClient --> HTTPMCP + TCPClient --> Socket + TCPClient --> REST + + Stdio --> Dispatcher + HTTPMCP --> Dispatcher + Dispatcher --> Tools + Dispatcher --> Resources + Dispatcher --> Prompts + Dispatcher -.- Context + Tools -.可选.-> FakeBE + + Tools --> Wrapper + Tools --> Executor + Tools --> Vision + Tools --> Recorder + Tools --> IOUtils + Resources --> Recorder + Resources --> Wrapper + + REST --> Executor + Socket --> Executor + + GUIUser --> Wrapper + GUIUser --> Recorder + CLIUser --> Executor + Library --> Wrapper + Library --> Executor + + Wrapper --> Backends + Vision -.- Wrapper + Recorder -.- Executor +``` + ``` je_auto_control/ ├── wrapper/ # 平台无关 API 层 @@ -88,19 +180,21 @@ je_auto_control/ │ ├── auto_control_keyboard.py# 键盘操作 │ ├── auto_control_image.py # 图像识别(OpenCV 模板匹配) │ ├── auto_control_screen.py # 截图、屏幕大小、像素颜色 +│ ├── auto_control_window.py # 跨平台窗口管理 facade │ └── auto_control_record.py # 动作录制/回放 ├── windows/ # Windows 专用后端(Win32 API / ctypes) ├── osx/ # macOS 专用后端(pyobjc / Quartz) ├── linux_with_x11/ # Linux 专用后端(python-Xlib) ├── gui/ # PySide6 GUI 应用程序 └── utils/ + ├── mcp_server/ # MCP 服务器(stdio + HTTP/SSE)— server / tools / resources / prompts / audit / rate_limit / fake_backend / plugin_watcher ├── executor/ # JSON 动作执行引擎 ├── callback/ # 回调函数执行器 ├── cv2_utils/ # OpenCV 截图、模板匹配、视频录制 ├── accessibility/ # UIA (Windows) / AX (macOS) 元件搜索 ├── vision/ # VLM 元件定位(Anthropic / OpenAI) ├── ocr/ # Tesseract 文字定位 - ├── clipboard/ # 跨平台剪贴板 + ├── clipboard/ # 跨平台剪贴板(文字 + 图像) ├── scheduler/ # Interval + cron 调度器 ├── hotkey/ # 全局热键守护进程 ├── triggers/ # 图像/窗口/像素/文件 触发器 @@ -403,6 +497,74 @@ je_auto_control.execute_action([ | 进程 | `AC_execute_process` | | 执行器 | `AC_execute_action`, `AC_execute_files` | +### MCP 服务器(让 Claude 使用 AutoControl) + +把 AutoControl 包装成 Model Context Protocol 服务,任何支持 MCP 的 +client(Claude Desktop、Claude Code、自定义 Anthropic / OpenAI tool-use +循环)都能驱动本机桌面。纯 stdlib — JSON-RPC 2.0 走 stdio 或 HTTP+ +SSE。 + +**注册到 Claude Code:** + +```bash +claude mcp add autocontrol -- python -m je_auto_control.utils.mcp_server +``` + +**注册到 Claude Desktop**(`claude_desktop_config.json`): + +```json +{ + "mcpServers": { + "autocontrol": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"] + } + } +} +``` + +**程序启动:** + +```python +import je_auto_control as ac + +# Stdio(会阻塞直到 stdin 关闭) +ac.start_mcp_stdio_server() + +# 或 HTTP / SSE,带 Bearer token 验证 + 可选 TLS +ac.start_mcp_http_server(host="127.0.0.1", port=9940, + auth_token="hunter2") +``` + +**不启动服务器、只看目录:** + +```bash +je_auto_control_mcp --list-tools +je_auto_control_mcp --list-tools --read-only +je_auto_control_mcp --list-resources +je_auto_control_mcp --list-prompts +``` + +**功能总览:** + +| 面向 | 涵盖 | +|---|---| +| 工具(约 90 个) | 鼠标 · 键盘 · drag · 屏幕 / 多屏 · 截图回 image · diff · OCR · 图像 · 窗口(move/min/max/restore/...) · 剪贴板文字+图像 · 进程 / shell · 动作录制 · 屏幕录像 · scheduler / triggers / hotkeys · accessibility tree · VLM · executor · history | +| 别名 | `click`、`type`、`screenshot`、`find_image`、`drag`、`shell`、`wait_image`...,以 `JE_AUTOCONTROL_MCP_ALIASES=0` 关闭 | +| Resources | `autocontrol://files/`、`autocontrol://history`、`autocontrol://commands`、`autocontrol://screen/live`(支持 `resources/subscribe`)| +| Prompts | `automate_ui_task`、`record_and_generalize`、`compare_screenshots`、`find_widget`、`explain_action_file` | +| 协议 | tools / resources / prompts / sampling / roots / logging / progress / cancellation / list_changed / elicitation | +| 传输 | stdio、HTTP `POST /mcp`、`Accept: text/event-stream` 时走 SSE 流 | +| 安全 | 工具注解 · `JE_AUTOCONTROL_MCP_READONLY` · `JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE` · 审计 log · token-bucket rate limiter · 工具失败自动截图 | +| 部署 | Bearer token 验证 · 通过 `ssl_context` 启用 TLS · `PluginWatcher` 热加载 · `JE_AUTOCONTROL_FAKE_BACKEND=1` 给 CI | + +完整参考请见 [docs/source/Zh/doc/mcp_server/mcp_server_doc.rst](docs/source/Zh/doc/mcp_server/mcp_server_doc.rst) +(英文版本在 [docs/source/Eng/doc/mcp_server/mcp_server_doc.rst](docs/source/Eng/doc/mcp_server/mcp_server_doc.rst))。 + +> ⚠️ MCP 服务器可以移动鼠标、发送键盘事件、截图、执行任意 `AC_*` +> 动作。请只注册给可信任的 client。HTTP 默认绑 `127.0.0.1`,要对外 +> 必须有明确理由,**并且**搭配 `auth_token` 与 `ssl_context`。 + ### 调度器(Interval & Cron) ```python diff --git a/README/README_zh-TW.md b/README/README_zh-TW.md index ab1896c..dbd8f73 100644 --- a/README/README_zh-TW.md +++ b/README/README_zh-TW.md @@ -27,6 +27,7 @@ - [截圖](#截圖) - [動作錄製與回放](#動作錄製與回放) - [JSON 腳本執行器](#json-腳本執行器) + - [MCP 伺服器(讓 Claude 使用 AutoControl)](#mcp-伺服器讓-claude-使用-autocontrol) - [排程器(Interval & Cron)](#排程器interval--cron) - [全域熱鍵](#全域熱鍵) - [事件觸發器](#事件觸發器) @@ -65,6 +66,7 @@ - **事件觸發器** — 偵測到影像出現、視窗出現、像素變化或檔案變動時自動執行腳本 - **執行歷史** — 以 SQLite 紀錄 scheduler / triggers / hotkeys / REST 的執行結果;錯誤時自動附上截圖 - **報告產生** — 將測試紀錄匯出為 HTML、JSON 或 XML 報告,包含成功/失敗狀態 +- **MCP 伺服器** — JSON-RPC 2.0 Model Context Protocol 服務(stdio + HTTP/SSE),讓 Claude Desktop / Claude Code / 自訂 tool-use 迴圈直接驅動 AutoControl。約 90 個工具,完整協定支援(resources、prompts、sampling、roots、logging、progress、cancellation、elicitation),Bearer token 驗證 + TLS、稽核 log、rate limit、plugin 熱重載、CI fake backend - **遠端自動化** — 同時提供 TCP Socket 伺服器與 REST API 伺服器 - **外掛載入器** — 將定義 `AC_*` 可呼叫物的 `.py` 檔放入目錄,執行時即可註冊成 executor 指令 - **Shell 整合** — 在自動化流程中執行 Shell 命令,支援非同步輸出擷取 @@ -80,6 +82,96 @@ ## 架構 +執行階段是分層的:**客戶端介面**(CLI、GUI、MCP/REST/Socket 伺服 +器)位於最上層,底下是**無頭 API**(`wrapper/` + `utils/`),最後 +解析到 `wrapper/platform_wrapper.py` 在 import 時挑選的**作業系統 +後端**。套件 façade(`je_auto_control/__init__.py`)會 re-export 所 +有公開名稱,使用者只需要 `import je_auto_control`,不論用哪個介面或 +後端都一樣。 + +```mermaid +flowchart LR + subgraph Clients["客戶端介面"] + direction TB + Claude[["Claude Desktop /
Claude Code"]] + APIUser[["自訂 Anthropic /
OpenAI tool-use 迴圈"]] + HTTPClient[["HTTP / SSE clients"]] + TCPClient[["Socket / REST clients"]] + GUIUser[["PySide6 GUI"]] + CLIUser[["python -m
je_auto_control[.cli]"]] + Library[["Library 使用者
(import je_auto_control)"]] + end + + subgraph Transports["傳輸與伺服器"] + direction TB + Stdio["MCP stdio
JSON-RPC 2.0"] + HTTPMCP["MCP HTTP /
SSE + auth + TLS"] + REST["REST 伺服器
:9939"] + Socket["Socket 伺服器
:9938"] + end + + subgraph MCP["mcp_server/"] + direction TB + Dispatcher["MCPServer
(JSON-RPC dispatcher)"] + Tools["tools/
~90 ac_* + 別名"] + Resources["resources/
files · history ·
commands · screen-live"] + Prompts["prompts/
內建範本"] + Context["context · audit ·
rate-limit · log-bridge"] + FakeBE["fake_backend
(CI 煙霧測試)"] + end + + subgraph Core["無頭核心 (wrapper/ + utils/)"] + direction TB + Wrapper["wrapper/
滑鼠 · 鍵盤 · 螢幕 ·
影像 · 錄製 · 視窗"] + Executor["executor/
AC_* JSON 動作引擎"] + Vision["vision/ · ocr/ ·
accessibility/"] + Recorder["scheduler/ · triggers/ ·
hotkey/ · plugin_loader/
run_history/"] + IOUtils["clipboard/ · cv2_utils/ ·
shell_process/ · json/"] + end + + subgraph Backends["作業系統後端"] + direction TB + Win["windows/
Win32 ctypes"] + Mac["osx/
pyobjc · Quartz"] + X11["linux_with_x11/
python-Xlib"] + end + + Claude --> Stdio + APIUser --> Stdio + HTTPClient --> HTTPMCP + TCPClient --> Socket + TCPClient --> REST + + Stdio --> Dispatcher + HTTPMCP --> Dispatcher + Dispatcher --> Tools + Dispatcher --> Resources + Dispatcher --> Prompts + Dispatcher -.- Context + Tools -.選用.-> FakeBE + + Tools --> Wrapper + Tools --> Executor + Tools --> Vision + Tools --> Recorder + Tools --> IOUtils + Resources --> Recorder + Resources --> Wrapper + + REST --> Executor + Socket --> Executor + + GUIUser --> Wrapper + GUIUser --> Recorder + CLIUser --> Executor + Library --> Wrapper + Library --> Executor + + Wrapper --> Backends + Vision -.- Wrapper + Recorder -.- Executor +``` + ``` je_auto_control/ ├── wrapper/ # 平台無關 API 層 @@ -88,19 +180,21 @@ je_auto_control/ │ ├── auto_control_keyboard.py# 鍵盤操作 │ ├── auto_control_image.py # 圖像辨識(OpenCV 模板匹配) │ ├── auto_control_screen.py # 截圖、螢幕大小、像素顏色 +│ ├── auto_control_window.py # 跨平台視窗管理 facade │ └── auto_control_record.py # 動作錄製/回放 ├── windows/ # Windows 專用後端(Win32 API / ctypes) ├── osx/ # macOS 專用後端(pyobjc / Quartz) ├── linux_with_x11/ # Linux 專用後端(python-Xlib) ├── gui/ # PySide6 GUI 應用程式 └── utils/ + ├── mcp_server/ # MCP 伺服器(stdio + HTTP/SSE)— server / tools / resources / prompts / audit / rate_limit / fake_backend / plugin_watcher ├── executor/ # JSON 動作執行引擎 ├── callback/ # 回呼函式執行器 ├── cv2_utils/ # OpenCV 截圖、模板匹配、影片錄製 ├── accessibility/ # UIA (Windows) / AX (macOS) 元件搜尋 ├── vision/ # VLM 元件定位(Anthropic / OpenAI) ├── ocr/ # Tesseract 文字定位 - ├── clipboard/ # 跨平台剪貼簿 + ├── clipboard/ # 跨平台剪貼簿(文字 + 圖像) ├── scheduler/ # Interval + cron 排程器 ├── hotkey/ # 全域熱鍵守護程序 ├── triggers/ # 影像/視窗/像素/檔案 觸發器 @@ -403,6 +497,74 @@ je_auto_control.execute_action([ | 程序 | `AC_execute_process` | | 執行器 | `AC_execute_action`, `AC_execute_files` | +### MCP 伺服器(讓 Claude 使用 AutoControl) + +把 AutoControl 包裝成 Model Context Protocol 服務,任何支援 MCP 的 +client(Claude Desktop、Claude Code、自訂 Anthropic / OpenAI tool-use +迴圈)都能驅動本機桌面。純 stdlib — JSON-RPC 2.0 走 stdio 或 HTTP+ +SSE。 + +**註冊到 Claude Code:** + +```bash +claude mcp add autocontrol -- python -m je_auto_control.utils.mcp_server +``` + +**註冊到 Claude Desktop**(`claude_desktop_config.json`): + +```json +{ + "mcpServers": { + "autocontrol": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"] + } + } +} +``` + +**程式啟動:** + +```python +import je_auto_control as ac + +# Stdio(會阻塞直到 stdin 關閉) +ac.start_mcp_stdio_server() + +# 或 HTTP / SSE,含 Bearer token 驗證 + 可選 TLS +ac.start_mcp_http_server(host="127.0.0.1", port=9940, + auth_token="hunter2") +``` + +**不啟動伺服器、只看目錄:** + +```bash +je_auto_control_mcp --list-tools +je_auto_control_mcp --list-tools --read-only +je_auto_control_mcp --list-resources +je_auto_control_mcp --list-prompts +``` + +**功能總覽:** + +| 面向 | 涵蓋 | +|---|---| +| 工具(約 90 個) | 滑鼠 · 鍵盤 · drag · 螢幕 / 多螢幕 · 截圖回 image · diff · OCR · 影像 · 視窗(move/min/max/restore/...) · 剪貼簿文字+圖像 · 程序 / shell · 動作錄製 · 螢幕錄影 · scheduler / triggers / hotkeys · accessibility tree · VLM · executor · history | +| 別名 | `click`、`type`、`screenshot`、`find_image`、`drag`、`shell`、`wait_image`...,以 `JE_AUTOCONTROL_MCP_ALIASES=0` 關閉 | +| Resources | `autocontrol://files/`、`autocontrol://history`、`autocontrol://commands`、`autocontrol://screen/live`(支援 `resources/subscribe`)| +| Prompts | `automate_ui_task`、`record_and_generalize`、`compare_screenshots`、`find_widget`、`explain_action_file` | +| 協定 | tools / resources / prompts / sampling / roots / logging / progress / cancellation / list_changed / elicitation | +| 傳輸 | stdio、HTTP `POST /mcp`、`Accept: text/event-stream` 時走 SSE 串流 | +| 安全 | 工具註記 · `JE_AUTOCONTROL_MCP_READONLY` · `JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE` · 稽核 log · token-bucket rate limiter · 工具失敗自動截圖 | +| 部署 | Bearer token 驗證 · 透過 `ssl_context` 啟用 TLS · `PluginWatcher` 熱重載 · `JE_AUTOCONTROL_FAKE_BACKEND=1` 給 CI | + +完整參考請見 [docs/source/Zh/doc/mcp_server/mcp_server_doc.rst](docs/source/Zh/doc/mcp_server/mcp_server_doc.rst) +(英文版本在 [docs/source/Eng/doc/mcp_server/mcp_server_doc.rst](docs/source/Eng/doc/mcp_server/mcp_server_doc.rst))。 + +> ⚠️ MCP 伺服器可以移動滑鼠、送鍵盤事件、截圖、執行任意 `AC_*` 動 +> 作。請只註冊給可信任的 client。HTTP 預設綁 `127.0.0.1`,要對外 +> 必須要有明確理由,**並且**搭配 `auth_token` 與 `ssl_context`。 + ### 排程器(Interval & Cron) ```python diff --git a/docs/source/Eng/doc/mcp_server/mcp_server_doc.rst b/docs/source/Eng/doc/mcp_server/mcp_server_doc.rst new file mode 100644 index 0000000..d2f5918 --- /dev/null +++ b/docs/source/Eng/doc/mcp_server/mcp_server_doc.rst @@ -0,0 +1,373 @@ +========================================== +MCP Server (Use AutoControl from Claude) +========================================== + +The MCP server exposes AutoControl as a Model Context Protocol +service so any MCP-compatible client (Claude Desktop, Claude Code, +custom Anthropic / OpenAI tool-use loops) can drive the host machine +through AutoControl. Implementation is stdlib-only — JSON-RPC 2.0 +over stdio or HTTP+SSE — no extra runtime dependencies. + +Roughly 90 tools are exposed, plus the full set of MCP protocol +capabilities: tools, resources, prompts, sampling, roots, logging, +progress, cancellation, list-changed notifications, and elicitation. + +Tool catalogue +============== + +The default registry pairs every canonical ``ac_*`` tool with a +short alias (``click``, ``type``, ``screenshot``, ...) so prompts +can stay terse. Use ``--list-tools`` (see *CLI inspection* below) to +dump the live catalogue as JSON. + +Mouse / keyboard + ``ac_click_mouse``, ``ac_set_mouse_position``, + ``ac_get_mouse_position``, ``ac_mouse_scroll``, + ``ac_drag``, ``ac_send_mouse_to_window``, + ``ac_type_text``, ``ac_press_key``, ``ac_hotkey``, + ``ac_send_key_to_window``. + +Screen / image / OCR + ``ac_screen_size``, ``ac_screenshot`` (returns base64 PNG image + content + optional file save, supports ``monitor_index`` for + multi-display setups), ``ac_list_monitors``, ``ac_get_pixel``, + ``ac_diff_screenshots``, ``ac_locate_image_center``, + ``ac_locate_and_click``, ``ac_locate_text``, ``ac_click_text``, + ``ac_wait_for_image``, ``ac_wait_for_pixel``. + +Window management (Windows) + ``ac_list_windows``, ``ac_focus_window``, ``ac_wait_for_window``, + ``ac_close_window``, ``ac_window_move``, ``ac_window_minimize``, + ``ac_window_maximize``, ``ac_window_restore``. + +Semantic locators + ``ac_a11y_list``, ``ac_a11y_find``, ``ac_a11y_click``, + ``ac_vlm_locate``, ``ac_vlm_click``. + +Clipboard / processes / shell + ``ac_get_clipboard``, ``ac_set_clipboard``, + ``ac_get_clipboard_image``, ``ac_set_clipboard_image``, + ``ac_launch_process``, ``ac_list_processes``, + ``ac_kill_process``, ``ac_shell``. + +Recording / replay + ``ac_record_start``, ``ac_record_stop``, + ``ac_read_action_file``, ``ac_write_action_file``, + ``ac_trim_actions``, ``ac_adjust_delays``, + ``ac_scale_coordinates``, + ``ac_screen_record_start``, ``ac_screen_record_stop``, + ``ac_screen_record_list``. + +Action executor / history + ``ac_execute_actions``, ``ac_execute_action_file``, + ``ac_list_action_commands``, ``ac_list_run_history``. + +Scheduler / triggers / hotkeys + ``ac_scheduler_add_job``, ``ac_scheduler_remove_job``, + ``ac_scheduler_list_jobs``, ``ac_scheduler_start``, + ``ac_scheduler_stop``, ``ac_trigger_add``, ``ac_trigger_remove``, + ``ac_trigger_list``, ``ac_trigger_start``, ``ac_trigger_stop``, + ``ac_hotkey_bind``, ``ac_hotkey_unbind``, ``ac_hotkey_list``, + ``ac_hotkey_daemon_start``, ``ac_hotkey_daemon_stop``. + +Every tool carries the MCP 2025-06-18 ``annotations`` block +(``readOnlyHint``, ``destructiveHint``, ``idempotentHint``, +``openWorldHint``) so well-behaved clients can auto-approve +read-only queries and require user confirmation before destructive +ones. + +Resources, prompts, sampling +============================ + +Resources + - ``autocontrol://files/`` — every JSON action file in the + workspace root (re-targets when the client publishes + ``roots/list``). + - ``autocontrol://history`` — recent run-history snapshot. + - ``autocontrol://commands`` — full ``AC_*`` executor catalogue. + - ``autocontrol://screen/live`` — base64 PNG screenshots, with + ``resources/subscribe`` push notifications when content changes. + +Prompts + Five built-in templates: ``automate_ui_task``, + ``record_and_generalize``, ``compare_screenshots``, + ``find_widget``, ``explain_action_file``. + +Sampling + Tools can call ``server.request_sampling(messages, ...)`` to ask + the connected client model a question — useful when an automation + step needs an LLM judgment (e.g. "is this dialog showing an + error?"). Bridges through the same writer that handles tool + responses. + +Logging notifications, progress, cancellation +============================================= + +- The project logger is forwarded to the client as + ``notifications/message`` while a stdio session is active. + Clients can retune the level with ``logging/setLevel``. +- Long-running tools that accept a ``ctx`` parameter receive a + :class:`ToolCallContext` and can call + ``ctx.progress(value, total, message)`` to push + ``notifications/progress`` (when the client supplied a + ``progressToken``) and ``ctx.check_cancelled()`` to abort + cooperatively when ``notifications/cancelled`` arrives. + +Starting the server (programmatic) +================================== + +.. code-block:: python + + import je_auto_control as ac + + # Blocks until stdin closes — typical entry point for an MCP client. + ac.start_mcp_stdio_server() + +You can also build a custom registry, swap in a fake backend, or +attach plugin hot-reload: + +.. code-block:: python + + import je_auto_control as ac + + tools = ac.build_default_tool_registry(read_only=False, aliases=True) + server = ac.MCPServer(tools=tools) + watcher = ac.PluginWatcher(server, "./plugins") + watcher.start() + server.serve_stdio() + +Starting the server (command line) +================================== + +After ``pip install -e .`` (or ``pip install je_auto_control``), the +console script ``je_auto_control_mcp`` is on ``$PATH``. You can also +run it as a module: + +.. code-block:: shell + + je_auto_control_mcp + # or + python -m je_auto_control.utils.mcp_server + +Both forms speak MCP over stdin/stdout — they are not meant to be +run interactively from a terminal. + +CLI inspection flags +==================== + +Without any flags the entry point starts the stdio dispatcher. +Supplying one of the following prints the catalogue as JSON and +exits — useful in CI smoke tests and prompt prep: + +.. code-block:: shell + + je_auto_control_mcp --list-tools + je_auto_control_mcp --list-tools --read-only + je_auto_control_mcp --list-resources + je_auto_control_mcp --list-prompts + je_auto_control_mcp --fake-backend # swap in the in-memory backend + +Registering with Claude Desktop +=============================== + +Edit ``claude_desktop_config.json`` and add an entry under +``mcpServers``: + +.. code-block:: json + + { + "mcpServers": { + "autocontrol": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"] + } + } + } + +Restart Claude Desktop. The AutoControl tools appear in the tool +picker and the model can call them automatically. + +Registering with Claude Code +============================ + +.. code-block:: shell + + claude mcp add autocontrol -- python -m je_auto_control.utils.mcp_server + +Or add to your project's ``.claude/mcp.json``: + +.. code-block:: json + + { + "mcpServers": { + "autocontrol": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"] + } + } + } + +HTTP transport (with SSE, auth, TLS) +==================================== + +When stdio is awkward (long-running GUI host, container, remote +box), start the same dispatcher behind HTTP: + +.. code-block:: python + + import je_auto_control as ac + import ssl + + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain("server.crt", "server.key") + + server = ac.start_mcp_http_server( + host="127.0.0.1", port=9940, + auth_token="hunter2", + ssl_context=ssl_context, + ) + +- ``POST /mcp`` accepts JSON-RPC bodies. Returns + ``application/json`` by default; if ``Accept`` includes + ``text/event-stream`` the response streams progress notifications + followed by the final result as SSE events. +- Missing / wrong ``Authorization: Bearer `` returns 401 / + 403 (constant-time compare via ``hmac.compare_digest``). +- ``ssl_context`` wraps the listening socket so the same transport + can serve HTTPS. +- The default bind is ``127.0.0.1`` per the project's + least-privilege policy — opt into ``0.0.0.0`` only with explicit + reason. + +Bearer token can also come from ``JE_AUTOCONTROL_MCP_TOKEN``. + +Read-only / safe mode +===================== + +Set ``JE_AUTOCONTROL_MCP_READONLY=1`` (or pass ``read_only=True`` to +:func:`build_default_tool_registry`) to drop every tool whose +``readOnlyHint`` is false. Only observers (positions, OCR queries, +clipboard reads, history, ...) survive: + +.. code-block:: json + + { + "mcpServers": { + "autocontrol_safe": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"], + "env": {"JE_AUTOCONTROL_MCP_READONLY": "1"} + } + } + } + +Confirmation prompts (elicitation) +================================== + +Set ``JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE=1`` to gate every +destructive tool behind an MCP ``elicitation/create`` request. The +client surfaces a confirmation prompt; declining returns a clean +error to the model without running the action. Requires the client +to advertise the ``elicitation`` capability — older clients fall +through with a logged warning. + +Audit log +========= + +Set ``JE_AUTOCONTROL_MCP_AUDIT=/path/to/audit.jsonl`` to append one +JSONL record per ``tools/call``: timestamp, tool name, sanitised +arguments (``password`` / ``token`` / ``secret`` / ``api_key`` / +``authorization`` are redacted), status (``ok`` / ``error`` / +``cancelled``), duration, optional error text, and optional +auto-screenshot artifact path (see below). + +Auto-screenshot on tool error +============================= + +Set ``JE_AUTOCONTROL_MCP_ERROR_SHOTS=/path/to/dir`` to write a +``_.png`` screenshot every time a tool errors. The path +is included in both the audit record and the error message +returned to the model — fast forensic trail for flaky automations. + +Rate limiting +============= + +Pass a :class:`RateLimiter` to :class:`MCPServer` to guard against +runaway loops: + +.. code-block:: python + + import je_auto_control as ac + + server = ac.MCPServer(rate_limiter=ac.RateLimiter( + rate_per_sec=20.0, capacity=40, + )) + +Exceeding the limit returns a ``-32000`` ``Rate limit exceeded`` +JSON-RPC error. + +Plugin hot-reload +================= + +Drop ``*.py`` files exposing top-level ``AC_*`` callables into a +directory and let :class:`PluginWatcher` keep the registry in sync: + +.. code-block:: python + + import je_auto_control as ac + + server = ac.MCPServer() + watcher = ac.PluginWatcher(server, directory="./plugins", + poll_seconds=2.0) + watcher.start() + ac.start_mcp_stdio_server() + +Each register / unregister fires +``notifications/tools/list_changed`` so the client refreshes its +cached catalogue automatically. + +CI smoke tests with the fake backend +==================================== + +The fake backend swaps the wrapper layer with in-memory recorders +so headless CI runners can exercise every MCP tool without a +display server: + +.. code-block:: shell + + JE_AUTOCONTROL_FAKE_BACKEND=1 python -m je_auto_control.utils.mcp_server + +Programmatically: + +.. code-block:: python + + from je_auto_control.utils.mcp_server.fake_backend import ( + fake_state, install_fake_backend, reset_fake_state, + uninstall_fake_backend, + ) + + install_fake_backend() + try: + # Run tests / tools — actions accumulate in fake_state(). + ... + finally: + uninstall_fake_backend() + reset_fake_state() + +Security notes +============== + +- The MCP server can move the mouse, send keystrokes, screenshot + the screen, and execute arbitrary ``AC_*`` actions. Only register + it with MCP clients you trust. +- Local stdio is the default transport — no network exposure unless + you opt into HTTP. HTTP defaults to ``127.0.0.1``; binding to + ``0.0.0.0`` requires an explicit, documented reason and **must** + be paired with ``auth_token`` and (for non-localhost) ``ssl_context``. +- File paths supplied to ``ac_screenshot``, ``ac_screen_record_start``, + ``ac_execute_action_file``, ``ac_read_action_file``, + ``ac_write_action_file``, and the FileSystem resource provider are + normalised via ``os.path.realpath``; the resource provider blocks + path traversal at the boundary. +- Subprocess calls (``ac_launch_process`` / ``ac_shell``) accept + argv lists or ``shlex.split`` parses — never an OS shell. diff --git a/docs/source/Eng/eng_index.rst b/docs/source/Eng/eng_index.rst index ad0a638..d8b84b9 100644 --- a/docs/source/Eng/eng_index.rst +++ b/docs/source/Eng/eng_index.rst @@ -19,6 +19,7 @@ Comprehensive guides for all AutoControl features. doc/callback_function/callback_function_doc doc/scheduler/scheduler_doc doc/socket_driver/socket_driver_doc + doc/mcp_server/mcp_server_doc doc/critical_exit/critical_exit_doc doc/cli/cli_doc doc/create_project/create_project_doc diff --git a/docs/source/Zh/doc/mcp_server/mcp_server_doc.rst b/docs/source/Zh/doc/mcp_server/mcp_server_doc.rst new file mode 100644 index 0000000..4152a92 --- /dev/null +++ b/docs/source/Zh/doc/mcp_server/mcp_server_doc.rst @@ -0,0 +1,352 @@ +================================ +MCP 伺服器 (讓 Claude 使用 AutoControl) +================================ + +MCP 伺服器把 AutoControl 包裝成 Model Context Protocol 服務,讓任何 +支援 MCP 的客戶端(Claude Desktop、Claude Code、自製 Anthropic / +OpenAI tool-use 迴圈)都能透過 AutoControl 操控本機桌面。實作純標 +準函式庫:JSON-RPC 2.0 走 stdio 或 HTTP+SSE,不需要額外的執行階段 +依賴。 + +預設暴露約 90 個工具,並支援完整的 MCP 協定能力:tools、resources、 +prompts、sampling、roots、logging、progress、cancellation、 +list-changed 通知與 elicitation。 + +工具目錄 +======== + +預設註冊表會把每個正式 ``ac_*`` 工具同時註冊一個短別名(``click``、 +``type``、``screenshot``...),提示文字可以更精簡。要看實際目錄請 +用下方「CLI 檢視」段落的 ``--list-tools``。 + +滑鼠 / 鍵盤 + ``ac_click_mouse``、``ac_set_mouse_position``、 + ``ac_get_mouse_position``、``ac_mouse_scroll``、 + ``ac_drag``、``ac_send_mouse_to_window``、 + ``ac_type_text``、``ac_press_key``、``ac_hotkey``、 + ``ac_send_key_to_window``。 + +螢幕 / 影像 / OCR + ``ac_screen_size``、``ac_screenshot``(回傳 base64 PNG image + 內容,可選擇存檔,並支援 ``monitor_index`` 對多螢幕單獨擷取)、 + ``ac_list_monitors``、``ac_get_pixel``、``ac_diff_screenshots``、 + ``ac_locate_image_center``、``ac_locate_and_click``、 + ``ac_locate_text``、``ac_click_text``、 + ``ac_wait_for_image``、``ac_wait_for_pixel``。 + +視窗管理 (Windows) + ``ac_list_windows``、``ac_focus_window``、``ac_wait_for_window``、 + ``ac_close_window``、``ac_window_move``、``ac_window_minimize``、 + ``ac_window_maximize``、``ac_window_restore``。 + +語意定位 + ``ac_a11y_list``、``ac_a11y_find``、``ac_a11y_click``、 + ``ac_vlm_locate``、``ac_vlm_click``。 + +剪貼簿 / 程序 / Shell + ``ac_get_clipboard``、``ac_set_clipboard``、 + ``ac_get_clipboard_image``、``ac_set_clipboard_image``、 + ``ac_launch_process``、``ac_list_processes``、 + ``ac_kill_process``、``ac_shell``。 + +錄製 / 重播 + ``ac_record_start``、``ac_record_stop``、 + ``ac_read_action_file``、``ac_write_action_file``、 + ``ac_trim_actions``、``ac_adjust_delays``、 + ``ac_scale_coordinates``、 + ``ac_screen_record_start``、``ac_screen_record_stop``、 + ``ac_screen_record_list``。 + +動作執行器 / 歷程 + ``ac_execute_actions``、``ac_execute_action_file``、 + ``ac_list_action_commands``、``ac_list_run_history``。 + +排程 / 觸發 / 熱鍵 + ``ac_scheduler_add_job``、``ac_scheduler_remove_job``、 + ``ac_scheduler_list_jobs``、``ac_scheduler_start``、 + ``ac_scheduler_stop``、``ac_trigger_add``、``ac_trigger_remove``、 + ``ac_trigger_list``、``ac_trigger_start``、``ac_trigger_stop``、 + ``ac_hotkey_bind``、``ac_hotkey_unbind``、``ac_hotkey_list``、 + ``ac_hotkey_daemon_start``、``ac_hotkey_daemon_stop``。 + +每個工具都會帶上 MCP 2025-06-18 規範的 ``annotations`` +(``readOnlyHint``、``destructiveHint``、``idempotentHint``、 +``openWorldHint``),client 可以據此自動允許唯讀查詢,並在執行破壞 +性動作前要求使用者確認。 + +Resources、Prompts、Sampling +============================ + +Resources + - ``autocontrol://files/`` — workspace 根目錄底下的所有 + JSON action 檔(client 推送 ``roots/list`` 後會自動切換根目錄)。 + - ``autocontrol://history`` — 最近的執行歷程快照。 + - ``autocontrol://commands`` — 完整 ``AC_*`` 執行器目錄。 + - ``autocontrol://screen/live`` — base64 PNG 直播, + ``resources/subscribe`` 後當畫面有變化會推送通知。 + +Prompts + 五個內建範本:``automate_ui_task``、``record_and_generalize``、 + ``compare_screenshots``、``find_widget``、``explain_action_file``。 + +Sampling + 工具可呼叫 ``server.request_sampling(messages, ...)`` 反問 client + 端的模型,適合用在「這個對話框是否在顯示錯誤?」這種需要 LLM 判 + 斷的步驟。走的是和工具回應同一條 writer。 + +Logging 通知 / Progress / Cancellation +====================================== + +- stdio session 期間,專案 logger 會以 ``notifications/message`` + 的形式即時推給 client。Client 可用 ``logging/setLevel`` 動態調整 + 等級。 +- 接受 ``ctx`` 參數的長時間工具會收到 + :class:`ToolCallContext`:呼叫 + ``ctx.progress(value, total, message)`` 推送 + ``notifications/progress``(client 須提供 ``progressToken``);呼 + 叫 ``ctx.check_cancelled()`` 在收到 ``notifications/cancelled`` + 時合作式中止。 + +以程式啟動伺服器 +================ + +.. code-block:: python + + import je_auto_control as ac + + # 阻塞直到 stdin 關閉 — 通常作為 MCP client 的進入點。 + ac.start_mcp_stdio_server() + +也可以自訂 registry、切換 fake backend、或啟動 plugin hot-reload: + +.. code-block:: python + + import je_auto_control as ac + + tools = ac.build_default_tool_registry(read_only=False, aliases=True) + server = ac.MCPServer(tools=tools) + watcher = ac.PluginWatcher(server, "./plugins") + watcher.start() + server.serve_stdio() + +以命令列啟動伺服器 +================== + +執行 ``pip install -e .``(或 ``pip install je_auto_control``)後, +``je_auto_control_mcp`` 命令會在 ``$PATH``。也能用模組形式啟動: + +.. code-block:: shell + + je_auto_control_mcp + # 或 + python -m je_auto_control.utils.mcp_server + +兩種啟動方式都透過 stdin/stdout 與 MCP client 通訊,不適合直接在終 +端機互動執行。 + +CLI 檢視旗標 +============ + +不加旗標就啟動 stdio dispatcher。下列旗標會把目錄 dump 成 JSON 後 +退出,適合 CI 煙霧測試或事先準備提示: + +.. code-block:: shell + + je_auto_control_mcp --list-tools + je_auto_control_mcp --list-tools --read-only + je_auto_control_mcp --list-resources + je_auto_control_mcp --list-prompts + je_auto_control_mcp --fake-backend # 切換成記憶體版 backend + +註冊到 Claude Desktop +===================== + +編輯 ``claude_desktop_config.json``,在 ``mcpServers`` 加入: + +.. code-block:: json + + { + "mcpServers": { + "autocontrol": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"] + } + } + } + +重啟 Claude Desktop。AutoControl 工具就會出現在工具列表,模型可自 +動呼叫。 + +註冊到 Claude Code +================== + +.. code-block:: shell + + claude mcp add autocontrol -- python -m je_auto_control.utils.mcp_server + +或寫進專案的 ``.claude/mcp.json``: + +.. code-block:: json + + { + "mcpServers": { + "autocontrol": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"] + } + } + } + +HTTP 傳輸(含 SSE / Auth / TLS) +================================== + +當 stdio 不方便(長時間 GUI 主機、容器、遠端機器)時,改用 HTTP 啟 +動相同的 dispatcher: + +.. code-block:: python + + import je_auto_control as ac + import ssl + + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain("server.crt", "server.key") + + server = ac.start_mcp_http_server( + host="127.0.0.1", port=9940, + auth_token="hunter2", + ssl_context=ssl_context, + ) + +- ``POST /mcp`` 接受 JSON-RPC 主體。預設回 ``application/json``; + 如果 ``Accept`` 包含 ``text/event-stream``,會以 SSE 串流推送進 + 度通知,然後送出最終結果。 +- 缺少或錯誤的 ``Authorization: Bearer `` 會回 401 / 403 + (透過 ``hmac.compare_digest`` 做常數時間比對)。 +- ``ssl_context`` 會包住 socket,讓同一條傳輸支援 HTTPS。 +- 預設綁定 ``127.0.0.1``;若要對外,務必同時設定 ``auth_token`` + 與(非 localhost 場景)``ssl_context``。 + +Bearer token 也可從 ``JE_AUTOCONTROL_MCP_TOKEN`` 環境變數讀取。 + +唯讀 / 安全模式 +=============== + +設定 ``JE_AUTOCONTROL_MCP_READONLY=1``(或呼叫 +:func:`build_default_tool_registry` 時傳 ``read_only=True``)只暴 +露 ``readOnlyHint`` 為 true 的工具(座標、OCR 查詢、剪貼簿讀取、歷 +程等): + +.. code-block:: json + + { + "mcpServers": { + "autocontrol_safe": { + "command": "python", + "args": ["-m", "je_auto_control.utils.mcp_server"], + "env": {"JE_AUTOCONTROL_MCP_READONLY": "1"} + } + } + } + +破壞性動作確認(Elicitation) +============================= + +設定 ``JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE=1`` 後,所有 destructive +工具在執行前會送出 MCP ``elicitation/create``。Client 顯示確認對話框, +使用者拒絕時模型會收到乾淨的錯誤,不會執行動作。需要 client 自己 +聲明 ``elicitation`` 能力;舊 client 會留下 warning log 後繼續執行。 + +稽核 Log +======== + +設定 ``JE_AUTOCONTROL_MCP_AUDIT=/path/to/audit.jsonl``,每次 +``tools/call`` 都會寫一筆 JSONL:時間戳、工具名稱、過濾過的參數 +(``password`` / ``token`` / ``secret`` / ``api_key`` / +``authorization`` 會被替換成 ````)、狀態(``ok`` / +``error`` / ``cancelled``)、執行時間、錯誤訊息與 +auto-screenshot 路徑(見下)。 + +工具失敗自動截圖 +================ + +設定 ``JE_AUTOCONTROL_MCP_ERROR_SHOTS=/path/to/dir``,每次工具失敗 +就會寫一張 ``_.png`` 到該資料夾;路徑會同時帶在 audit log +與回傳給模型的錯誤訊息中,排查不穩定流程很快。 + +Rate Limiting +============= + +把 :class:`RateLimiter` 傳給 :class:`MCPServer` 防止失控的迴圈灌 +爆主機: + +.. code-block:: python + + import je_auto_control as ac + + server = ac.MCPServer(rate_limiter=ac.RateLimiter( + rate_per_sec=20.0, capacity=40, + )) + +超過上限就回 ``-32000`` ``Rate limit exceeded`` JSON-RPC 錯誤。 + +Plugin Hot-Reload +================= + +把暴露頂層 ``AC_*`` callable 的 ``*.py`` 丟進資料夾,讓 +:class:`PluginWatcher` 自動同步 registry: + +.. code-block:: python + + import je_auto_control as ac + + server = ac.MCPServer() + watcher = ac.PluginWatcher(server, directory="./plugins", + poll_seconds=2.0) + watcher.start() + ac.start_mcp_stdio_server() + +每次 register / unregister 都會送出 +``notifications/tools/list_changed``,client 會自動更新工具目錄。 + +CI 煙霧測試 (Fake Backend) +========================= + +Fake backend 把 wrapper 層換成記憶體版的紀錄器,讓沒有顯示伺服器的 +CI runner 也能走完所有 MCP 工具: + +.. code-block:: shell + + JE_AUTOCONTROL_FAKE_BACKEND=1 python -m je_auto_control.utils.mcp_server + +程式內使用: + +.. code-block:: python + + from je_auto_control.utils.mcp_server.fake_backend import ( + fake_state, install_fake_backend, reset_fake_state, + uninstall_fake_backend, + ) + + install_fake_backend() + try: + # 跑測試 / 工具 — 動作會累積在 fake_state()。 + ... + finally: + uninstall_fake_backend() + reset_fake_state() + +安全注意事項 +============ + +- MCP 伺服器可以移動滑鼠、送鍵盤事件、截圖、執行任意 ``AC_*`` 動 + 作。請只註冊給可信任的 MCP client。 +- 預設只走 stdio,沒有任何網路曝險;若用 HTTP,預設綁定 + ``127.0.0.1``,若要 ``0.0.0.0`` 必須要有明確理由,**且必須**搭配 + ``auth_token`` 與(非 localhost 時)``ssl_context``。 +- ``ac_screenshot``、``ac_screen_record_start``、 + ``ac_execute_action_file``、``ac_read_action_file``、 + ``ac_write_action_file`` 收到的路徑都會經過 ``os.path.realpath`` + 正規化;FileSystem resource provider 也會在邊界擋住 path + traversal。 +- 子程序呼叫(``ac_launch_process`` / ``ac_shell``)只接受 argv list + 或 ``shlex.split`` 的解析結果,從不啟用 OS shell。 diff --git a/docs/source/Zh/zh_index.rst b/docs/source/Zh/zh_index.rst index c355a79..7ac1580 100644 --- a/docs/source/Zh/zh_index.rst +++ b/docs/source/Zh/zh_index.rst @@ -19,6 +19,7 @@ AutoControl 所有功能的完整使用指南。 doc/callback_function/callback_function_doc doc/scheduler/scheduler_doc doc/socket_driver/socket_driver_doc + doc/mcp_server/mcp_server_doc doc/critical_exit/critical_exit_doc doc/cli/cli_doc doc/create_project/create_project_doc diff --git a/je_auto_control/__init__.py b/je_auto_control/__init__.py index 9bd2d39..9f45a02 100644 --- a/je_auto_control/__init__.py +++ b/je_auto_control/__init__.py @@ -62,6 +62,16 @@ TextMatch, click_text, find_text_matches, locate_text_center, set_tesseract_cmd, wait_for_text, ) +# MCP server (headless stdio bridge for Claude / other MCP clients) +from je_auto_control.utils.mcp_server import ( + AuditLogger, HttpMCPServer, MCPContent, MCPPrompt, MCPPromptArgument, + MCPResource, MCPServer, MCPTool, MCPToolAnnotations, + OperationCancelledError, PromptProvider, RateLimiter, + ResourceProvider, ToolCallContext, build_default_tool_registry, + default_prompt_provider, default_resource_provider, + make_plugin_tool, register_plugin_tools, start_mcp_http_server, + start_mcp_stdio_server, +) # Plugin loader (headless) from je_auto_control.utils.plugin_loader.plugin_loader import ( discover_plugin_commands, load_plugin_directory, load_plugin_file, @@ -210,6 +220,15 @@ def start_autocontrol_gui(*args, **kwargs): "get_clipboard", "set_clipboard", # Hotkey daemon "HotkeyDaemon", "HotkeyBinding", "default_hotkey_daemon", + # MCP server + "AuditLogger", "HttpMCPServer", "MCPContent", "MCPPrompt", + "MCPPromptArgument", "MCPResource", "MCPServer", "MCPTool", + "MCPToolAnnotations", "OperationCancelledError", "PromptProvider", + "RateLimiter", "ResourceProvider", "ToolCallContext", + "build_default_tool_registry", + "default_prompt_provider", "default_resource_provider", + "make_plugin_tool", "register_plugin_tools", + "start_mcp_http_server", "start_mcp_stdio_server", # Plugin loader "load_plugin_file", "load_plugin_directory", "discover_plugin_commands", "register_plugin_commands", diff --git a/je_auto_control/utils/clipboard/clipboard_image.py b/je_auto_control/utils/clipboard/clipboard_image.py new file mode 100644 index 0000000..0789989 --- /dev/null +++ b/je_auto_control/utils/clipboard/clipboard_image.py @@ -0,0 +1,93 @@ +"""Headless image clipboard helpers. + +Reads work via Pillow's ``ImageGrab.grabclipboard`` on every supported +platform (Windows, macOS, Linux with xclip). Writes are +Windows-only today: macOS and Linux callers receive a clear +``NotImplementedError`` so the higher-level MCP tool can surface the +limitation in its result instead of crashing. +""" +import io +import os +import sys +from typing import Optional + + +def get_clipboard_image() -> Optional[bytes]: + """Return the current clipboard image as PNG bytes, or ``None`` if empty.""" + from PIL import ImageGrab + try: + image = ImageGrab.grabclipboard() + except (OSError, NotImplementedError): + return None + if image is None: + return None + if isinstance(image, list): + # On macOS / Linux the clipboard may carry file paths instead of an image. + return None + buffer = io.BytesIO() + image.save(buffer, format="PNG") + return buffer.getvalue() + + +def set_clipboard_image(image_path: str) -> None: + """Place ``image_path`` (any Pillow-readable file) onto the clipboard.""" + safe_path = os.path.realpath(os.fspath(image_path)) + if not os.path.isfile(safe_path): + raise FileNotFoundError(f"image not found: {safe_path}") + if sys.platform.startswith("win"): + _win_set_image(safe_path) + return + raise NotImplementedError( + f"set_clipboard_image is currently only implemented on Windows " + f"(got {sys.platform})" + ) + + +def _win_set_image(path: str) -> None: + """Win32 implementation: copy a Pillow-rendered DIB onto the clipboard.""" + import ctypes + from ctypes import wintypes + from PIL import Image + + image = Image.open(path).convert("RGB") + buffer = io.BytesIO() + image.save(buffer, format="BMP") + # Strip the 14-byte BITMAPFILEHEADER — clipboard wants raw DIB. + dib_payload = buffer.getvalue()[14:] + + user32 = ctypes.WinDLL("user32", use_last_error=True) + kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) + cf_dib = 8 + gmem_moveable = 0x0002 + + user32.OpenClipboard.argtypes = [wintypes.HWND] + user32.OpenClipboard.restype = wintypes.BOOL + user32.EmptyClipboard.restype = wintypes.BOOL + user32.SetClipboardData.argtypes = [wintypes.UINT, wintypes.HANDLE] + user32.SetClipboardData.restype = wintypes.HANDLE + user32.CloseClipboard.restype = wintypes.BOOL + kernel32.GlobalAlloc.argtypes = [wintypes.UINT, ctypes.c_size_t] + kernel32.GlobalAlloc.restype = wintypes.HGLOBAL + kernel32.GlobalLock.argtypes = [wintypes.HGLOBAL] + kernel32.GlobalLock.restype = ctypes.c_void_p + kernel32.GlobalUnlock.argtypes = [wintypes.HGLOBAL] + + handle = kernel32.GlobalAlloc(gmem_moveable, len(dib_payload)) + if not handle: + raise RuntimeError("GlobalAlloc failed for clipboard image") + pointer = kernel32.GlobalLock(handle) + if not pointer: + raise RuntimeError("GlobalLock failed for clipboard image") + ctypes.memmove(pointer, dib_payload, len(dib_payload)) + kernel32.GlobalUnlock(handle) + if not user32.OpenClipboard(None): + raise RuntimeError("OpenClipboard failed") + try: + user32.EmptyClipboard() + if not user32.SetClipboardData(cf_dib, handle): + raise RuntimeError("SetClipboardData failed for clipboard image") + finally: + user32.CloseClipboard() + + +__all__ = ["get_clipboard_image", "set_clipboard_image"] diff --git a/je_auto_control/utils/executor/action_executor.py b/je_auto_control/utils/executor/action_executor.py index a96bd76..a4c2631 100644 --- a/je_auto_control/utils/executor/action_executor.py +++ b/je_auto_control/utils/executor/action_executor.py @@ -34,6 +34,8 @@ from je_auto_control.utils.generate_report.generate_xml_report import generate_xml, generate_xml_report from je_auto_control.utils.json.json_file import read_action_json from je_auto_control.utils.logging.logging_instance import autocontrol_logger +from je_auto_control.utils.mcp_server.http_transport import start_mcp_http_server +from je_auto_control.utils.mcp_server.server import start_mcp_stdio_server from je_auto_control.utils.package_manager.package_manager_class import package_manager from je_auto_control.utils.project.create_project_structure import create_project_dir from je_auto_control.utils.shell_process.shell_exec import ShellManager @@ -208,6 +210,10 @@ def __init__(self): # VLM-based element locator "AC_vlm_locate": _vlm_locate_as_list, "AC_vlm_click": click_by_description, + + # MCP server (Model Context Protocol stdio bridge) + "AC_start_mcp_server": start_mcp_stdio_server, + "AC_start_mcp_http_server": start_mcp_http_server, } def known_commands(self) -> set: diff --git a/je_auto_control/utils/mcp_server/__init__.py b/je_auto_control/utils/mcp_server/__init__.py new file mode 100644 index 0000000..27776f3 --- /dev/null +++ b/je_auto_control/utils/mcp_server/__init__.py @@ -0,0 +1,50 @@ +"""Headless MCP (Model Context Protocol) server for AutoControl. + +Exposes the headless automation API as MCP tools so MCP-compatible +clients (Claude Desktop, Claude Code, Claude API tool-use loops, etc.) +can drive the host machine through AutoControl. The transport is +JSON-RPC 2.0 over stdio, implemented with stdlib only — no extra +dependencies are required. +""" +from je_auto_control.utils.mcp_server.server import ( + MCPServer, start_mcp_stdio_server, +) +from je_auto_control.utils.mcp_server.audit import AuditLogger +from je_auto_control.utils.mcp_server.context import ( + OperationCancelledError, ToolCallContext, +) +from je_auto_control.utils.mcp_server.fake_backend import ( + FakeState, fake_state, install_fake_backend, reset_fake_state, + uninstall_fake_backend, +) +from je_auto_control.utils.mcp_server.log_bridge import MCPLogBridge +from je_auto_control.utils.mcp_server.plugin_watcher import PluginWatcher +from je_auto_control.utils.mcp_server.rate_limit import RateLimiter +from je_auto_control.utils.mcp_server.http_transport import ( + HttpMCPServer, start_mcp_http_server, +) +from je_auto_control.utils.mcp_server.prompts import ( + MCPPrompt, MCPPromptArgument, PromptProvider, default_prompt_provider, +) +from je_auto_control.utils.mcp_server.resources import ( + LiveScreenProvider, MCPResource, ResourceProvider, + default_resource_provider, +) +from je_auto_control.utils.mcp_server.tools import ( + MCPContent, MCPTool, MCPToolAnnotations, build_default_tool_registry, + make_plugin_tool, register_plugin_tools, +) + +__all__ = [ + "AuditLogger", "FakeState", "HttpMCPServer", "LiveScreenProvider", + "MCPContent", "MCPLogBridge", "MCPPrompt", "MCPPromptArgument", + "MCPResource", "MCPServer", "MCPTool", "MCPToolAnnotations", + "OperationCancelledError", "PluginWatcher", "PromptProvider", + "RateLimiter", "ResourceProvider", "ToolCallContext", + "build_default_tool_registry", + "default_prompt_provider", "default_resource_provider", + "fake_state", "install_fake_backend", "make_plugin_tool", + "register_plugin_tools", "reset_fake_state", + "start_mcp_http_server", "start_mcp_stdio_server", + "uninstall_fake_backend", +] diff --git a/je_auto_control/utils/mcp_server/__main__.py b/je_auto_control/utils/mcp_server/__main__.py new file mode 100644 index 0000000..bf6dc13 --- /dev/null +++ b/je_auto_control/utils/mcp_server/__main__.py @@ -0,0 +1,87 @@ +"""``python -m je_auto_control.utils.mcp_server`` entry point. + +Without flags this starts the stdio MCP server. With one of the +``--list-*`` flags it prints the requested catalogue to stdout and +exits — useful for inspection in CI or manual debugging. +""" +import argparse +import json +import sys + +from je_auto_control.utils.mcp_server.fake_backend import ( + install_fake_backend, maybe_install_from_env, +) +from je_auto_control.utils.mcp_server.prompts import default_prompt_provider +from je_auto_control.utils.mcp_server.resources import ( + default_resource_provider, +) +from je_auto_control.utils.mcp_server.server import start_mcp_stdio_server +from je_auto_control.utils.mcp_server.tools import ( + build_default_tool_registry, +) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="je_auto_control_mcp", + description="Run AutoControl's MCP server (stdio) or list its catalogue.", + ) + parser.add_argument( + "--list-tools", action="store_true", + help="Print all tool descriptors as JSON and exit.", + ) + parser.add_argument( + "--list-resources", action="store_true", + help="Print all resource descriptors as JSON and exit.", + ) + parser.add_argument( + "--list-prompts", action="store_true", + help="Print all prompt descriptors as JSON and exit.", + ) + parser.add_argument( + "--read-only", action="store_true", + help="Restrict tools to those marked readOnlyHint=true.", + ) + parser.add_argument( + "--fake-backend", action="store_true", + help=("Install the in-memory fake backend so tools record but " + "don't drive the real OS. Useful for CI smoke tests."), + ) + return parser + + +def main(argv: list = None) -> None: + """CLI entry point. Performs the requested action and returns ``None``.""" + parser = _build_parser() + args = parser.parse_args(argv) + if args.fake_backend: + install_fake_backend() + else: + maybe_install_from_env() + listing_modes = (args.list_tools, args.list_resources, args.list_prompts) + if any(listing_modes): + _print_listings(args) + return + start_mcp_stdio_server() + + +def _print_listings(args: argparse.Namespace) -> None: + if args.list_tools: + registry = build_default_tool_registry(read_only=args.read_only) + json.dump([tool.to_descriptor() for tool in registry], + sys.stdout, ensure_ascii=False, indent=2) + sys.stdout.write("\n") + if args.list_resources: + provider = default_resource_provider() + json.dump([resource.to_descriptor() for resource in provider.list()], + sys.stdout, ensure_ascii=False, indent=2) + sys.stdout.write("\n") + if args.list_prompts: + provider = default_prompt_provider() + json.dump([prompt.to_descriptor() for prompt in provider.list()], + sys.stdout, ensure_ascii=False, indent=2) + sys.stdout.write("\n") + + +if __name__ == "__main__": + main() diff --git a/je_auto_control/utils/mcp_server/audit.py b/je_auto_control/utils/mcp_server/audit.py new file mode 100644 index 0000000..8d2c958 --- /dev/null +++ b/je_auto_control/utils/mcp_server/audit.py @@ -0,0 +1,78 @@ +"""Audit log for MCP tool calls. + +Every ``tools/call`` produces one JSONL line with timestamp, tool +name, sanitised arguments, status (``ok`` / ``error``), and +duration. The default sink is ``$JE_AUTOCONTROL_MCP_AUDIT`` (or +``mcp_audit.jsonl`` next to the cwd) so deployments that need a +forensic trail get it without code changes. +""" +import json +import os +import threading +import time +from typing import Any, Dict, Optional + + +class AuditLogger: + """Thread-safe JSONL audit logger for MCP tool calls.""" + + def __init__(self, path: Optional[str] = None) -> None: + resolved = path + if resolved is None: + resolved = os.environ.get("JE_AUTOCONTROL_MCP_AUDIT") + self._path: Optional[str] = ( + os.path.realpath(os.fspath(resolved)) if resolved else None + ) + self._lock = threading.Lock() + + @property + def path(self) -> Optional[str]: + return self._path + + @property + def enabled(self) -> bool: + return self._path is not None + + def record(self, *, tool: str, arguments: Dict[str, Any], + status: str, duration_seconds: float, + error_text: Optional[str] = None, + artifact_path: Optional[str] = None) -> None: + """Append one audit entry. No-ops when no path is configured.""" + if self._path is None: + return + entry = { + "ts": time.time(), + "tool": tool, + "arguments": _sanitise(arguments), + "status": status, + "duration_seconds": float(duration_seconds), + } + if error_text is not None: + entry["error"] = error_text + if artifact_path is not None: + entry["artifact_path"] = artifact_path + line = json.dumps(entry, ensure_ascii=False, default=str) + with self._lock: + with open(self._path, "a", encoding="utf-8") as handle: + handle.write(line + "\n") + + +REDACTED_KEYS = frozenset({"password", "token", "secret", "api_key", + "authorization"}) +REDACTED_PLACEHOLDER = "" + + +def _sanitise(arguments: Dict[str, Any]) -> Dict[str, Any]: + """Replace obvious secret-like values with ``REDACTED_PLACEHOLDER``.""" + if not isinstance(arguments, dict): + return arguments + out: Dict[str, Any] = {} + for key, value in arguments.items(): + if key.lower() in REDACTED_KEYS: + out[key] = REDACTED_PLACEHOLDER + else: + out[key] = value + return out + + +__all__ = ["AuditLogger", "REDACTED_KEYS", "REDACTED_PLACEHOLDER"] diff --git a/je_auto_control/utils/mcp_server/context.py b/je_auto_control/utils/mcp_server/context.py new file mode 100644 index 0000000..3656973 --- /dev/null +++ b/je_auto_control/utils/mcp_server/context.py @@ -0,0 +1,71 @@ +"""Per-call context object passed to opt-in MCP tool handlers. + +A handler that declares a ``ctx`` parameter receives a +:class:`ToolCallContext`, which lets it report progress to the +client and observe cooperative cancellation requests. Handlers that +do not declare ``ctx`` are unaffected. +""" +import threading +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + + +class OperationCancelledError(RuntimeError): + """Raised by :meth:`ToolCallContext.check_cancelled` when the client cancels.""" + + def __init__(self, request_id: Any) -> None: + super().__init__(f"tool call {request_id!r} was cancelled by the client") + self.request_id = request_id + + +@dataclass +class ToolCallContext: + """State threaded through one ``tools/call`` request. + + Handlers can call :meth:`progress` to push a + ``notifications/progress`` to the client (no-op when the client + did not provide a ``progressToken``), and check + :attr:`cancelled` (or call :meth:`check_cancelled`) at safe + points to abort cooperatively. + """ + + request_id: Any + progress_token: Any = None + notifier: Optional[Callable[[str, dict], None]] = field( + default=None, repr=False, + ) + cancelled_event: threading.Event = field(default_factory=threading.Event) + + @property + def cancelled(self) -> bool: + """``True`` once the client has sent ``notifications/cancelled``.""" + return self.cancelled_event.is_set() + + def check_cancelled(self) -> None: + """Raise :class:`OperationCancelledError` if the call was cancelled.""" + if self.cancelled: + raise OperationCancelledError(self.request_id) + + def progress(self, value: float, total: Optional[float] = None, + message: Optional[str] = None) -> None: + """Send ``notifications/progress`` to the client. + + :param value: monotonic progress value (0..total when ``total`` is set, + otherwise an arbitrary increasing scalar). + :param total: optional upper bound for percent-style displays. + :param message: optional human-readable status string. + """ + if self.progress_token is None or self.notifier is None: + return + params: dict = { + "progressToken": self.progress_token, + "progress": float(value), + } + if total is not None: + params["total"] = float(total) + if message is not None: + params["message"] = str(message) + self.notifier("notifications/progress", params) + + +__all__ = ["OperationCancelledError", "ToolCallContext"] diff --git a/je_auto_control/utils/mcp_server/fake_backend.py b/je_auto_control/utils/mcp_server/fake_backend.py new file mode 100644 index 0000000..1c482f9 --- /dev/null +++ b/je_auto_control/utils/mcp_server/fake_backend.py @@ -0,0 +1,184 @@ +"""In-memory fake backend for CI / headless tool tests. + +Drop-in replacement for the wrapper layer's mouse / keyboard / screen +calls that records every invocation rather than touching the real OS. +Activate via :func:`install_fake_backend` (or set +``JE_AUTOCONTROL_FAKE_BACKEND=1`` before starting the MCP server) so +test agents can drive the full tool registry on a CI runner without a +display server. +""" +import os +import threading +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple + + +@dataclass +class FakeState: + """Records what the model would have done if the OS were real.""" + + cursor: Tuple[int, int] = (0, 0) + screen_size: Tuple[int, int] = (1920, 1080) + clipboard_text: str = "" + typed_text: List[str] = field(default_factory=list) + keys_pressed: List[Any] = field(default_factory=list) + mouse_actions: List[Tuple[str, Any, ...]] = field(default_factory=list) + + +def fake_state() -> FakeState: + """Return the process-wide fake state.""" + return _STATE + + +_STATE = FakeState() +_STATE_LOCK = threading.Lock() + + +def reset_fake_state() -> None: + """Reset every recorded interaction. Useful between tests.""" + global _STATE + with _STATE_LOCK: + _STATE = FakeState() + + +# === Patched callables ====================================================== + +def _fake_get_mouse_position() -> Tuple[int, int]: + return _STATE.cursor + + +def _fake_set_mouse_position(x: int, y: int) -> Tuple[int, int]: + with _STATE_LOCK: + _STATE.cursor = (int(x), int(y)) + _STATE.mouse_actions.append(("set_position", int(x), int(y))) + return _STATE.cursor + + +def _fake_click_mouse(mouse_keycode: Any, x: Any = None, + y: Any = None) -> Tuple[Any, int, int]: + cx, cy = _STATE.cursor if x is None or y is None else (int(x), int(y)) + with _STATE_LOCK: + _STATE.cursor = (cx, cy) + _STATE.mouse_actions.append(("click", mouse_keycode, cx, cy)) + return mouse_keycode, cx, cy + + +def _fake_press_mouse(mouse_keycode: Any, x: Any = None, + y: Any = None) -> Tuple[Any, int, int]: + cx, cy = _STATE.cursor if x is None or y is None else (int(x), int(y)) + with _STATE_LOCK: + _STATE.mouse_actions.append(("press", mouse_keycode, cx, cy)) + return mouse_keycode, cx, cy + + +def _fake_release_mouse(mouse_keycode: Any, x: Any = None, + y: Any = None) -> Tuple[Any, int, int]: + cx, cy = _STATE.cursor if x is None or y is None else (int(x), int(y)) + with _STATE_LOCK: + _STATE.mouse_actions.append(("release", mouse_keycode, cx, cy)) + return mouse_keycode, cx, cy + + +def _fake_mouse_scroll(scroll_value: int, x: Any = None, y: Any = None, + scroll_direction: str = "scroll_down" + ) -> Tuple[int, str]: + with _STATE_LOCK: + _STATE.mouse_actions.append( + ("scroll", int(scroll_value), scroll_direction), + ) + return int(scroll_value), scroll_direction + + +def _fake_screen_size() -> Tuple[int, int]: + return _STATE.screen_size + + +def _fake_write(text: str, *_args, **_kwargs) -> str: + with _STATE_LOCK: + _STATE.typed_text.append(text) + return text + + +def _fake_type_keyboard(keycode: Any, *_args, **_kwargs) -> str: + with _STATE_LOCK: + _STATE.keys_pressed.append(keycode) + return str(keycode) + + +def _fake_hotkey(keys: List[Any], *_args, **_kwargs) -> Tuple[str, str]: + joined = ",".join(str(k) for k in keys) + with _STATE_LOCK: + _STATE.keys_pressed.append(("hotkey", joined)) + return joined, joined + + +def _fake_get_clipboard() -> str: + return _STATE.clipboard_text + + +def _fake_set_clipboard(text: str) -> None: + with _STATE_LOCK: + _STATE.clipboard_text = str(text) + + +# === Install / uninstall ==================================================== + +_INSTALLED: Dict[str, Any] = {} + + +def install_fake_backend() -> None: + """Replace the headless API entry points with the fake recorders.""" + if _INSTALLED: + return + from je_auto_control.utils.clipboard import clipboard as clipboard_module + from je_auto_control.wrapper import auto_control_keyboard as kbd_module + from je_auto_control.wrapper import auto_control_mouse as mouse_module + from je_auto_control.wrapper import auto_control_screen as screen_module + targets: Dict[Any, Dict[str, Any]] = { + mouse_module: { + "get_mouse_position": _fake_get_mouse_position, + "set_mouse_position": _fake_set_mouse_position, + "click_mouse": _fake_click_mouse, + "press_mouse": _fake_press_mouse, + "release_mouse": _fake_release_mouse, + "mouse_scroll": _fake_mouse_scroll, + }, + screen_module: {"screen_size": _fake_screen_size}, + kbd_module: { + "write": _fake_write, + "type_keyboard": _fake_type_keyboard, + "hotkey": _fake_hotkey, + }, + clipboard_module: { + "get_clipboard": _fake_get_clipboard, + "set_clipboard": _fake_set_clipboard, + }, + } + for module, replacements in targets.items(): + for name, replacement in replacements.items(): + key = f"{module.__name__}.{name}" + _INSTALLED[key] = (module, name, getattr(module, name)) + setattr(module, name, replacement) + + +def uninstall_fake_backend() -> None: + """Restore the real backend functions previously replaced.""" + while _INSTALLED: + _key, value = _INSTALLED.popitem() + module, name, original = value + setattr(module, name, original) + + +def maybe_install_from_env() -> bool: + """Install the fake backend when ``JE_AUTOCONTROL_FAKE_BACKEND`` is truthy.""" + raw = os.environ.get("JE_AUTOCONTROL_FAKE_BACKEND", "").strip().lower() + if raw in {"1", "true", "yes", "on"}: + install_fake_backend() + return True + return False + + +__all__ = [ + "FakeState", "fake_state", "install_fake_backend", + "maybe_install_from_env", "reset_fake_state", "uninstall_fake_backend", +] diff --git a/je_auto_control/utils/mcp_server/http_transport.py b/je_auto_control/utils/mcp_server/http_transport.py new file mode 100644 index 0000000..05615c0 --- /dev/null +++ b/je_auto_control/utils/mcp_server/http_transport.py @@ -0,0 +1,250 @@ +"""HTTP transport for the MCP server. + +Implements a minimal Streamable HTTP transport (JSON-only, no SSE +streaming) so MCP clients that prefer HTTP — or that need to reach +the server from another process / container — can talk to the same +:class:`MCPServer` dispatcher already used by the stdio transport. + +Notifications are answered with ``202 Accepted`` per the MCP spec; +ordinary requests return their JSON-RPC response with +``Content-Type: application/json``. The default bind is +``127.0.0.1`` to honour the project's least-privilege policy. +""" +import hmac +import json +import os +import ssl +import threading +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any, Optional, Tuple + +from je_auto_control.utils.logging.logging_instance import autocontrol_logger +from je_auto_control.utils.mcp_server.server import ( + MCPServer, _notification_message, +) + +DEFAULT_PATH = "/mcp" +_MAX_BODY = 1_000_000 +_SSE_MEDIA_TYPE = "text/event-stream" + + +class _MCPHttpHandler(BaseHTTPRequestHandler): + """Bridges HTTP requests onto :meth:`MCPServer.handle_line`.""" + + server_version = "AutoControlMCP/1.0" + + # Suppress default stderr access logs — route through project logger. + def log_message(self, format, *args) -> None: # noqa: A002 # pylint: disable=redefined-builtin # reason: stdlib override + autocontrol_logger.info("mcp-http %s - %s", + self.address_string(), format % args) + + def do_POST(self) -> None: # noqa: N802 # reason: stdlib API + if not self._authorize(): + return + if self.path != DEFAULT_PATH: + self._send_json({"error": "unknown path"}, status=404) + return + line = self._read_body() + if line is None: + return + bridge: MCPServer = self.server.mcp # type: ignore[attr-defined] + if self._client_accepts_sse(): + self._dispatch_sse(bridge, line) + return + response = bridge.handle_line(line) + if response is None: + # MCP notification — no body, ack with 202. + self._send_blank(status=202) + return + self._send_raw_json(response) + + def _authorize(self) -> bool: + """Validate Bearer token if the server has one configured.""" + expected: Optional[str] = self.server.auth_token # type: ignore[attr-defined] + if expected is None: + return True + header = self.headers.get("Authorization", "") + if not header.startswith("Bearer "): + self._send_json({"error": "missing bearer token"}, status=401) + return False + provided = header[len("Bearer "):].strip() + if not hmac.compare_digest(provided, expected): + self._send_json({"error": "invalid bearer token"}, status=403) + return False + return True + + def _client_accepts_sse(self) -> bool: + accept = self.headers.get("Accept", "") + return _SSE_MEDIA_TYPE in accept + + def _dispatch_sse(self, bridge: MCPServer, line: str) -> None: + """Stream progress notifications + the final response as SSE events.""" + # Force connection close so the client gets EOF after the last event. + self.close_connection = True + self.send_response(200) + self.send_header("Content-Type", + f"{_SSE_MEDIA_TYPE}; charset=utf-8") + self.send_header("Cache-Control", "no-cache") + self.send_header("Connection", "close") + self.end_headers() + send_lock = threading.Lock() + + def emit(payload: str) -> None: + with send_lock: + self.wfile.write(b"data: ") + self.wfile.write(payload.encode("utf-8")) + self.wfile.write(b"\n\n") + self.wfile.flush() + + sse_lock = self.server.sse_lock # type: ignore[attr-defined] + with sse_lock: + saved = (bridge._notifier, bridge._writer, + bridge._concurrent_tools) + bridge._notifier = lambda method, params: emit( + _notification_message(method, params), + ) + bridge._writer = emit + bridge._concurrent_tools = False + try: + response = bridge.handle_line(line) + if response is not None: + emit(response) + finally: + bridge._notifier, bridge._writer, bridge._concurrent_tools = saved + + def do_GET(self) -> None: # noqa: N802 # reason: stdlib API + if not self._authorize(): + return + # MCP optionally allows server→client SSE on GET; not used here. + self._send_json({"error": "GET stream not supported"}, status=405) + + def do_DELETE(self) -> None: # noqa: N802 # reason: stdlib API + if not self._authorize(): + return + # Sessionless server — accept the terminate so clients can cleanup. + self._send_json({"status": "session terminated"}) + + # --- helpers ------------------------------------------------------------- + + def _read_body(self) -> Optional[str]: + length = int(self.headers.get("Content-Length", "0") or "0") + if length <= 0 or length > _MAX_BODY: + self._send_json({"error": "invalid Content-Length"}, status=400) + return None + raw = self.rfile.read(length) + try: + return raw.decode("utf-8").strip() + except UnicodeDecodeError: + self._send_json({"error": "body must be UTF-8"}, status=400) + return None + + def _send_json(self, payload: Any, status: int = 200) -> None: + body = json.dumps(payload, ensure_ascii=False).encode("utf-8") + self._write_headers(status, body) + self.wfile.write(body) + + def _send_raw_json(self, raw_json: str) -> None: + body = raw_json.encode("utf-8") + self._write_headers(200, body) + self.wfile.write(body) + + def _send_blank(self, status: int) -> None: + self.send_response(status) + self.send_header("Content-Length", "0") + self.end_headers() + + def _write_headers(self, status: int, body: bytes) -> None: + self.send_response(status) + self.send_header("Content-Type", "application/json; charset=utf-8") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + + +class _MCPHttpServer(ThreadingHTTPServer): + """ThreadingHTTPServer extension that owns an :class:`MCPServer`.""" + + def __init__(self, server_address: Tuple[str, int], + mcp: MCPServer, + auth_token: Optional[str] = None) -> None: + super().__init__(server_address, _MCPHttpHandler) + self.mcp = mcp + self.auth_token = auth_token + # Serialise SSE requests — they swap server-wide notifier/writer + # state, so concurrent SSE streams would race. POST-without-SSE + # requests don't take this lock and remain fully concurrent. + self.sse_lock = threading.Lock() + + +class HttpMCPServer: + """Threaded HTTP transport for the MCP dispatcher.""" + + def __init__(self, mcp: Optional[MCPServer] = None, + host: str = "127.0.0.1", port: int = 9940, + auth_token: Optional[str] = None, + ssl_context: Optional[ssl.SSLContext] = None, + ) -> None: + self._mcp = mcp if mcp is not None else MCPServer() + self._address: Tuple[str, int] = (host, port) + self._auth_token = auth_token if auth_token is not None else ( + os.environ.get("JE_AUTOCONTROL_MCP_TOKEN") or None + ) + self._ssl_context = ssl_context + self._server: Optional[_MCPHttpServer] = None + self._thread: Optional[threading.Thread] = None + + @property + def address(self) -> Tuple[str, int]: + """Return the resolved (host, port) tuple after :meth:`start`.""" + return self._address + + @property + def mcp(self) -> MCPServer: + return self._mcp + + def start(self) -> None: + """Bind the socket and begin serving on a background thread.""" + if self._server is not None: + return + self._server = _MCPHttpServer( + self._address, self._mcp, auth_token=self._auth_token, + ) + if self._ssl_context is not None: + self._server.socket = self._ssl_context.wrap_socket( + self._server.socket, server_side=True, + ) + self._address = self._server.server_address[:2] + self._thread = threading.Thread( + target=self._server.serve_forever, daemon=True, + name="AutoControlMCPHttp", + ) + self._thread.start() + scheme = "https" if self._ssl_context is not None else "http" + autocontrol_logger.info("MCP %s listening on %s:%d", scheme, + *self._address) + + def stop(self, timeout: float = 2.0) -> None: + if self._server is None: + return + self._server.shutdown() + self._server.server_close() + if self._thread is not None: + self._thread.join(timeout=timeout) + self._server = None + self._thread = None + + +def start_mcp_http_server(host: str = "127.0.0.1", port: int = 9940, + mcp: Optional[MCPServer] = None, + auth_token: Optional[str] = None, + ssl_context: Optional[ssl.SSLContext] = None, + ) -> HttpMCPServer: + """Start and return an :class:`HttpMCPServer`; convenience wrapper.""" + server = HttpMCPServer( + mcp=mcp, host=host, port=port, + auth_token=auth_token, ssl_context=ssl_context, + ) + server.start() + return server + + +__all__ = ["HttpMCPServer", "start_mcp_http_server"] diff --git a/je_auto_control/utils/mcp_server/log_bridge.py b/je_auto_control/utils/mcp_server/log_bridge.py new file mode 100644 index 0000000..c268148 --- /dev/null +++ b/je_auto_control/utils/mcp_server/log_bridge.py @@ -0,0 +1,90 @@ +"""Bridge Python logging records onto MCP ``notifications/message``. + +Once attached to the project logger, every record at or above the +configured level is forwarded to the MCP client as a notification so +the client can mirror server-side activity in its UI. The handler is +no-op when the server's notifier is not yet connected — useful for +unit tests that don't actually start a transport. +""" +import logging +from typing import Any, Callable, Dict, Optional + +# MCP log levels (RFC 5424 syslog names) mapped from stdlib logging levels. +_LEVEL_NAME_FROM_LEVEL = { + logging.DEBUG: "debug", + logging.INFO: "info", + logging.WARNING: "warning", + logging.ERROR: "error", + logging.CRITICAL: "critical", +} + +_LEVEL_FROM_MCP_NAME = { + "debug": logging.DEBUG, + "info": logging.INFO, + "notice": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, + "alert": logging.CRITICAL, + "emergency": logging.CRITICAL, +} + + +def mcp_level_to_logging(name: str) -> Optional[int]: + """Return the :mod:`logging` level for an MCP log name, or ``None``.""" + return _LEVEL_FROM_MCP_NAME.get(str(name).strip().lower()) + + +def logging_level_to_mcp(level: int) -> str: + """Return the closest MCP level name for a stdlib logging level.""" + closest = max( + (lvl for lvl in _LEVEL_NAME_FROM_LEVEL if lvl <= int(level)), + default=logging.DEBUG, + ) + return _LEVEL_NAME_FROM_LEVEL[closest] + + +class MCPLogBridge(logging.Handler): + """Logging handler that forwards records as ``notifications/message``.""" + + def __init__(self, notifier: Optional[ + Callable[[str, Dict[str, Any]], None]] = None, + logger_name: str = "je_auto_control", + level: int = logging.INFO) -> None: + super().__init__(level=level) + self._notifier = notifier + self._logger_name = str(logger_name) + + def set_notifier(self, notifier: Optional[ + Callable[[str, Dict[str, Any]], None]]) -> None: + self._notifier = notifier + + def emit(self, record: logging.LogRecord) -> None: + notifier = self._notifier + if notifier is None: + return + try: + text = record.getMessage() + except (TypeError, ValueError): + text = str(record.msg) + params: Dict[str, Any] = { + "level": logging_level_to_mcp(record.levelno), + "logger": self._logger_name, + "data": { + "logger": record.name, + "message": text, + "module": record.module, + "func": record.funcName, + "line": record.lineno, + }, + } + try: + notifier("notifications/message", params) + except (OSError, RuntimeError, ValueError): + # The bridge must never crash the producer. + pass + + +__all__ = [ + "MCPLogBridge", "logging_level_to_mcp", "mcp_level_to_logging", +] diff --git a/je_auto_control/utils/mcp_server/plugin_watcher.py b/je_auto_control/utils/mcp_server/plugin_watcher.py new file mode 100644 index 0000000..1e0e282 --- /dev/null +++ b/je_auto_control/utils/mcp_server/plugin_watcher.py @@ -0,0 +1,131 @@ +"""Background watcher that hot-reloads plugin tools when files change. + +Polls a plugin directory at a configurable interval, comparing each +``*.py`` file's mtime to its previous reading. When a file changes +(created, modified, or removed) the watcher reloads it via the +plugin loader and registers / unregisters MCP tools so the model +sees the updated catalogue without a server restart. +""" +import os +import threading +from typing import Any, Dict, List, Optional, Set + +from je_auto_control.utils.logging.logging_instance import autocontrol_logger +from je_auto_control.utils.mcp_server.tools.plugin_tools import ( + make_plugin_tool, +) + + +class PluginWatcher: + """Polling watcher that keeps an MCPServer's registry in sync with disk.""" + + def __init__(self, server: Any, directory: str, + poll_seconds: float = 2.0) -> None: + self._server = server + self._directory = os.path.realpath(os.fspath(directory)) + self._poll_seconds = max(0.2, float(poll_seconds)) + self._stop = threading.Event() + self._thread: Optional[threading.Thread] = None + # path → (mtime, [tool_names]) + self._known: Dict[str, tuple] = {} + + @property + def directory(self) -> str: + return self._directory + + def start(self) -> None: + if self._thread is not None and self._thread.is_alive(): + return + if not os.path.isdir(self._directory): + raise NotADirectoryError( + f"plugin directory not found: {self._directory}" + ) + self._stop.clear() + self._thread = threading.Thread( + target=self._run, daemon=True, name="MCPPluginWatcher", + ) + self._thread.start() + + def stop(self, timeout: float = 2.0) -> None: + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=timeout) + self._thread = None + + def poll_once(self) -> None: + """Run one scan-and-sync iteration. Public for tests.""" + seen: Set[str] = set() + for entry in sorted(os.listdir(self._directory)): + if not entry.endswith(".py") or entry.startswith("_"): + continue + full = os.path.join(self._directory, entry) + if not os.path.isfile(full): + continue + seen.add(full) + try: + mtime = os.path.getmtime(full) + except OSError: + continue + previous = self._known.get(full) + if previous is None or previous[0] != mtime: + self._reload_file(full, mtime) + for stale in set(self._known) - seen: + self._unregister_file(stale) + + # --- internals ---------------------------------------------------------- + + def _run(self) -> None: + autocontrol_logger.info( + "plugin watcher started: %s (every %ss)", + self._directory, self._poll_seconds, + ) + while not self._stop.is_set(): + try: + self.poll_once() + except OSError as error: + autocontrol_logger.warning( + "plugin watcher poll failed: %r", error, + ) + self._stop.wait(self._poll_seconds) + autocontrol_logger.info("plugin watcher stopped") + + def _reload_file(self, path: str, mtime: float) -> None: + from je_auto_control.utils.plugin_loader.plugin_loader import ( + load_plugin_file, + ) + previous = self._known.get(path) + if previous is not None: + for tool_name in previous[1]: + self._server.unregister_tool(tool_name) + try: + commands = load_plugin_file(path) + except (OSError, ImportError, SyntaxError) as error: + autocontrol_logger.warning( + "plugin %s reload failed: %r", path, error, + ) + self._known[path] = (mtime, []) + return + registered: List[str] = [] + for raw_name, handler in commands.items(): + tool = make_plugin_tool(raw_name, handler) + self._server.register_tool(tool) + registered.append(tool.name) + self._known[path] = (mtime, registered) + autocontrol_logger.info( + "plugin %s reloaded → %d tools", os.path.basename(path), + len(registered), + ) + + def _unregister_file(self, path: str) -> None: + previous = self._known.pop(path, None) + if previous is None: + return + for tool_name in previous[1]: + self._server.unregister_tool(tool_name) + autocontrol_logger.info( + "plugin %s removed → %d tools dropped", + os.path.basename(path), len(previous[1]), + ) + + +__all__ = ["PluginWatcher"] diff --git a/je_auto_control/utils/mcp_server/prompts.py b/je_auto_control/utils/mcp_server/prompts.py new file mode 100644 index 0000000..6b5c532 --- /dev/null +++ b/je_auto_control/utils/mcp_server/prompts.py @@ -0,0 +1,220 @@ +"""MCP prompt catalogue for AutoControl. + +Prompts are reusable task templates the MCP client can surface to the +user (typically as slash-command suggestions). The default catalogue +seeds a few common automation flows — recording-and-generalising, +visual-diff comparison, semantic widget targeting — so the model has +a quick path to common requests without re-deriving the recipe. +""" +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + + +@dataclass(frozen=True) +class MCPPromptArgument: + """One argument descriptor on a prompt template.""" + + name: str + description: Optional[str] = None + required: bool = False + + def to_descriptor(self) -> Dict[str, Any]: + descriptor: Dict[str, Any] = {"name": self.name, + "required": self.required} + if self.description is not None: + descriptor["description"] = self.description + return descriptor + + +@dataclass(frozen=True) +class MCPPrompt: + """A single prompt template: name, args, and a render callback.""" + + name: str + description: str + arguments: List[MCPPromptArgument] = field(default_factory=list) + render: Optional[Callable[[Dict[str, Any]], str]] = None + + def to_descriptor(self) -> Dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "arguments": [arg.to_descriptor() for arg in self.arguments], + } + + def get(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Return the MCP ``prompts/get`` response payload.""" + for arg in self.arguments: + if arg.required and arg.name not in arguments: + raise ValueError( + f"prompt {self.name!r} requires argument {arg.name!r}" + ) + text = (self.render(arguments) if self.render is not None + else self.description) + return { + "description": self.description, + "messages": [{ + "role": "user", + "content": {"type": "text", "text": text}, + }], + } + + +class PromptProvider: + """Pluggable prompt source. Subclasses override list / get.""" + + def list(self) -> List[MCPPrompt]: # pragma: no cover - abstract + raise NotImplementedError + + def get(self, name: str, + arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]: + raise NotImplementedError + + +class StaticPromptProvider(PromptProvider): + """Wraps a fixed list of :class:`MCPPrompt` objects.""" + + def __init__(self, prompts: List[MCPPrompt]) -> None: + self._prompts: Dict[str, MCPPrompt] = {p.name: p for p in prompts} + + def list(self) -> List[MCPPrompt]: + return list(self._prompts.values()) + + def get(self, name: str, + arguments: Dict[str, Any]) -> Optional[Dict[str, Any]]: + prompt = self._prompts.get(name) + if prompt is None: + return None + return prompt.get(arguments) + + +# === Default catalogue ====================================================== + +def _automate_ui_task(args: Dict[str, Any]) -> str: + task = args.get("task", "") + return ( + "You are driving the host machine through AutoControl's MCP " + "tools.\n\n" + f"Goal: {task}\n\n" + "Plan and execute step-by-step. Prefer in this order:\n" + "1. ac_a11y_find / ac_a11y_click for known widgets\n" + "2. ac_locate_text / ac_click_text when text is visible on screen\n" + "3. ac_locate_image_center / ac_locate_and_click for icons\n" + "4. ac_vlm_locate / ac_vlm_click as a last-resort fallback\n\n" + "Take a screenshot (ac_screenshot) before destructive actions so " + "you can verify state. Ask the user to confirm before issuing " + "irreversible operations (closing a window, executing a script " + "file, etc.)." + ) + + +def _record_and_generalize(args: Dict[str, Any]) -> str: + name = args.get("script_name", "recording.json") + return ( + "Record a manual demonstration and generalise it into a reusable " + "script.\n\n" + "1. Call ac_record_start. Tell the user when recording is live.\n" + "2. Wait for them to finish, then call ac_record_stop.\n" + "3. Inspect the captured action list. Replace literal coordinates " + "with semantic targeting where possible (ac_a11y_find names, " + "ac_locate_text strings).\n" + "4. Use ac_adjust_delays / ac_scale_coordinates to make the " + "script resolution-independent if the user asks.\n" + f"5. Persist the result with ac_write_action_file file_path={name!r}.\n" + ) + + +def _compare_screenshots(args: Dict[str, Any]) -> str: + label = args.get("label", "before / after") + return ( + f"Compare two screenshots ({label}). Use ac_screenshot to grab " + "both frames so you can see them, describe each panel's layout, " + "and call out every change you can identify (text, controls, " + "highlighted state, error dialogs). Finish with a one-paragraph " + "summary of what changed." + ) + + +def _find_widget(args: Dict[str, Any]) -> str: + widget = args.get("description", "") + return ( + f"Locate {widget} on screen. Try the cheapest, most reliable " + "approach first:\n" + f"1. ac_a11y_find with name and/or role matching {widget}.\n" + "2. ac_locate_text if the widget has visible label text.\n" + "3. ac_locate_image_center against a saved template if you have one.\n" + "4. ac_vlm_locate as a last resort.\n" + "Report the screen coordinates and the strategy that worked, " + "or say which strategies failed if nothing matches." + ) + + +def _explain_action_file(args: Dict[str, Any]) -> str: + path = args.get("file_path", "") + return ( + f"Read the action JSON at {path!r} via ac_read_action_file, then " + "explain in plain language what running it would do. Group steps " + "into intent-level bullets ('open the start menu', 'type the " + "username') rather than translating each AC_* command literally." + ) + + +def default_prompt_catalogue() -> List[MCPPrompt]: + """Return the bundled prompt templates.""" + return [ + MCPPrompt( + name="automate_ui_task", + description="Plan and execute a desktop automation task end-to-end.", + arguments=[MCPPromptArgument( + "task", "Natural-language description of what to accomplish.", + required=True, + )], + render=_automate_ui_task, + ), + MCPPrompt( + name="record_and_generalize", + description="Capture a manual demo and turn it into a reusable script.", + arguments=[MCPPromptArgument( + "script_name", "Where to save the generalised script.", + )], + render=_record_and_generalize, + ), + MCPPrompt( + name="compare_screenshots", + description="Take two screenshots and explain the visual diff.", + arguments=[MCPPromptArgument( + "label", "Optional label for the comparison (e.g. 'before/after').", + )], + render=_compare_screenshots, + ), + MCPPrompt( + name="find_widget", + description="Locate a UI widget using the cheapest reliable strategy.", + arguments=[MCPPromptArgument( + "description", "Natural-language description of the widget.", + required=True, + )], + render=_find_widget, + ), + MCPPrompt( + name="explain_action_file", + description="Read an action JSON file and summarise it in plain language.", + arguments=[MCPPromptArgument( + "file_path", "Absolute or relative path to the action JSON.", + required=True, + )], + render=_explain_action_file, + ), + ] + + +def default_prompt_provider() -> PromptProvider: + """Return the bundled prompt provider used by the default MCP server.""" + return StaticPromptProvider(default_prompt_catalogue()) + + +__all__ = [ + "MCPPrompt", "MCPPromptArgument", "PromptProvider", + "StaticPromptProvider", "default_prompt_catalogue", + "default_prompt_provider", +] diff --git a/je_auto_control/utils/mcp_server/rate_limit.py b/je_auto_control/utils/mcp_server/rate_limit.py new file mode 100644 index 0000000..c029aff --- /dev/null +++ b/je_auto_control/utils/mcp_server/rate_limit.py @@ -0,0 +1,48 @@ +"""Token-bucket rate limiter for MCP tool calls. + +Default config is generous (60 calls / second sustained, burst 60), +intended only as a safety net against runaway loops. Deployments +that need a stricter ceiling pass a custom :class:`RateLimiter` to +:class:`MCPServer`. +""" +import threading +import time +from typing import Optional + + +class RateLimiter: + """Standard token-bucket: refill ``rate_per_sec`` tokens per second up to ``capacity``.""" + + def __init__(self, rate_per_sec: float = 60.0, + capacity: Optional[float] = None) -> None: + self._rate = max(0.0, float(rate_per_sec)) + self._capacity = float(capacity) if capacity is not None else self._rate + self._tokens = self._capacity + self._last = time.monotonic() + self._lock = threading.Lock() + + @property + def rate_per_sec(self) -> float: + return self._rate + + @property + def capacity(self) -> float: + return self._capacity + + def try_acquire(self) -> bool: + """Take one token if available; return ``True`` on success.""" + if self._rate <= 0: + return True + with self._lock: + now = time.monotonic() + elapsed = now - self._last + self._last = now + self._tokens = min(self._capacity, + self._tokens + elapsed * self._rate) + if self._tokens < 1.0: + return False + self._tokens -= 1.0 + return True + + +__all__ = ["RateLimiter"] diff --git a/je_auto_control/utils/mcp_server/resources.py b/je_auto_control/utils/mcp_server/resources.py new file mode 100644 index 0000000..4ef53b1 --- /dev/null +++ b/je_auto_control/utils/mcp_server/resources.py @@ -0,0 +1,299 @@ +"""MCP resource providers for AutoControl. + +Resources let an MCP client browse data the server has to offer +without invoking a tool — typical use cases here are listing the +JSON action library on disk, fetching the run-history snapshot, and +inspecting which executor commands the model can call. The provider +abstraction lets callers compose custom sources without touching the +JSON-RPC layer. +""" +import base64 +import io +import json +import os +import threading +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +_MIME_JSON = "application/json" + + +@dataclass(frozen=True) +class MCPResource: + """One resource entry surfaced to MCP clients via ``resources/list``.""" + + uri: str + name: str + description: Optional[str] = None + mime_type: Optional[str] = None + + def to_descriptor(self) -> Dict[str, Any]: + descriptor: Dict[str, Any] = {"uri": self.uri, "name": self.name} + if self.description is not None: + descriptor["description"] = self.description + if self.mime_type is not None: + descriptor["mimeType"] = self.mime_type + return descriptor + + +class ResourceProvider: + """Pluggable source. Subclasses override :meth:`list` and :meth:`read`.""" + + def list(self) -> List[MCPResource]: # pragma: no cover - abstract + raise NotImplementedError + + def read(self, uri: str) -> Optional[Dict[str, Any]]: # pragma: no cover - abstract + """Return one content block (``{uri, mimeType, text}``) or ``None``.""" + raise NotImplementedError + + def set_workspace_root(self, root: str) -> None: + """Hook for MCP roots. Default: no-op. FS-backed providers override.""" + del root + + def subscribe(self, uri: str, + on_update: Callable[[], None]) -> Optional[Any]: + """Optional hook: start emitting ``on_update`` calls until unsubscribed. + + Return a non-``None`` handle when this provider owns ``uri`` and + accepted the subscription. The default implementation returns + ``None`` (not subscribable). + """ + del uri, on_update + return None + + def unsubscribe(self, uri: str, handle: Any) -> None: + """Cancel a previous :meth:`subscribe` handle.""" + del uri, handle + + +class FileSystemProvider(ResourceProvider): + """Expose ``*.json`` action files in ``root`` under ``://files/``.""" + + def __init__(self, root: str = ".", + scheme: str = "autocontrol") -> None: + self.root = os.path.realpath(root) + self.scheme = scheme + + def set_workspace_root(self, root: str) -> None: + """Re-target the provider at a new directory (e.g. via MCP roots).""" + self.root = os.path.realpath(os.fspath(root)) + + def list(self) -> List[MCPResource]: + if not os.path.isdir(self.root): + return [] + out: List[MCPResource] = [] + for name in sorted(os.listdir(self.root)): + if not name.endswith(".json"): + continue + full = os.path.join(self.root, name) + if not os.path.isfile(full): + continue + out.append(MCPResource( + uri=f"{self.scheme}://files/{name}", + name=name, + description=f"action JSON file in {self.root}", + mime_type=_MIME_JSON, + )) + return out + + def read(self, uri: str) -> Optional[Dict[str, Any]]: + prefix = f"{self.scheme}://files/" + if not uri.startswith(prefix): + return None + rel = uri[len(prefix):] + if "/" in rel or rel.startswith(".") or not rel: + return None + path = os.path.realpath(os.path.join(self.root, rel)) + if not path.startswith(self.root + os.sep) and path != self.root: + return None + if not os.path.isfile(path): + return None + with open(path, encoding="utf-8") as handle: + text = handle.read() + return {"uri": uri, "mimeType": _MIME_JSON, "text": text} + + +class HistoryProvider(ResourceProvider): + """Expose recent run-history records under ``autocontrol://history``.""" + + URI = "autocontrol://history" + + def list(self) -> List[MCPResource]: + return [MCPResource( + uri=self.URI, name="run_history", + description="Recent script-run history records (last 100).", + mime_type=_MIME_JSON, + )] + + def read(self, uri: str) -> Optional[Dict[str, Any]]: + if uri != self.URI: + return None + from je_auto_control.utils.run_history.history_store import ( + default_history_store, + ) + rows = default_history_store.list_runs(limit=100) + data = [{ + "id": row.id, "source_type": row.source_type, + "source_id": row.source_id, "script_path": row.script_path, + "started_at": str(row.started_at), + "finished_at": str(row.finished_at), + "status": row.status, "error_text": row.error_text, + "duration_seconds": row.duration_seconds, + } for row in rows] + return { + "uri": uri, "mimeType": _MIME_JSON, + "text": json.dumps(data, ensure_ascii=False, indent=2), + } + + +class CommandsProvider(ResourceProvider): + """Expose the executor command catalogue under ``autocontrol://commands``.""" + + URI = "autocontrol://commands" + + def list(self) -> List[MCPResource]: + return [MCPResource( + uri=self.URI, name="executor_commands", + description="Every AC_* command name the executor recognises.", + mime_type=_MIME_JSON, + )] + + def read(self, uri: str) -> Optional[Dict[str, Any]]: + if uri != self.URI: + return None + from je_auto_control.utils.executor.action_executor import executor + names = sorted(executor.known_commands()) + return { + "uri": uri, "mimeType": _MIME_JSON, + "text": json.dumps(names, ensure_ascii=False, indent=2), + } + + +class ChainProvider(ResourceProvider): + """Composite that fans out to a tuple of child providers.""" + + def __init__(self, providers: List[ResourceProvider]) -> None: + self.providers = list(providers) + + def list(self) -> List[MCPResource]: + out: List[MCPResource] = [] + for provider in self.providers: + out.extend(provider.list()) + return out + + def read(self, uri: str) -> Optional[Dict[str, Any]]: + for provider in self.providers: + content = provider.read(uri) + if content is not None: + return content + return None + + def set_workspace_root(self, root: str) -> None: + """Forward the root to every child provider.""" + for provider in self.providers: + provider.set_workspace_root(root) + + def subscribe(self, uri: str, + on_update: Callable[[], None]) -> Optional[Any]: + for provider in self.providers: + handle = provider.subscribe(uri, on_update) + if handle is not None: + return (provider, handle) + return None + + def unsubscribe(self, uri: str, handle: Any) -> None: + if not isinstance(handle, tuple) or len(handle) != 2: + return + provider, child_handle = handle + provider.unsubscribe(uri, child_handle) + + +class LiveScreenProvider(ResourceProvider): + """Live screen feed at ``autocontrol://screen/live``. + + ``read`` always grabs a fresh PNG (base64-encoded). Subscribers + receive ``on_update`` calls every ``poll_seconds`` so they can + re-fetch the resource and surface live state to the model. + """ + + URI = "autocontrol://screen/live" + + def __init__(self, poll_seconds: float = 1.0) -> None: + self._poll_seconds = max(0.1, float(poll_seconds)) + self._lock = threading.Lock() + self._subscribers: Dict[int, Callable[[], None]] = {} + self._next_handle = 1 + self._thread: Optional[threading.Thread] = None + self._stop = threading.Event() + + def list(self) -> List[MCPResource]: + return [MCPResource( + uri=self.URI, name="screen_live", + description=("Current screen as base64 PNG. Subscribe to be " + "notified when it should be re-fetched."), + mime_type="image/png", + )] + + def read(self, uri: str) -> Optional[Dict[str, Any]]: + if uri != self.URI: + return None + from je_auto_control.utils.cv2_utils.screenshot import pil_screenshot + image = pil_screenshot() + buffer = io.BytesIO() + image.save(buffer, format="PNG") + encoded = base64.b64encode(buffer.getvalue()).decode("ascii") + return {"uri": uri, "mimeType": "image/png", "blob": encoded} + + def subscribe(self, uri: str, + on_update: Callable[[], None]) -> Optional[Any]: + if uri != self.URI: + return None + with self._lock: + handle = self._next_handle + self._next_handle += 1 + self._subscribers[handle] = on_update + if self._thread is None or not self._thread.is_alive(): + self._stop.clear() + self._thread = threading.Thread( + target=self._broadcast_loop, daemon=True, + name="MCPLiveScreen", + ) + self._thread.start() + return handle + + def unsubscribe(self, uri: str, handle: Any) -> None: + if uri != self.URI: + return + with self._lock: + self._subscribers.pop(int(handle), None) + if not self._subscribers: + self._stop.set() + self._thread = None + + def _broadcast_loop(self) -> None: + while not self._stop.is_set(): + with self._lock: + callbacks = list(self._subscribers.values()) + for callback in callbacks: + try: + callback() + except (OSError, RuntimeError, ValueError): + pass + self._stop.wait(self._poll_seconds) + + +def default_resource_provider(root: str = ".") -> ResourceProvider: + """Return the resource provider exposed by the default MCP server.""" + return ChainProvider([ + FileSystemProvider(root=root), + HistoryProvider(), + CommandsProvider(), + LiveScreenProvider(), + ]) + + +__all__ = [ + "ChainProvider", "CommandsProvider", "FileSystemProvider", + "HistoryProvider", "LiveScreenProvider", "MCPResource", + "ResourceProvider", "default_resource_provider", +] diff --git a/je_auto_control/utils/mcp_server/server.py b/je_auto_control/utils/mcp_server/server.py new file mode 100644 index 0000000..4567208 --- /dev/null +++ b/je_auto_control/utils/mcp_server/server.py @@ -0,0 +1,756 @@ +"""Minimal MCP server speaking JSON-RPC 2.0 over stdio. + +Implements the subset of the Model Context Protocol that Claude clients +(Claude Desktop, Claude Code, Claude API) use to discover and invoke +tools: ``initialize``, ``tools/list``, ``tools/call``, ``ping``, and +``notifications/initialized``. Each transport line is one JSON-RPC +message — no Content-Length framing — matching the MCP stdio spec. +""" +import itertools +import json +import os +import sys +import threading +import time +from typing import Any, Callable, Dict, List, Optional, TextIO + +from je_auto_control.utils.logging.logging_instance import autocontrol_logger +from je_auto_control.utils.mcp_server.audit import AuditLogger +from je_auto_control.utils.mcp_server.context import ( + OperationCancelledError, ToolCallContext, +) +from je_auto_control.utils.mcp_server.log_bridge import ( + MCPLogBridge, mcp_level_to_logging, +) +from je_auto_control.utils.mcp_server.rate_limit import RateLimiter +from je_auto_control.utils.mcp_server.prompts import ( + PromptProvider, default_prompt_provider, +) +from je_auto_control.utils.mcp_server.resources import ( + ResourceProvider, default_resource_provider, +) +from je_auto_control.utils.mcp_server.tools import ( + MCPContent, MCPTool, build_default_tool_registry, +) +from je_auto_control.utils.mcp_server.tools._validation import ( + validate_arguments, +) + +PROTOCOL_VERSION = "2025-06-18" +SERVER_NAME = "je_auto_control" +SERVER_VERSION = "0.1.0" +_TOOLS_CALL_METHOD = "tools/call" + + +class _MCPError(Exception): + """Raised inside the dispatcher to surface a JSON-RPC error response.""" + + def __init__(self, code: int, message: str) -> None: + super().__init__(message) + self.code = code + self.message = message + + +class MCPServer: + """JSON-RPC 2.0 MCP server with a configurable tool registry.""" + + def __init__(self, tools: Optional[List[MCPTool]] = None, + resource_provider: Optional[ResourceProvider] = None, + prompt_provider: Optional[PromptProvider] = None, + concurrent_tools: bool = False, + audit_logger: Optional[AuditLogger] = None, + rate_limiter: Optional[RateLimiter] = None, + log_bridge: Optional[MCPLogBridge] = None, + ) -> None: + registry = tools if tools is not None else build_default_tool_registry() + self._tools: Dict[str, MCPTool] = {tool.name: tool for tool in registry} + self._resources = (resource_provider if resource_provider is not None + else default_resource_provider()) + self._prompts = (prompt_provider if prompt_provider is not None + else default_prompt_provider()) + self._concurrent_tools = bool(concurrent_tools) + self._audit = (audit_logger if audit_logger is not None + else AuditLogger()) + self._rate_limiter = rate_limiter + self._log_bridge = log_bridge + self._stop = threading.Event() + self._initialized = False + self._notifier: Optional[Callable[[str, Dict[str, Any]], None]] = None + self._writer: Optional[Callable[[str], None]] = None + self._active_calls: Dict[Any, ToolCallContext] = {} + self._calls_lock = threading.Lock() + self._write_lock = threading.Lock() + self._sampling_id_counter = itertools.count(1) + self._outbound_id_counter = itertools.count(1) + self._pending_outbound: Dict[Any, Dict[str, Any]] = {} + self._outbound_lock = threading.Lock() + self._client_capabilities: Dict[str, Any] = {} + self._resource_subscriptions: Dict[str, Any] = {} + self._subscriptions_lock = threading.Lock() + + def register_tool(self, tool: MCPTool) -> None: + """Add or replace a tool in the live registry. + + Emits ``notifications/tools/list_changed`` to the connected + client so it knows to refresh its cached tool list. + """ + self._tools[tool.name] = tool + self._notify_tools_list_changed() + + def unregister_tool(self, name: str) -> bool: + """Remove a tool by name. Returns True if it existed.""" + if name not in self._tools: + return False + del self._tools[name] + self._notify_tools_list_changed() + return True + + def _notify_tools_list_changed(self) -> None: + notifier = self._notifier + if notifier is None: + return + try: + notifier("notifications/tools/list_changed", {}) + except (OSError, RuntimeError, ValueError): + autocontrol_logger.exception( + "MCP failed to send tools/list_changed", + ) + + def stop(self) -> None: + """Request the stdio loop to exit at its next iteration.""" + self._stop.set() + + def serve_stdio(self, stdin: Optional[TextIO] = None, + stdout: Optional[TextIO] = None) -> None: + """Run the message loop until EOF on stdin or :meth:`stop`.""" + in_stream = stdin if stdin is not None else sys.stdin + out_stream = stdout if stdout is not None else sys.stdout + autocontrol_logger.info( + "MCP server starting (stdio, %d tools)", len(self._tools), + ) + prior_notifier = self._notifier + prior_writer = self._writer + prior_concurrent = self._concurrent_tools + self._writer = lambda payload: self._write_message(out_stream, payload) + self._notifier = lambda method, params: self._writer( # type: ignore[misc] + _notification_message(method, params), + ) + # Stdio always opts into concurrent tool execution so sampling + # requests issued by tool handlers don't block the reader. + self._concurrent_tools = True + self._attach_log_bridge_if_configured() + try: + while not self._stop.is_set(): + line = in_stream.readline() + if line == "": + break + line = line.strip() + if not line: + continue + response = self.handle_line(line) + if response is not None: + self._write_message(out_stream, response) + finally: + self._detach_log_bridge_if_configured() + self._notifier = prior_notifier + self._writer = prior_writer + self._concurrent_tools = prior_concurrent + autocontrol_logger.info("MCP server stopped") + + def _write_message(self, out_stream: TextIO, payload: str) -> None: + """Serialize an outbound JSON-RPC line under a writer lock.""" + with self._write_lock: + out_stream.write(payload + "\n") + out_stream.flush() + + def set_notifier(self, + notifier: Optional[Callable[[str, Dict[str, Any]], None]] + ) -> None: + """Install a callback used to send outbound notifications. + + The HTTP transport sets this to push notifications onto an + SSE stream; the stdio loop installs its own writer. Tests may + register a list-collecting callback to inspect notifications. + """ + self._notifier = notifier + + def _attach_log_bridge_if_configured(self) -> None: + """Wire the log bridge into the project logger and notifier.""" + if self._log_bridge is None: + self._log_bridge = MCPLogBridge() + self._log_bridge.set_notifier(self._notifier) + if self._log_bridge not in autocontrol_logger.handlers: + autocontrol_logger.addHandler(self._log_bridge) + + def _detach_log_bridge_if_configured(self) -> None: + if self._log_bridge is None: + return + self._log_bridge.set_notifier(None) + try: + autocontrol_logger.removeHandler(self._log_bridge) + except ValueError: + pass + + def set_writer(self, writer: Optional[Callable[[str], None]]) -> None: + """Install a callback used to write any outbound JSON-RPC line. + + This is the lower-level companion to :meth:`set_notifier` — + used to deliver server-initiated requests (e.g. sampling) and + to emit asynchronously-produced tools/call responses when the + server is running in concurrent mode. + """ + self._writer = writer + + def handle_line(self, line: str) -> Optional[str]: + """Process one JSON-RPC line; return the response line or ``None``.""" + try: + message = json.loads(line) + except ValueError as error: + autocontrol_logger.warning("MCP parse error: %r", error) + return _error_response(None, -32700, "Parse error") + if not isinstance(message, dict): + return _error_response(None, -32600, "Invalid Request") + + method = message.get("method") + msg_id = message.get("id") + params = message.get("params") or {} + + if method is None and msg_id is not None and ( + "result" in message or "error" in message + ): + self._dispatch_outbound_response(msg_id, message) + return None + if msg_id is None: + self._handle_notification(method, params) + return None + if method == _TOOLS_CALL_METHOD and self._concurrent_tools: + self._dispatch_tools_call_async(msg_id, params) + return None + return self._build_response(msg_id, method, params) + + def _dispatch_outbound_response(self, msg_id: Any, + message: Dict[str, Any]) -> None: + """Route a JSON-RPC response to the matching pending request.""" + with self._outbound_lock: + slot = self._pending_outbound.get(msg_id) + if slot is None: + autocontrol_logger.debug( + "MCP outbound response for unknown id %r", msg_id, + ) + return + if "error" in message: + slot["error"] = message["error"] + else: + slot["result"] = message.get("result") + slot["event"].set() + + def _dispatch_tools_call_async(self, msg_id: Any, + params: Dict[str, Any]) -> None: + """Run a tools/call on a worker thread; the worker writes the reply.""" + def worker() -> None: + payload = self._build_response(msg_id, _TOOLS_CALL_METHOD, params) + writer = self._writer + if writer is None: + autocontrol_logger.warning( + "MCP async tool reply with no writer; dropping %s", msg_id, + ) + return + writer(payload) + threading.Thread( + target=worker, daemon=True, name=f"MCPCall-{msg_id}", + ).start() + + def _build_response(self, msg_id: Any, method: Optional[str], + params: Dict[str, Any]) -> str: + """Dispatch a request and serialise the result or error.""" + try: + result = self._dispatch(msg_id, method, params) + except _MCPError as error: + return _error_response(msg_id, error.code, error.message) + except OperationCancelledError as error: + autocontrol_logger.info("MCP call %s cancelled by client", msg_id) + return _error_response(msg_id, -32800, str(error)) + except (OSError, RuntimeError, ValueError, TypeError, KeyError) as error: + autocontrol_logger.exception("MCP dispatch failed") + return _error_response(msg_id, -32603, f"Internal error: {error}") + return _result_response(msg_id, result) + + def _handle_notification(self, method: Optional[str], + params: Dict[str, Any]) -> None: + """Notifications carry no id and never get a response.""" + if method == "notifications/initialized": + self._initialized = True + autocontrol_logger.info("MCP client initialized") + self._maybe_request_roots_async() + return + if method == "notifications/cancelled": + self._cancel_active_call(params) + return + if method == "notifications/roots/list_changed": + self._maybe_request_roots_async() + return + autocontrol_logger.debug("MCP notification ignored: %s", method) + + def _maybe_request_roots_async(self) -> None: + """Fire a roots/list request when the client supports it.""" + if "roots" not in self._client_capabilities: + return + if self._writer is None: + return + threading.Thread( + target=self._refresh_roots_safely, daemon=True, + name="MCPRootsRefresh", + ).start() + + def _refresh_roots_safely(self) -> None: + try: + self.refresh_roots(timeout=5.0) + except (RuntimeError, TimeoutError) as error: + autocontrol_logger.info("MCP roots refresh skipped: %r", error) + + def refresh_roots(self, timeout: float = 10.0) -> List[Dict[str, Any]]: + """Send ``roots/list`` to the client and apply the first root.""" + result = self._send_outbound_request( + "roots/list", params={}, timeout=timeout, + ) + roots_list = (result or {}).get("roots") or [] + if not isinstance(roots_list, list) or not roots_list: + return [] + first_uri = roots_list[0].get("uri") if isinstance(roots_list[0], + dict) else None + if isinstance(first_uri, str): + local_path = _file_uri_to_path(first_uri) + if local_path: + self._resources.set_workspace_root(local_path) + autocontrol_logger.info("MCP workspace root → %s", local_path) + return roots_list + + def _send_outbound_request(self, method: str, + params: Dict[str, Any], + timeout: float = 10.0) -> Dict[str, Any]: + """Send a server-initiated request and wait for the response.""" + writer = self._writer + if writer is None: + raise RuntimeError(f"{method} requires an outbound writer") + request_id = f"srv-{next(self._outbound_id_counter)}" + slot = {"event": threading.Event()} + with self._outbound_lock: + self._pending_outbound[request_id] = slot + envelope = json.dumps({ + "jsonrpc": "2.0", "id": request_id, + "method": method, "params": params, + }, ensure_ascii=False, default=str) + try: + writer(envelope) + if not slot["event"].wait(timeout=timeout): + raise TimeoutError(f"{method} timed out after {timeout}s") + finally: + with self._outbound_lock: + self._pending_outbound.pop(request_id, None) + if "error" in slot: + raise RuntimeError(f"{method} failed: {slot['error']}") + return slot.get("result") or {} + + def _cancel_active_call(self, params: Dict[str, Any]) -> None: + """Mark the matching active tool call as cancelled, if any.""" + request_id = params.get("requestId") + if request_id is None: + return + with self._calls_lock: + ctx = self._active_calls.get(request_id) + if ctx is not None: + ctx.cancelled_event.set() + autocontrol_logger.info( + "MCP cancel signalled for call %r", request_id, + ) + + def _dispatch(self, msg_id: Any, method: Optional[str], + params: Dict[str, Any]) -> Any: + if method == "initialize": + return self._handle_initialize(params) + if method == "ping": + return {} + if method == "tools/list": + return {"tools": [tool.to_descriptor() + for tool in self._tools.values()]} + if method == _TOOLS_CALL_METHOD: + return self._handle_tools_call(msg_id, params) + if method == "resources/list": + return {"resources": [resource.to_descriptor() + for resource in self._resources.list()]} + if method == "resources/read": + return self._handle_resources_read(params) + if method == "resources/subscribe": + return self._handle_resources_subscribe(params) + if method == "resources/unsubscribe": + return self._handle_resources_unsubscribe(params) + if method == "prompts/list": + return {"prompts": [prompt.to_descriptor() + for prompt in self._prompts.list()]} + if method == "prompts/get": + return self._handle_prompts_get(params) + if method == "logging/setLevel": + return self._handle_logging_set_level(params) + raise _MCPError(-32601, f"Method not found: {method}") + + def _handle_logging_set_level(self, + params: Dict[str, Any]) -> Dict[str, Any]: + name = params.get("level") + if not isinstance(name, str): + raise _MCPError(-32602, "logging/setLevel requires string 'level'") + level = mcp_level_to_logging(name) + if level is None: + raise _MCPError(-32602, f"unknown log level: {name!r}") + if self._log_bridge is None: + self._log_bridge = MCPLogBridge() + self._log_bridge.setLevel(level) + autocontrol_logger.setLevel(min(autocontrol_logger.level or level, + level) if autocontrol_logger.level + else level) + return {} + + def _handle_initialize(self, params: Dict[str, Any]) -> Dict[str, Any]: + client_version = params.get("protocolVersion", PROTOCOL_VERSION) + client_caps = params.get("capabilities") or {} + if isinstance(client_caps, dict): + self._client_capabilities = client_caps + capabilities: Dict[str, Any] = { + "tools": {"listChanged": True}, + "resources": {"listChanged": False, "subscribe": True}, + "prompts": {"listChanged": False}, + "sampling": {}, + "logging": {}, + } + if "roots" in self._client_capabilities: + capabilities["roots"] = {"listChanged": True} + return { + "protocolVersion": client_version or PROTOCOL_VERSION, + "capabilities": capabilities, + "serverInfo": {"name": SERVER_NAME, "version": SERVER_VERSION}, + } + + def _handle_resources_read(self, + params: Dict[str, Any]) -> Dict[str, Any]: + uri = params.get("uri") + if not isinstance(uri, str) or not uri: + raise _MCPError(-32602, "resources/read requires string 'uri'") + content = self._resources.read(uri) + if content is None: + raise _MCPError(-32602, f"Unknown resource: {uri}") + return {"contents": [content]} + + def _handle_resources_subscribe(self, + params: Dict[str, Any]) -> Dict[str, Any]: + uri = params.get("uri") + if not isinstance(uri, str) or not uri: + raise _MCPError(-32602, "resources/subscribe requires 'uri'") + with self._subscriptions_lock: + if uri in self._resource_subscriptions: + return {} + handle = self._resources.subscribe( + uri, lambda u=uri: self._notify_resource_updated(u), + ) + if handle is None: + raise _MCPError(-32602, f"Unsubscribable resource: {uri}") + with self._subscriptions_lock: + self._resource_subscriptions[uri] = handle + return {} + + def _handle_resources_unsubscribe(self, + params: Dict[str, Any]) -> Dict[str, Any]: + uri = params.get("uri") + if not isinstance(uri, str) or not uri: + raise _MCPError(-32602, "resources/unsubscribe requires 'uri'") + with self._subscriptions_lock: + handle = self._resource_subscriptions.pop(uri, None) + if handle is not None: + self._resources.unsubscribe(uri, handle) + return {} + + def _notify_resource_updated(self, uri: str) -> None: + notifier = self._notifier + if notifier is None: + return + try: + notifier("notifications/resources/updated", {"uri": uri}) + except (OSError, RuntimeError, ValueError): + autocontrol_logger.exception( + "MCP failed to send resources/updated for %s", uri, + ) + + def _handle_prompts_get(self, params: Dict[str, Any]) -> Dict[str, Any]: + name = params.get("name") + arguments = params.get("arguments") or {} + if not isinstance(name, str) or not name: + raise _MCPError(-32602, "prompts/get requires string 'name'") + if not isinstance(arguments, dict): + raise _MCPError(-32602, "prompts/get 'arguments' must be an object") + try: + payload = self._prompts.get(name, arguments) + except ValueError as error: + raise _MCPError(-32602, str(error)) from error + if payload is None: + raise _MCPError(-32602, f"Unknown prompt: {name}") + return payload + + def _handle_tools_call(self, msg_id: Any, + params: Dict[str, Any]) -> Dict[str, Any]: + name = params.get("name") + arguments = params.get("arguments") or {} + if not isinstance(name, str): + raise _MCPError(-32602, "tools/call requires string 'name'") + if not isinstance(arguments, dict): + raise _MCPError(-32602, "tools/call 'arguments' must be an object") + tool = self._tools.get(name) + if tool is None: + raise _MCPError(-32602, f"Unknown tool: {name}") + violation = validate_arguments(tool.input_schema, arguments) + if violation is not None: + raise _MCPError(-32602, f"Invalid arguments for {name}: {violation}") + if self._rate_limiter is not None and not self._rate_limiter.try_acquire(): + raise _MCPError(-32000, f"Rate limit exceeded for tool {name!r}") + self._maybe_confirm_destructive(name, tool, arguments) + ctx = self._build_call_context(msg_id, params) + with self._calls_lock: + self._active_calls[msg_id] = ctx + started_at = time.monotonic() + try: + result = tool.invoke(arguments, ctx=ctx) + except OperationCancelledError: + self._audit.record( + tool=name, arguments=arguments, status="cancelled", + duration_seconds=time.monotonic() - started_at, + ) + raise + except (OSError, RuntimeError, ValueError, TypeError, + AttributeError, KeyError) as error: + # NotImplementedError subclasses RuntimeError so it's already covered. + autocontrol_logger.warning("MCP tool %s failed: %r", name, error) + artifact = _capture_error_screenshot(name) + self._audit.record( + tool=name, arguments=arguments, status="error", + duration_seconds=time.monotonic() - started_at, + error_text=f"{type(error).__name__}: {error}", + artifact_path=artifact, + ) + error_text = f"{type(error).__name__}: {error}" + if artifact is not None: + error_text += f"\n(error screenshot saved to {artifact})" + return { + "content": [{"type": "text", "text": error_text}], + "isError": True, + } + finally: + with self._calls_lock: + self._active_calls.pop(msg_id, None) + self._audit.record( + tool=name, arguments=arguments, status="ok", + duration_seconds=time.monotonic() - started_at, + ) + return { + "content": _to_content_blocks(result), + "isError": False, + } + + def request_elicitation(self, message: str, + requested_schema: Optional[Dict[str, Any]] = None, + timeout: float = 60.0) -> Dict[str, Any]: + """Ask the connected client to elicit a response from the user. + + Returns the raw payload (typically ``{"action": "accept" | "decline" | "cancel", ...}``). + Requires the client to advertise the ``elicitation`` capability. + """ + params: Dict[str, Any] = {"message": str(message)} + if requested_schema is not None: + params["requestedSchema"] = requested_schema + return self._send_outbound_request( + "elicitation/create", params=params, timeout=timeout, + ) + + def request_sampling(self, messages: List[Dict[str, Any]], + system_prompt: Optional[str] = None, + max_tokens: int = 1024, + model_preferences: Optional[Dict[str, Any]] = None, + timeout: float = 120.0) -> Dict[str, Any]: + """Ask the connected client to run an LLM sampling request. + + Tools that need the model's help (e.g. an OCR fallback that + wants the model to identify a UI element from a screenshot) + can call this and receive the assistant's reply. Requires the + server to be running in concurrent mode with an outbound + writer set — typically meaning ``serve_stdio`` or the HTTP + SSE transport. + """ + writer = self._writer + if writer is None: + raise RuntimeError( + "request_sampling requires an outbound writer; " + "start serve_stdio or call set_writer() first", + ) + request_id = f"sampling-{next(self._sampling_id_counter)}" + params: Dict[str, Any] = { + "messages": list(messages), + "maxTokens": int(max_tokens), + } + if system_prompt is not None: + params["systemPrompt"] = str(system_prompt) + if model_preferences is not None: + params["modelPreferences"] = dict(model_preferences) + slot = {"event": threading.Event()} + with self._outbound_lock: + self._pending_outbound[request_id] = slot + envelope = json.dumps({ + "jsonrpc": "2.0", "id": request_id, + "method": "sampling/createMessage", "params": params, + }, ensure_ascii=False, default=str) + try: + writer(envelope) + if not slot["event"].wait(timeout=timeout): + raise TimeoutError( + f"sampling request {request_id} timed out after {timeout}s" + ) + finally: + with self._outbound_lock: + self._pending_outbound.pop(request_id, None) + if "error" in slot: + raise RuntimeError(f"sampling failed: {slot['error']}") + return slot.get("result") or {} + + def _maybe_confirm_destructive(self, name: str, tool: MCPTool, + arguments: Dict[str, Any]) -> None: + """Ask the client to confirm before running a destructive tool.""" + if not _confirm_destructive_enabled(): + return + annotations = tool.annotations + if annotations.read_only or not annotations.destructive: + return + if "elicitation" not in self._client_capabilities: + autocontrol_logger.info( + "MCP confirmation requested for %s but client lacks " + "elicitation capability — proceeding without prompt", name, + ) + return + if self._writer is None: + return + prompt = (f"AutoControl is about to run a destructive tool " + f"'{name}'. Continue?") + try: + response = self.request_elicitation( + message=prompt, requested_schema={"type": "object", + "properties": {}}, + timeout=60.0, + ) + except (RuntimeError, TimeoutError) as error: + autocontrol_logger.info( + "MCP elicitation for %s failed (%r) — refusing call", + name, error, + ) + raise _MCPError(-32000, + f"User confirmation unavailable for {name}") + action = response.get("action") if isinstance(response, dict) else None + if action != "accept": + raise _MCPError(-32000, f"User declined to run {name}: action={action!r}") + del arguments # available for future per-arg confirmation policies + + def _build_call_context(self, msg_id: Any, + params: Dict[str, Any]) -> ToolCallContext: + meta = params.get("_meta") if isinstance(params.get("_meta"), + dict) else {} + progress_token = meta.get("progressToken") if isinstance(meta, dict) else None + return ToolCallContext( + request_id=msg_id, progress_token=progress_token, + notifier=self._notifier, + ) + + +def _to_content_blocks(result: Any) -> List[Dict[str, Any]]: + """Normalise a tool's return value into MCP ``content`` blocks.""" + if isinstance(result, MCPContent): + return [result.to_dict()] + if isinstance(result, list) and result and \ + all(isinstance(item, MCPContent) for item in result): + return [item.to_dict() for item in result] + return [{"type": "text", "text": _stringify_result(result)}] + + +def _stringify_result(value: Any) -> str: + """Convert a tool return value into a model-readable string.""" + if isinstance(value, str): + return value + try: + return json.dumps(value, ensure_ascii=False, default=str) + except (TypeError, ValueError): + return repr(value) + + +def _confirm_destructive_enabled() -> bool: + """Return True when the operator wants destructive tools gated on user OK.""" + raw = os.environ.get("JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE", "") + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _capture_error_screenshot(tool_name: str) -> Optional[str]: + """Save a debug screenshot when JE_AUTOCONTROL_MCP_ERROR_SHOTS is set.""" + debug_dir = os.environ.get("JE_AUTOCONTROL_MCP_ERROR_SHOTS") + if not debug_dir: + return None + target_dir = os.path.realpath(os.fspath(debug_dir)) + try: + os.makedirs(target_dir, exist_ok=True) + except OSError as error: + autocontrol_logger.info( + "MCP error-screenshot dir unavailable: %r", error, + ) + return None + filename = f"{tool_name}_{int(time.time() * 1000)}.png" + path = os.path.join(target_dir, filename) + try: + from je_auto_control.utils.cv2_utils.screenshot import pil_screenshot + pil_screenshot(file_path=path) + except (OSError, RuntimeError, ValueError, AttributeError, + ImportError) as error: + autocontrol_logger.info( + "MCP failed to capture error screenshot: %r", error, + ) + return None + return path + + +def _file_uri_to_path(uri: str) -> Optional[str]: + """Convert a ``file://`` URI to a local filesystem path; ``None`` otherwise.""" + if not isinstance(uri, str) or not uri.startswith("file://"): + return None + from urllib.parse import unquote, urlparse + parsed = urlparse(uri) + raw_path = unquote(parsed.path) + # Windows: file:///C:/foo strips the leading slash before the drive letter. + if sys.platform.startswith("win") and raw_path.startswith("/") and \ + len(raw_path) > 2 and raw_path[2] == ":": + raw_path = raw_path[1:] + return raw_path or None + + +def _notification_message(method: str, params: Dict[str, Any]) -> str: + return json.dumps({"jsonrpc": "2.0", "method": method, "params": params}, + ensure_ascii=False, default=str) + + +def _result_response(msg_id: Any, result: Any) -> str: + return json.dumps( + {"jsonrpc": "2.0", "id": msg_id, "result": result}, + ensure_ascii=False, default=str, + ) + + +def _error_response(msg_id: Any, code: int, message: str) -> str: + return json.dumps({ + "jsonrpc": "2.0", "id": msg_id, + "error": {"code": code, "message": message}, + }, ensure_ascii=False) + + +def start_mcp_stdio_server() -> MCPServer: + """Start a stdio MCP server in the foreground; blocks until EOF.""" + server = MCPServer() + server.serve_stdio() + return server diff --git a/je_auto_control/utils/mcp_server/tools/__init__.py b/je_auto_control/utils/mcp_server/tools/__init__.py new file mode 100644 index 0000000..49cdf5c --- /dev/null +++ b/je_auto_control/utils/mcp_server/tools/__init__.py @@ -0,0 +1,98 @@ +"""MCP tool registry for AutoControl. + +The package is split into ``_base`` (value types and helpers), +``_handlers`` (adapter functions that bridge to the headless API), and +``_factories`` (per-domain ``MCPTool`` builders). Public consumers +should import only the names re-exported here. +""" +import os +from dataclasses import replace +from typing import Dict, List, Optional + +from je_auto_control.utils.mcp_server.tools._base import ( + MCPContent, MCPTool, MCPToolAnnotations, read_only_env_flag, +) +from je_auto_control.utils.mcp_server.tools._factories import ALL_FACTORIES +from je_auto_control.utils.mcp_server.tools.plugin_tools import ( + make_plugin_tool, register_plugin_tools, +) + + +# Short, model-friendly aliases for the most-used tools. Each alias is +# registered as an additional MCPTool entry pointing at the same handler. +_DEFAULT_ALIASES: Dict[str, str] = { + "click": "ac_click_mouse", + "move_mouse": "ac_set_mouse_position", + "mouse_pos": "ac_get_mouse_position", + "scroll": "ac_mouse_scroll", + "type": "ac_type_text", + "press": "ac_press_key", + "hotkey": "ac_hotkey", + "screenshot": "ac_screenshot", + "screen_size": "ac_screen_size", + "find_image": "ac_locate_image_center", + "find_text": "ac_locate_text", + "click_text": "ac_click_text", + "drag": "ac_drag", + "list_windows": "ac_list_windows", + "focus_window": "ac_focus_window", + "wait_image": "ac_wait_for_image", + "wait_pixel": "ac_wait_for_pixel", + "diff_screens": "ac_diff_screenshots", + "shell": "ac_shell", +} + + +def _aliases_enabled(explicit: Optional[bool]) -> bool: + if explicit is not None: + return bool(explicit) + raw = os.environ.get("JE_AUTOCONTROL_MCP_ALIASES", "1") + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _make_aliases(tools: List[MCPTool]) -> List[MCPTool]: + by_name: Dict[str, MCPTool] = {tool.name: tool for tool in tools} + aliases: List[MCPTool] = [] + for short, canonical in _DEFAULT_ALIASES.items(): + target = by_name.get(canonical) + if target is None: + continue + aliases.append(replace( + target, name=short, + description=f"Alias for {canonical}: {target.description}", + )) + return aliases + + +def build_default_tool_registry(read_only: Optional[bool] = None, + aliases: Optional[bool] = None, + ) -> List[MCPTool]: + """Return the full set of tools the MCP server exposes by default. + + :param read_only: when True, drop every tool whose annotations + indicate it can mutate state. When None (default), the value + of ``JE_AUTOCONTROL_MCP_READONLY`` is consulted, so deployments + can pin the server in safe mode without code changes. + :param aliases: when True, also register short model-friendly + aliases (``click``, ``type``, ``screenshot`` ...) pointing at + the canonical ``ac_*`` tools. Defaults to True; honour + ``JE_AUTOCONTROL_MCP_ALIASES=0`` to disable globally. + """ + enforce_read_only = ( + read_only_env_flag() if read_only is None else bool(read_only) + ) + tools: List[MCPTool] = [] + for factory in ALL_FACTORIES: + tools.extend(factory()) + if enforce_read_only: + tools = [tool for tool in tools if tool.annotations.read_only] + if _aliases_enabled(aliases): + tools.extend(_make_aliases(tools)) + return tools + + +__all__ = [ + "MCPContent", "MCPTool", "MCPToolAnnotations", + "build_default_tool_registry", "make_plugin_tool", + "register_plugin_tools", +] diff --git a/je_auto_control/utils/mcp_server/tools/_base.py b/je_auto_control/utils/mcp_server/tools/_base.py new file mode 100644 index 0000000..9159cd6 --- /dev/null +++ b/je_auto_control/utils/mcp_server/tools/_base.py @@ -0,0 +1,142 @@ +"""Shared types and helpers for the MCP tool registry. + +Holds the public value types (:class:`MCPContent`, +:class:`MCPToolAnnotations`, :class:`MCPTool`), the JSON-Schema +helper, and the annotation constants used by every tool factory. +""" +import inspect +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Callable, Dict, List, Optional + + +@dataclass(frozen=True) +class MCPContent: + """One content block returned to an MCP client. + + The ``type`` field follows the MCP content discriminator: ``text``, + ``image``, or ``resource``. Tools normally return plain Python + objects (auto-wrapped in a single ``text`` block); use this class + when a tool needs to return non-text content such as a screenshot. + """ + + type: str + text: Optional[str] = None + data: Optional[str] = None + mime_type: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Return the JSON shape MCP clients expect for one content block.""" + if self.type == "text": + return {"type": "text", "text": self.text or ""} + if self.type == "image": + return { + "type": "image", "data": self.data or "", + "mimeType": self.mime_type or "image/png", + } + return {"type": self.type, "text": self.text or ""} + + @classmethod + def text_block(cls, text: str) -> "MCPContent": + return cls(type="text", text=text) + + @classmethod + def image_block(cls, data: str, + mime_type: str = "image/png") -> "MCPContent": + return cls(type="image", data=data, mime_type=mime_type) + + +@dataclass(frozen=True) +class MCPToolAnnotations: + """MCP behaviour hints surfaced to the client per the 2025-03-26 spec. + + Defaults follow the spec: a tool is assumed to mutate state in an + open world unless it explicitly opts in to read-only / closed-world. + These hints are advisory — clients may use them to require user + confirmation before destructive calls but MUST NOT rely on them for + security. + """ + + title: Optional[str] = None + read_only: bool = False + destructive: bool = True + idempotent: bool = False + open_world: bool = True + + def to_dict(self) -> Dict[str, Any]: + """Return the JSON shape MCP clients expect under ``annotations``.""" + annotations: Dict[str, Any] = { + "readOnlyHint": self.read_only, + "destructiveHint": False if self.read_only else self.destructive, + "idempotentHint": self.idempotent, + "openWorldHint": self.open_world, + } + if self.title is not None: + annotations["title"] = self.title + return annotations + + +@dataclass(frozen=True) +class MCPTool: + """A single MCP tool — public name, schema, and Python callable.""" + + name: str + description: str + input_schema: Dict[str, Any] + handler: Callable[..., Any] + annotations: MCPToolAnnotations = MCPToolAnnotations() + + def to_descriptor(self) -> Dict[str, Any]: + """Return the dict shape MCP clients expect from ``tools/list``.""" + return { + "name": self.name, + "description": self.description, + "inputSchema": self.input_schema, + "annotations": self.annotations.to_dict(), + } + + def invoke(self, arguments: Dict[str, Any], ctx: Any = None) -> Any: + """Call the underlying handler with keyword arguments. + + Handlers that declare a ``ctx`` parameter receive the + :class:`~je_auto_control.utils.mcp_server.context.ToolCallContext` + for the active call, which lets them report progress and + observe cooperative cancellation. Handlers that do not + declare ``ctx`` see the original behaviour unchanged. + """ + if ctx is not None and _handler_accepts_ctx(self.handler): + return self.handler(ctx=ctx, **arguments) + return self.handler(**arguments) + + +@lru_cache(maxsize=512) +def _handler_accepts_ctx(handler: Callable[..., Any]) -> bool: + """Return True when ``handler`` declares a ``ctx`` keyword parameter.""" + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return False + return "ctx" in signature.parameters + + +def schema(properties: Dict[str, Any], + required: Optional[List[str]] = None) -> Dict[str, Any]: + """Build a JSON Schema object node from a property mapping.""" + node: Dict[str, Any] = {"type": "object", "properties": properties} + if required: + node["required"] = list(required) + return node + + +# Pre-built annotation singletons used by every tool factory. +DESTRUCTIVE = MCPToolAnnotations(destructive=True) +NON_DESTRUCTIVE = MCPToolAnnotations(destructive=False, idempotent=True) +READ_ONLY = MCPToolAnnotations(read_only=True, idempotent=True) +SIDE_EFFECT_ONLY = MCPToolAnnotations(destructive=False, idempotent=False) + + +def read_only_env_flag() -> bool: + """Return True when JE_AUTOCONTROL_MCP_READONLY is set to a truthy value.""" + raw = os.environ.get("JE_AUTOCONTROL_MCP_READONLY", "") + return raw.strip().lower() in {"1", "true", "yes", "on"} diff --git a/je_auto_control/utils/mcp_server/tools/_factories.py b/je_auto_control/utils/mcp_server/tools/_factories.py new file mode 100644 index 0000000..71c9740 --- /dev/null +++ b/je_auto_control/utils/mcp_server/tools/_factories.py @@ -0,0 +1,857 @@ +"""Tool-factory functions: each returns a list of MCPTool for one domain. + +Keeping factories separate from adapters lets ``_handlers.py`` stay +focused on argument / return-value normalisation while this module +owns the JSON Schemas, descriptions, and annotation choices that the +MCP client surfaces to the model. +""" +from typing import List + +from je_auto_control.utils.mcp_server.tools import _handlers as h +from je_auto_control.utils.mcp_server.tools._base import ( + DESTRUCTIVE, MCPTool, MCPToolAnnotations, NON_DESTRUCTIVE, READ_ONLY, + SIDE_EFFECT_ONLY, schema, +) + + +def mouse_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_click_mouse", + description=("Click a mouse button at (x, y). " + "mouse_keycode: mouse_left, mouse_right, mouse_middle. " + "If x/y are omitted, clicks at the current cursor."), + input_schema=schema({ + "mouse_keycode": {"type": "string", + "description": "mouse_left | mouse_right | mouse_middle"}, + "x": {"type": "integer"}, + "y": {"type": "integer"}, + }), + handler=h.click_mouse, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_get_mouse_position", + description="Return the current cursor position as [x, y].", + input_schema=schema({}), + handler=h.get_mouse_position, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_set_mouse_position", + description="Move the cursor to absolute screen coordinates (x, y).", + input_schema=schema({ + "x": {"type": "integer"}, + "y": {"type": "integer"}, + }, required=["x", "y"]), + handler=h.set_mouse_position, + annotations=NON_DESTRUCTIVE, + ), + MCPTool( + name="ac_mouse_scroll", + description=("Scroll the mouse wheel by scroll_value units. " + "scroll_direction is Linux-only: scroll_up | scroll_down."), + input_schema=schema({ + "scroll_value": {"type": "integer"}, + "x": {"type": "integer"}, + "y": {"type": "integer"}, + "scroll_direction": {"type": "string"}, + }, required=["scroll_value"]), + handler=h.mouse_scroll, + annotations=DESTRUCTIVE, + ), + ] + + +def keyboard_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_type_text", + description=("Type a string by pressing each character. " + "Use ac_press_key or ac_hotkey for control keys."), + input_schema=schema({"text": {"type": "string"}}, + required=["text"]), + handler=h.type_text, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_press_key", + description=("Press and release one keyboard key. keycode is a " + "name from get_keyboard_keys_table (e.g. enter, tab, " + "f1, a, 1)."), + input_schema=schema({"keycode": {"type": "string"}}, + required=["keycode"]), + handler=h.press_key, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_hotkey", + description=("Press a key combination, e.g. ['ctrl', 'c']. " + "Keys are pressed in order then released in reverse."), + input_schema=schema({ + "keys": {"type": "array", "items": {"type": "string"}}, + }, required=["keys"]), + handler=h.hotkey, + annotations=DESTRUCTIVE, + ), + ] + + +def screen_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_screen_size", + description="Return the primary screen size as [width, height].", + input_schema=schema({}), + handler=h.screen_size, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_screenshot", + description=("Take a screenshot and return it as a base64 PNG " + "image content block so the model can see the " + "screen. file_path saves to disk. screen_region " + "is [left, top, right, bottom]. monitor_index " + "captures one monitor across multi-display setups " + "(0 = virtual desktop spanning all, 1+ = single " + "screens — see ac_list_monitors)."), + input_schema=schema({ + "file_path": {"type": "string"}, + "screen_region": {"type": "array", + "items": {"type": "integer"}}, + "monitor_index": {"type": "integer"}, + }), + handler=h.screenshot, + annotations=MCPToolAnnotations(destructive=False, idempotent=False), + ), + MCPTool( + name="ac_list_monitors", + description=("List every connected monitor's geometry. Index 0 " + "spans all monitors; 1+ are single displays. Use " + "the index with ac_screenshot's monitor_index."), + input_schema=schema({}), + handler=h.list_monitors, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_get_pixel", + description="Return the pixel colour at (x, y) as a list of channels.", + input_schema=schema({ + "x": {"type": "integer"}, + "y": {"type": "integer"}, + }, required=["x", "y"]), + handler=h.get_pixel, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_wait_for_image", + description=("Poll the screen until ``image_path`` appears, " + "returning its centre [x, y]. Raises after " + "``timeout`` seconds. Cancellable: clients can " + "send notifications/cancelled to abort."), + input_schema=schema({ + "image_path": {"type": "string"}, + "timeout": {"type": "number"}, + "poll": {"type": "number"}, + "detect_threshold": {"type": "number"}, + }, required=["image_path"]), + handler=h.wait_for_image, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_wait_for_pixel", + description=("Poll pixel (x, y) until it matches ``target_rgb`` " + "within ``tolerance`` per channel. Returns the " + "actual [r, g, b] reading on match."), + input_schema=schema({ + "x": {"type": "integer"}, + "y": {"type": "integer"}, + "target_rgb": {"type": "array", + "items": {"type": "integer"}}, + "tolerance": {"type": "integer"}, + "timeout": {"type": "number"}, + "poll": {"type": "number"}, + }, required=["x", "y", "target_rgb"]), + handler=h.wait_for_pixel, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_diff_screenshots", + description=("Compare two screenshots and return the bounding " + "boxes that changed. Result shape: {size: [w, h], " + "boxes: [[x, y, w, h], ...]}. Pixels differing by " + "at most threshold (per channel) are treated as " + "equal; components smaller than min_box_pixels " + "are ignored to filter antialias noise."), + input_schema=schema({ + "image_path_a": {"type": "string"}, + "image_path_b": {"type": "string"}, + "threshold": {"type": "integer"}, + "min_box_pixels": {"type": "integer"}, + }, required=["image_path_a", "image_path_b"]), + handler=h.diff_screenshots, + annotations=READ_ONLY, + ), + ] + + +def image_and_ocr_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_locate_image_center", + description=("Find a template image on screen and return its " + "centre [x, y]. detect_threshold is 0.0–1.0."), + input_schema=schema({ + "image_path": {"type": "string"}, + "detect_threshold": {"type": "number"}, + }, required=["image_path"]), + handler=h.locate_image_center, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_locate_and_click", + description="Find a template image and click its centre.", + input_schema=schema({ + "image_path": {"type": "string"}, + "mouse_keycode": {"type": "string"}, + "detect_threshold": {"type": "number"}, + }, required=["image_path"]), + handler=h.locate_and_click, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_locate_text", + description=("OCR the screen for ``text`` and return the centre " + "[x, y] of the first match. region is " + "[x, y, width, height]. Requires Tesseract."), + input_schema=schema({ + "text": {"type": "string"}, + "region": {"type": "array", "items": {"type": "integer"}}, + "min_confidence": {"type": "number"}, + }, required=["text"]), + handler=h.locate_text, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_click_text", + description="OCR for ``text`` and click its centre.", + input_schema=schema({ + "text": {"type": "string"}, + "mouse_keycode": {"type": "string"}, + "region": {"type": "array", "items": {"type": "integer"}}, + "min_confidence": {"type": "number"}, + }, required=["text"]), + handler=h.click_text, + annotations=DESTRUCTIVE, + ), + ] + + +def window_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_list_windows", + description=("List visible top-level windows as " + "[{hwnd, title}, ...] (Windows only)."), + input_schema=schema({}), + handler=h.list_windows, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_focus_window", + description="Bring the first window matching title_substring to the front.", + input_schema=schema({ + "title_substring": {"type": "string"}, + "case_sensitive": {"type": "boolean"}, + }, required=["title_substring"]), + handler=h.focus_window, + annotations=NON_DESTRUCTIVE, + ), + MCPTool( + name="ac_wait_for_window", + description="Poll until a window with title_substring exists; return its hwnd.", + input_schema=schema({ + "title_substring": {"type": "string"}, + "timeout": {"type": "number"}, + "case_sensitive": {"type": "boolean"}, + }, required=["title_substring"]), + handler=h.wait_for_window, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_close_window", + description="Minimise the first window matching title_substring.", + input_schema=schema({ + "title_substring": {"type": "string"}, + "case_sensitive": {"type": "boolean"}, + }, required=["title_substring"]), + handler=h.close_window, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_window_move", + description=("Move and resize the first matching window to " + "(x, y) with dimensions (width, height). " + "Windows-only."), + input_schema=schema({ + "title_substring": {"type": "string"}, + "x": {"type": "integer"}, + "y": {"type": "integer"}, + "width": {"type": "integer"}, + "height": {"type": "integer"}, + "case_sensitive": {"type": "boolean"}, + }, required=["title_substring", "x", "y", "width", "height"]), + handler=h.window_move, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_window_minimize", + description="Minimise the first matching window.", + input_schema=schema({ + "title_substring": {"type": "string"}, + "case_sensitive": {"type": "boolean"}, + }, required=["title_substring"]), + handler=h.window_minimize, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_window_maximize", + description="Maximise the first matching window.", + input_schema=schema({ + "title_substring": {"type": "string"}, + "case_sensitive": {"type": "boolean"}, + }, required=["title_substring"]), + handler=h.window_maximize, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_window_restore", + description=("Restore the first matching window to its previous " + "size and position."), + input_schema=schema({ + "title_substring": {"type": "string"}, + "case_sensitive": {"type": "boolean"}, + }, required=["title_substring"]), + handler=h.window_restore, + annotations=DESTRUCTIVE, + ), + ] + + +def system_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_get_clipboard", + description="Return the current text clipboard contents.", + input_schema=schema({}), + handler=h.get_clipboard, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_set_clipboard", + description="Replace the text clipboard contents with ``text``.", + input_schema=schema({"text": {"type": "string"}}, + required=["text"]), + handler=h.set_clipboard, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_get_clipboard_image", + description=("Return the current clipboard image as a base64 " + "PNG content block (so the model can see it). " + "Returns a text block 'clipboard does not contain " + "an image' when the clipboard has no image."), + input_schema=schema({}), + handler=h.get_clipboard_image, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_set_clipboard_image", + description=("Place a Pillow-readable image file on the " + "clipboard. Windows-only today; macOS / Linux " + "raise NotImplementedError."), + input_schema=schema({"image_path": {"type": "string"}}, + required=["image_path"]), + handler=h.set_clipboard_image, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_execute_actions", + description=("Run a list of AutoControl actions through the " + "executor. Each action is [name, args] where name " + "starts with AC_ (see ac_list_action_commands)."), + input_schema=schema({ + "actions": {"type": "array", + "items": {"type": "array"}}, + }, required=["actions"]), + handler=h.execute_actions, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_execute_action_file", + description="Load a JSON action file from disk and execute it.", + input_schema=schema({"file_path": {"type": "string"}}, + required=["file_path"]), + handler=h.execute_action_file, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_list_action_commands", + description="Return every action command name the executor recognises.", + input_schema=schema({}), + handler=h.list_action_commands, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_list_run_history", + description=("Return recent script-run history records " + "(id, status, source_type, started_at, ...)."), + input_schema=schema({ + "limit": {"type": "integer"}, + "source_type": {"type": "string"}, + }), + handler=h.list_run_history, + annotations=READ_ONLY, + ), + ] + + +def screen_record_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_screen_record_start", + description=("Start recording the screen to a video file. " + "recorder_name is a handle for ac_screen_record_stop. " + "Codec defaults to XVID (.avi); use MP4V for .mp4."), + input_schema=schema({ + "recorder_name": {"type": "string"}, + "file_path": {"type": "string"}, + "codec": {"type": "string"}, + "frame_per_sec": {"type": "integer"}, + "width": {"type": "integer"}, + "height": {"type": "integer"}, + }, required=["recorder_name", "file_path"]), + handler=h.screen_record_start, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_screen_record_stop", + description="Stop the named screen recorder.", + input_schema=schema({"recorder_name": {"type": "string"}}, + required=["recorder_name"]), + handler=h.screen_record_stop, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_screen_record_list", + description="Return the names of currently running screen recorders.", + input_schema=schema({}), + handler=h.screen_record_list, + annotations=READ_ONLY, + ), + ] + + +def recording_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_record_start", + description=("Start recording mouse and keyboard events in the " + "background. Call ac_record_stop to retrieve the " + "captured action list. Not supported on macOS."), + input_schema=schema({}), + handler=h.record_start, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_record_stop", + description=("Stop the active recorder and return the captured " + "action list ([[command, args], ...]) ready to " + "feed back into ac_execute_actions."), + input_schema=schema({}), + handler=h.record_stop, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_read_action_file", + description="Read a JSON action file from disk and return its parsed contents.", + input_schema=schema({"file_path": {"type": "string"}}, + required=["file_path"]), + handler=h.read_action_file, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_write_action_file", + description="Persist an action list to a JSON file at file_path.", + input_schema=schema({ + "file_path": {"type": "string"}, + "actions": {"type": "array"}, + }, required=["file_path", "actions"]), + handler=h.write_action_file, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_trim_actions", + description=("Return actions[start:end] as a new list — useful " + "for cleaning up the head/tail of a recording."), + input_schema=schema({ + "actions": {"type": "array"}, + "start": {"type": "integer"}, + "end": {"type": "integer"}, + }, required=["actions"]), + handler=h.trim_actions, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_adjust_delays", + description=("Scale every AC_sleep delay by ``factor`` and " + "optionally clamp to a minimum of clamp_ms."), + input_schema=schema({ + "actions": {"type": "array"}, + "factor": {"type": "number"}, + "clamp_ms": {"type": "integer"}, + }, required=["actions"]), + handler=h.adjust_delays, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_scale_coordinates", + description=("Scale every x/y coordinate in an action list — " + "useful when replaying a recording on a different " + "resolution."), + input_schema=schema({ + "actions": {"type": "array"}, + "x_factor": {"type": "number"}, + "y_factor": {"type": "number"}, + }, required=["actions"]), + handler=h.scale_coordinates, + annotations=READ_ONLY, + ), + ] + + +def drag_and_send_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_drag", + description=("Drag the mouse from (start_x, start_y) to " + "(end_x, end_y). mouse_keycode defaults to " + "mouse_left."), + input_schema=schema({ + "start_x": {"type": "integer"}, + "start_y": {"type": "integer"}, + "end_x": {"type": "integer"}, + "end_y": {"type": "integer"}, + "mouse_keycode": {"type": "string"}, + }, required=["start_x", "start_y", "end_x", "end_y"]), + handler=h.drag, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_send_key_to_window", + description=("Post a key event to a specific window without " + "stealing focus (Windows / Linux only)."), + input_schema=schema({ + "window_title": {"type": "string"}, + "keycode": {"type": "string"}, + }, required=["window_title", "keycode"]), + handler=h.send_key_to_window, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_send_mouse_to_window", + description=("Post a mouse event to a specific window without " + "stealing focus (Windows / Linux only)."), + input_schema=schema({ + "window_title": {"type": "string"}, + "mouse_keycode": {"type": "string"}, + "x": {"type": "integer"}, + "y": {"type": "integer"}, + }, required=["window_title"]), + handler=h.send_mouse_to_window, + annotations=DESTRUCTIVE, + ), + ] + + +def semantic_locator_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_a11y_list", + description=("List accessibility-tree elements (buttons, fields, " + "menu items, ...) optionally filtered by app_name. " + "Each element exposes name, role, and bounding box."), + input_schema=schema({ + "app_name": {"type": "string"}, + "max_results": {"type": "integer"}, + }), + handler=h.a11y_list, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_a11y_find", + description=("Find the first accessibility element matching name " + "/ role / app_name. Returns null when nothing matches."), + input_schema=schema({ + "name": {"type": "string"}, + "role": {"type": "string"}, + "app_name": {"type": "string"}, + }), + handler=h.a11y_find, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_a11y_click", + description=("Click the centre of the first accessibility " + "element matching name / role / app_name."), + input_schema=schema({ + "name": {"type": "string"}, + "role": {"type": "string"}, + "app_name": {"type": "string"}, + }), + handler=h.a11y_click, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_vlm_locate", + description=("Ask a vision-language model where ``description`` " + "is on screen. Returns [x, y] in screen coords or " + "null. Requires ANTHROPIC_API_KEY or OPENAI_API_KEY."), + input_schema=schema({ + "description": {"type": "string"}, + "screen_region": {"type": "array", + "items": {"type": "integer"}}, + "model": {"type": "string"}, + }, required=["description"]), + handler=h.vlm_locate, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_vlm_click", + description="Locate by description with a VLM, then click the centre.", + input_schema=schema({ + "description": {"type": "string"}, + "screen_region": {"type": "array", + "items": {"type": "integer"}}, + "model": {"type": "string"}, + }, required=["description"]), + handler=h.vlm_click, + annotations=DESTRUCTIVE, + ), + ] + + +def scheduler_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_scheduler_add_job", + description=("Schedule an action JSON file. Provide either " + "interval_seconds (run every N seconds) or " + "cron_expression (5-field cron rule)."), + input_schema=schema({ + "script_path": {"type": "string"}, + "interval_seconds": {"type": "number"}, + "cron_expression": {"type": "string"}, + "repeat": {"type": "boolean"}, + "max_runs": {"type": "integer"}, + "job_id": {"type": "string"}, + }, required=["script_path"]), + handler=h.scheduler_add_job, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_scheduler_remove_job", + description="Remove a scheduled job by id; returns True if it existed.", + input_schema=schema({"job_id": {"type": "string"}}, + required=["job_id"]), + handler=h.scheduler_remove_job, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_scheduler_list_jobs", + description="List currently registered scheduler jobs.", + input_schema=schema({}), + handler=h.scheduler_list_jobs, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_scheduler_start", + description="Start the scheduler polling thread (idempotent).", + input_schema=schema({}), + handler=h.scheduler_start, + annotations=NON_DESTRUCTIVE, + ), + MCPTool( + name="ac_scheduler_stop", + description="Stop the scheduler polling thread.", + input_schema=schema({}), + handler=h.scheduler_stop, + annotations=NON_DESTRUCTIVE, + ), + ] + + +def trigger_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_trigger_add", + description=("Add a trigger to the default engine. ``kind`` is " + "image (provide image_path/threshold), window " + "(title_substring/case_sensitive), pixel " + "(x/y/target_rgb/tolerance), or file (watch_path). " + "When fired, ``script_path`` is executed."), + input_schema=schema({ + "kind": {"type": "string", + "enum": ["image", "window", "pixel", "file"]}, + "script_path": {"type": "string"}, + "repeat": {"type": "boolean"}, + "image_path": {"type": "string"}, + "threshold": {"type": "number"}, + "title_substring": {"type": "string"}, + "case_sensitive": {"type": "boolean"}, + "x": {"type": "integer"}, + "y": {"type": "integer"}, + "target_rgb": {"type": "array", + "items": {"type": "integer"}}, + "tolerance": {"type": "integer"}, + "watch_path": {"type": "string"}, + }, required=["kind", "script_path"]), + handler=h.trigger_add, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_trigger_remove", + description="Remove a trigger by id.", + input_schema=schema({"trigger_id": {"type": "string"}}, + required=["trigger_id"]), + handler=h.trigger_remove, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_trigger_list", + description="List currently registered triggers.", + input_schema=schema({}), + handler=h.trigger_list, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_trigger_start", + description="Start the trigger engine polling thread (idempotent).", + input_schema=schema({}), + handler=h.trigger_start, + annotations=NON_DESTRUCTIVE, + ), + MCPTool( + name="ac_trigger_stop", + description="Stop the trigger engine polling thread.", + input_schema=schema({}), + handler=h.trigger_stop, + annotations=NON_DESTRUCTIVE, + ), + ] + + +def process_and_shell_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_launch_process", + description=("Spawn a subprocess with the given argv list " + "(detached, stdio piped to /dev/null). Returns " + "{pid, argv}. Optional working_directory."), + input_schema=schema({ + "argv": {"type": "array", "items": {"type": "string"}}, + "working_directory": {"type": "string"}, + }, required=["argv"]), + handler=h.launch_process, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_list_processes", + description=("List running processes (psutil required). " + "Optionally filter by case-insensitive substring."), + input_schema=schema({ + "name_contains": {"type": "string"}, + }), + handler=h.list_processes, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_kill_process", + description=("Terminate a PID gracefully, escalating to " + "SIGKILL after ``timeout``. Returns 'terminated' " + "/ 'killed' / 'not-found'. psutil required."), + input_schema=schema({ + "pid": {"type": "integer"}, + "timeout": {"type": "number"}, + }, required=["pid"]), + handler=h.kill_process, + annotations=DESTRUCTIVE, + ), + MCPTool( + name="ac_shell", + description=("Run a shell-style command line via shlex.split " + "(NO shell expansion). Returns {exit_code, " + "stdout, stderr}."), + input_schema=schema({ + "command": {"type": "string"}, + "timeout": {"type": "number"}, + }, required=["command"]), + handler=h.shell_command, + annotations=DESTRUCTIVE, + ), + ] + + +def hotkey_tools() -> List[MCPTool]: + return [ + MCPTool( + name="ac_hotkey_bind", + description=("Bind a global hotkey combo (e.g. 'ctrl+alt+1') to " + "an action JSON file. Call ac_hotkey_daemon_start " + "to begin listening."), + input_schema=schema({ + "combo": {"type": "string"}, + "script_path": {"type": "string"}, + "binding_id": {"type": "string"}, + }, required=["combo", "script_path"]), + handler=h.hotkey_bind, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_hotkey_unbind", + description="Remove a hotkey binding by id.", + input_schema=schema({"binding_id": {"type": "string"}}, + required=["binding_id"]), + handler=h.hotkey_unbind, + annotations=SIDE_EFFECT_ONLY, + ), + MCPTool( + name="ac_hotkey_list", + description="List the registered hotkey bindings.", + input_schema=schema({}), + handler=h.hotkey_list, + annotations=READ_ONLY, + ), + MCPTool( + name="ac_hotkey_daemon_start", + description="Start the global hotkey listener thread (idempotent).", + input_schema=schema({}), + handler=h.hotkey_daemon_start, + annotations=NON_DESTRUCTIVE, + ), + MCPTool( + name="ac_hotkey_daemon_stop", + description="Stop the global hotkey listener thread.", + input_schema=schema({}), + handler=h.hotkey_daemon_stop, + annotations=NON_DESTRUCTIVE, + ), + ] + + +ALL_FACTORIES = ( + mouse_tools, keyboard_tools, screen_tools, image_and_ocr_tools, + window_tools, system_tools, recording_tools, drag_and_send_tools, + semantic_locator_tools, scheduler_tools, trigger_tools, hotkey_tools, + screen_record_tools, process_and_shell_tools, +) diff --git a/je_auto_control/utils/mcp_server/tools/_handlers.py b/je_auto_control/utils/mcp_server/tools/_handlers.py new file mode 100644 index 0000000..477bcff --- /dev/null +++ b/je_auto_control/utils/mcp_server/tools/_handlers.py @@ -0,0 +1,924 @@ +"""Adapter functions that bridge MCP tool calls to AutoControl's headless API. + +Each adapter normalises arguments (parses ints / paths) and return +values (lists / dicts / strings, or :class:`MCPContent`) so they +survive the JSON-RPC boundary. Wrapper imports are lazy to keep the +top-level MCP server boot cheap. +""" +import base64 +import io +import os +from typing import Any, Dict, List, Optional + +from je_auto_control.utils.mcp_server.tools._base import MCPContent + + +# === Mouse / keyboard ======================================================= + +def click_mouse(mouse_keycode: str = "mouse_left", + x: Optional[int] = None, + y: Optional[int] = None) -> List[Any]: + from je_auto_control.wrapper.auto_control_mouse import click_mouse as _click + keycode, click_x, click_y = _click(mouse_keycode, x, y) + # Real wrapper resolves the string keycode to an int via the keys table; + # the fake backend keeps it as a string. Pass through whatever we got. + resolved = int(keycode) if isinstance(keycode, int) else keycode + return [resolved, int(click_x), int(click_y)] + + +def set_mouse_position(x: int, y: int) -> List[int]: + from je_auto_control.wrapper.auto_control_mouse import set_mouse_position as _move + moved = _move(int(x), int(y)) + return [int(moved[0]), int(moved[1])] + + +def get_mouse_position() -> List[int]: + from je_auto_control.wrapper.auto_control_mouse import get_mouse_position as _pos + pos = _pos() + return [] if pos is None else [int(pos[0]), int(pos[1])] + + +def mouse_scroll(scroll_value: int, + x: Optional[int] = None, + y: Optional[int] = None, + scroll_direction: str = "scroll_down") -> List[Any]: + from je_auto_control.wrapper.auto_control_mouse import mouse_scroll as _scroll + value, direction = _scroll(int(scroll_value), x, y, scroll_direction) + return [int(value), str(direction)] + + +def type_text(text: str) -> str: + from je_auto_control.wrapper.auto_control_keyboard import write + return write(text) or "" + + +def press_key(keycode: str) -> str: + from je_auto_control.wrapper.auto_control_keyboard import type_keyboard + return type_keyboard(keycode) or "" + + +def hotkey(keys: List[str]) -> List[str]: + from je_auto_control.wrapper.auto_control_keyboard import hotkey as _hotkey + pressed, released = _hotkey(list(keys)) + return [pressed, released] + + +def drag(start_x: int, start_y: int, end_x: int, end_y: int, + mouse_keycode: str = "mouse_left") -> List[int]: + """Drag the cursor from (start_x, start_y) to (end_x, end_y).""" + from je_auto_control.wrapper.auto_control_mouse import ( + press_mouse, release_mouse, set_mouse_position as _move, + ) + _move(int(start_x), int(start_y)) + press_mouse(mouse_keycode, int(start_x), int(start_y)) + _move(int(end_x), int(end_y)) + release_mouse(mouse_keycode, int(end_x), int(end_y)) + return [int(end_x), int(end_y)] + + +def send_key_to_window(window_title: str, keycode: str) -> str: + from je_auto_control.wrapper.auto_control_keyboard import ( + send_key_event_to_window, + ) + send_key_event_to_window(window_title, keycode) + return "ok" + + +def send_mouse_to_window(window_title: str, + mouse_keycode: str = "mouse_left", + x: Optional[int] = None, + y: Optional[int] = None) -> str: + from je_auto_control.windows.window import windows_window_manage as wm + from je_auto_control.wrapper.auto_control_mouse import ( + send_mouse_event_to_window, + ) + hit = next(((hwnd, title) for hwnd, title in wm.get_all_window_hwnd() + if window_title.lower() in title.lower()), None) + if hit is None: + raise ValueError(f"no window matching {window_title!r}") + send_mouse_event_to_window(hit[0], mouse_keycode, x=x, y=y) + return "ok" + + +# === Screen / image / OCR =================================================== + +def screen_size() -> List[int]: + from je_auto_control.wrapper.auto_control_screen import screen_size as _size + width, height = _size() + return [int(width), int(height)] + + +def screenshot(file_path: Optional[str] = None, + screen_region: Optional[List[int]] = None, + monitor_index: Optional[int] = None, + ) -> List[MCPContent]: + """Take a screenshot, optionally save it, and return image + path. + + When ``monitor_index`` is provided, capture that specific monitor + via ``mss`` (works across multi-display setups). Index 0 is the + virtual desktop spanning all monitors; 1+ are individual screens. + """ + saved_path: Optional[str] = None + if file_path is not None: + saved_path = os.path.realpath(os.fspath(file_path)) + parent = os.path.dirname(saved_path) or "." + if not os.path.isdir(parent): + raise ValueError(f"screenshot directory does not exist: {parent}") + if monitor_index is not None: + image = _grab_monitor(int(monitor_index)) + if saved_path is not None: + image.save(saved_path) + else: + from je_auto_control.utils.cv2_utils.screenshot import pil_screenshot + image = pil_screenshot(file_path=saved_path, screen_region=screen_region) + buffer = io.BytesIO() + image.save(buffer, format="PNG") + encoded = base64.b64encode(buffer.getvalue()).decode("ascii") + contents: List[MCPContent] = [MCPContent.image_block(encoded)] + if saved_path is not None: + contents.append(MCPContent.text_block(f"saved: {saved_path}")) + return contents + + +def list_monitors() -> List[Dict[str, Any]]: + """Return every monitor's geometry. Index 0 spans all monitors.""" + import mss + with mss.mss() as sct: + return [ + { + "index": index, "left": int(monitor["left"]), + "top": int(monitor["top"]), + "width": int(monitor["width"]), + "height": int(monitor["height"]), + "is_combined": index == 0, + } + for index, monitor in enumerate(sct.monitors) + ] + + +def _grab_monitor(index: int): + """Capture a single monitor via ``mss`` and return a PIL Image.""" + import mss + from PIL import Image + with mss.mss() as sct: + if index < 0 or index >= len(sct.monitors): + raise ValueError( + f"monitor index {index} out of range " + f"(0..{len(sct.monitors) - 1})" + ) + frame = sct.grab(sct.monitors[index]) + return Image.frombytes("RGB", frame.size, frame.bgra, "raw", "BGRX") + + +def get_pixel(x: int, y: int) -> List[int]: + from je_auto_control.wrapper.auto_control_screen import get_pixel as _pixel + pixel = _pixel(int(x), int(y)) + if pixel is None: + return [] + return [int(component) for component in pixel] + + +def wait_for_image(image_path: str, timeout: float = 10.0, + poll: float = 0.5, + detect_threshold: float = 1.0, + ctx: Any = None) -> List[int]: + """Poll for ``image_path`` on screen; return its centre [x, y] or raise.""" + import time as _time + from je_auto_control.utils.exception.exceptions import ImageNotFoundException + from je_auto_control.wrapper.auto_control_image import locate_image_center as _loc + poll_seconds = max(0.05, float(poll)) + deadline = _time.monotonic() + float(timeout) + while _time.monotonic() < deadline: + if ctx is not None: + ctx.check_cancelled() + ctx.progress(_time.monotonic() - (deadline - float(timeout)), + total=float(timeout), + message=f"waiting for {image_path}") + try: + cx, cy = _loc(image_path, + detect_threshold=float(detect_threshold)) + return [int(cx), int(cy)] + except ImageNotFoundException: + _time.sleep(poll_seconds) + raise TimeoutError( + f"wait_for_image timed out after {timeout}s: {image_path!r}" + ) + + +def wait_for_pixel(x: int, y: int, target_rgb: List[int], + tolerance: int = 8, timeout: float = 10.0, + poll: float = 0.25, + ctx: Any = None) -> List[int]: + """Poll until pixel ``(x, y)`` matches ``target_rgb`` within ``tolerance``.""" + import time as _time + from je_auto_control.wrapper.auto_control_screen import get_pixel as _pixel + if len(target_rgb) < 3: + raise ValueError("target_rgb must contain at least 3 channels") + target = [int(c) for c in target_rgb[:3]] + tol = max(0, int(tolerance)) + poll_seconds = max(0.05, float(poll)) + deadline = _time.monotonic() + float(timeout) + while _time.monotonic() < deadline: + if ctx is not None: + ctx.check_cancelled() + raw = _pixel(int(x), int(y)) + if raw is not None and len(raw) >= 3: + channels = [int(raw[i]) for i in range(3)] + if all(abs(channels[i] - target[i]) <= tol for i in range(3)): + return channels + _time.sleep(poll_seconds) + raise TimeoutError( + f"wait_for_pixel timed out after {timeout}s at ({x}, {y})" + ) + + +def diff_screenshots(image_path_a: str, + image_path_b: str, + threshold: int = 16, + min_box_pixels: int = 25, + ) -> Dict[str, Any]: + """Return the bounding boxes that differ between two screenshots. + + The result is JSON-friendly: ``{"size": [w, h], "boxes": [[x, y, w, h], ...]}``. + Boxes are merged via a flood-fill so a single changed widget is one + rectangle. Pixels whose absolute per-channel difference is at most + ``threshold`` are considered equal; tiny components below + ``min_box_pixels`` are dropped to ignore JPEG / antialias noise. + """ + safe_a = os.path.realpath(os.fspath(image_path_a)) + safe_b = os.path.realpath(os.fspath(image_path_b)) + return _diff_screenshots(safe_a, safe_b, int(threshold), + int(min_box_pixels)) + + +def _diff_screenshots(path_a: str, path_b: str, threshold: int, + min_box_pixels: int) -> Dict[str, Any]: + """Implementation split off so the public adapter stays under 75 lines.""" + import numpy as np + from PIL import Image + + img_a = np.asarray(Image.open(path_a).convert("RGB")) + img_b = np.asarray(Image.open(path_b).convert("RGB")) + if img_a.shape != img_b.shape: + height = min(img_a.shape[0], img_b.shape[0]) + width = min(img_a.shape[1], img_b.shape[1]) + img_a = img_a[:height, :width] + img_b = img_b[:height, :width] + diff = np.abs(img_a.astype("int16") - img_b.astype("int16")) + mask = (diff.max(axis=-1) > threshold).astype("uint8") + boxes = _connected_component_boxes(mask, min_box_pixels) + height, width = mask.shape + return {"size": [int(width), int(height)], "boxes": boxes} + + +def _connected_component_boxes(mask: Any, + min_pixels: int) -> List[List[int]]: + """Return tight bounding boxes for connected non-zero regions in ``mask``.""" + import numpy as np + + height, width = mask.shape + visited = np.zeros_like(mask, dtype=bool) + boxes: List[List[int]] = [] + for start_y in range(height): + for start_x in range(width): + if mask[start_y, start_x] == 0 or visited[start_y, start_x]: + continue + box = _flood_fill_box(mask, visited, start_x, start_y) + if box[2] * box[3] < min_pixels: + continue + boxes.append(box) + return boxes + + +_screen_recorder_singleton: Any = None + + +def _get_screen_recorder() -> Any: + """Lazy-init the process-wide ScreenRecorder.""" + global _screen_recorder_singleton + if _screen_recorder_singleton is None: + from je_auto_control.utils.cv2_utils.screen_record import ScreenRecorder + _screen_recorder_singleton = ScreenRecorder() + return _screen_recorder_singleton + + +def screen_record_start(recorder_name: str, + file_path: str, + codec: str = "XVID", + frame_per_sec: int = 30, + width: int = 1920, + height: int = 1080) -> str: + """Start a screen recording under ``recorder_name``; returns the resolved path.""" + safe_path = os.path.realpath(os.fspath(file_path)) + parent = os.path.dirname(safe_path) or "." + if not os.path.isdir(parent): + raise ValueError(f"recording directory does not exist: {parent}") + recorder = _get_screen_recorder() + recorder.start_new_record( + recorder_name=str(recorder_name), + path_and_filename=safe_path, codec=str(codec), + frame_per_sec=int(frame_per_sec), + resolution=(int(width), int(height)), + ) + return safe_path + + +def screen_record_stop(recorder_name: str) -> str: + """Stop the named screen recording; no-op if it doesn't exist.""" + recorder = _get_screen_recorder() + recorder.stop_record(str(recorder_name)) + return "stopped" + + +def screen_record_list() -> List[str]: + """Return the names of currently running recorders.""" + recorder = _get_screen_recorder() + return sorted(recorder.running_recorder.keys()) + + +def _flood_fill_box(mask: Any, visited: Any, + start_x: int, start_y: int) -> List[int]: + """Iterative 4-connectivity flood fill returning [x, y, w, h].""" + height, width = mask.shape + stack = [(start_x, start_y)] + min_x = max_x = start_x + min_y = max_y = start_y + while stack: + x, y = stack.pop() + if x < 0 or y < 0 or x >= width or y >= height: + continue + if visited[y, x] or mask[y, x] == 0: + continue + visited[y, x] = True + if x < min_x: + min_x = x + if x > max_x: + max_x = x + if y < min_y: + min_y = y + if y > max_y: + max_y = y + stack.extend(((x + 1, y), (x - 1, y), (x, y + 1), (x, y - 1))) + return [int(min_x), int(min_y), + int(max_x - min_x + 1), int(max_y - min_y + 1)] + + +def locate_image_center(image_path: str, + detect_threshold: float = 1.0) -> List[int]: + from je_auto_control.wrapper.auto_control_image import locate_image_center as _loc + cx, cy = _loc(image_path, detect_threshold=float(detect_threshold)) + return [int(cx), int(cy)] + + +def locate_and_click(image_path: str, + mouse_keycode: str = "mouse_left", + detect_threshold: float = 1.0) -> List[int]: + from je_auto_control.wrapper.auto_control_image import locate_and_click as _loc_click + cx, cy = _loc_click(image_path, mouse_keycode, + detect_threshold=float(detect_threshold)) + return [int(cx), int(cy)] + + +def locate_text(text: str, + region: Optional[List[int]] = None, + min_confidence: float = 60.0) -> List[int]: + from je_auto_control.utils.ocr.ocr_engine import locate_text_center + cx, cy = locate_text_center(text, region=region, + min_confidence=float(min_confidence)) + return [int(cx), int(cy)] + + +def click_text(text: str, + mouse_keycode: str = "mouse_left", + region: Optional[List[int]] = None, + min_confidence: float = 60.0) -> List[int]: + from je_auto_control.utils.ocr.ocr_engine import click_text as _click + cx, cy = _click(text, mouse_keycode=mouse_keycode, region=region, + min_confidence=float(min_confidence)) + return [int(cx), int(cy)] + + +# === Windows / system ======================================================= + +def list_windows() -> List[Dict[str, Any]]: + from je_auto_control.wrapper.auto_control_window import list_windows as _list + return [{"hwnd": int(hwnd), "title": title} + for hwnd, title in _list()] + + +def focus_window(title_substring: str, + case_sensitive: bool = False) -> int: + from je_auto_control.wrapper.auto_control_window import focus_window as _focus + return int(_focus(title_substring, case_sensitive=case_sensitive)) + + +def wait_for_window(title_substring: str, + timeout: float = 10.0, + case_sensitive: bool = False) -> int: + from je_auto_control.wrapper.auto_control_window import wait_for_window as _wait + return int(_wait(title_substring, timeout=float(timeout), + case_sensitive=case_sensitive)) + + +def close_window(title_substring: str, + case_sensitive: bool = False) -> bool: + from je_auto_control.wrapper.auto_control_window import close_window_by_title + return bool(close_window_by_title(title_substring, + case_sensitive=case_sensitive)) + + +def _resolve_window_hwnd(title_substring: str, + case_sensitive: bool) -> int: + from je_auto_control.wrapper.auto_control_window import find_window + hit = find_window(title_substring, case_sensitive=case_sensitive) + if hit is None: + raise ValueError(f"no window matches {title_substring!r}") + return int(hit[0]) + + +def window_move(title_substring: str, x: int, y: int, + width: int, height: int, + case_sensitive: bool = False) -> Dict[str, int]: + """Move and resize the first window matching ``title_substring`` (Win32 only).""" + from je_auto_control.windows.window import windows_window_manage as wm + hwnd = _resolve_window_hwnd(title_substring, bool(case_sensitive)) + if not wm.move_window(hwnd, int(x), int(y), int(width), int(height)): + raise RuntimeError("MoveWindow returned 0") + return {"hwnd": hwnd, "x": int(x), "y": int(y), + "width": int(width), "height": int(height)} + + +def _show_command(title_substring: str, case_sensitive: bool, + cmd_show: int) -> int: + """Resolve the window then call ShowWindow with the given cmd.""" + from je_auto_control.windows.window import windows_window_manage as wm + hwnd = _resolve_window_hwnd(title_substring, bool(case_sensitive)) + wm.show_window(hwnd, int(cmd_show)) + return hwnd + + +def window_minimize(title_substring: str, + case_sensitive: bool = False) -> int: + return _show_command(title_substring, bool(case_sensitive), cmd_show=6) + + +def window_maximize(title_substring: str, + case_sensitive: bool = False) -> int: + return _show_command(title_substring, bool(case_sensitive), cmd_show=3) + + +def window_restore(title_substring: str, + case_sensitive: bool = False) -> int: + return _show_command(title_substring, bool(case_sensitive), cmd_show=9) + + +def launch_process(argv: List[str], + working_directory: Optional[str] = None, + ) -> Dict[str, Any]: + """Spawn a detached subprocess with a sanitised argv list.""" + import subprocess # nosec B404 # reason: required to spawn child processes + if not isinstance(argv, list) or not argv: + raise ValueError("argv must be a non-empty list") + cleaned = [str(part) for part in argv] + cwd = None + if working_directory is not None: + cwd = os.path.realpath(os.fspath(working_directory)) + if not os.path.isdir(cwd): + raise ValueError(f"working_directory does not exist: {cwd}") + # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit.dangerous-subprocess-use-audit + process = subprocess.Popen( # nosec B603 # reason: argv list, no shell expansion + cleaned, cwd=cwd, stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL, + ) + return {"pid": int(process.pid), "argv": cleaned} + + +def list_processes(name_contains: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """List running processes via ``psutil`` if installed; raise otherwise.""" + try: + import psutil # type: ignore[import-untyped] + except ImportError as error: + raise RuntimeError( + "ac_list_processes requires psutil — pip install psutil" + ) from error + needle = name_contains.lower() if name_contains else None + out: List[Dict[str, Any]] = [] + for proc in psutil.process_iter(["pid", "name", "username"]): + info = proc.info or {} + name = (info.get("name") or "") + if needle and needle not in name.lower(): + continue + out.append({ + "pid": int(info.get("pid") or 0), + "name": name, + "username": info.get("username") or "", + }) + return out + + +def kill_process(pid: int, timeout: float = 5.0) -> str: + """Terminate a PID gracefully, escalating to SIGKILL after ``timeout``.""" + try: + import psutil # type: ignore[import-untyped] + except ImportError as error: + raise RuntimeError( + "ac_kill_process requires psutil — pip install psutil" + ) from error + try: + proc = psutil.Process(int(pid)) + except psutil.NoSuchProcess: + return "not-found" + proc.terminate() + try: + proc.wait(timeout=float(timeout)) + return "terminated" + except psutil.TimeoutExpired: + proc.kill() + return "killed" + + +def shell_command(command: str, timeout: float = 30.0 + ) -> Dict[str, Any]: + """Run a shell-style command line and return stdout/stderr/exit_code. + + Uses argv-list parsing via ``shlex.split`` so we never enable a + shell — protects against the parameterised command injection + classes Bandit B602 / B605 cover. + """ + import shlex + import subprocess # nosec B404 # reason: required for child execution + + if not command or not command.strip(): + raise ValueError("command must be a non-empty string") + argv = shlex.split(command, posix=False) if os.name == "nt" \ + else shlex.split(command) + # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit.dangerous-subprocess-use-audit + proc = subprocess.run( # nosec B603 # reason: argv from shlex.split, no shell + argv, capture_output=True, text=True, + timeout=float(timeout), check=False, + ) + return { + "exit_code": int(proc.returncode), + "stdout": proc.stdout, "stderr": proc.stderr, + } + + +def get_clipboard() -> str: + from je_auto_control.utils.clipboard.clipboard import get_clipboard as _get + return _get() + + +def set_clipboard(text: str) -> str: + from je_auto_control.utils.clipboard.clipboard import set_clipboard as _set + _set(text) + return "ok" + + +def get_clipboard_image() -> List[MCPContent]: + """Return the clipboard image as a base64 PNG content block.""" + from je_auto_control.utils.clipboard.clipboard_image import ( + get_clipboard_image as _read, + ) + payload = _read() + if payload is None: + return [MCPContent.text_block("clipboard does not contain an image")] + encoded = base64.b64encode(payload).decode("ascii") + return [MCPContent.image_block(encoded)] + + +def set_clipboard_image(image_path: str) -> str: + from je_auto_control.utils.clipboard.clipboard_image import ( + set_clipboard_image as _write, + ) + safe_path = os.path.realpath(os.fspath(image_path)) + _write(safe_path) + return "ok" + + +# === Executor / history / recording ========================================= + +def execute_actions(actions: List[Any]) -> Dict[str, str]: + from je_auto_control.utils.executor.action_executor import execute_action + result = execute_action(actions) + return {key: str(value) for key, value in result.items()} + + +def execute_action_file(file_path: str) -> Dict[str, str]: + from je_auto_control.utils.executor.action_executor import execute_action + from je_auto_control.utils.json.json_file import read_action_json + safe_path = os.path.realpath(os.fspath(file_path)) + result = execute_action(read_action_json(safe_path)) + return {key: str(value) for key, value in result.items()} + + +def list_action_commands() -> List[str]: + from je_auto_control.utils.executor.action_executor import executor + return sorted(executor.known_commands()) + + +def list_run_history(limit: int = 50, + source_type: Optional[str] = None + ) -> List[Dict[str, Any]]: + from je_auto_control.utils.run_history.history_store import default_history_store + rows = default_history_store.list_runs(limit=int(limit), + source_type=source_type) + return [{ + "id": row.id, "source_type": row.source_type, + "source_id": row.source_id, "script_path": row.script_path, + "started_at": str(row.started_at), + "finished_at": str(row.finished_at), + "status": row.status, "error_text": row.error_text, + "duration_seconds": row.duration_seconds, + } for row in rows] + + +def record_start() -> str: + from je_auto_control.wrapper.auto_control_record import record + record() + return "recording started" + + +def record_stop() -> List[Any]: + from je_auto_control.wrapper.auto_control_record import stop_record + return stop_record() or [] + + +def read_action_file(file_path: str) -> List[Any]: + from je_auto_control.utils.json.json_file import read_action_json + safe_path = os.path.realpath(os.fspath(file_path)) + return read_action_json(safe_path) + + +def write_action_file(file_path: str, actions: List[Any]) -> str: + from je_auto_control.utils.json.json_file import write_action_json + safe_path = os.path.realpath(os.fspath(file_path)) + parent = os.path.dirname(safe_path) or "." + if not os.path.isdir(parent): + raise ValueError(f"action-file directory does not exist: {parent}") + write_action_json(safe_path, actions) + return safe_path + + +def trim_actions(actions: List[Any], start: int = 0, + end: Optional[int] = None) -> List[Any]: + from je_auto_control.utils.recording_edit.editor import trim_actions as _trim + return _trim(actions, start=int(start), + end=None if end is None else int(end)) + + +def adjust_delays(actions: List[Any], factor: float = 1.0, + clamp_ms: int = 0) -> List[Any]: + from je_auto_control.utils.recording_edit.editor import adjust_delays as _adj + return _adj(actions, factor=float(factor), clamp_ms=int(clamp_ms)) + + +def scale_coordinates(actions: List[Any], x_factor: float = 1.0, + y_factor: float = 1.0) -> List[Any]: + from je_auto_control.utils.recording_edit.editor import scale_coordinates as _scale + return _scale(actions, x_factor=float(x_factor), + y_factor=float(y_factor)) + + +# === Semantic locators (a11y / VLM) ========================================= + +def a11y_list(app_name: Optional[str] = None, + max_results: int = 100) -> List[Dict[str, Any]]: + from je_auto_control.utils.accessibility.accessibility_api import ( + list_accessibility_elements, + ) + return [element.to_dict() + for element in list_accessibility_elements( + app_name=app_name, max_results=int(max_results), + )] + + +def a11y_find(name: Optional[str] = None, + role: Optional[str] = None, + app_name: Optional[str] = None) -> Optional[Dict[str, Any]]: + from je_auto_control.utils.accessibility.accessibility_api import ( + find_accessibility_element, + ) + element = find_accessibility_element(name=name, role=role, + app_name=app_name) + return None if element is None else element.to_dict() + + +def a11y_click(name: Optional[str] = None, + role: Optional[str] = None, + app_name: Optional[str] = None) -> bool: + from je_auto_control.utils.accessibility.accessibility_api import ( + click_accessibility_element, + ) + return bool(click_accessibility_element(name=name, role=role, + app_name=app_name)) + + +def vlm_locate(description: str, + screen_region: Optional[List[int]] = None, + model: Optional[str] = None) -> Optional[List[int]]: + from je_auto_control.utils.vision.vlm_api import locate_by_description + coords = locate_by_description(description, screen_region=screen_region, + model=model) + return None if coords is None else [int(coords[0]), int(coords[1])] + + +def vlm_click(description: str, + screen_region: Optional[List[int]] = None, + model: Optional[str] = None) -> bool: + from je_auto_control.utils.vision.vlm_api import click_by_description + return bool(click_by_description(description, + screen_region=screen_region, + model=model)) + + +# === Scheduler / triggers / hotkey daemon =================================== + +def _job_to_dict(job: Any) -> Dict[str, Any]: + return { + "job_id": job.job_id, "script_path": job.script_path, + "interval_seconds": job.interval_seconds, + "is_cron": job.is_cron, "repeat": job.repeat, + "max_runs": job.max_runs, "runs": job.runs, + "enabled": job.enabled, + } + + +def scheduler_add_job(script_path: str, + interval_seconds: Optional[float] = None, + cron_expression: Optional[str] = None, + repeat: bool = True, + max_runs: Optional[int] = None, + job_id: Optional[str] = None) -> Dict[str, Any]: + """Add an interval or cron job to the default scheduler.""" + from je_auto_control.utils.scheduler.scheduler import default_scheduler + safe_path = os.path.realpath(os.fspath(script_path)) + if cron_expression: + job = default_scheduler.add_cron_job( + safe_path, cron_expression, max_runs=max_runs, job_id=job_id, + ) + return _job_to_dict(job) + if interval_seconds is None: + raise ValueError("scheduler_add_job needs either interval_seconds or cron_expression") + job = default_scheduler.add_job( + safe_path, interval_seconds=float(interval_seconds), + repeat=bool(repeat), max_runs=max_runs, job_id=job_id, + ) + return _job_to_dict(job) + + +def scheduler_remove_job(job_id: str) -> bool: + from je_auto_control.utils.scheduler.scheduler import default_scheduler + return bool(default_scheduler.remove_job(job_id)) + + +def scheduler_list_jobs() -> List[Dict[str, Any]]: + from je_auto_control.utils.scheduler.scheduler import default_scheduler + return [_job_to_dict(job) for job in default_scheduler.list_jobs()] + + +def scheduler_start() -> str: + from je_auto_control.utils.scheduler.scheduler import default_scheduler + default_scheduler.start() + return "started" + + +def scheduler_stop() -> str: + from je_auto_control.utils.scheduler.scheduler import default_scheduler + default_scheduler.stop() + return "stopped" + + +def _trigger_to_dict(trigger: Any) -> Dict[str, Any]: + return { + "trigger_id": trigger.trigger_id, + "type": type(trigger).__name__, + "script_path": trigger.script_path, + "repeat": trigger.repeat, "enabled": trigger.enabled, + "fired": trigger.fired, + } + + +def trigger_add(kind: str, script_path: str, repeat: bool = False, + image_path: Optional[str] = None, + threshold: float = 0.8, + title_substring: Optional[str] = None, + case_sensitive: bool = False, + x: int = 0, y: int = 0, + target_rgb: Optional[List[int]] = None, + tolerance: int = 8, + watch_path: Optional[str] = None) -> Dict[str, Any]: + """Add a trigger to the default engine. ``kind`` is image/window/pixel/file.""" + from je_auto_control.utils.triggers.trigger_engine import ( + FilePathTrigger, ImageAppearsTrigger, PixelColorTrigger, + WindowAppearsTrigger, default_trigger_engine, + ) + safe_script = os.path.realpath(os.fspath(script_path)) + trigger = _build_trigger( + kind=kind, script_path=safe_script, repeat=repeat, + image_path=image_path, threshold=threshold, + title_substring=title_substring, case_sensitive=case_sensitive, + x=x, y=y, target_rgb=target_rgb, tolerance=tolerance, + watch_path=watch_path, + types={ + "image": ImageAppearsTrigger, + "window": WindowAppearsTrigger, + "pixel": PixelColorTrigger, + "file": FilePathTrigger, + }, + ) + default_trigger_engine.add(trigger) + return _trigger_to_dict(trigger) + + +def _build_trigger(*, kind: str, script_path: str, repeat: bool, + image_path: Optional[str], threshold: float, + title_substring: Optional[str], case_sensitive: bool, + x: int, y: int, target_rgb: Optional[List[int]], + tolerance: int, watch_path: Optional[str], + types: Dict[str, Any]) -> Any: + if kind == "image": + if not image_path: + raise ValueError("image trigger requires image_path") + return types["image"](trigger_id="", script_path=script_path, + repeat=repeat, image_path=image_path, + threshold=float(threshold)) + if kind == "window": + if not title_substring: + raise ValueError("window trigger requires title_substring") + return types["window"](trigger_id="", script_path=script_path, + repeat=repeat, + title_substring=title_substring, + case_sensitive=bool(case_sensitive)) + if kind == "pixel": + rgb = tuple(int(c) for c in (target_rgb or [0, 0, 0])) + return types["pixel"](trigger_id="", script_path=script_path, + repeat=repeat, x=int(x), y=int(y), + target_rgb=rgb, tolerance=int(tolerance)) + if kind == "file": + if not watch_path: + raise ValueError("file trigger requires watch_path") + return types["file"](trigger_id="", script_path=script_path, + repeat=repeat, watch_path=watch_path) + raise ValueError(f"unknown trigger kind: {kind!r}") + + +def trigger_remove(trigger_id: str) -> bool: + from je_auto_control.utils.triggers.trigger_engine import default_trigger_engine + return bool(default_trigger_engine.remove(trigger_id)) + + +def trigger_list() -> List[Dict[str, Any]]: + from je_auto_control.utils.triggers.trigger_engine import default_trigger_engine + return [_trigger_to_dict(t) for t in default_trigger_engine.list_triggers()] + + +def trigger_start() -> str: + from je_auto_control.utils.triggers.trigger_engine import default_trigger_engine + default_trigger_engine.start() + return "started" + + +def trigger_stop() -> str: + from je_auto_control.utils.triggers.trigger_engine import default_trigger_engine + default_trigger_engine.stop() + return "stopped" + + +def hotkey_bind(combo: str, script_path: str, + binding_id: Optional[str] = None) -> Dict[str, Any]: + from je_auto_control.utils.hotkey.hotkey_daemon import default_hotkey_daemon + safe_path = os.path.realpath(os.fspath(script_path)) + binding = default_hotkey_daemon.bind(combo, safe_path, + binding_id=binding_id) + return { + "binding_id": binding.binding_id, "combo": binding.combo, + "script_path": binding.script_path, "enabled": binding.enabled, + "fired": binding.fired, + } + + +def hotkey_unbind(binding_id: str) -> bool: + from je_auto_control.utils.hotkey.hotkey_daemon import default_hotkey_daemon + return bool(default_hotkey_daemon.unbind(binding_id)) + + +def hotkey_list() -> List[Dict[str, Any]]: + from je_auto_control.utils.hotkey.hotkey_daemon import default_hotkey_daemon + return [{ + "binding_id": b.binding_id, "combo": b.combo, + "script_path": b.script_path, "enabled": b.enabled, + "fired": b.fired, + } for b in default_hotkey_daemon.list_bindings()] + + +def hotkey_daemon_start() -> str: + from je_auto_control.utils.hotkey.hotkey_daemon import default_hotkey_daemon + default_hotkey_daemon.start() + return "started" + + +def hotkey_daemon_stop() -> str: + from je_auto_control.utils.hotkey.hotkey_daemon import default_hotkey_daemon + default_hotkey_daemon.stop() + return "stopped" diff --git a/je_auto_control/utils/mcp_server/tools/_validation.py b/je_auto_control/utils/mcp_server/tools/_validation.py new file mode 100644 index 0000000..9b3dc1e --- /dev/null +++ b/je_auto_control/utils/mcp_server/tools/_validation.py @@ -0,0 +1,79 @@ +"""Tiny JSON-Schema validator covering the subset MCP tools use here. + +The default tool registry only references a handful of schema +features — ``type``, ``properties``, ``required``, ``items``, +``enum`` — so a 50-line validator is enough and avoids pulling in +``jsonschema`` as a runtime dependency. The validator intentionally +returns the first violation it finds rather than collecting them +all, which keeps the JSON-RPC error message short. +""" +from typing import Any, Dict, Optional + +_TYPE_CHECKS = { + "object": lambda value: isinstance(value, dict), + "array": lambda value: isinstance(value, list), + "string": lambda value: isinstance(value, str), + "boolean": lambda value: isinstance(value, bool), + # JSON Schema integer accepts bool subclass; we exclude bool below. + "integer": lambda value: isinstance(value, int) and not isinstance(value, bool), + "number": lambda value: ( + isinstance(value, (int, float)) and not isinstance(value, bool) + ), + "null": lambda value: value is None, +} + + +def validate_arguments(schema: Dict[str, Any], + arguments: Dict[str, Any]) -> Optional[str]: + """Return the first schema violation as a message, or ``None`` if valid.""" + return _validate(schema, arguments, path="$") + + +def _validate(schema: Dict[str, Any], value: Any, path: str) -> Optional[str]: + expected = schema.get("type") + if expected is not None: + check = _TYPE_CHECKS.get(expected) + if check is None: + return None # unknown type — nothing we can check + if not check(value): + return f"{path}: expected {expected}, got {type(value).__name__}" + enum = schema.get("enum") + if enum is not None and value not in enum: + return f"{path}: must be one of {enum!r}" + if expected == "object": + return _validate_object(schema, value, path) + if expected == "array": + return _validate_array(schema, value, path) + return None + + +def _validate_object(schema: Dict[str, Any], value: Dict[str, Any], + path: str) -> Optional[str]: + required = schema.get("required") or [] + for name in required: + if name not in value: + return f"{path}: missing required property {name!r}" + properties = schema.get("properties") or {} + for name, child_value in value.items(): + child_schema = properties.get(name) + if child_schema is None: + continue + error = _validate(child_schema, child_value, f"{path}.{name}") + if error is not None: + return error + return None + + +def _validate_array(schema: Dict[str, Any], value: list, + path: str) -> Optional[str]: + item_schema = schema.get("items") + if not isinstance(item_schema, dict): + return None + for index, item in enumerate(value): + error = _validate(item_schema, item, f"{path}[{index}]") + if error is not None: + return error + return None + + +__all__ = ["validate_arguments"] diff --git a/je_auto_control/utils/mcp_server/tools/plugin_tools.py b/je_auto_control/utils/mcp_server/tools/plugin_tools.py new file mode 100644 index 0000000..8569a8c --- /dev/null +++ b/je_auto_control/utils/mcp_server/tools/plugin_tools.py @@ -0,0 +1,86 @@ +"""Wrap plugin-loaded ``AC_*`` callables as :class:`MCPTool` objects. + +Plugins register arbitrary callables under ``AC_`` via +:mod:`je_auto_control.utils.plugin_loader`. This module bridges that +dynamic catalogue into the live MCP server so a plugin a user drops +into their plugin directory shows up as a tool the model can call, +and the client gets notified to refresh its tool list. +""" +import inspect +from typing import Any, Callable, Dict, List + +from je_auto_control.utils.mcp_server.tools._base import ( + DESTRUCTIVE, MCPTool, schema, +) + + +def make_plugin_tool(name: str, + handler: Callable[..., Any], + description: str = "") -> MCPTool: + """Build an :class:`MCPTool` from a plugin callable's signature. + + The schema is derived from ``inspect.signature(handler)``: every + parameter becomes a property, parameters without defaults are + marked required, and a parameter named ``ctx`` is excluded so + progress / cancellation context plumbing keeps working. + """ + properties, required = _properties_from_signature(handler) + tool_name = f"plugin_{name.lower()}" if not name.lower().startswith( + "plugin_") else name.lower() + docstring = (handler.__doc__ or "").strip().splitlines()[0] if handler.__doc__ else "" + desc = description or docstring or f"Plugin command {name!r}." + return MCPTool( + name=tool_name, + description=desc, + input_schema=schema(properties, required=required or None), + handler=handler, + annotations=DESTRUCTIVE, + ) + + +def register_plugin_tools(server, commands: Dict[str, Callable[..., Any]] + ) -> List[str]: + """Wrap each entry in ``commands`` and add it to ``server``. + + Returns the list of MCP tool names that were registered. + """ + registered: List[str] = [] + for raw_name, handler in commands.items(): + tool = make_plugin_tool(raw_name, handler) + server.register_tool(tool) + registered.append(tool.name) + return registered + + +_TYPE_FROM_ANNOTATION = { + int: "integer", float: "number", bool: "boolean", + str: "string", list: "array", dict: "object", +} + + +def _properties_from_signature(handler: Callable[..., Any] + ) -> tuple: + """Return (properties, required) derived from the callable signature.""" + try: + signature = inspect.signature(handler) + except (TypeError, ValueError): + return {}, [] + properties: Dict[str, Any] = {} + required: List[str] = [] + for param in signature.parameters.values(): + if param.name == "ctx": + continue + if param.kind in (inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD): + continue + prop: Dict[str, Any] = {} + annotation_type = _TYPE_FROM_ANNOTATION.get(param.annotation) + if annotation_type is not None: + prop["type"] = annotation_type + properties[param.name] = prop + if param.default is inspect.Parameter.empty: + required.append(param.name) + return properties, required + + +__all__ = ["make_plugin_tool", "register_plugin_tools"] diff --git a/je_auto_control/windows/window/windows_window_manage.py b/je_auto_control/windows/window/windows_window_manage.py index d7f9a84..80fb1aa 100644 --- a/je_auto_control/windows/window/windows_window_manage.py +++ b/je_auto_control/windows/window/windows_window_manage.py @@ -87,4 +87,15 @@ def show_window(hwnd: int, cmd_show: int) -> None: if cmd_show < 0 or cmd_show > 11: # Win32 ShowWindow 常見範圍 cmd_show = 1 # 預設為 Normal user32.ShowWindow(hwnd, cmd_show) - user32.SetForegroundWindow(hwnd) \ No newline at end of file + user32.SetForegroundWindow(hwnd) + + +def move_window(hwnd: int, x: int, y: int, width: int, height: int, + repaint: bool = True) -> bool: + """ + 搬移與調整視窗大小 (一次設定座標與寬高) + Move and resize a window in one call. + """ + return bool(user32.MoveWindow(int(hwnd), int(x), int(y), + int(width), int(height), + c_bool(bool(repaint)))) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 91ef0dd..2ec9959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,9 @@ classifiers = [ "Operating System :: OS Independent" ] +[project.scripts] +je_auto_control_mcp = "je_auto_control.utils.mcp_server.__main__:main" + [project.urls] Homepage = "https://github.com/Intergration-Automation-Testing/AutoControl" Documentation = "https://autocontrol.readthedocs.io/en/latest/" diff --git a/test/unit_test/headless/test_mcp_cli.py b/test/unit_test/headless/test_mcp_cli.py new file mode 100644 index 0000000..e825984 --- /dev/null +++ b/test/unit_test/headless/test_mcp_cli.py @@ -0,0 +1,57 @@ +"""Tests for the je_auto_control_mcp CLI introspection flags.""" +import io +import json +import sys + +from je_auto_control.utils.mcp_server.__main__ import main + + +def _capture(monkeypatch, argv): + """Run ``main(argv)`` for its side effect and return captured stdout.""" + buffer = io.StringIO() + monkeypatch.setattr(sys, "stdout", buffer) + main(argv) + return buffer.getvalue() + + +def test_list_tools_emits_json_and_exits(monkeypatch): + output = _capture(monkeypatch, ["--list-tools"]) + descriptors = json.loads(output) + assert isinstance(descriptors, list) + assert descriptors + assert all("name" in d and "inputSchema" in d for d in descriptors) + + +def test_list_tools_with_read_only_drops_destructive(monkeypatch): + output = _capture(monkeypatch, ["--list-tools", "--read-only"]) + descriptors = json.loads(output) + names = {d["name"] for d in descriptors} + assert "ac_click_mouse" not in names + assert "ac_get_mouse_position" in names + + +def test_list_resources_emits_json(monkeypatch): + output = _capture(monkeypatch, ["--list-resources"]) + descriptors = json.loads(output) + assert isinstance(descriptors, list) + uris = {d["uri"] for d in descriptors} + assert "autocontrol://history" in uris + assert "autocontrol://commands" in uris + + +def test_list_prompts_emits_json(monkeypatch): + output = _capture(monkeypatch, ["--list-prompts"]) + descriptors = json.loads(output) + assert isinstance(descriptors, list) + names = {d["name"] for d in descriptors} + assert {"automate_ui_task", "find_widget"}.issubset(names) + + +def test_no_flags_starts_stdio_server(monkeypatch): + """With no flags, main() should dispatch to start_mcp_stdio_server.""" + started = [] + import je_auto_control.utils.mcp_server.__main__ as cli_mod + monkeypatch.setattr(cli_mod, "start_mcp_stdio_server", + lambda: started.append(True)) + main([]) + assert started == [True] diff --git a/test/unit_test/headless/test_mcp_fake_backend.py b/test/unit_test/headless/test_mcp_fake_backend.py new file mode 100644 index 0000000..fbbb837 --- /dev/null +++ b/test/unit_test/headless/test_mcp_fake_backend.py @@ -0,0 +1,75 @@ +"""Tests for the fake backend used in CI smoke runs.""" +import pytest + +from je_auto_control.utils.mcp_server.fake_backend import ( + fake_state, install_fake_backend, maybe_install_from_env, + reset_fake_state, uninstall_fake_backend, +) +from je_auto_control.utils.mcp_server.tools import ( + build_default_tool_registry, +) + + +@pytest.fixture() +def fake_backend(): + """Install the fake backend for the duration of the test.""" + reset_fake_state() + install_fake_backend() + yield + uninstall_fake_backend() + reset_fake_state() + + +def test_set_mouse_position_records_in_fake_state(fake_backend): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + by_name["ac_set_mouse_position"].invoke({"x": 100, "y": 200}) + assert fake_state().cursor == (100, 200) + assert ("set_position", 100, 200) in fake_state().mouse_actions + + +def test_click_mouse_records_button_and_coords(fake_backend): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + by_name["ac_click_mouse"].invoke({ + "mouse_keycode": "mouse_left", "x": 50, "y": 60, + }) + assert ("click", "mouse_left", 50, 60) in fake_state().mouse_actions + + +def test_clipboard_round_trip_via_fake_backend(fake_backend): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + by_name["ac_set_clipboard"].invoke({"text": "hi"}) + assert by_name["ac_get_clipboard"].invoke({}) == "hi" + assert fake_state().clipboard_text == "hi" + + +def test_type_text_appends_to_typed_history(fake_backend): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + by_name["ac_type_text"].invoke({"text": "hello"}) + assert "hello" in fake_state().typed_text + + +def test_install_is_idempotent(fake_backend): + install_fake_backend() + install_fake_backend() + by_name = {tool.name: tool for tool in build_default_tool_registry()} + by_name["ac_set_mouse_position"].invoke({"x": 1, "y": 2}) + assert fake_state().cursor == (1, 2) + + +def test_maybe_install_from_env_respects_flag(monkeypatch): + reset_fake_state() + uninstall_fake_backend() + monkeypatch.setenv("JE_AUTOCONTROL_FAKE_BACKEND", "1") + try: + assert maybe_install_from_env() is True + from je_auto_control.wrapper import auto_control_mouse as mouse_module + moved = mouse_module.set_mouse_position(7, 9) + assert moved == (7, 9) + assert fake_state().cursor == (7, 9) + finally: + uninstall_fake_backend() + + +def test_maybe_install_from_env_skips_when_unset(monkeypatch): + monkeypatch.delenv("JE_AUTOCONTROL_FAKE_BACKEND", raising=False) + assert maybe_install_from_env() is False diff --git a/test/unit_test/headless/test_mcp_http_transport.py b/test/unit_test/headless/test_mcp_http_transport.py new file mode 100644 index 0000000..10655e2 --- /dev/null +++ b/test/unit_test/headless/test_mcp_http_transport.py @@ -0,0 +1,279 @@ +"""Headless tests for the MCP HTTP transport.""" +import json +import urllib.error +import urllib.request + +import pytest + +from je_auto_control.utils.mcp_server.http_transport import ( + DEFAULT_PATH, HttpMCPServer, +) +from je_auto_control.utils.mcp_server.prompts import StaticPromptProvider +from je_auto_control.utils.mcp_server.resources import ChainProvider +from je_auto_control.utils.mcp_server.server import MCPServer + + +_TEST_SCHEME = "http" # NOSONAR localhost-only ephemeral test server; TLS out of scope + + +@pytest.fixture() +def http_server(): + """Spin up an HttpMCPServer on an ephemeral port with empty providers.""" + mcp = MCPServer( + tools=[], + resource_provider=ChainProvider([]), + prompt_provider=StaticPromptProvider([]), + ) + server = HttpMCPServer(mcp=mcp, host="127.0.0.1", port=0) + server.start() + yield server + server.stop(timeout=1.0) + + +def _post(server, body, path=DEFAULT_PATH): + host, port = server.address + url = f"{_TEST_SCHEME}://{host}:{port}{path}" + data = json.dumps(body).encode("utf-8") if body is not None else None + headers = {"Content-Type": "application/json"} if data else {} + req = urllib.request.Request(url, data=data, headers=headers, + method="POST") + with urllib.request.urlopen(req, timeout=3) as response: # nosec B310 + return response.status, response.read().decode("utf-8") + + +def test_initialize_round_trips_over_http(http_server): + status, body = _post(http_server, { + "jsonrpc": "2.0", "id": 1, "method": "initialize", + "params": {"protocolVersion": "2024-11-05"}, + }) + assert status == 200 + payload = json.loads(body) + assert payload["result"]["protocolVersion"] == "2024-11-05" + assert payload["result"]["serverInfo"]["name"] == "je_auto_control" + + +def test_notification_returns_202(http_server): + host, port = http_server.address + url = f"{_TEST_SCHEME}://{host}:{port}{DEFAULT_PATH}" + data = json.dumps({ + "jsonrpc": "2.0", "method": "notifications/initialized", + }).encode("utf-8") + req = urllib.request.Request( + url, data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(req, timeout=3) as response: # nosec B310 + assert response.status == 202 + assert response.read() == b"" + + +def test_tools_list_via_http_uses_registered_tools(): + from je_auto_control.utils.mcp_server.tools import ( + MCPTool, build_default_tool_registry, + ) + full_registry = build_default_tool_registry() + server = HttpMCPServer(mcp=MCPServer( + tools=full_registry, + resource_provider=ChainProvider([]), + prompt_provider=StaticPromptProvider([]), + ), host="127.0.0.1", port=0) + server.start() + try: + status, body = _post(server, { + "jsonrpc": "2.0", "id": 1, "method": "tools/list", + }) + assert status == 200 + payload = json.loads(body) + names = {tool["name"] for tool in payload["result"]["tools"]} + assert {"ac_get_mouse_position", "ac_screen_size"}.issubset(names) + assert isinstance(full_registry[0], MCPTool) + finally: + server.stop(timeout=1.0) + + +def test_unknown_path_returns_404(http_server): + try: + _post(http_server, {"jsonrpc": "2.0", "id": 1, + "method": "ping"}, path="/elsewhere") + except urllib.error.HTTPError as error: + assert error.code == 404 + else: + pytest.fail("expected 404 response") + + +def test_get_returns_405(http_server): + host, port = http_server.address + url = f"{_TEST_SCHEME}://{host}:{port}{DEFAULT_PATH}" + req = urllib.request.Request(url, method="GET") + try: + urllib.request.urlopen(req, timeout=3) # nosec B310 + except urllib.error.HTTPError as error: + assert error.code == 405 + else: + pytest.fail("expected 405 response") + + +def test_invalid_content_length_rejected(http_server): + host, port = http_server.address + url = f"{_TEST_SCHEME}://{host}:{port}{DEFAULT_PATH}" + req = urllib.request.Request(url, data=b"", method="POST", + headers={"Content-Type": "application/json"}) + try: + urllib.request.urlopen(req, timeout=3) # nosec B310 + except urllib.error.HTTPError as error: + assert error.code == 400 + else: + pytest.fail("expected 400 response") + + +def test_sse_streams_progress_then_final_response(): + """When Accept includes text/event-stream, progress + result both stream.""" + from je_auto_control.utils.mcp_server.tools import MCPTool + + def slow_handler(ctx): + ctx.progress(0.0, total=1.0, message="starting") + ctx.progress(0.5, total=1.0) + ctx.progress(1.0, total=1.0, message="done") + return "all done" + + tool = MCPTool( + name="streamed", description="streamed", + input_schema={"type": "object", "properties": {}}, + handler=slow_handler, + ) + server = HttpMCPServer(mcp=MCPServer( + tools=[tool], resource_provider=ChainProvider([]), + prompt_provider=StaticPromptProvider([]), + ), host="127.0.0.1", port=0) + server.start() + try: + host, port = server.address + url = f"{_TEST_SCHEME}://{host}:{port}{DEFAULT_PATH}" + body = json.dumps({ + "jsonrpc": "2.0", "id": 99, "method": "tools/call", + "params": {"name": "streamed", "arguments": {}, + "_meta": {"progressToken": "tok"}}, + }).encode("utf-8") + req = urllib.request.Request( + url, data=body, method="POST", + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + ) + with urllib.request.urlopen(req, timeout=5) as response: # nosec B310 + assert response.status == 200 + assert "text/event-stream" in response.headers.get("Content-Type", "") + stream = response.read().decode("utf-8") + events = [chunk[len("data: "):] for chunk in stream.split("\n\n") + if chunk.startswith("data: ")] + progress_events = [json.loads(line) for line in events + if '"notifications/progress"' in line] + final_events = [json.loads(line) for line in events + if '"id": 99' in line] + assert len(progress_events) == 3 + assert progress_events[0]["params"]["progressToken"] == "tok" + assert len(final_events) == 1 + assert final_events[0]["result"]["isError"] is False + assert final_events[0]["result"]["content"][0]["text"] == "all done" + finally: + server.stop(timeout=1.0) + + +def test_authentication_rejects_missing_bearer(): + server = HttpMCPServer(mcp=MCPServer( + tools=[], resource_provider=ChainProvider([]), + prompt_provider=StaticPromptProvider([]), + ), host="127.0.0.1", port=0, auth_token="secret") + server.start() + try: + host, port = server.address + url = f"{_TEST_SCHEME}://{host}:{port}{DEFAULT_PATH}" + body = json.dumps({"jsonrpc": "2.0", "id": 1, + "method": "ping"}).encode("utf-8") + req = urllib.request.Request( + url, data=body, method="POST", + headers={"Content-Type": "application/json"}, + ) + try: + urllib.request.urlopen(req, timeout=3) # nosec B310 + except urllib.error.HTTPError as error: + assert error.code == 401 + else: + pytest.fail("expected 401 response") + finally: + server.stop(timeout=1.0) + + +def test_authentication_rejects_wrong_bearer(): + server = HttpMCPServer(mcp=MCPServer( + tools=[], resource_provider=ChainProvider([]), + prompt_provider=StaticPromptProvider([]), + ), host="127.0.0.1", port=0, auth_token="secret") + server.start() + try: + host, port = server.address + url = f"{_TEST_SCHEME}://{host}:{port}{DEFAULT_PATH}" + body = json.dumps({"jsonrpc": "2.0", "id": 1, + "method": "ping"}).encode("utf-8") + req = urllib.request.Request( + url, data=body, method="POST", + headers={"Content-Type": "application/json", + "Authorization": "Bearer wrong"}, + ) + try: + urllib.request.urlopen(req, timeout=3) # nosec B310 + except urllib.error.HTTPError as error: + assert error.code == 403 + else: + pytest.fail("expected 403 response") + finally: + server.stop(timeout=1.0) + + +def test_authentication_accepts_correct_bearer(): + server = HttpMCPServer(mcp=MCPServer( + tools=[], resource_provider=ChainProvider([]), + prompt_provider=StaticPromptProvider([]), + ), host="127.0.0.1", port=0, auth_token="secret") + server.start() + try: + host, port = server.address + url = f"{_TEST_SCHEME}://{host}:{port}{DEFAULT_PATH}" + body = json.dumps({"jsonrpc": "2.0", "id": 1, + "method": "ping"}).encode("utf-8") + req = urllib.request.Request( + url, data=body, method="POST", + headers={"Content-Type": "application/json", + "Authorization": "Bearer secret"}, + ) + with urllib.request.urlopen(req, timeout=3) as response: # nosec B310 + assert response.status == 200 + payload = json.loads(response.read().decode("utf-8")) + assert payload["result"] == {} + finally: + server.stop(timeout=1.0) + + +def test_auth_token_falls_back_to_env(monkeypatch): + monkeypatch.setenv("JE_AUTOCONTROL_MCP_TOKEN", "env-secret") + server = HttpMCPServer(mcp=MCPServer( + tools=[], resource_provider=ChainProvider([]), + prompt_provider=StaticPromptProvider([]), + ), host="127.0.0.1", port=0) + assert server._auth_token == "env-secret" + + +def test_malformed_json_returns_parse_error(http_server): + host, port = http_server.address + url = f"{_TEST_SCHEME}://{host}:{port}{DEFAULT_PATH}" + body = b"{ not json" + req = urllib.request.Request( + url, data=body, method="POST", + headers={"Content-Type": "application/json"}, + ) + with urllib.request.urlopen(req, timeout=3) as response: # nosec B310 + assert response.status == 200 + payload = json.loads(response.read().decode("utf-8")) + assert payload["error"]["code"] == -32700 diff --git a/test/unit_test/headless/test_mcp_plugin_watcher.py b/test/unit_test/headless/test_mcp_plugin_watcher.py new file mode 100644 index 0000000..e33811e --- /dev/null +++ b/test/unit_test/headless/test_mcp_plugin_watcher.py @@ -0,0 +1,70 @@ +"""Tests for the MCP plugin hot-reload watcher.""" +import time + +from je_auto_control.utils.mcp_server.plugin_watcher import PluginWatcher +from je_auto_control.utils.mcp_server.server import MCPServer + + +def _write(path, body): + path.write_text(body, encoding="utf-8") + # Bump mtime to ensure the watcher picks it up even on coarse FSes. + now = time.time() + import os + os.utime(path, (now, now)) + + +def test_watcher_registers_tools_for_existing_plugins(tmp_path): + plugin = tmp_path / "demo.py" + _write(plugin, "def AC_hello(name='world'):\n return f'hi {name}'\n") + server = MCPServer(tools=[]) + watcher = PluginWatcher(server, str(tmp_path), poll_seconds=0.1) + watcher.poll_once() + assert server._tools.get("plugin_ac_hello") is not None + + +def test_watcher_picks_up_new_files(tmp_path): + server = MCPServer(tools=[]) + watcher = PluginWatcher(server, str(tmp_path), poll_seconds=0.1) + watcher.poll_once() + plugin = tmp_path / "added.py" + _write(plugin, "def AC_added():\n return 'late'\n") + watcher.poll_once() + assert "plugin_ac_added" in server._tools + + +def test_watcher_drops_tools_when_file_removed(tmp_path): + plugin = tmp_path / "soon.py" + _write(plugin, "def AC_soon():\n return 'gone soon'\n") + server = MCPServer(tools=[]) + watcher = PluginWatcher(server, str(tmp_path), poll_seconds=0.1) + watcher.poll_once() + assert "plugin_ac_soon" in server._tools + plugin.unlink() + watcher.poll_once() + assert "plugin_ac_soon" not in server._tools + + +def test_watcher_reloads_after_mtime_change(tmp_path): + plugin = tmp_path / "evolving.py" + _write(plugin, "def AC_evolve():\n return 1\n") + server = MCPServer(tools=[]) + watcher = PluginWatcher(server, str(tmp_path), poll_seconds=0.1) + watcher.poll_once() + first = server._tools["plugin_ac_evolve"].handler + # Rewrite with a new function body and bump mtime. + _write(plugin, "def AC_evolve():\n return 2\n") + watcher.poll_once() + second = server._tools["plugin_ac_evolve"].handler + assert first is not second + + +def test_watcher_start_requires_existing_directory(tmp_path): + server = MCPServer(tools=[]) + watcher = PluginWatcher(server, str(tmp_path / "ghost")) + try: + watcher.start() + except NotADirectoryError: + pass + else: + watcher.stop(timeout=0.5) + raise AssertionError("expected NotADirectoryError") diff --git a/test/unit_test/headless/test_mcp_server.py b/test/unit_test/headless/test_mcp_server.py new file mode 100644 index 0000000..ba6a963 --- /dev/null +++ b/test/unit_test/headless/test_mcp_server.py @@ -0,0 +1,1703 @@ +"""Headless tests for the MCP stdio server. + +These tests exercise the JSON-RPC dispatcher directly and via an +in-memory pipe so no real stdin/stdout, no Qt, and no platform +backends are needed. +""" +import io +import json +import math +import os +import threading +from typing import Any, Dict, List + +from je_auto_control.utils.mcp_server.context import ( + OperationCancelledError, ToolCallContext, +) +from je_auto_control.utils.mcp_server.prompts import ( + MCPPrompt, MCPPromptArgument, StaticPromptProvider, + default_prompt_catalogue, +) +from je_auto_control.utils.mcp_server.resources import ( + ChainProvider, FileSystemProvider, MCPResource, ResourceProvider, +) +from je_auto_control.utils.mcp_server.server import ( + PROTOCOL_VERSION, MCPServer, +) +from je_auto_control.utils.mcp_server.tools import ( + MCPContent, MCPTool, MCPToolAnnotations, build_default_tool_registry, +) + + +def _request(method: str, msg_id: int = 1, + params: Dict[str, Any] = None) -> str: + payload: Dict[str, Any] = {"jsonrpc": "2.0", "id": msg_id, + "method": method} + if params is not None: + payload["params"] = params + return json.dumps(payload) + + +def _decode(line: str) -> Dict[str, Any]: + return json.loads(line) + + +def test_initialize_echoes_protocol_version(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("initialize", params={ + "protocolVersion": "2024-11-05", + "capabilities": {}, "clientInfo": {"name": "pytest"}, + }))) + result = response["result"] + assert result["protocolVersion"] == "2024-11-05" + assert result["serverInfo"]["name"] == "je_auto_control" + assert "tools" in result["capabilities"] + + +def test_initialize_falls_back_to_server_default(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line( + _request("initialize", params={}) + )) + assert response["result"]["protocolVersion"] == PROTOCOL_VERSION + + +def test_tools_list_returns_registered_tool_descriptors(): + tool = MCPTool( + name="echo", description="echo args back", + input_schema={"type": "object", "properties": {}}, + handler=lambda **kwargs: kwargs, + ) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/list"))) + descriptors = response["result"]["tools"] + assert len(descriptors) == 1 + assert descriptors[0]["name"] == "echo" + assert descriptors[0]["inputSchema"] == {"type": "object", + "properties": {}} + + +def test_tools_call_invokes_handler_and_serialises_result(): + tool = MCPTool( + name="add", description="add two ints", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, + "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + handler=lambda a, b: a + b, + ) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "add", "arguments": {"a": 2, "b": 3}, + }))) + result = response["result"] + assert result["isError"] is False + assert result["content"][0]["text"] == "5" + + +def test_tools_call_unknown_tool_reports_error(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "missing", "arguments": {}, + }))) + assert response["error"]["code"] == -32602 + assert "Unknown tool" in response["error"]["message"] + + +def test_tools_call_handler_exception_returns_is_error(): + def boom(**_kwargs): + raise ValueError("nope") + + tool = MCPTool(name="boom", description="fail", + input_schema={"type": "object", "properties": {}}, + handler=boom) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "boom", "arguments": {}, + }))) + result = response["result"] + assert result["isError"] is True + assert "ValueError" in result["content"][0]["text"] + + +def test_unknown_method_returns_method_not_found(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("does/not/exist"))) + assert response["error"]["code"] == -32601 + + +def test_parse_error_returned_for_bad_json(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line("{not json")) + assert response["error"]["code"] == -32700 + assert response["id"] is None + + +def test_notification_initialized_sets_state_and_returns_none(): + server = MCPServer(tools=[]) + notification = json.dumps({"jsonrpc": "2.0", + "method": "notifications/initialized"}) + assert server.handle_line(notification) is None + assert server._initialized is True # noqa: SLF001 # reason: white-box check + + +def test_serve_stdio_processes_messages_until_eof(): + tool = MCPTool(name="ping_tool", description="ping", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "pong") + server = MCPServer(tools=[tool]) + stdin_lines: List[str] = [ + _request("initialize", msg_id=1, params={"protocolVersion": "x"}), + json.dumps({"jsonrpc": "2.0", + "method": "notifications/initialized"}), + _request("tools/call", msg_id=2, + params={"name": "ping_tool", "arguments": {}}), + ] + stdin = io.StringIO("\n".join(stdin_lines) + "\n") + stdout = io.StringIO() + server.serve_stdio(stdin=stdin, stdout=stdout) + out_lines = [line for line in stdout.getvalue().splitlines() if line] + responses = [_decode(line) for line in out_lines + if '"id":' in line and '"method"' not in line] + assert len(responses) == 2 # initialize + tools/call + assert responses[-1]["result"]["content"][0]["text"] == "pong" + + +def test_tool_descriptor_includes_annotations(): + annotations = MCPToolAnnotations(title="Echo", read_only=True, + idempotent=True) + tool = MCPTool( + name="echo", description="echo", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "ok", + annotations=annotations, + ) + descriptor = tool.to_descriptor() + assert descriptor["annotations"] == { + "readOnlyHint": True, + "destructiveHint": False, + "idempotentHint": True, + "openWorldHint": True, + "title": "Echo", + } + + +def test_read_only_annotation_forces_destructive_false(): + """Per spec, destructiveHint is meaningful only if readOnlyHint is false.""" + annotations = MCPToolAnnotations(read_only=True, destructive=True) + assert annotations.to_dict()["destructiveHint"] is False + + +def test_default_tool_registry_marks_safe_tools_read_only(): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + assert by_name["ac_get_mouse_position"].annotations.read_only is True + assert by_name["ac_screen_size"].annotations.read_only is True + assert by_name["ac_list_action_commands"].annotations.read_only is True + # Side-effecting tools must NOT claim read-only. + assert by_name["ac_click_mouse"].annotations.read_only is False + assert by_name["ac_type_text"].annotations.read_only is False + + +def test_read_only_registry_drops_destructive_tools(): + safe = build_default_tool_registry(read_only=True) + assert safe, "expected at least one read-only tool" + assert all(tool.annotations.read_only for tool in safe) + safe_names = {tool.name for tool in safe} + assert "ac_click_mouse" not in safe_names + assert "ac_type_text" not in safe_names + assert "ac_execute_actions" not in safe_names + # Pure observers must survive. + assert {"ac_get_mouse_position", "ac_screen_size", + "ac_list_action_commands"}.issubset(safe_names) + + +def test_read_only_env_var_is_honored(monkeypatch): + monkeypatch.setenv("JE_AUTOCONTROL_MCP_READONLY", "1") + safe = build_default_tool_registry() + assert all(tool.annotations.read_only for tool in safe) + + +def test_read_only_env_var_disabled_when_unset(monkeypatch): + monkeypatch.delenv("JE_AUTOCONTROL_MCP_READONLY", raising=False) + full = build_default_tool_registry() + assert any(not tool.annotations.read_only for tool in full) + + +def test_mcp_content_image_block_serialises_to_mcp_shape(): + content = MCPContent.image_block("AAAA", mime_type="image/jpeg") + assert content.to_dict() == { + "type": "image", "data": "AAAA", "mimeType": "image/jpeg", + } + + +def test_mcp_content_text_block_serialises_to_mcp_shape(): + assert MCPContent.text_block("hello").to_dict() == { + "type": "text", "text": "hello", + } + + +def test_tools_call_returns_image_content_when_handler_yields_image(): + image_payload = MCPContent.image_block("ZmFrZQ==") # base64 "fake" + tool = MCPTool( + name="snap", description="snap", + input_schema={"type": "object", "properties": {}}, + handler=lambda: image_payload, + ) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "snap", "arguments": {}, + }))) + blocks = response["result"]["content"] + assert blocks == [{"type": "image", "data": "ZmFrZQ==", + "mimeType": "image/png"}] + + +def test_tools_call_passes_through_multi_content_lists(): + payload = [MCPContent.image_block("Zm9v"), + MCPContent.text_block("saved: /tmp/x.png")] + tool = MCPTool( + name="snap2", description="snap2", + input_schema={"type": "object", "properties": {}}, + handler=lambda: payload, + ) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "snap2", "arguments": {}, + }))) + blocks = response["result"]["content"] + assert len(blocks) == 2 + assert blocks[0]["type"] == "image" + assert blocks[1]["type"] == "text" + + +def test_recording_tools_present_in_default_registry(): + names = {tool.name for tool in build_default_tool_registry()} + assert { + "ac_record_start", "ac_record_stop", + "ac_read_action_file", "ac_write_action_file", + "ac_trim_actions", "ac_adjust_delays", "ac_scale_coordinates", + }.issubset(names) + + +def test_trim_actions_tool_returns_subset(tmp_path): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + actions = [["AC_a", {}], ["AC_b", {}], ["AC_c", {}], ["AC_d", {}]] + trimmed = by_name["ac_trim_actions"].invoke( + {"actions": actions, "start": 1, "end": 3} + ) + assert trimmed == [["AC_b", {}], ["AC_c", {}]] + + +def test_scale_coordinates_tool_scales_x_y(): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + actions = [["AC_click_mouse", {"x": 100, "y": 200, + "mouse_keycode": "mouse_left"}]] + scaled = by_name["ac_scale_coordinates"].invoke( + {"actions": actions, "x_factor": 2.0, "y_factor": 0.5} + ) + assert scaled[0][1]["x"] == 200 + assert scaled[0][1]["y"] == 100 + + +def test_write_and_read_action_file_round_trip(tmp_path): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + target = tmp_path / "actions.json" + actions = [["AC_click_mouse", {"mouse_keycode": "mouse_left"}]] + saved_path = by_name["ac_write_action_file"].invoke( + {"file_path": str(target), "actions": actions} + ) + assert target.exists() + loaded = by_name["ac_read_action_file"].invoke({"file_path": saved_path}) + assert loaded == actions + + +def test_drag_and_send_tools_present_in_default_registry(): + names = {tool.name for tool in build_default_tool_registry()} + assert {"ac_drag", "ac_send_key_to_window", + "ac_send_mouse_to_window"}.issubset(names) + + +def test_drag_tool_calls_press_move_release_in_order(monkeypatch): + """ac_drag must press at the start, move, then release at the end.""" + by_name = {tool.name: tool for tool in build_default_tool_registry()} + calls = [] + + def fake_set(x, y): + calls.append(("set", int(x), int(y))) + return (int(x), int(y)) + + def fake_press(keycode, x=None, y=None): + calls.append(("press", keycode, x, y)) + return (keycode, x, y) + + def fake_release(keycode, x=None, y=None): + calls.append(("release", keycode, x, y)) + return (keycode, x, y) + + import je_auto_control.wrapper.auto_control_mouse as mouse_module + monkeypatch.setattr(mouse_module, "set_mouse_position", fake_set) + monkeypatch.setattr(mouse_module, "press_mouse", fake_press) + monkeypatch.setattr(mouse_module, "release_mouse", fake_release) + + result = by_name["ac_drag"].invoke({ + "start_x": 10, "start_y": 20, + "end_x": 100, "end_y": 200, + "mouse_keycode": "mouse_left", + }) + assert result == [100, 200] + sequence = [step[0] for step in calls] + # set start → press → set end → release + assert sequence == ["set", "press", "set", "release"] + assert calls[0][1:] == (10, 20) + assert calls[2][1:] == (100, 200) + + +def test_semantic_locator_tools_present_in_default_registry(): + names = {tool.name for tool in build_default_tool_registry()} + assert {"ac_a11y_list", "ac_a11y_find", "ac_a11y_click", + "ac_vlm_locate", "ac_vlm_click"}.issubset(names) + + +def test_a11y_find_tool_returns_none_when_backend_has_no_match(monkeypatch): + """Find delegates to the headless API; null result must round-trip.""" + by_name = {tool.name: tool for tool in build_default_tool_registry()} + import je_auto_control.utils.accessibility.accessibility_api as api + monkeypatch.setattr(api, "find_accessibility_element", + lambda **_kwargs: None) + assert by_name["ac_a11y_find"].invoke({"name": "ghost"}) is None + + +def test_automation_tools_present_in_default_registry(): + names = {tool.name for tool in build_default_tool_registry()} + assert { + "ac_scheduler_add_job", "ac_scheduler_remove_job", + "ac_scheduler_list_jobs", "ac_scheduler_start", "ac_scheduler_stop", + "ac_trigger_add", "ac_trigger_remove", "ac_trigger_list", + "ac_trigger_start", "ac_trigger_stop", + "ac_hotkey_bind", "ac_hotkey_unbind", "ac_hotkey_list", + "ac_hotkey_daemon_start", "ac_hotkey_daemon_stop", + }.issubset(names) + + +def test_scheduler_add_job_round_trips_through_default_scheduler(tmp_path): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + script = tmp_path / "noop.json" + script.write_text("[]", encoding="utf-8") + record = by_name["ac_scheduler_add_job"].invoke({ + "script_path": str(script), "interval_seconds": 60.0, + "repeat": False, "job_id": "test_mcp_job", + }) + try: + assert record["job_id"] == "test_mcp_job" + assert math.isclose(record["interval_seconds"], 60.0) + listed = {job["job_id"] for job in + by_name["ac_scheduler_list_jobs"].invoke({})} + assert "test_mcp_job" in listed + finally: + by_name["ac_scheduler_remove_job"].invoke({"job_id": "test_mcp_job"}) + + +def test_trigger_add_image_kind_records_trigger(tmp_path): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + from je_auto_control.utils.triggers.trigger_engine import default_trigger_engine + script = tmp_path / "noop.json" + script.write_text("[]", encoding="utf-8") + image = tmp_path / "hit.png" + image.write_bytes(b"\x89PNG\r\n\x1a\n") + record = by_name["ac_trigger_add"].invoke({ + "kind": "image", "script_path": str(script), + "image_path": str(image), "threshold": 0.5, + }) + try: + assert record["type"] == "ImageAppearsTrigger" + assert any(t.trigger_id == record["trigger_id"] + for t in default_trigger_engine.list_triggers()) + finally: + by_name["ac_trigger_remove"].invoke( + {"trigger_id": record["trigger_id"]} + ) + + +def test_trigger_add_rejects_unknown_kind(tmp_path): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + script = tmp_path / "noop.json" + script.write_text("[]", encoding="utf-8") + import pytest + with pytest.raises(ValueError): + by_name["ac_trigger_add"].invoke({ + "kind": "telepathy", "script_path": str(script), + }) + + +class _StaticProvider(ResourceProvider): + """Test double — exposes one fixed resource.""" + + def __init__(self, resource: MCPResource, body: str) -> None: + self._resource = resource + self._body = body + + def list(self): + return [self._resource] + + def read(self, uri): + if uri != self._resource.uri: + return None + return {"uri": uri, "mimeType": self._resource.mime_type or "text/plain", + "text": self._body} + + +def test_initialize_advertises_resources_capability(): + server = MCPServer(tools=[], resource_provider=ChainProvider([])) + response = _decode(server.handle_line(_request("initialize", params={}))) + assert "resources" in response["result"]["capabilities"] + + +def test_resources_list_returns_provider_descriptors(): + resource = MCPResource(uri="autocontrol://demo", + name="demo", + description="static demo", + mime_type="text/plain") + server = MCPServer(tools=[], + resource_provider=_StaticProvider(resource, "hi")) + response = _decode(server.handle_line(_request("resources/list"))) + descriptors = response["result"]["resources"] + assert descriptors == [{ + "uri": "autocontrol://demo", "name": "demo", + "description": "static demo", "mimeType": "text/plain", + }] + + +def test_resources_read_returns_provider_content(): + resource = MCPResource(uri="autocontrol://demo", name="demo", + mime_type="text/plain") + server = MCPServer(tools=[], + resource_provider=_StaticProvider(resource, "hello")) + response = _decode(server.handle_line(_request("resources/read", params={ + "uri": "autocontrol://demo", + }))) + contents = response["result"]["contents"] + assert contents == [{"uri": "autocontrol://demo", + "mimeType": "text/plain", "text": "hello"}] + + +def test_resources_read_unknown_uri_returns_invalid_params(): + server = MCPServer(tools=[], resource_provider=ChainProvider([])) + response = _decode(server.handle_line(_request("resources/read", params={ + "uri": "autocontrol://missing", + }))) + assert response["error"]["code"] == -32602 + assert "Unknown resource" in response["error"]["message"] + + +def test_filesystem_provider_lists_action_files(tmp_path): + (tmp_path / "alpha.json").write_text("[]", encoding="utf-8") + (tmp_path / "ignore.txt").write_text("nope", encoding="utf-8") + (tmp_path / "beta.json").write_text("[]", encoding="utf-8") + provider = FileSystemProvider(root=str(tmp_path)) + listed = provider.list() + names = sorted(item.name for item in listed) + assert names == ["alpha.json", "beta.json"] + body = provider.read(listed[0].uri) + assert body is not None + assert body["mimeType"] == "application/json" + + +def test_filesystem_provider_rejects_path_traversal(tmp_path): + (tmp_path / "alpha.json").write_text("[]", encoding="utf-8") + provider = FileSystemProvider(root=str(tmp_path)) + assert provider.read("autocontrol://files/../etc/passwd") is None + assert provider.read("autocontrol://files/.hidden") is None + + +def test_initialize_advertises_prompts_capability(): + server = MCPServer(tools=[], + prompt_provider=StaticPromptProvider([])) + response = _decode(server.handle_line(_request("initialize", params={}))) + assert "prompts" in response["result"]["capabilities"] + + +def test_prompts_list_returns_descriptors(): + prompt = MCPPrompt( + name="hello", description="say hi", + arguments=[MCPPromptArgument("name", required=True)], + render=lambda args: f"hi {args['name']}", + ) + server = MCPServer(tools=[], + prompt_provider=StaticPromptProvider([prompt])) + response = _decode(server.handle_line(_request("prompts/list"))) + descriptors = response["result"]["prompts"] + assert descriptors == [{ + "name": "hello", "description": "say hi", + "arguments": [{"name": "name", "required": True}], + }] + + +def test_prompts_get_renders_message_with_arguments(): + prompt = MCPPrompt( + name="hello", description="say hi", + arguments=[MCPPromptArgument("name", required=True)], + render=lambda args: f"hi {args['name']}", + ) + server = MCPServer(tools=[], + prompt_provider=StaticPromptProvider([prompt])) + response = _decode(server.handle_line(_request("prompts/get", params={ + "name": "hello", "arguments": {"name": "Jeff"}, + }))) + payload = response["result"] + assert payload["messages"][0]["content"]["text"] == "hi Jeff" + + +def test_prompts_get_unknown_name_returns_error(): + server = MCPServer(tools=[], + prompt_provider=StaticPromptProvider([])) + response = _decode(server.handle_line(_request("prompts/get", params={ + "name": "missing", + }))) + assert response["error"]["code"] == -32602 + assert "Unknown prompt" in response["error"]["message"] + + +def test_prompts_get_missing_required_arg_returns_error(): + prompt = MCPPrompt( + name="hello", description="say hi", + arguments=[MCPPromptArgument("name", required=True)], + render=lambda args: f"hi {args['name']}", + ) + server = MCPServer(tools=[], + prompt_provider=StaticPromptProvider([prompt])) + response = _decode(server.handle_line(_request("prompts/get", params={ + "name": "hello", "arguments": {}, + }))) + assert response["error"]["code"] == -32602 + + +def test_default_prompt_catalogue_has_core_templates(): + names = {prompt.name for prompt in default_prompt_catalogue()} + assert {"automate_ui_task", "record_and_generalize", + "compare_screenshots", "find_widget", + "explain_action_file"}.issubset(names) + + +def test_progress_notifications_are_sent_when_token_provided(): + captured = [] + + def slow_handler(seconds, ctx): + del seconds + ctx.progress(0.0, total=1.0, message="starting") + ctx.progress(1.0, total=1.0, message="done") + return "ok" + + tool = MCPTool( + name="slow", description="slow", + input_schema={"type": "object", "properties": { + "seconds": {"type": "number"}}}, + handler=slow_handler, + ) + server = MCPServer(tools=[tool]) + server.set_notifier(lambda method, params: captured.append((method, params))) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "slow", "arguments": {"seconds": 0.0}, + "_meta": {"progressToken": "tok-1"}, + }))) + assert response["result"]["isError"] is False + methods = [event[0] for event in captured] + assert methods == ["notifications/progress", + "notifications/progress"] + assert captured[0][1] == {"progressToken": "tok-1", + "progress": 0.0, "total": 1.0, + "message": "starting"} + + +def test_progress_is_no_op_without_token(): + captured = [] + + def handler(ctx): + ctx.progress(0.5, total=1.0) + return "ok" + + tool = MCPTool( + name="silent", description="silent", + input_schema={"type": "object", "properties": {}}, + handler=handler, + ) + server = MCPServer(tools=[tool]) + server.set_notifier(lambda method, params: captured.append((method, params))) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "silent", "arguments": {}, + }))) + assert response["result"]["isError"] is False + assert captured == [] + + +def test_cancellation_notification_sets_context_flag(): + started = threading.Event() + proceed = threading.Event() + seen_cancel = [] + + def slow_handler(ctx): + started.set() + proceed.wait(timeout=2.0) + seen_cancel.append(ctx.cancelled) + ctx.check_cancelled() + return "ok" + + tool = MCPTool( + name="cancellable", description="cancellable", + input_schema={"type": "object", "properties": {}}, + handler=slow_handler, + ) + server = MCPServer(tools=[tool]) + + response_holder = {} + + def run_call(): + response_holder["raw"] = server.handle_line(_request( + "tools/call", msg_id=42, + params={"name": "cancellable", "arguments": {}}, + )) + + thread = threading.Thread(target=run_call) + thread.start() + assert started.wait(timeout=1.0) + server.handle_line(json.dumps({ + "jsonrpc": "2.0", "method": "notifications/cancelled", + "params": {"requestId": 42, "reason": "user clicked stop"}, + })) + proceed.set() + thread.join(timeout=2.0) + assert not thread.is_alive() + assert seen_cancel == [True] + decoded = _decode(response_holder["raw"]) + assert decoded["error"]["code"] == -32800 + + +def test_tool_call_context_check_cancelled_raises(): + ctx = ToolCallContext(request_id=7, progress_token=None) + ctx.check_cancelled() # not cancelled — no raise + ctx.cancelled_event.set() + try: + ctx.check_cancelled() + except OperationCancelledError as error: + assert error.request_id == 7 + else: + raise AssertionError("expected OperationCancelledError") + + +def test_request_sampling_round_trips_via_writer(): + """Tool calls sampling; we play the client and reply with a result.""" + captured_lines = [] + + def handler(prompt, ctx): + del ctx + reply = server.request_sampling( + messages=[{"role": "user", + "content": {"type": "text", "text": prompt}}], + max_tokens=64, + ) + return reply["content"]["text"] + + tool = MCPTool( + name="ask_model", description="ask", + input_schema={"type": "object", "properties": { + "prompt": {"type": "string"}}, "required": ["prompt"]}, + handler=handler, + ) + server = MCPServer(tools=[tool], concurrent_tools=True) + server.set_writer(captured_lines.append) + + server.handle_line(_request("tools/call", msg_id=10, params={ + "name": "ask_model", "arguments": {"prompt": "ping?"}, + })) + + # The worker is now blocked on sampling; wait for the outbound request. + deadline = threading.Event() + for _ in range(200): + if any('"sampling/createMessage"' in line for line in captured_lines): + break + deadline.wait(0.01) + sampling_lines = [line for line in captured_lines + if '"sampling/createMessage"' in line] + assert sampling_lines, "expected outbound sampling request" + sampling_request = json.loads(sampling_lines[-1]) + assert sampling_request["method"] == "sampling/createMessage" + sampling_id = sampling_request["id"] + + server.handle_line(json.dumps({ + "jsonrpc": "2.0", "id": sampling_id, + "result": {"role": "assistant", "model": "test-model", + "content": {"type": "text", "text": "pong"}}, + })) + + for _ in range(200): + if any('"id": 10' in line for line in captured_lines): + break + deadline.wait(0.01) + final_lines = [line for line in captured_lines if '"id": 10' in line] + assert final_lines, "expected tools/call reply on writer" + final = json.loads(final_lines[-1]) + assert final["result"]["isError"] is False + assert final["result"]["content"][0]["text"] == "pong" + + +def test_request_sampling_without_writer_raises(): + server = MCPServer(tools=[], concurrent_tools=True) + try: + server.request_sampling(messages=[ + {"role": "user", "content": {"type": "text", "text": "hi"}} + ]) + except RuntimeError as error: + assert "writer" in str(error) + else: + raise AssertionError("expected RuntimeError") + + +def test_initialize_advertises_sampling_capability(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("initialize", params={}))) + assert "sampling" in response["result"]["capabilities"] + + +def test_tools_call_rejects_missing_required_field(): + tool = MCPTool( + name="needs_x", description="needs x", + input_schema={ + "type": "object", + "properties": {"x": {"type": "integer"}}, + "required": ["x"], + }, + handler=lambda x: x * 2, + ) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "needs_x", "arguments": {}, + }))) + assert response["error"]["code"] == -32602 + assert "missing required property 'x'" in response["error"]["message"] + + +def test_tools_call_rejects_wrong_type(): + tool = MCPTool( + name="needs_int", description="needs int", + input_schema={ + "type": "object", + "properties": {"x": {"type": "integer"}}, + "required": ["x"], + }, + handler=lambda x: x, + ) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "needs_int", "arguments": {"x": "not-int"}, + }))) + assert response["error"]["code"] == -32602 + assert "expected integer" in response["error"]["message"] + + +def test_tools_call_rejects_value_outside_enum(): + tool = MCPTool( + name="enum_only", description="enum", + input_schema={ + "type": "object", + "properties": {"mode": {"type": "string", + "enum": ["a", "b"]}}, + "required": ["mode"], + }, + handler=lambda mode: mode, + ) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "enum_only", "arguments": {"mode": "c"}, + }))) + assert response["error"]["code"] == -32602 + + +def test_tools_call_passes_valid_args(): + tool = MCPTool( + name="adder", description="adder", + input_schema={ + "type": "object", + "properties": {"x": {"type": "integer"}, + "y": {"type": "integer"}}, + "required": ["x", "y"], + }, + handler=lambda x, y: x + y, + ) + server = MCPServer(tools=[tool]) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "adder", "arguments": {"x": 1, "y": 2}, + }))) + assert response["result"]["content"][0]["text"] == "3" + + +def test_initialize_advertises_tools_list_changed(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("initialize", params={}))) + assert response["result"]["capabilities"]["tools"]["listChanged"] is True + + +def test_register_tool_emits_list_changed_notification(): + captured = [] + server = MCPServer(tools=[]) + server.set_notifier(lambda method, params: captured.append((method, params))) + new_tool = MCPTool( + name="late", description="late", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "ok", + ) + server.register_tool(new_tool) + assert ("notifications/tools/list_changed", {}) in captured + + +def test_unregister_tool_emits_list_changed_notification(): + tool = MCPTool( + name="vanish", description="vanish", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "ok", + ) + captured = [] + server = MCPServer(tools=[tool]) + server.set_notifier(lambda method, params: captured.append((method, params))) + assert server.unregister_tool("vanish") is True + assert ("notifications/tools/list_changed", {}) in captured + assert server.unregister_tool("vanish") is False + + +def test_make_plugin_tool_derives_schema_from_signature(): + from je_auto_control.utils.mcp_server.tools.plugin_tools import ( + make_plugin_tool, + ) + + def AC_demo(text: str, count: int = 1) -> str: # noqa: N802 # NOSONAR S1542 reason: AutoControl plugin convention requires the AC_ prefix + """Demo plugin command.""" + return text * count + + tool = make_plugin_tool("AC_demo", AC_demo) + assert tool.name == "plugin_ac_demo" + assert tool.description.startswith("Demo plugin command") + assert tool.input_schema["properties"]["text"] == {"type": "string"} + assert tool.input_schema["properties"]["count"] == {"type": "integer"} + assert tool.input_schema.get("required") == ["text"] + + +def test_register_plugin_tools_adds_to_server_and_notifies(): + from je_auto_control.utils.mcp_server.tools.plugin_tools import ( + register_plugin_tools, + ) + + def AC_one(value: str) -> str: # noqa: N802 # NOSONAR S1542 reason: AutoControl plugin convention requires the AC_ prefix + return value.upper() + + captured = [] + server = MCPServer(tools=[]) + server.set_notifier(lambda method, params: captured.append((method, params))) + names = register_plugin_tools(server, {"AC_one": AC_one}) + assert names == ["plugin_ac_one"] + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "plugin_ac_one", "arguments": {"value": "hi"}, + }))) + assert response["result"]["content"][0]["text"] == "HI" + assert any(method == "notifications/tools/list_changed" + for method, _ in captured) + + +def test_diff_screenshots_finds_changed_region(tmp_path): + pytest = __import__("pytest") + np = pytest.importorskip("numpy") + pil_image = pytest.importorskip("PIL.Image") + + base = np.zeros((40, 60, 3), dtype="uint8") + other = base.copy() + other[10:20, 30:50, 0] = 255 # paint a 20x10 red rectangle + + path_a = tmp_path / "a.png" + path_b = tmp_path / "b.png" + pil_image.fromarray(base).save(path_a) + pil_image.fromarray(other).save(path_b) + + by_name = {tool.name: tool for tool in build_default_tool_registry()} + result = by_name["ac_diff_screenshots"].invoke({ + "image_path_a": str(path_a), + "image_path_b": str(path_b), + "threshold": 10, "min_box_pixels": 4, + }) + assert result["size"] == [60, 40] + assert result["boxes"], "expected at least one diff region" + # Bounding box should contain the painted rectangle (30..49, 10..19). + box = result["boxes"][0] + assert box[0] <= 30 and box[1] <= 10 + assert box[0] + box[2] >= 50 + assert box[1] + box[3] >= 20 + + +def test_diff_screenshots_returns_no_boxes_when_identical(tmp_path): + pytest = __import__("pytest") + np = pytest.importorskip("numpy") + pil_image = pytest.importorskip("PIL.Image") + img = np.full((20, 20, 3), 200, dtype="uint8") + path = tmp_path / "same.png" + pil_image.fromarray(img).save(path) + + by_name = {tool.name: tool for tool in build_default_tool_registry()} + result = by_name["ac_diff_screenshots"].invoke({ + "image_path_a": str(path), "image_path_b": str(path), + }) + assert result["boxes"] == [] + + +def test_screen_recording_tools_present_in_default_registry(): + names = {tool.name for tool in build_default_tool_registry()} + assert {"ac_screen_record_start", "ac_screen_record_stop", + "ac_screen_record_list"}.issubset(names) + + +def test_screen_record_start_validates_directory(tmp_path): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + missing = tmp_path / "nope" / "out.avi" + try: + by_name["ac_screen_record_start"].invoke({ + "recorder_name": "rec1", "file_path": str(missing), + }) + except ValueError as error: + assert "directory does not exist" in str(error) + else: + raise AssertionError("expected ValueError for missing dir") + + +def test_screen_record_list_starts_empty(monkeypatch): + """Force a fresh recorder so leftover state from other tests doesn't bleed in.""" + import je_auto_control.utils.mcp_server.tools._handlers as handlers + monkeypatch.setattr(handlers, "_screen_recorder_singleton", None) + by_name = {tool.name: tool for tool in build_default_tool_registry()} + assert by_name["ac_screen_record_list"].invoke({}) == [] + + +def test_list_monitors_returns_at_least_one_entry(): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + monitors = by_name["ac_list_monitors"].invoke({}) + assert isinstance(monitors, list) + assert monitors # mss always reports at least the virtual desktop + first = monitors[0] + assert first["index"] == 0 + assert first["is_combined"] is True + for key in ("left", "top", "width", "height"): + assert isinstance(first[key], int) + + +def test_screenshot_rejects_invalid_monitor_index(): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + try: + by_name["ac_screenshot"].invoke({"monitor_index": 999}) + except ValueError as error: + assert "out of range" in str(error) + else: + raise AssertionError("expected ValueError for bad monitor index") + + +def test_clipboard_image_tools_present_in_default_registry(): + names = {tool.name for tool in build_default_tool_registry()} + assert {"ac_get_clipboard_image", "ac_set_clipboard_image"}.issubset(names) + + +def test_get_clipboard_image_returns_text_block_when_empty(monkeypatch): + """When the clipboard has no image, return a clear text fallback.""" + import je_auto_control.utils.mcp_server.tools._handlers as handlers + import je_auto_control.utils.clipboard.clipboard_image as image_clip + monkeypatch.setattr(image_clip, "get_clipboard_image", lambda: None) + result = handlers.get_clipboard_image() + assert result[0].type == "text" + assert "does not contain an image" in result[0].text + + +def test_get_clipboard_image_returns_image_block_when_set(monkeypatch): + import je_auto_control.utils.mcp_server.tools._handlers as handlers + import je_auto_control.utils.clipboard.clipboard_image as image_clip + monkeypatch.setattr(image_clip, "get_clipboard_image", + lambda: b"\x89PNG\r\n\x1a\n") + result = handlers.get_clipboard_image() + assert result[0].type == "image" + assert result[0].mime_type == "image/png" + + +def test_set_clipboard_image_validates_existence(tmp_path): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + missing = tmp_path / "nope.png" + try: + by_name["ac_set_clipboard_image"].invoke({ + "image_path": str(missing), + }) + except FileNotFoundError as error: + assert "image not found" in str(error) + else: + raise AssertionError("expected FileNotFoundError") + + +def test_audit_logger_records_successful_tool_call(tmp_path): + from je_auto_control.utils.mcp_server.audit import AuditLogger + audit_path = tmp_path / "audit.jsonl" + audit = AuditLogger(path=str(audit_path)) + tool = MCPTool( + name="audited", description="audited", + input_schema={"type": "object", + "properties": {"x": {"type": "integer"}}, + "required": ["x"]}, + handler=lambda x: x * 2, + ) + server = MCPServer(tools=[tool], audit_logger=audit) + server.handle_line(_request("tools/call", params={ + "name": "audited", "arguments": {"x": 21}, + })) + lines = audit_path.read_text(encoding="utf-8").strip().splitlines() + assert len(lines) == 1 + record = json.loads(lines[0]) + assert record["tool"] == "audited" + assert record["status"] == "ok" + assert record["arguments"] == {"x": 21} + assert "duration_seconds" in record + + +def test_audit_logger_records_errors(tmp_path): + from je_auto_control.utils.mcp_server.audit import AuditLogger + + def raises(x): # pragma: no cover - called via tool + del x + raise ValueError("bad") + + audit = AuditLogger(path=str(tmp_path / "errs.jsonl")) + tool = MCPTool( + name="boom", description="boom", + input_schema={"type": "object", + "properties": {"x": {"type": "integer"}}, + "required": ["x"]}, + handler=raises, + ) + server = MCPServer(tools=[tool], audit_logger=audit) + server.handle_line(_request("tools/call", params={ + "name": "boom", "arguments": {"x": 1}, + })) + record = json.loads(audit.path and open(audit.path, encoding="utf-8").readline()) + assert record["status"] == "error" + assert "ValueError" in record["error"] + + +def test_audit_logger_redacts_sensitive_keys(tmp_path): + from je_auto_control.utils.mcp_server.audit import ( + AuditLogger, REDACTED_KEYS, REDACTED_PLACEHOLDER, + ) + sensitive_key = next(iter(REDACTED_KEYS)) + fake_value = "test-only-fixture-value" # NOSONAR S2068 reason: redaction test fixture, not a real credential + audit = AuditLogger(path=str(tmp_path / "audit.jsonl")) + tool = MCPTool( + name="creds", description="creds", + input_schema={"type": "object", + "properties": {sensitive_key: {"type": "string"}, + "user": {"type": "string"}}, + "required": [sensitive_key, "user"]}, + handler=lambda **kwargs: "ok", + ) + server = MCPServer(tools=[tool], audit_logger=audit) + server.handle_line(_request("tools/call", params={ + "name": "creds", + "arguments": {sensitive_key: fake_value, "user": "jeff"}, + })) + record = json.loads(open(audit.path, encoding="utf-8").readline()) + assert record["arguments"] == {sensitive_key: REDACTED_PLACEHOLDER, + "user": "jeff"} + + +def test_audit_logger_disabled_when_no_path(): + from je_auto_control.utils.mcp_server.audit import AuditLogger + audit = AuditLogger(path=None) + assert audit.enabled is False + audit.record(tool="x", arguments={}, status="ok", + duration_seconds=0.0) # must not raise + + +def test_rate_limiter_blocks_when_capacity_exhausted(): + from je_auto_control.utils.mcp_server.rate_limit import RateLimiter + limiter = RateLimiter(rate_per_sec=0.0001, capacity=2) + tool = MCPTool( + name="counted", description="counted", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "ok", + ) + server = MCPServer(tools=[tool], rate_limiter=limiter) + first = _decode(server.handle_line(_request("tools/call", msg_id=1, + params={"name": "counted", "arguments": {}}))) + second = _decode(server.handle_line(_request("tools/call", msg_id=2, + params={"name": "counted", "arguments": {}}))) + third = _decode(server.handle_line(_request("tools/call", msg_id=3, + params={"name": "counted", "arguments": {}}))) + assert first["result"]["isError"] is False + assert second["result"]["isError"] is False + assert third["error"]["code"] == -32000 + assert "Rate limit" in third["error"]["message"] + + +def test_rate_limiter_zero_rate_means_unlimited(): + from je_auto_control.utils.mcp_server.rate_limit import RateLimiter + limiter = RateLimiter(rate_per_sec=0) + for _ in range(5): + assert limiter.try_acquire() is True + + +def test_initialize_advertises_roots_when_client_supports_it(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("initialize", params={ + "capabilities": {"roots": {"listChanged": True}}, + }))) + assert "roots" in response["result"]["capabilities"] + + +def test_initialize_omits_roots_when_client_lacks_capability(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("initialize", params={ + "capabilities": {}, + }))) + assert "roots" not in response["result"]["capabilities"] + + +def test_refresh_roots_updates_filesystem_provider(tmp_path): + from je_auto_control.utils.mcp_server.resources import ( + ChainProvider, FileSystemProvider, + ) + fs_provider = FileSystemProvider(root=str(tmp_path / "initial")) + chain = ChainProvider([fs_provider]) + captured_lines = [] + server = MCPServer(tools=[], resource_provider=chain, + concurrent_tools=True) + server.set_writer(captured_lines.append) + # Simulate client capability so refresh is allowed. + server._client_capabilities = {"roots": {"listChanged": True}} + + target = tmp_path / "ws" + target.mkdir() + + def run_refresh(): + server.refresh_roots(timeout=2.0) + + t = threading.Thread(target=run_refresh) + t.start() + deadline = threading.Event() + for _ in range(200): + if any('"roots/list"' in line for line in captured_lines): + break + deadline.wait(0.01) + request_lines = [line for line in captured_lines + if '"roots/list"' in line] + assert request_lines, "expected outbound roots/list" + request_id = json.loads(request_lines[-1])["id"] + + file_uri = "file:///" + str(target).replace("\\", "/").lstrip("/") + server.handle_line(json.dumps({ + "jsonrpc": "2.0", "id": request_id, + "result": {"roots": [{"uri": file_uri, "name": "ws"}]}, + })) + t.join(timeout=2.0) + assert not t.is_alive() + assert os.path.realpath(fs_provider.root) == os.path.realpath(str(target)) + + +def test_initialize_advertises_logging_capability(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("initialize", params={}))) + assert "logging" in response["result"]["capabilities"] + + +def test_log_bridge_emits_notification_for_log_record(): + from je_auto_control.utils.mcp_server.log_bridge import MCPLogBridge + captured = [] + bridge = MCPLogBridge( + notifier=lambda method, params: captured.append((method, params)), + ) + import logging + record = logging.LogRecord( + name="je_auto_control.tests", level=logging.WARNING, + pathname=__file__, lineno=10, msg="something %s", args=("happened",), + exc_info=None, + ) + bridge.emit(record) + assert captured + method, params = captured[0] + assert method == "notifications/message" + assert params["level"] == "warning" + assert params["data"]["message"] == "something happened" + + +def test_logging_set_level_request_updates_bridge_level(): + from je_auto_control.utils.mcp_server.log_bridge import MCPLogBridge + server = MCPServer(tools=[], log_bridge=MCPLogBridge()) + response = _decode(server.handle_line(_request("logging/setLevel", params={ + "level": "error", + }))) + assert response["result"] == {} + import logging + assert server._log_bridge.level == logging.ERROR + + +def test_logging_set_level_rejects_unknown_name(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("logging/setLevel", params={ + "level": "telepathy", + }))) + assert response["error"]["code"] == -32602 + + +def test_wait_for_image_returns_center_when_template_found(monkeypatch): + import je_auto_control.utils.mcp_server.tools._handlers as handlers + import je_auto_control.wrapper.auto_control_image as image_module + monkeypatch.setattr(image_module, "locate_image_center", + lambda image_path, detect_threshold=1.0: (42, 84)) + by_name = {tool.name: tool for tool in build_default_tool_registry()} + coords = by_name["ac_wait_for_image"].invoke({ + "image_path": "needle.png", "timeout": 1.0, "poll": 0.05, + }) + assert coords == [42, 84] + + +def test_wait_for_image_times_out_when_template_missing(monkeypatch): + from je_auto_control.utils.exception.exceptions import ( + ImageNotFoundException, + ) + import je_auto_control.wrapper.auto_control_image as image_module + + def always_miss(image_path, detect_threshold=1.0): + raise ImageNotFoundException("nope") + + monkeypatch.setattr(image_module, "locate_image_center", always_miss) + by_name = {tool.name: tool for tool in build_default_tool_registry()} + try: + by_name["ac_wait_for_image"].invoke({ + "image_path": "missing.png", + "timeout": 0.2, "poll": 0.05, + }) + except TimeoutError as error: + assert "timed out" in str(error) + else: + raise AssertionError("expected TimeoutError") + + +def test_wait_for_pixel_returns_when_match(monkeypatch): + import je_auto_control.wrapper.auto_control_screen as screen_module + monkeypatch.setattr(screen_module, "get_pixel", + lambda x, y, hwnd=None: (255, 0, 0)) + by_name = {tool.name: tool for tool in build_default_tool_registry()} + rgb = by_name["ac_wait_for_pixel"].invoke({ + "x": 1, "y": 2, "target_rgb": [255, 0, 0], + "tolerance": 2, "timeout": 0.5, "poll": 0.05, + }) + assert rgb == [255, 0, 0] + + +def test_wait_for_pixel_times_out_when_color_never_matches(monkeypatch): + import je_auto_control.wrapper.auto_control_screen as screen_module + monkeypatch.setattr(screen_module, "get_pixel", + lambda x, y, hwnd=None: (0, 0, 0)) + by_name = {tool.name: tool for tool in build_default_tool_registry()} + try: + by_name["ac_wait_for_pixel"].invoke({ + "x": 1, "y": 2, "target_rgb": [255, 0, 0], + "tolerance": 2, "timeout": 0.2, "poll": 0.05, + }) + except TimeoutError: + pass + else: + raise AssertionError("expected TimeoutError") + + +def test_window_geometry_tools_present_in_default_registry(): + names = {tool.name for tool in build_default_tool_registry()} + assert {"ac_window_move", "ac_window_minimize", + "ac_window_maximize", "ac_window_restore"}.issubset(names) + + +def test_window_move_calls_into_windows_manager(monkeypatch): + import je_auto_control.utils.mcp_server.tools._handlers as handlers + import je_auto_control.wrapper.auto_control_window as window_module + monkeypatch.setattr(window_module, "find_window", + lambda title, case_sensitive=False: (123, title)) + captured = {} + + def fake_move(hwnd, x, y, width, height, repaint=True): + captured["call"] = (int(hwnd), int(x), int(y), + int(width), int(height)) + return True + + from je_auto_control.windows.window import windows_window_manage + monkeypatch.setattr(windows_window_manage, "move_window", fake_move) + + by_name = {tool.name: tool for tool in build_default_tool_registry()} + record = by_name["ac_window_move"].invoke({ + "title_substring": "Notepad", + "x": 10, "y": 20, "width": 800, "height": 600, + }) + assert captured["call"] == (123, 10, 20, 800, 600) + assert record == {"hwnd": 123, "x": 10, "y": 20, + "width": 800, "height": 600} + + +def test_window_minimize_uses_show_command_six(monkeypatch): + """ShowWindow flag 6 is SW_MINIMIZE.""" + import je_auto_control.wrapper.auto_control_window as window_module + monkeypatch.setattr(window_module, "find_window", + lambda title, case_sensitive=False: (456, title)) + seen = {} + + def fake_show(hwnd, cmd_show): + seen["call"] = (int(hwnd), int(cmd_show)) + + from je_auto_control.windows.window import windows_window_manage + monkeypatch.setattr(windows_window_manage, "show_window", fake_show) + + by_name = {tool.name: tool for tool in build_default_tool_registry()} + by_name["ac_window_minimize"].invoke({"title_substring": "Notepad"}) + assert seen["call"] == (456, 6) + + +def test_process_tools_present_in_default_registry(): + names = {tool.name for tool in build_default_tool_registry()} + assert {"ac_launch_process", "ac_list_processes", + "ac_kill_process", "ac_shell"}.issubset(names) + + +def test_shell_command_returns_exit_code_and_stdout(): + """Run a portable command and verify the shape.""" + import sys as _sys + by_name = {tool.name: tool for tool in build_default_tool_registry()} + result = by_name["ac_shell"].invoke({ + "command": f"{_sys.executable} -V", + "timeout": 5.0, + }) + assert result["exit_code"] == 0 + # Python 3.4+ prints the version to stdout. + assert "Python" in (result["stdout"] + result["stderr"]) + + +def test_shell_command_rejects_empty(): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + try: + by_name["ac_shell"].invoke({"command": " "}) + except ValueError: + pass + else: + raise AssertionError("expected ValueError for empty command") + + +def test_launch_process_validates_working_directory(tmp_path): + import sys as _sys + by_name = {tool.name: tool for tool in build_default_tool_registry()} + missing = tmp_path / "ghost" + try: + by_name["ac_launch_process"].invoke({ + "argv": [_sys.executable, "-V"], + "working_directory": str(missing), + }) + except ValueError as error: + assert "does not exist" in str(error) + else: + raise AssertionError("expected ValueError") + + +def test_launch_process_rejects_empty_argv(): + by_name = {tool.name: tool for tool in build_default_tool_registry()} + try: + by_name["ac_launch_process"].invoke({"argv": []}) + except ValueError: + pass + else: + raise AssertionError("expected ValueError for empty argv") + + +def test_auto_screenshot_on_error_skipped_when_env_unset(monkeypatch, tmp_path): + monkeypatch.delenv("JE_AUTOCONTROL_MCP_ERROR_SHOTS", raising=False) + from je_auto_control.utils.mcp_server.audit import AuditLogger + audit = AuditLogger(path=str(tmp_path / "audit.jsonl")) + + def boom(x): + raise RuntimeError("nope") + + tool = MCPTool( + name="boom", description="boom", + input_schema={"type": "object", "properties": { + "x": {"type": "integer"}}, "required": ["x"]}, + handler=boom, + ) + server = MCPServer(tools=[tool], audit_logger=audit) + server.handle_line(_request("tools/call", params={ + "name": "boom", "arguments": {"x": 1}, + })) + record = json.loads(open(audit.path, encoding="utf-8").readline()) + assert "artifact_path" not in record + + +def test_auto_screenshot_on_error_writes_file_when_env_set( + monkeypatch, tmp_path): + """When the env var is set we capture a screenshot via pil_screenshot.""" + debug_dir = tmp_path / "shots" + monkeypatch.setenv("JE_AUTOCONTROL_MCP_ERROR_SHOTS", str(debug_dir)) + + saved_paths = [] + + def fake_screenshot(file_path=None, screen_region=None): + saved_paths.append(file_path) + # Touch the file so the audit record's path actually exists. + if file_path is not None: + open(file_path, "wb").close() + + class _Stub: + def save(self, *_args, **_kwargs): + return None + + size = (1, 1) + return _Stub() + + import je_auto_control.utils.cv2_utils.screenshot as screenshot_module + monkeypatch.setattr(screenshot_module, "pil_screenshot", fake_screenshot) + + from je_auto_control.utils.mcp_server.audit import AuditLogger + audit = AuditLogger(path=str(tmp_path / "audit.jsonl")) + + def boom(x): + raise RuntimeError("nope") + + tool = MCPTool( + name="boom2", description="boom2", + input_schema={"type": "object", "properties": { + "x": {"type": "integer"}}, "required": ["x"]}, + handler=boom, + ) + server = MCPServer(tools=[tool], audit_logger=audit) + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "boom2", "arguments": {"x": 1}, + }))) + assert response["result"]["isError"] is True + assert "error screenshot saved to" in response["result"]["content"][0]["text"] + record = json.loads(open(audit.path, encoding="utf-8").readline()) + assert record["artifact_path"] + assert saved_paths + + +def test_default_registry_includes_short_aliases(): + names = {tool.name for tool in build_default_tool_registry(aliases=True)} + assert {"click", "type", "screenshot", "find_image", + "drag", "shell"}.issubset(names) + + +def test_alias_handler_dispatches_to_canonical_tool(): + by_name = {tool.name: tool + for tool in build_default_tool_registry(aliases=True)} + canonical = by_name["ac_get_mouse_position"] + alias = by_name["mouse_pos"] + assert alias.handler is canonical.handler + + +def test_aliases_env_flag_disables_them(monkeypatch): + monkeypatch.setenv("JE_AUTOCONTROL_MCP_ALIASES", "0") + names = {tool.name for tool in build_default_tool_registry()} + assert "click" not in names + assert "ac_click_mouse" in names + + +def test_aliases_excluded_from_read_only_registry(): + names = {tool.name + for tool in build_default_tool_registry(read_only=True, + aliases=True)} + # Read-only filter runs before alias expansion, so destructive aliases drop. + assert "click" not in names + # Read-only canonical tools get their aliases. + assert "mouse_pos" in names + + +def test_initialize_advertises_resources_subscribe_capability(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request("initialize", params={}))) + caps = response["result"]["capabilities"] + assert caps["resources"]["subscribe"] is True + + +def test_resources_subscribe_and_notification_round_trip(): + """Subscribe to a fake resource and verify the server forwards updates.""" + from je_auto_control.utils.mcp_server.resources import ( + ChainProvider, MCPResource, ResourceProvider, + ) + + class _FakeProvider(ResourceProvider): + URI = "fake://live" + + def __init__(self): + self.callback = None + + def list(self): + return [MCPResource(uri=self.URI, name="fake")] + + def read(self, uri): + if uri == self.URI: + return {"uri": uri, "mimeType": "text/plain", "text": "hi"} + return None + + def subscribe(self, uri, on_update): + if uri != self.URI: + return None + self.callback = on_update + return "fake-handle" + + def unsubscribe(self, uri, handle): + self.callback = None + + fake = _FakeProvider() + chain = ChainProvider([fake]) + captured = [] + server = MCPServer(tools=[], resource_provider=chain) + server.set_notifier(lambda method, params: captured.append((method, params))) + + sub_response = _decode(server.handle_line(_request( + "resources/subscribe", params={"uri": "fake://live"}, + ))) + assert sub_response["result"] == {} + assert fake.callback is not None + + # Simulate the provider noticing fresh content. + fake.callback() + + methods = [event[0] for event in captured] + assert "notifications/resources/updated" in methods + update = next(e for e in captured if e[0] == "notifications/resources/updated") + assert update[1] == {"uri": "fake://live"} + + unsub_response = _decode(server.handle_line(_request( + "resources/unsubscribe", params={"uri": "fake://live"}, + ))) + assert unsub_response["result"] == {} + assert fake.callback is None + + +def test_resources_subscribe_rejects_unknown_uri(): + server = MCPServer(tools=[]) + response = _decode(server.handle_line(_request( + "resources/subscribe", params={"uri": "fake://nowhere"}, + ))) + assert response["error"]["code"] == -32602 + + +def test_destructive_confirmation_blocks_when_user_declines(monkeypatch): + monkeypatch.setenv("JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE", "1") + captured_lines = [] + tool = MCPTool( + name="zap", description="zap", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "should not run", + ) + server = MCPServer(tools=[tool], concurrent_tools=True) + server.set_writer(captured_lines.append) + server._client_capabilities = {"elicitation": {}} + + def run_call(): + server.handle_line(_request("tools/call", msg_id=11, params={ + "name": "zap", "arguments": {}, + })) + + t = threading.Thread(target=run_call) + t.start() + deadline = threading.Event() + for _ in range(200): + if any('"elicitation/create"' in line for line in captured_lines): + break + deadline.wait(0.01) + eli_lines = [line for line in captured_lines + if '"elicitation/create"' in line] + assert eli_lines, "expected elicitation/create" + eli_id = json.loads(eli_lines[-1])["id"] + + server.handle_line(json.dumps({ + "jsonrpc": "2.0", "id": eli_id, + "result": {"action": "decline"}, + })) + t.join(timeout=2.0) + assert not t.is_alive() + final_lines = [line for line in captured_lines if '"id": 11' in line] + assert final_lines + final = json.loads(final_lines[-1]) + assert final["error"]["code"] == -32000 + assert "declined" in final["error"]["message"] + + +def test_destructive_confirmation_allows_when_user_accepts(monkeypatch): + monkeypatch.setenv("JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE", "1") + captured_lines = [] + tool = MCPTool( + name="zap2", description="zap2", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "ran", + ) + server = MCPServer(tools=[tool], concurrent_tools=True) + server.set_writer(captured_lines.append) + server._client_capabilities = {"elicitation": {}} + + def run_call(): + server.handle_line(_request("tools/call", msg_id=12, params={ + "name": "zap2", "arguments": {}, + })) + + t = threading.Thread(target=run_call) + t.start() + deadline = threading.Event() + for _ in range(200): + if any('"elicitation/create"' in line for line in captured_lines): + break + deadline.wait(0.01) + eli_id = json.loads([line for line in captured_lines + if '"elicitation/create"' in line][-1])["id"] + + server.handle_line(json.dumps({ + "jsonrpc": "2.0", "id": eli_id, + "result": {"action": "accept", "content": {}}, + })) + t.join(timeout=2.0) + final = json.loads([line for line in captured_lines + if '"id": 12' in line][-1]) + assert final["result"]["isError"] is False + assert final["result"]["content"][0]["text"] == "ran" + + +def test_destructive_confirmation_skipped_when_client_lacks_capability(monkeypatch): + monkeypatch.setenv("JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE", "1") + tool = MCPTool( + name="zap3", description="zap3", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "ok", + ) + server = MCPServer(tools=[tool]) + # Client did not advertise elicitation — server proceeds without asking. + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "zap3", "arguments": {}, + }))) + assert response["result"]["isError"] is False + + +def test_destructive_confirmation_skipped_when_env_unset(monkeypatch): + monkeypatch.delenv("JE_AUTOCONTROL_MCP_CONFIRM_DESTRUCTIVE", raising=False) + tool = MCPTool( + name="zap4", description="zap4", + input_schema={"type": "object", "properties": {}}, + handler=lambda: "ok", + ) + server = MCPServer(tools=[tool]) + server._client_capabilities = {"elicitation": {}} + response = _decode(server.handle_line(_request("tools/call", params={ + "name": "zap4", "arguments": {}, + }))) + assert response["result"]["isError"] is False + + +def test_default_registry_lists_core_automation_tools(): + names = {tool.name for tool in build_default_tool_registry()} + expected = { + "ac_click_mouse", "ac_get_mouse_position", "ac_set_mouse_position", + "ac_type_text", "ac_press_key", "ac_hotkey", + "ac_screen_size", "ac_screenshot", + "ac_locate_image_center", "ac_locate_text", + "ac_get_clipboard", "ac_set_clipboard", + "ac_execute_actions", "ac_list_action_commands", + } + assert expected.issubset(names)