From fad15396c32054172c93c6a97e25ced9d74ca675 Mon Sep 17 00:00:00 2001 From: Gregory Zak Date: Thu, 5 Mar 2026 16:39:40 -0800 Subject: [PATCH] feat(adapters): add mcp_tool_interceptor() to AxonFlowLangGraphAdapter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a factory method that returns an async MCP tool interceptor ready for use with MultiServerMCPClient(tool_interceptors=[...]). The interceptor enforces AxonFlow input and output policies around every MCP tool call: mcp_check_input → handler() → mcp_check_output. Two design decisions incorporated: - Redacted output passthrough: when mcp_check_output returns redacted_data, the interceptor substitutes it for the original result rather than discarding it silently. - Pluggable connector type derivation: accepts an optional connector_type_fn callable to override the default "{server_name}.{tool_name}" mapping, allowing callers to match their AxonFlow policy connector type keys. Closes #107 feat: add mcp_check_input and mcp_check_output methods (#103) --- CHANGELOG.md | 12 ++ axonflow/adapters/__init__.py | 9 +- axonflow/adapters/langgraph.py | 89 +++++++++++- tests/test_langgraph_adapter.py | 238 ++++++++++++++++++++++++++++++++ 4 files changed, 346 insertions(+), 2 deletions(-) create mode 100644 tests/test_langgraph_adapter.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a18ed6f..8e124d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ All notable changes to the AxonFlow Python SDK will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [3.9.0] - 2026-03-05 + +### Added + +- **`AxonFlowLangGraphAdapter.mcp_tool_interceptor()`**: Factory method returning an async callable ready for use with `MultiServerMCPClient(tool_interceptors=[...])`. Enforces AxonFlow input and output policies around every MCP tool call: `mcp_check_input → handler() → mcp_check_output`. Handles policy blocks and returns redacted output when `mcp_check_output` applies redaction. + - **`MCPInterceptorOptions`**: Configuration dataclass accepted by `mcp_tool_interceptor()` with two fields: + - `connector_type_fn`: Optional callable to override the default `"{server_name}.{tool_name}"` connector type mapping + - `operation`: Operation type forwarded to `mcp_check_input` (default: `"execute"`; use `"query"` for known read-only tool calls) + - `MCPInterceptorOptions` and `WorkflowApprovalRequiredError` are now exported from `axonflow.adapters` + +--- + ## [3.8.0] - 2026-03-03 ### Added diff --git a/axonflow/adapters/__init__.py b/axonflow/adapters/__init__.py index 7007da1..3a7c5ed 100644 --- a/axonflow/adapters/__init__.py +++ b/axonflow/adapters/__init__.py @@ -1,8 +1,15 @@ """AxonFlow adapters for external orchestrators.""" -from axonflow.adapters.langgraph import AxonFlowLangGraphAdapter, WorkflowBlockedError +from axonflow.adapters.langgraph import ( + AxonFlowLangGraphAdapter, + MCPInterceptorOptions, + WorkflowApprovalRequiredError, + WorkflowBlockedError, +) __all__ = [ "AxonFlowLangGraphAdapter", + "MCPInterceptorOptions", + "WorkflowApprovalRequiredError", "WorkflowBlockedError", ] diff --git a/axonflow/adapters/langgraph.py b/axonflow/adapters/langgraph.py index 6a088f2..c422124 100644 --- a/axonflow/adapters/langgraph.py +++ b/axonflow/adapters/langgraph.py @@ -33,8 +33,10 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Callable +from axonflow.exceptions import PolicyViolationError from axonflow.workflow import ( ApprovalStatus, CreateWorkflowRequest, @@ -82,6 +84,21 @@ def __init__( self.reason = reason +@dataclass +class MCPInterceptorOptions: + """Options for :meth:`AxonFlowLangGraphAdapter.mcp_tool_interceptor`. + + Attributes: + connector_type_fn: Optional callable that maps an MCP request to a + connector type string. Defaults to ``"{server_name}.{tool_name}"``. + operation: Operation type passed to ``mcp_check_input``. Defaults to + ``"execute"``. Set to ``"query"`` for known read-only tool calls. + """ + + connector_type_fn: Callable[[Any], str] | None = field(default=None) + operation: str = field(default="execute") + + class AxonFlowLangGraphAdapter: """Wraps LangGraph workflows with AxonFlow governance gates. @@ -490,6 +507,76 @@ async def wait_for_approval( msg = f"Approval timeout after {timeout}s for step {step_id}" raise TimeoutError(msg) + def mcp_tool_interceptor( + self, + options: MCPInterceptorOptions | None = None, + ) -> Callable: + """Return an async MCP tool interceptor for use with MultiServerMCPClient. + + The interceptor enforces AxonFlow input and output policies around every + MCP tool call. Pass the result directly to MultiServerMCPClient's + ``tool_interceptors`` parameter: + + Example: + >>> mcp_client = MultiServerMCPClient( + ... {"my-server": {"url": "...", "transport": "http"}}, + ... tool_interceptors=[adapter.mcp_tool_interceptor()], + ... ) + + With custom options: + + Example: + >>> opts = MCPInterceptorOptions( + ... connector_type_fn=lambda req: req.server_name, + ... operation="query", + ... ) + >>> tool_interceptors=[adapter.mcp_tool_interceptor(opts)] + + Args: + options: Optional :class:`MCPInterceptorOptions` controlling connector + type derivation and operation type. Uses defaults if not provided. + + Returns: + An async callable ``(request, handler) -> result`` suitable for + ``MultiServerMCPClient(tool_interceptors=[...])``. + """ + opts = options or MCPInterceptorOptions() + + def _default_connector_type(request: Any) -> str: + return f"{request.server_name}.{request.name}" + + resolve_connector_type = opts.connector_type_fn or _default_connector_type + + async def _interceptor(request: Any, handler: Callable) -> Any: + connector_type = resolve_connector_type(request) + statement = f"{connector_type}({request.args!r})" + + pre_check = await self.client.mcp_check_input( + connector_type=connector_type, + statement=statement, + operation=opts.operation, + parameters=request.args, + ) + if not pre_check.allowed: + raise PolicyViolationError(pre_check.block_reason or "Tool call blocked by policy") + + result = await handler(request) + + output_check = await self.client.mcp_check_output( + connector_type=connector_type, + message=f"{{result: {result!r}}}", + ) + if not output_check.allowed: + raise PolicyViolationError( + output_check.block_reason or "Tool result blocked by policy" + ) + if output_check.redacted_data is not None: + return output_check.redacted_data + + return result + + return _interceptor + async def __aenter__(self) -> AxonFlowLangGraphAdapter: """Context manager entry.""" return self diff --git a/tests/test_langgraph_adapter.py b/tests/test_langgraph_adapter.py new file mode 100644 index 0000000..3e1929f --- /dev/null +++ b/tests/test_langgraph_adapter.py @@ -0,0 +1,238 @@ +"""Unit tests for AxonFlowLangGraphAdapter.mcp_tool_interceptor().""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from axonflow import AxonFlow +from axonflow.adapters.langgraph import AxonFlowLangGraphAdapter, MCPInterceptorOptions +from axonflow.exceptions import PolicyViolationError +from axonflow.types import MCPCheckInputResponse, MCPCheckOutputResponse # noqa: TC001 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_request( + server_name: str = "myserver", name: str = "mytool", args: dict | None = None +) -> Any: + req = MagicMock() + req.server_name = server_name + req.name = name + req.args = args or {"key": "value"} + return req + + +def _input_allowed() -> MCPCheckInputResponse: + return MCPCheckInputResponse(allowed=True, policies_evaluated=1) + + +def _input_blocked(reason: str = "Blocked by policy") -> MCPCheckInputResponse: + return MCPCheckInputResponse(allowed=False, block_reason=reason, policies_evaluated=1) + + +def _output_allowed(redacted_data: Any = None) -> MCPCheckOutputResponse: + return MCPCheckOutputResponse(allowed=True, policies_evaluated=1, redacted_data=redacted_data) + + +def _output_blocked(reason: str = "Output blocked") -> MCPCheckOutputResponse: + return MCPCheckOutputResponse(allowed=False, block_reason=reason, policies_evaluated=1) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestMCPToolInterceptor: + @pytest.fixture + def client(self) -> AxonFlow: + c = MagicMock(spec=AxonFlow) + c.mcp_check_input = AsyncMock(return_value=_input_allowed()) + c.mcp_check_output = AsyncMock(return_value=_output_allowed()) + return c + + @pytest.fixture + def adapter(self, client: AxonFlow) -> AxonFlowLangGraphAdapter: + return AxonFlowLangGraphAdapter(client, "test-workflow") + + # --- happy path --- + + @pytest.mark.asyncio + async def test_allowed_call_returns_handler_result( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + handler = AsyncMock(return_value="tool-result") + request = _make_request() + + result = await adapter.mcp_tool_interceptor()(request, handler) + + assert result == "tool-result" + handler.assert_awaited_once_with(request) + + @pytest.mark.asyncio + async def test_connector_type_derived_from_request( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + handler = AsyncMock(return_value="ok") + request = _make_request(server_name="srv", name="tool") + + await adapter.mcp_tool_interceptor()(request, handler) + + client.mcp_check_input.assert_awaited_once() + call_kwargs = client.mcp_check_input.call_args.kwargs + assert call_kwargs["connector_type"] == "srv.tool" + assert call_kwargs["parameters"] == request.args + + @pytest.mark.asyncio + async def test_same_connector_type_sent_to_check_output( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + handler = AsyncMock(return_value="ok") + request = _make_request(server_name="srv", name="tool") + + await adapter.mcp_tool_interceptor()(request, handler) + + call_kwargs = client.mcp_check_output.call_args.kwargs + assert call_kwargs["connector_type"] == "srv.tool" + + # --- operation --- + + @pytest.mark.asyncio + async def test_default_operation_is_execute( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + await adapter.mcp_tool_interceptor()(_make_request(), AsyncMock(return_value="ok")) + + assert client.mcp_check_input.call_args.kwargs["operation"] == "execute" + + @pytest.mark.asyncio + async def test_custom_operation_forwarded( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + opts = MCPInterceptorOptions(operation="query") + await adapter.mcp_tool_interceptor(opts)(_make_request(), AsyncMock(return_value="ok")) + + assert client.mcp_check_input.call_args.kwargs["operation"] == "query" + + # --- input blocked --- + + @pytest.mark.asyncio + async def test_raises_on_blocked_input( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + client.mcp_check_input.return_value = _input_blocked("SQL injection detected") + handler = AsyncMock() + + with pytest.raises(PolicyViolationError, match="SQL injection detected"): + await adapter.mcp_tool_interceptor()(MagicMock(), handler) + + handler.assert_not_awaited() + + @pytest.mark.asyncio + async def test_blocked_input_uses_fallback_message_when_no_reason( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + client.mcp_check_input.return_value = MCPCheckInputResponse( + allowed=False, block_reason=None, policies_evaluated=1 + ) + with pytest.raises(PolicyViolationError, match="Tool call blocked by policy"): + await adapter.mcp_tool_interceptor()(MagicMock(), AsyncMock()) + + # --- output blocked --- + + @pytest.mark.asyncio + async def test_raises_on_blocked_output( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + client.mcp_check_output.return_value = _output_blocked("Exfiltration limit exceeded") + handler = AsyncMock(return_value="data") + + with pytest.raises(PolicyViolationError, match="Exfiltration limit exceeded"): + await adapter.mcp_tool_interceptor()(MagicMock(), handler) + + @pytest.mark.asyncio + async def test_blocked_output_uses_fallback_message_when_no_reason( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + client.mcp_check_output.return_value = MCPCheckOutputResponse( + allowed=False, block_reason=None, policies_evaluated=1 + ) + with pytest.raises(PolicyViolationError, match="Tool result blocked by policy"): + await adapter.mcp_tool_interceptor()(MagicMock(), AsyncMock(return_value="x")) + + # --- redacted output --- + + @pytest.mark.asyncio + async def test_returns_redacted_data_when_present( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + client.mcp_check_output.return_value = _output_allowed(redacted_data="[REDACTED]") + handler = AsyncMock(return_value="sensitive-data") + + result = await adapter.mcp_tool_interceptor()(MagicMock(), handler) + + assert result == "[REDACTED]" + + @pytest.mark.asyncio + async def test_returns_original_result_when_no_redaction( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + client.mcp_check_output.return_value = _output_allowed(redacted_data=None) + handler = AsyncMock(return_value="clean-data") + + result = await adapter.mcp_tool_interceptor()(MagicMock(), handler) + + assert result == "clean-data" + + # --- custom connector_type_fn --- + + @pytest.mark.asyncio + async def test_custom_connector_type_fn( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + handler = AsyncMock(return_value="ok") + request = _make_request(server_name="srv", name="tool") + opts = MCPInterceptorOptions(connector_type_fn=lambda req: req.server_name) + + await adapter.mcp_tool_interceptor(opts)(request, handler) + + call_kwargs = client.mcp_check_input.call_args.kwargs + assert call_kwargs["connector_type"] == "srv" + + @pytest.mark.asyncio + async def test_custom_connector_type_fn_also_used_for_output_check( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + handler = AsyncMock(return_value="ok") + request = _make_request(server_name="srv", name="tool") + opts = MCPInterceptorOptions(connector_type_fn=lambda req: "custom-type") + + await adapter.mcp_tool_interceptor(opts)(request, handler) + + assert client.mcp_check_output.call_args.kwargs["connector_type"] == "custom-type" + + # --- factory returns independent callables --- + + @pytest.mark.asyncio + async def test_each_call_to_factory_returns_independent_interceptor( + self, adapter: AxonFlowLangGraphAdapter, client: AxonFlow + ) -> None: + handler = AsyncMock(return_value="ok") + interceptor_a = adapter.mcp_tool_interceptor() + interceptor_b = adapter.mcp_tool_interceptor( + MCPInterceptorOptions(connector_type_fn=lambda r: "other") + ) + assert interceptor_a is not interceptor_b + + await interceptor_a(_make_request(), handler) + call_a = client.mcp_check_input.call_args_list[-1].kwargs["connector_type"] + + await interceptor_b(_make_request(), handler) + call_b = client.mcp_check_input.call_args_list[-1].kwargs["connector_type"] + + assert call_a != call_b