Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/echo_mcp_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import asyncio

try:
from core.env_server.types import CallToolAction, ListToolsAction
from core.env_server.mcp_types import CallToolAction, ListToolsAction
except ImportError:
from openenv_core.env_server.types import CallToolAction, ListToolsAction
from openenv_core.env_server.mcp_types import CallToolAction, ListToolsAction

from envs.echo_env import EchoEnv

Expand All @@ -34,7 +34,7 @@ async def main():
list_action = ListToolsAction()
list_result = client.step(list_action)
for tool in list_result.observation.tools:
print(f" - {tool['name']}: {tool['description']}")
print(f" - {tool.name}: {tool.description}")
print()

# Call echo_message tool using step API
Expand Down
4 changes: 2 additions & 2 deletions examples/test_mcp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
sys.path.insert(0, 'src')

from envs.echo_env.server.echo_environment import EchoEnvironment
from core.env_server.types import ListToolsAction, CallToolAction
from core.env_server.mcp_types import ListToolsAction, CallToolAction


async def main():
Expand All @@ -26,7 +26,7 @@ async def main():
print(f" - Has 'tools' attribute: {hasattr(obs, 'tools')}")
if hasattr(obs, "tools"):
print(f" - Number of tools: {len(obs.tools)}")
print(f" - Tool names: {[t['name'] for t in obs.tools]}")
print(f" - Tool names: {[t.name for t in obs.tools]}")
else:
print(" - ERROR: No 'tools' attribute!")
return False
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ exclude_lines = [
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]

[tool.pytest.ini_options]
pythonpath = ["src"]
asyncio_mode = "auto"
8 changes: 7 additions & 1 deletion src/core/env_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
from .http_server import HTTPEnvServer, create_app, create_fastapi_app
from .interfaces import Environment, Message, ModelTokenizer, Transform
from .mcp_environment import MCPEnvironment
from .types import (
from .mcp_types import (
Action,
CallToolAction,
CallToolObservation,
ListToolsAction,
ListToolsObservation,
Observation,
State,
Tool,
ToolError,
ToolErrorType,
)
from .web_interface import create_web_interface_app, WebInterfaceManager

Expand All @@ -36,6 +39,9 @@
"ListToolsObservation",
"CallToolObservation",
"State",
"Tool",
"ToolError",
"ToolErrorType",
# Base transforms
"CompositeTransform",
"NullTransform",
Expand Down
2 changes: 1 addition & 1 deletion src/core/env_server/base_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""Base transform implementations for composing environment-specific transforms."""

from .interfaces import Transform
from .types import Observation
from .mcp_types import Observation


class CompositeTransform(Transform):
Expand Down
2 changes: 1 addition & 1 deletion src/core/env_server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .interfaces import Environment
from .mcp_environment import MCPEnvironment
from .types import Action, CallToolAction, ListToolsAction, Observation
from .mcp_types import Action, CallToolAction, ListToolsAction, Observation


class HTTPEnvServer:
Expand Down
2 changes: 1 addition & 1 deletion src/core/env_server/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABC, abstractmethod
from typing import Any, Protocol, TypedDict

from .types import Action, Observation, State
from .mcp_types import Action, Observation, State


class Message(TypedDict):
Expand Down
28 changes: 20 additions & 8 deletions src/core/env_server/mcp_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any

from .interfaces import Environment
from .types import Action, CallToolAction, ListToolsAction, Observation, State
from .mcp_types import Action, CallToolAction, ListToolsAction, Observation, State


class MCPEnvironment(Environment):
Expand Down Expand Up @@ -125,7 +125,13 @@ async def _handle_mcp_action(self, action: Action) -> Observation:
Raises:
ValueError: If MCP client not configured or action type invalid
"""
from .types import CallToolObservation, ListToolsObservation
from .mcp_types import (
CallToolObservation,
ListToolsObservation,
Tool,
ToolError,
ToolErrorType,
)

if self.mcp_client is None:
raise ValueError("MCP client not configured for this environment")
Expand All @@ -136,11 +142,11 @@ async def _handle_mcp_action(self, action: Action) -> Observation:
return ListToolsObservation(
done=False,
tools=[
{
"name": tool.name,
"description": tool.description,
"inputSchema": tool.inputSchema,
}
Tool(
name=tool.name,
description=tool.description or "",
input_schema=tool.inputSchema,
)
for tool in tools
],
)
Expand All @@ -157,7 +163,13 @@ async def _handle_mcp_action(self, action: Action) -> Observation:
)
except Exception as e:
return CallToolObservation(
done=False, error=str(e), tool_name=action.tool_name
done=False,
result=None,
tool_name=action.tool_name,
error=ToolError(
error_type=ToolErrorType.EXECUTION_ERROR,
message=str(e),
),
)

