From 9a1db23f2f2d5206323767e12466a5ff44e7def8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 1 Dec 2025 21:18:27 +0000 Subject: [PATCH 1/4] Initial plan From d38891a6db62f2aeafeda366a102c4a66497c585 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 1 Dec 2025 21:30:31 +0000 Subject: [PATCH 2/4] Implement Pankit's review comments: rename types.py, add typed Tool and ToolError classes Co-authored-by: Darktex <890615+Darktex@users.noreply.github.com> --- examples/echo_mcp_demo.py | 6 +- examples/test_mcp_integration.py | 4 +- pyproject.toml | 4 ++ src/core/env_server/__init__.py | 8 ++- src/core/env_server/base_transforms.py | 2 +- src/core/env_server/http_server.py | 2 +- src/core/env_server/interfaces.py | 2 +- src/core/env_server/mcp_environment.py | 28 +++++++--- .../env_server/{types.py => mcp_types.py} | 56 +++++++++++++++++-- src/core/env_server/web_interface.py | 2 +- src/core/tools/local_python_executor.py | 2 +- src/envs/browsergym_env/models.py | 2 +- src/envs/chat_env/client.py | 2 +- src/envs/chat_env/models.py | 2 +- src/envs/coding_env/server/python_executor.py | 2 +- src/envs/coding_env/server/transforms.py | 2 +- src/envs/echo_env/client.py | 50 ++++++++++++++--- src/envs/echo_env/server/app.py | 4 +- src/envs/echo_env/server/echo_environment.py | 2 +- src/envs/finrl_env/client.py | 2 +- src/envs/finrl_env/models.py | 2 +- .../finrl_env/server/finrl_environment.py | 2 +- src/envs/textarena_env/models.py | 2 +- .../templates/openenv_env/client.py | 2 +- .../templates/openenv_env/models.py | 2 +- .../server/__ENV_NAME___environment.py | 2 +- tests/core/mcp/test_mcp.py | 4 +- 27 files changed, 151 insertions(+), 49 deletions(-) rename src/core/env_server/{types.py => mcp_types.py} (59%) diff --git a/examples/echo_mcp_demo.py b/examples/echo_mcp_demo.py index 8f465447..5213faad 100644 --- a/examples/echo_mcp_demo.py +++ b/examples/echo_mcp_demo.py @@ -10,9 +10,9 @@ import asyncio try: - from core.env_server.types import CallToolAction, ListToolsAction + from core.env_server.mcp_types import CallToolAction, ListToolsAction except ImportError: - from openenv_core.env_server.types import CallToolAction, ListToolsAction + from openenv_core.env_server.mcp_types import CallToolAction, ListToolsAction from envs.echo_env import EchoEnv @@ -34,7 +34,7 @@ async def main(): list_action = ListToolsAction() list_result = client.step(list_action) for tool in list_result.observation.tools: - print(f" - {tool['name']}: {tool['description']}") + print(f" - {tool.name}: {tool.description}") print() # Call echo_message tool using step API diff --git a/examples/test_mcp_integration.py b/examples/test_mcp_integration.py index 5b239353..932cb87d 100644 --- a/examples/test_mcp_integration.py +++ b/examples/test_mcp_integration.py @@ -6,7 +6,7 @@ sys.path.insert(0, 'src') from envs.echo_env.server.echo_environment import EchoEnvironment -from core.env_server.types import ListToolsAction, CallToolAction +from core.env_server.mcp_types import ListToolsAction, CallToolAction async def main(): @@ -26,7 +26,7 @@ async def main(): print(f" - Has 'tools' attribute: {hasattr(obs, 'tools')}") if hasattr(obs, "tools"): print(f" - Number of tools: {len(obs.tools)}") - print(f" - Tool names: {[t['name'] for t in obs.tools]}") + print(f" - Tool names: {[t.name for t in obs.tools]}") else: print(" - ERROR: No 'tools' attribute!") return False diff --git a/pyproject.toml b/pyproject.toml index 37d7400a..b9674b03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,3 +55,7 @@ exclude_lines = [ "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + +[tool.pytest.ini_options] +pythonpath = ["src"] +asyncio_mode = "auto" diff --git a/src/core/env_server/__init__.py b/src/core/env_server/__init__.py index a08e26c5..c9df2fbb 100644 --- a/src/core/env_server/__init__.py +++ b/src/core/env_server/__init__.py @@ -10,7 +10,7 @@ from .http_server import HTTPEnvServer, create_app, create_fastapi_app from .interfaces import Environment, Message, ModelTokenizer, Transform from .mcp_environment import MCPEnvironment -from .types import ( +from .mcp_types import ( Action, CallToolAction, CallToolObservation, @@ -18,6 +18,9 @@ ListToolsObservation, Observation, State, + Tool, + ToolError, + ToolErrorType, ) from .web_interface import create_web_interface_app, WebInterfaceManager @@ -36,6 +39,9 @@ "ListToolsObservation", "CallToolObservation", "State", + "Tool", + "ToolError", + "ToolErrorType", # Base transforms "CompositeTransform", "NullTransform", diff --git a/src/core/env_server/base_transforms.py b/src/core/env_server/base_transforms.py index d8165e3d..8d236f2f 100644 --- a/src/core/env_server/base_transforms.py +++ b/src/core/env_server/base_transforms.py @@ -7,7 +7,7 @@ """Base transform implementations for composing environment-specific transforms.""" from .interfaces import Transform -from .types import Observation +from .mcp_types import Observation class CompositeTransform(Transform): diff --git a/src/core/env_server/http_server.py b/src/core/env_server/http_server.py index 15748394..d159c1c9 100644 --- a/src/core/env_server/http_server.py +++ b/src/core/env_server/http_server.py @@ -23,7 +23,7 @@ from .interfaces import Environment from .mcp_environment import MCPEnvironment -from .types import Action, CallToolAction, ListToolsAction, Observation +from .mcp_types import Action, CallToolAction, ListToolsAction, Observation class HTTPEnvServer: diff --git a/src/core/env_server/interfaces.py b/src/core/env_server/interfaces.py index caa2d76d..8d4d7119 100644 --- a/src/core/env_server/interfaces.py +++ b/src/core/env_server/interfaces.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from typing import Any, Protocol, TypedDict -from .types import Action, Observation, State +from .mcp_types import Action, Observation, State class Message(TypedDict): diff --git a/src/core/env_server/mcp_environment.py b/src/core/env_server/mcp_environment.py index d0064d0d..67716b92 100644 --- a/src/core/env_server/mcp_environment.py +++ b/src/core/env_server/mcp_environment.py @@ -17,7 +17,7 @@ from typing import Any from .interfaces import Environment -from .types import Action, CallToolAction, ListToolsAction, Observation, State +from .mcp_types import Action, CallToolAction, ListToolsAction, Observation, State class MCPEnvironment(Environment): @@ -125,7 +125,13 @@ async def _handle_mcp_action(self, action: Action) -> Observation: Raises: ValueError: If MCP client not configured or action type invalid """ - from .types import CallToolObservation, ListToolsObservation + from .mcp_types import ( + CallToolObservation, + ListToolsObservation, + Tool, + ToolError, + ToolErrorType, + ) if self.mcp_client is None: raise ValueError("MCP client not configured for this environment") @@ -136,11 +142,11 @@ async def _handle_mcp_action(self, action: Action) -> Observation: return ListToolsObservation( done=False, tools=[ - { - "name": tool.name, - "description": tool.description, - "inputSchema": tool.inputSchema, - } + Tool( + name=tool.name, + description=tool.description or "", + input_schema=tool.inputSchema, + ) for tool in tools ], ) @@ -157,7 +163,13 @@ async def _handle_mcp_action(self, action: Action) -> Observation: ) except Exception as e: return CallToolObservation( - done=False, error=str(e), tool_name=action.tool_name + done=False, + result=None, + tool_name=action.tool_name, + error=ToolError( + error_type=ToolErrorType.EXECUTION_ERROR, + message=str(e), + ), ) else: diff --git a/src/core/env_server/types.py b/src/core/env_server/mcp_types.py similarity index 59% rename from src/core/env_server/types.py rename to src/core/env_server/mcp_types.py index 7f58973e..8c51d3fe 100644 --- a/src/core/env_server/types.py +++ b/src/core/env_server/mcp_types.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field +from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -12,6 +13,46 @@ Scalar = Union[int, float, bool] +class ToolErrorType(Enum): + """Types of errors that can occur during tool execution.""" + + INVALID_ARGUMENTS = "invalid_arguments" + TOOL_NOT_FOUND = "tool_not_found" + TRANSPORT_ERROR = "transport_error" + EXECUTION_ERROR = "execution_error" + TIMEOUT = "timeout" + INTERNAL_ERROR = "internal_error" + + +@dataclass +class ToolError: + """ + Structured error information for tool call failures. + + This captures errors at the infrastructure/transport level, not errors + that are part of the tool's normal result (those should be in the result field). + """ + + error_type: ToolErrorType + message: str + details: Optional[Dict[str, Any]] = None + + +@dataclass +class Tool: + """ + Strongly typed representation of an MCP tool. + + Follows the MCP specification for tool definitions with JSON Schema + for input/output validation. + """ + + name: str + description: str + input_schema: Dict[str, Any] + output_schema: Optional[Dict[str, Any]] = None + + @dataclass(kw_only=True) class Action: """Base class for all environment actions.""" @@ -57,10 +98,11 @@ class ListToolsObservation(Observation): """ Observation returned from ListToolsAction. - Contains the list of available tools with their schemas. + Contains the list of available tools with their schemas, following + the MCP specification format. """ - tools: List[Dict[str, Any]] = field(default_factory=list) + tools: List[Tool] = field(default_factory=list) @dataclass(kw_only=True) @@ -68,12 +110,14 @@ class CallToolObservation(Observation): """ Observation returned from CallToolAction. - Contains the result of calling a tool, or an error if the call failed. + Contains the result of calling a tool. The result field contains the tool's + output (including any tool-level errors). The error field is used only for + infrastructure-level errors (invalid args, transport issues, etc.). """ - result: Optional[Any] = None - error: Optional[str] = None - tool_name: Optional[str] = None + tool_name: str + result: Any + error: Optional[ToolError] = None @dataclass diff --git a/src/core/env_server/web_interface.py b/src/core/env_server/web_interface.py index 3c36aa1d..41e78a61 100644 --- a/src/core/env_server/web_interface.py +++ b/src/core/env_server/web_interface.py @@ -25,7 +25,7 @@ from pydantic import BaseModel from .interfaces import Environment -from .types import Action, Observation, State, EnvironmentMetadata +from .mcp_types import Action, Observation, State, EnvironmentMetadata def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata: diff --git a/src/core/tools/local_python_executor.py b/src/core/tools/local_python_executor.py index 1ebcf6b6..76365ae0 100644 --- a/src/core/tools/local_python_executor.py +++ b/src/core/tools/local_python_executor.py @@ -28,7 +28,7 @@ from smolagents import LocalPythonExecutor -from core.env_server.types import CodeExecResult +from core.env_server.mcp_types import CodeExecResult logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) diff --git a/src/envs/browsergym_env/models.py b/src/envs/browsergym_env/models.py index 1c68cef6..163b4453 100644 --- a/src/envs/browsergym_env/models.py +++ b/src/envs/browsergym_env/models.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from typing import List, Optional -from openenv_core.env_server.types import Action, Observation, State +from openenv_core.env_server.mcp_types import Action, Observation, State @dataclass(kw_only=True) diff --git a/src/envs/chat_env/client.py b/src/envs/chat_env/client.py index 96e5927f..0de21860 100644 --- a/src/envs/chat_env/client.py +++ b/src/envs/chat_env/client.py @@ -17,7 +17,7 @@ from core.client_types import StepResult from core.env_server.interfaces import Message -from core.env_server.types import State +from core.env_server.mcp_types import State from core.http_env_client import HTTPEnvClient from .models import ChatAction, ChatObservation, ChatState diff --git a/src/envs/chat_env/models.py b/src/envs/chat_env/models.py index 321565ed..6323eb0b 100644 --- a/src/envs/chat_env/models.py +++ b/src/envs/chat_env/models.py @@ -16,7 +16,7 @@ import torch from core.env_server.interfaces import Message -from core.env_server.types import Action, Observation, State +from core.env_server.mcp_types import Action, Observation, State @dataclass diff --git a/src/envs/coding_env/server/python_executor.py b/src/envs/coding_env/server/python_executor.py index 17b6ecc1..b8002384 100644 --- a/src/envs/coding_env/server/python_executor.py +++ b/src/envs/coding_env/server/python_executor.py @@ -27,7 +27,7 @@ from smolagents import LocalPythonExecutor -from openenv_core.env_server.types import CodeExecResult +from openenv_core.env_server.mcp_types import CodeExecResult logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) diff --git a/src/envs/coding_env/server/transforms.py b/src/envs/coding_env/server/transforms.py index ee5a1c4b..a990164d 100644 --- a/src/envs/coding_env/server/transforms.py +++ b/src/envs/coding_env/server/transforms.py @@ -11,7 +11,7 @@ from openenv_core.env_server.base_transforms import CompositeTransform from openenv_core.env_server.interfaces import Transform -from openenv_core.env_server.types import Observation +from openenv_core.env_server.mcp_types import Observation from coding_env.models import CodeObservation diff --git a/src/envs/echo_env/client.py b/src/envs/echo_env/client.py index e49e2185..3508529e 100644 --- a/src/envs/echo_env/client.py +++ b/src/envs/echo_env/client.py @@ -15,22 +15,28 @@ try: from core.client_types import StepResult - from core.env_server.types import ( + from core.env_server.mcp_types import ( CallToolAction, CallToolObservation, ListToolsObservation, Observation, State, + Tool, + ToolError, + ToolErrorType, ) from core.http_env_client import HTTPEnvClient except ImportError: from openenv_core.client_types import StepResult - from openenv_core.env_server.types import ( + from openenv_core.env_server.mcp_types import ( CallToolAction, CallToolObservation, ListToolsObservation, Observation, State, + Tool, + ToolError, + ToolErrorType, ) from openenv_core.http_env_client import HTTPEnvClient @@ -43,7 +49,7 @@ class EchoEnv(HTTPEnvClient[CallToolAction, Observation]): methods to interact with it using MCP actions. Example: - >>> from core.env_server.types import CallToolAction + >>> from core.env_server.mcp_types import CallToolAction >>> # Connect to a running server >>> client = EchoEnv(base_url="http://localhost:8000") >>> result = client.reset() @@ -54,7 +60,7 @@ class EchoEnv(HTTPEnvClient[CallToolAction, Observation]): >>> print(result.observation.result) # {"echoed_message": "Hello!"} Example with Docker: - >>> from core.env_server.types import CallToolAction + >>> from core.env_server.mcp_types import CallToolAction >>> # Automatically start container and connect >>> client = EchoEnv.from_docker_image("echo-env:latest") >>> result = client.reset() @@ -92,20 +98,50 @@ def _parse_result(self, payload: Dict) -> StepResult[Observation]: # Create appropriate typed observation based on fields present if "tools" in obs_data: + # Parse tools into Tool objects + tools = [] + for tool_data in obs_data.get("tools", []): + if isinstance(tool_data, dict): + tools.append( + Tool( + name=tool_data.get("name", ""), + description=tool_data.get("description", ""), + input_schema=tool_data.get("input_schema", {}), + output_schema=tool_data.get("output_schema"), + ) + ) + elif isinstance(tool_data, Tool): + tools.append(tool_data) + observation = ListToolsObservation( done=obs_data.get("done", False), reward=obs_data.get("reward"), metadata=obs_data.get("metadata", {}), - tools=obs_data.get("tools", []), + tools=tools, ) elif "result" in obs_data or "error" in obs_data or "tool_name" in obs_data: + # Parse error into ToolError if present + error = None + error_data = obs_data.get("error") + if error_data: + if isinstance(error_data, dict): + error = ToolError( + error_type=ToolErrorType( + error_data.get("error_type", "internal_error") + ), + message=error_data.get("message", str(error_data)), + details=error_data.get("details"), + ) + elif isinstance(error_data, ToolError): + error = error_data + observation = CallToolObservation( done=obs_data.get("done", False), reward=obs_data.get("reward"), metadata=obs_data.get("metadata", {}), result=obs_data.get("result"), - error=obs_data.get("error"), - tool_name=obs_data.get("tool_name"), + error=error, + tool_name=obs_data.get("tool_name", ""), ) else: observation = Observation( diff --git a/src/envs/echo_env/server/app.py b/src/envs/echo_env/server/app.py index 4f34dc14..b7a7618c 100644 --- a/src/envs/echo_env/server/app.py +++ b/src/envs/echo_env/server/app.py @@ -23,12 +23,12 @@ try: from core.env_server import create_app - from core.env_server.types import Action, Observation + from core.env_server.mcp_types import Action, Observation from .echo_environment import EchoEnvironment except ImportError: from openenv_core.env_server import create_app - from openenv_core.env_server.types import Action, Observation + from openenv_core.env_server.mcp_types import Action, Observation from server.echo_environment import EchoEnvironment diff --git a/src/envs/echo_env/server/echo_environment.py b/src/envs/echo_env/server/echo_environment.py index e652adcd..b51104ce 100644 --- a/src/envs/echo_env/server/echo_environment.py +++ b/src/envs/echo_env/server/echo_environment.py @@ -30,7 +30,7 @@ class EchoEnvironment(MCPEnvironment): Example: >>> from envs.echo_env.server import EchoEnvironment >>> from core.env_server import create_fastapi_app - >>> from core.env_server.types import Action, Observation + >>> from core.env_server.mcp_types import Action, Observation >>> >>> env = EchoEnvironment() >>> app = create_fastapi_app(env, Action, Observation) diff --git a/src/envs/finrl_env/client.py b/src/envs/finrl_env/client.py index 0b6468ae..2dcaac0a 100644 --- a/src/envs/finrl_env/client.py +++ b/src/envs/finrl_env/client.py @@ -15,7 +15,7 @@ from core.client_types import StepResult -from core.env_server.types import State +from core.env_server.mcp_types import State from core.http_env_client import HTTPEnvClient from .models import FinRLAction, FinRLObservation diff --git a/src/envs/finrl_env/models.py b/src/envs/finrl_env/models.py index d841c0c8..0b96d202 100644 --- a/src/envs/finrl_env/models.py +++ b/src/envs/finrl_env/models.py @@ -13,7 +13,7 @@ from dataclasses import dataclass, field -from core.env_server.types import Action, Observation +from core.env_server.mcp_types import Action, Observation @dataclass(kw_only=True) diff --git a/src/envs/finrl_env/server/finrl_environment.py b/src/envs/finrl_env/server/finrl_environment.py index 6cae2dba..88f9d501 100644 --- a/src/envs/finrl_env/server/finrl_environment.py +++ b/src/envs/finrl_env/server/finrl_environment.py @@ -14,7 +14,7 @@ import numpy as np from core.env_server.interfaces import Environment -from core.env_server.types import State +from core.env_server.mcp_types import State from ..models import FinRLAction, FinRLObservation diff --git a/src/envs/textarena_env/models.py b/src/envs/textarena_env/models.py index 4fea2c17..f5acd50c 100644 --- a/src/envs/textarena_env/models.py +++ b/src/envs/textarena_env/models.py @@ -11,7 +11,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional -from core.env_server.types import Action, Observation, State +from core.env_server.mcp_types import Action, Observation, State @dataclass diff --git a/src/openenv_cli/templates/openenv_env/client.py b/src/openenv_cli/templates/openenv_env/client.py index 34d35267..b2f3fb3e 100644 --- a/src/openenv_cli/templates/openenv_env/client.py +++ b/src/openenv_cli/templates/openenv_env/client.py @@ -14,7 +14,7 @@ from typing import Any, Dict from openenv_core.client_types import StepResult -from openenv_core.env_server.types import State +from openenv_core.env_server.mcp_types import State from openenv_core.http_env_client import HTTPEnvClient from .models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation diff --git a/src/openenv_cli/templates/openenv_env/models.py b/src/openenv_cli/templates/openenv_env/models.py index c2e40616..a3ec4847 100644 --- a/src/openenv_cli/templates/openenv_env/models.py +++ b/src/openenv_cli/templates/openenv_env/models.py @@ -12,7 +12,7 @@ from dataclasses import dataclass -from openenv_core.env_server.types import Action, Observation +from openenv_core.env_server.mcp_types import Action, Observation @dataclass(kw_only=True) diff --git a/src/openenv_cli/templates/openenv_env/server/__ENV_NAME___environment.py b/src/openenv_cli/templates/openenv_env/server/__ENV_NAME___environment.py index 63df6c01..4391847e 100644 --- a/src/openenv_cli/templates/openenv_env/server/__ENV_NAME___environment.py +++ b/src/openenv_cli/templates/openenv_env/server/__ENV_NAME___environment.py @@ -14,7 +14,7 @@ from uuid import uuid4 from openenv_core.env_server.interfaces import Environment -from openenv_core.env_server.types import State +from openenv_core.env_server.mcp_types import State from models import __ENV_CLASS_NAME__Action, __ENV_CLASS_NAME__Observation diff --git a/tests/core/mcp/test_mcp.py b/tests/core/mcp/test_mcp.py index cb1b3d6c..f8409b41 100644 --- a/tests/core/mcp/test_mcp.py +++ b/tests/core/mcp/test_mcp.py @@ -6,7 +6,7 @@ import pytest -from core.env_server.types import CallToolAction, ListToolsAction +from core.env_server.mcp_types import CallToolAction, ListToolsAction from fastmcp import Client, FastMCP @@ -51,7 +51,7 @@ async def test_echo_env_mcp_integration(): assert not obs.done assert hasattr(obs, "tools") assert len(obs.tools) == 1 - assert obs.tools[0]["name"] == "echo_message" + assert obs.tools[0].name == "echo_message" # Test CallToolAction call_action = CallToolAction( From a6181d1b138e70b6846d7b891789b8aa496b56dc Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 1 Dec 2025 21:32:12 +0000 Subject: [PATCH 3/4] Fix ToolErrorType enum value error handling in client.py Co-authored-by: Darktex <890615+Darktex@users.noreply.github.com> --- src/envs/echo_env/client.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/envs/echo_env/client.py b/src/envs/echo_env/client.py index 3508529e..3daa4e92 100644 --- a/src/envs/echo_env/client.py +++ b/src/envs/echo_env/client.py @@ -125,10 +125,13 @@ def _parse_result(self, payload: Dict) -> StepResult[Observation]: error_data = obs_data.get("error") if error_data: if isinstance(error_data, dict): + error_type_str = error_data.get("error_type", "internal_error") + try: + error_type = ToolErrorType(error_type_str) + except ValueError: + error_type = ToolErrorType.INTERNAL_ERROR error = ToolError( - error_type=ToolErrorType( - error_data.get("error_type", "internal_error") - ), + error_type=error_type, message=error_data.get("message", str(error_data)), details=error_data.get("details"), ) From 4843088f2f1d67e3c22985deef811fc6c89080de Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 1 Dec 2025 22:02:36 +0000 Subject: [PATCH 4/4] Add comprehensive tests for MCP types (Tool, ToolError, observations) Co-authored-by: Darktex <890615+Darktex@users.noreply.github.com> --- tests/conftest.py | 24 ++ tests/core/__init__.py | 0 tests/core/mcp/__init__.py | 0 .../test_mcp_integration.py} | 5 +- tests/core/test_mcp_module/test_mcp_types.py | 285 ++++++++++++++++++ 5 files changed, 312 insertions(+), 2 deletions(-) create mode 100644 tests/conftest.py delete mode 100644 tests/core/__init__.py delete mode 100644 tests/core/mcp/__init__.py rename tests/core/{mcp/test_mcp.py => test_mcp_module/test_mcp_integration.py} (91%) create mode 100644 tests/core/test_mcp_module/test_mcp_types.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..093193ff --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Pytest configuration for OpenEnv tests. + +This file adds the src directory to sys.path so that tests can import +core, envs, and other modules from the src directory. + +NOTE: Do not create __init__.py files in test directories that have +the same name as source directories (e.g., tests/core/) to avoid +import conflicts. +""" + +import sys +from pathlib import Path + +# Add src to path for tests to find core and envs modules +_src_path = str(Path(__file__).resolve().parent.parent / "src") +if _src_path not in sys.path: + sys.path.insert(0, _src_path) diff --git a/tests/core/__init__.py b/tests/core/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/core/mcp/__init__.py b/tests/core/mcp/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/core/mcp/test_mcp.py b/tests/core/test_mcp_module/test_mcp_integration.py similarity index 91% rename from tests/core/mcp/test_mcp.py rename to tests/core/test_mcp_module/test_mcp_integration.py index f8409b41..f303ad5f 100644 --- a/tests/core/mcp/test_mcp.py +++ b/tests/core/test_mcp_module/test_mcp_integration.py @@ -33,8 +33,9 @@ async def add(a: int, b: int) -> int: # Test call_tool result = await client.call_tool("add", {"a": 5, "b": 3}) - # FastMCP returns the raw result from the function - assert result == 8 + # FastMCP wraps results in CallToolResult, access via .data + result_data = result.data if hasattr(result, "data") else result + assert result_data == 8 @pytest.mark.asyncio diff --git a/tests/core/test_mcp_module/test_mcp_types.py b/tests/core/test_mcp_module/test_mcp_types.py new file mode 100644 index 00000000..6c182690 --- /dev/null +++ b/tests/core/test_mcp_module/test_mcp_types.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for MCP types: Tool, ToolError, ToolErrorType, and related observations. + +These tests validate the strongly typed MCP types work correctly. +""" + +import pytest +from dataclasses import asdict + +from core.env_server.mcp_types import ( + Tool, + ToolError, + ToolErrorType, + ListToolsObservation, + CallToolObservation, + ListToolsAction, + CallToolAction, +) + + +class TestTool: + """Tests for the Tool dataclass.""" + + def test_tool_creation_basic(self): + """Test basic Tool creation with required fields.""" + tool = Tool( + name="test_tool", + description="A test tool", + input_schema={"type": "object", "properties": {"arg1": {"type": "string"}}}, + ) + assert tool.name == "test_tool" + assert tool.description == "A test tool" + assert tool.input_schema == {"type": "object", "properties": {"arg1": {"type": "string"}}} + assert tool.output_schema is None + + def test_tool_creation_with_output_schema(self): + """Test Tool creation with optional output_schema.""" + tool = Tool( + name="echo", + description="Echo a message", + input_schema={"type": "object", "properties": {"message": {"type": "string"}}}, + output_schema={"type": "object", "properties": {"echoed": {"type": "string"}}}, + ) + assert tool.name == "echo" + assert tool.output_schema == {"type": "object", "properties": {"echoed": {"type": "string"}}} + + def test_tool_serialization(self): + """Test that Tool can be serialized via asdict.""" + tool = Tool( + name="add", + description="Add two numbers", + input_schema={"type": "object"}, + ) + serialized = asdict(tool) + assert serialized == { + "name": "add", + "description": "Add two numbers", + "input_schema": {"type": "object"}, + "output_schema": None, + } + + def test_tool_equality(self): + """Test Tool equality comparison.""" + tool1 = Tool(name="tool", description="desc", input_schema={}) + tool2 = Tool(name="tool", description="desc", input_schema={}) + tool3 = Tool(name="other", description="desc", input_schema={}) + + assert tool1 == tool2 + assert tool1 != tool3 + + +class TestToolError: + """Tests for the ToolError dataclass.""" + + def test_tool_error_creation(self): + """Test basic ToolError creation.""" + error = ToolError( + error_type=ToolErrorType.EXECUTION_ERROR, + message="Something went wrong", + ) + assert error.error_type == ToolErrorType.EXECUTION_ERROR + assert error.message == "Something went wrong" + assert error.details is None + + def test_tool_error_with_details(self): + """Test ToolError with additional details.""" + error = ToolError( + error_type=ToolErrorType.INVALID_ARGUMENTS, + message="Missing required argument", + details={"missing_field": "arg1", "provided_fields": ["arg2"]}, + ) + assert error.error_type == ToolErrorType.INVALID_ARGUMENTS + assert error.details == {"missing_field": "arg1", "provided_fields": ["arg2"]} + + def test_tool_error_serialization(self): + """Test that ToolError can be serialized via asdict.""" + error = ToolError( + error_type=ToolErrorType.TIMEOUT, + message="Request timed out", + ) + serialized = asdict(error) + assert serialized == { + "error_type": ToolErrorType.TIMEOUT, + "message": "Request timed out", + "details": None, + } + + +class TestToolErrorType: + """Tests for the ToolErrorType enum.""" + + def test_all_error_types_exist(self): + """Test all expected error types are defined.""" + expected_types = [ + "INVALID_ARGUMENTS", + "TOOL_NOT_FOUND", + "TRANSPORT_ERROR", + "EXECUTION_ERROR", + "TIMEOUT", + "INTERNAL_ERROR", + ] + for error_type in expected_types: + assert hasattr(ToolErrorType, error_type) + + def test_error_type_values(self): + """Test error type string values.""" + assert ToolErrorType.INVALID_ARGUMENTS.value == "invalid_arguments" + assert ToolErrorType.TOOL_NOT_FOUND.value == "tool_not_found" + assert ToolErrorType.TRANSPORT_ERROR.value == "transport_error" + assert ToolErrorType.EXECUTION_ERROR.value == "execution_error" + assert ToolErrorType.TIMEOUT.value == "timeout" + assert ToolErrorType.INTERNAL_ERROR.value == "internal_error" + + def test_error_type_from_value(self): + """Test creating ToolErrorType from string value.""" + assert ToolErrorType("execution_error") == ToolErrorType.EXECUTION_ERROR + assert ToolErrorType("internal_error") == ToolErrorType.INTERNAL_ERROR + + +class TestListToolsObservation: + """Tests for the ListToolsObservation dataclass.""" + + def test_list_tools_observation_empty(self): + """Test ListToolsObservation with empty tools list.""" + obs = ListToolsObservation() + assert obs.done is False + assert obs.tools == [] + + def test_list_tools_observation_with_tools(self): + """Test ListToolsObservation with typed Tool objects.""" + tools = [ + Tool(name="tool1", description="First tool", input_schema={}), + Tool(name="tool2", description="Second tool", input_schema={"type": "object"}), + ] + obs = ListToolsObservation(tools=tools) + + assert len(obs.tools) == 2 + assert obs.tools[0].name == "tool1" + assert obs.tools[1].name == "tool2" + assert isinstance(obs.tools[0], Tool) + assert isinstance(obs.tools[1], Tool) + + def test_list_tools_observation_serialization(self): + """Test ListToolsObservation serialization.""" + tool = Tool(name="echo", description="Echo tool", input_schema={}) + obs = ListToolsObservation(tools=[tool]) + serialized = asdict(obs) + + assert "tools" in serialized + assert len(serialized["tools"]) == 1 + assert serialized["tools"][0]["name"] == "echo" + + +class TestCallToolObservation: + """Tests for the CallToolObservation dataclass.""" + + def test_call_tool_observation_success(self): + """Test CallToolObservation for a successful tool call.""" + obs = CallToolObservation( + tool_name="echo", + result={"message": "Hello!"}, + ) + assert obs.tool_name == "echo" + assert obs.result == {"message": "Hello!"} + assert obs.error is None + assert obs.done is False + + def test_call_tool_observation_with_error(self): + """Test CallToolObservation with a ToolError.""" + error = ToolError( + error_type=ToolErrorType.TOOL_NOT_FOUND, + message="Tool 'missing' not found", + ) + obs = CallToolObservation( + tool_name="missing", + result=None, + error=error, + ) + assert obs.tool_name == "missing" + assert obs.result is None + assert obs.error is not None + assert obs.error.error_type == ToolErrorType.TOOL_NOT_FOUND + + def test_call_tool_observation_serialization(self): + """Test CallToolObservation serialization.""" + obs = CallToolObservation( + tool_name="add", + result={"sum": 10}, + ) + serialized = asdict(obs) + + assert serialized["tool_name"] == "add" + assert serialized["result"] == {"sum": 10} + assert serialized["error"] is None + + +class TestActions: + """Tests for MCP action types.""" + + def test_list_tools_action(self): + """Test ListToolsAction creation.""" + action = ListToolsAction() + assert action.metadata == {} + + def test_call_tool_action(self): + """Test CallToolAction creation.""" + action = CallToolAction( + tool_name="echo", + parameters={"message": "Hello"}, + ) + assert action.tool_name == "echo" + assert action.parameters == {"message": "Hello"} + + def test_call_tool_action_empty_parameters(self): + """Test CallToolAction with default empty parameters.""" + action = CallToolAction(tool_name="no_args_tool") + assert action.tool_name == "no_args_tool" + assert action.parameters == {} + + +@pytest.mark.asyncio +async def test_mcp_environment_returns_typed_tools(): + """Test that MCPEnvironment returns properly typed Tool objects.""" + from envs.echo_env.server.echo_environment import EchoEnvironment + + env = EchoEnvironment() + + # Get tools via ListToolsAction + list_action = ListToolsAction() + obs = await env._handle_mcp_action(list_action) + + # Verify we get Tool objects, not dicts + assert len(obs.tools) > 0 + for tool in obs.tools: + assert isinstance(tool, Tool) + assert isinstance(tool.name, str) + assert isinstance(tool.description, str) + assert isinstance(tool.input_schema, dict) + + +@pytest.mark.asyncio +async def test_mcp_environment_error_handling(): + """Test that MCPEnvironment properly handles errors with ToolError.""" + from envs.echo_env.server.echo_environment import EchoEnvironment + + env = EchoEnvironment() + + # Try calling a non-existent tool + action = CallToolAction( + tool_name="nonexistent_tool", + parameters={}, + ) + obs = await env._handle_mcp_action(action) + + # Should have an error with ToolError type + assert obs.error is not None + assert isinstance(obs.error, ToolError) + assert isinstance(obs.error.error_type, ToolErrorType) + assert obs.error.message # Should have an error message