else:
Expand Down
56 changes: 50 additions & 6 deletions src/core/env_server/types.py → src/core/env_server/mcp_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,54 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Union


# Type aliases
Scalar = Union[int, float, bool]


class ToolErrorType(Enum):
"""Types of errors that can occur during tool execution."""

INVALID_ARGUMENTS = "invalid_arguments"
TOOL_NOT_FOUND = "tool_not_found"
TRANSPORT_ERROR = "transport_error"
EXECUTION_ERROR = "execution_error"
TIMEOUT = "timeout"
INTERNAL_ERROR = "internal_error"


@dataclass
class ToolError:
"""
Structured error information for tool call failures.

This captures errors at the infrastructure/transport level, not errors
that are part of the tool's normal result (those should be in the result field).
"""

error_type: ToolErrorType
message: str
details: Optional[Dict[str, Any]] = None


@dataclass
class Tool:
"""
Strongly typed representation of an MCP tool.

Follows the MCP specification for tool definitions with JSON Schema
for input/output validation.
"""

name: str
description: str
input_schema: Dict[str, Any]
output_schema: Optional[Dict[str, Any]] = None


@dataclass(kw_only=True)
class Action:
"""Base class for all environment actions."""
Expand Down Expand Up @@ -57,23 +98,26 @@ class ListToolsObservation(Observation):
"""
Observation returned from ListToolsAction.

Contains the list of available tools with their schemas.
Contains the list of available tools with their schemas, following
the MCP specification format.
"""

tools: List[Dict[str, Any]] = field(default_factory=list)
tools: List[Tool] = field(default_factory=list)


@dataclass(kw_only=True)
class CallToolObservation(Observation):
"""
Observation returned from CallToolAction.

Contains the result of calling a tool, or an error if the call failed.
Contains the result of calling a tool. The result field contains the tool's
output (including any tool-level errors). The error field is used only for
infrastructure-level errors (invalid args, transport issues, etc.).
"""

result: Optional[Any] = None
error: Optional[str] = None
tool_name: Optional[str] = None
tool_name: str
result: Any
error: Optional[ToolError] = None


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/core/env_server/web_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pydantic import BaseModel

from .interfaces import Environment
from .types import Action, Observation, State, EnvironmentMetadata
from .mcp_types import Action, Observation, State, EnvironmentMetadata


def load_environment_metadata(env: Environment, env_name: Optional[str] = None) -> EnvironmentMetadata:
Expand Down
2 changes: 1 addition & 1 deletion src/core/tools/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from smolagents import LocalPythonExecutor

from core.env_server.types import CodeExecResult
from core.env_server.mcp_types import CodeExecResult

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
2 changes: 1 addition & 1 deletion src/envs/browsergym_env/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from typing import List, Optional

from openenv_core.env_server.types import Action, Observation, State
from openenv_core.env_server.mcp_types import Action, Observation, State


@dataclass(kw_only=True)
Expand Down
2 changes: 1 addition & 1 deletion src/envs/chat_env/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from core.client_types import StepResult

from core.env_server.interfaces import Message
from core.env_server.types import State
from core.env_server.mcp_types import State
from core.http_env_client import HTTPEnvClient

from .models import ChatAction, ChatObservation, ChatState
Expand Down
2 changes: 1 addition & 1 deletion src/envs/chat_env/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from core.env_server.interfaces import Message
from core.env_server.types import Action, Observation, State
from core.env_server.mcp_types import Action, Observation, State


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/envs/coding_env/server/python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from smolagents import LocalPythonExecutor

from openenv_core.env_server.types import CodeExecResult
from openenv_core.env_server.mcp_types import CodeExecResult

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down
2 changes: 1 addition & 1 deletion src/envs/coding_env/server/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from openenv_core.env_server.base_transforms import CompositeTransform
from openenv_core.env_server.interfaces import Transform
from openenv_core.env_server.types import Observation
from openenv_core.env_server.mcp_types import Observation

from coding_env.models import CodeObservation

Expand Down
Loading