Skip to content

Commit 82acaf2

Browse files
committed
feat: request and response models for reset and step endpoints
1 parent ff1bd7c commit 82acaf2

File tree

3 files changed

+258
-143
lines changed

3 files changed

+258
-143
lines changed

src/core/env_server/http_server.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,24 @@
1414
from __future__ import annotations
1515

1616
import asyncio
17+
import inspect
1718
import os
1819
from concurrent.futures import ThreadPoolExecutor
19-
from typing import Any, Dict, Type, Optional
20+
from typing import Any, Dict, Optional, Type
2021

21-
from pydantic import ValidationError
2222
from fastapi import Body, FastAPI, HTTPException, status
23+
from pydantic import ValidationError
2324

2425
from .interfaces import Environment
25-
from .types import Action, Observation, State
26+
from .types import (
27+
Action,
28+
Observation,
29+
ResetRequest,
30+
ResetResponse,
31+
State,
32+
StepRequest,
33+
StepResponse,
34+
)
2635

2736

2837
class HTTPEnvServer:
@@ -81,21 +90,37 @@ def register_routes(self, app: Any) -> None:
8190
if not isinstance(app, FastAPI):
8291
raise TypeError("app must be a FastAPI instance")
8392

84-
@app.post("/reset")
85-
async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]:
93+
@app.post("/reset", response_model=ResetResponse)
94+
async def reset(
95+
request: ResetRequest = Body(default_factory=ResetRequest),
96+
) -> ResetResponse:
8697
"""Reset endpoint - returns initial observation."""
87-
# TODO: Handle seed, episode_id from request if provided
88-
# Run sync environment code in thread pool to avoid blocking asyncio loop
89-
loop = asyncio.get_event_loop()
90-
observation = await loop.run_in_executor(self._executor, self.env.reset)
91-
return self._serialize_observation(observation)
92-
93-
@app.post("/step")
94-
async def step(request: Dict[str, Any]) -> Dict[str, Any]:
98+
# Handle optional parameters
99+
kwargs = {}
100+
if request.seed is not None:
101+
kwargs["seed"] = request.seed
102+
if request.episode_id is not None:
103+
kwargs["episode_id"] = request.episode_id
104+
105+
# Pass arguments only if environment accepts them
106+
sig = inspect.signature(self.env.reset)
107+
valid_kwargs = {}
108+
109+
has_kwargs = any(
110+
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
111+
)
112+
113+
for k, v in kwargs.items():
114+
if k in sig.parameters or has_kwargs:
115+
valid_kwargs[k] = v
116+
117+
observation = self.env.reset(**valid_kwargs)
118+
return ResetResponse(**self._serialize_observation(observation))
119+
120+
@app.post("/step", response_model=StepResponse)
121+
async def step(request: StepRequest) -> StepResponse:
95122
"""Step endpoint - executes action and returns observation."""
96-
# Support both {"action": {...}} and direct action fields
97-
action_data = request.get("action", request)
98-
# TODO: Handle timeout_s, request_id, episode_id from request if provided
123+
action_data = request.action
99124

100125
# Deserialize action with Pydantic validation
101126
try:
@@ -106,20 +131,33 @@ async def step(request: Dict[str, Any]) -> Dict[str, Any]:
106131
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors()
107132
)
108133

109-
# Execute step in thread pool to avoid blocking asyncio loop
110-
loop = asyncio.get_event_loop()
111-
observation = await loop.run_in_executor(
112-
self._executor, self.env.step, action
134+
# Handle optional parameters
135+
kwargs = {}
136+
if request.timeout_s is not None:
137+
kwargs["timeout_s"] = request.timeout_s
138+
139+
# Pass arguments only if environment accepts them
140+
sig = inspect.signature(self.env.step)
141+
valid_kwargs = {}
142+
143+
has_kwargs = any(
144+
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
113145
)
114146

147+
for k, v in kwargs.items():
148+
if k in sig.parameters or has_kwargs:
149+
valid_kwargs[k] = v
150+
151+
# Execute step
152+
observation = self.env.step(action, **valid_kwargs)
153+
115154
# Return serialized observation
116-
return self._serialize_observation(observation)
155+
return StepResponse(**self._serialize_observation(observation))
117156

118-
@app.get("/state")
119-
async def get_state() -> Dict[str, Any]:
157+
@app.get("/state", response_model=State)
158+
async def get_state() -> State:
120159
"""State endpoint - returns current environment state."""
121-
state: State = self.env.state
122-
return state.model_dump()
160+
return self.env.state
123161

124162
@app.get("/health")
125163
async def health() -> Dict[str, str]:

src/core/env_server/interfaces.py

Lines changed: 128 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,118 +1,128 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
#
4-
# This source code is licensed under the BSD-style license found in the
5-
# LICENSE file in the root directory of this source tree.
6-
7-
from abc import ABC, abstractmethod
8-
from typing import Any, Protocol, TypedDict
9-
10-
from .types import Action, Observation, State
11-
12-
13-
class Message(TypedDict):
14-
"""A message in a conversation.
15-
16-
Compatible with Huggingface chat template format.
17-
"""
18-
19-
role: str
20-
content: str
21-
22-
23-
class ModelTokenizer(Protocol):
24-
"""Protocol for tokenizers that support chat templates.
25-
26-
This protocol defines the interface that tokenizers must implement
27-
to work with chat-based environments. It's compatible with
28-
Huggingface transformers tokenizers.
29-
"""
30-
31-
def apply_chat_template(
32-
self,
33-
conversation: list[Message],
34-
tokenize: bool = True,
35-
return_tensors: str | None = None,
36-
**kwargs: Any,
37-
) -> Any:
38-
"""Apply a chat template to format and optionally tokenize a conversation.
39-
40-
Args:
41-
conversation: List of message dictionaries with 'role' and 'content'
42-
tokenize: Whether to tokenize the output
43-
return_tensors: Format for returned tensors ('pt' for PyTorch)
44-
**kwargs: Additional arguments
45-
46-
Returns:
47-
Formatted and optionally tokenized conversation
48-
"""
49-
...
50-
51-
def decode(
52-
self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
53-
) -> str:
54-
"""Decode token IDs back to text.
55-
56-
Args:
57-
token_ids: Token IDs to decode
58-
skip_special_tokens: Whether to skip special tokens in output
59-
**kwargs: Additional arguments
60-
61-
Returns:
62-
Decoded text string
63-
"""
64-
...
65-
66-
67-
class Transform(ABC):
68-
"""Transform observations to add rewards, metrics, or other modifications.
69-
70-
Transforms follow the TorchRL pattern where they take an observation
71-
and return a (potentially modified) observation. This allows for
72-
flexible reward computation and observation augmentation.
73-
"""
74-
75-
@abstractmethod
76-
def __call__(self, observation: Observation) -> Observation:
77-
"""Transform an observation.
78-
79-
Args:
80-
observation: The input observation
81-
82-
Returns:
83-
The transformed observation
84-
"""
85-
pass
86-
87-
88-
class Environment(ABC):
89-
"""Base class for all environment servers following Gym/Gymnasium API.
90-
91-
Args:
92-
transform: Optional transform to apply to observations
93-
"""
94-
95-
def __init__(self, transform: Transform | None = None):
96-
self.transform = transform
97-
98-
@abstractmethod
99-
def reset(self) -> Observation:
100-
"""Reset the environment and return initial observation."""
101-
pass
102-
103-
@abstractmethod
104-
def step(self, action: Action) -> Observation:
105-
"""Take a step in the environment."""
106-
pass
107-
108-
@property
109-
@abstractmethod
110-
def state(self) -> State:
111-
"""Get the current environment state."""
112-
pass
113-
114-
def _apply_transform(self, observation: Observation) -> Observation:
115-
"""Apply transform if one is provided."""
116-
if self.transform is not None:
117-
return self.transform(observation)
118-
return observation
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from abc import ABC, abstractmethod
8+
from typing import Any, Optional, Protocol, TypedDict
9+
10+
from .types import Action, Observation, State
11+
12+
13+
class Message(TypedDict):
14+
"""A message in a conversation.
15+
16+
Compatible with Huggingface chat template format.
17+
"""
18+
19+
role: str
20+
content: str
21+
22+
23+
class ModelTokenizer(Protocol):
24+
"""Protocol for tokenizers that support chat templates.
25+
26+
This protocol defines the interface that tokenizers must implement
27+
to work with chat-based environments. It's compatible with
28+
Huggingface transformers tokenizers.
29+
"""
30+
31+
def apply_chat_template(
32+
self,
33+
conversation: list[Message],
34+
tokenize: bool = True,
35+
return_tensors: str | None = None,
36+
**kwargs: Any,
37+
) -> Any:
38+
"""Apply a chat template to format and optionally tokenize a conversation.
39+
40+
Args:
41+
conversation: List of message dictionaries with 'role' and 'content'
42+
tokenize: Whether to tokenize the output
43+
return_tensors: Format for returned tensors ('pt' for PyTorch)
44+
**kwargs: Additional arguments
45+
46+
Returns:
47+
Formatted and optionally tokenized conversation
48+
"""
49+
...
50+
51+
def decode(
52+
self, token_ids: Any, skip_special_tokens: bool = False, **kwargs: Any
53+
) -> str:
54+
"""Decode token IDs back to text.
55+
56+
Args:
57+
token_ids: Token IDs to decode
58+
skip_special_tokens: Whether to skip special tokens in output
59+
**kwargs: Additional arguments
60+
61+
Returns:
62+
Decoded text string
63+
"""
64+
...
65+
66+
67+
class Transform(ABC):
68+
"""Transform observations to add rewards, metrics, or other modifications.
69+
70+
Transforms follow the TorchRL pattern where they take an observation
71+
and return a (potentially modified) observation. This allows for
72+
flexible reward computation and observation augmentation.
73+
"""
74+
75+
@abstractmethod
76+
def __call__(self, observation: Observation) -> Observation:
77+
"""Transform an observation.
78+
79+
Args:
80+
observation: The input observation
81+
82+
Returns:
83+
The transformed observation
84+
"""
85+
pass
86+
87+
88+
class Environment(ABC):
89+
"""Base class for all environment servers following Gym/Gymnasium API.
90+
91+
Args:
92+
transform: Optional transform to apply to observations
93+
"""
94+
95+
def __init__(self, transform: Transform | None = None):
96+
self.transform = transform
97+
98+
@abstractmethod
99+
def reset(
100+
self,
101+
seed: Optional[int] = None,
102+
episode_id: Optional[str] = None,
103+
**kwargs: Any,
104+
) -> Observation:
105+
"""Reset the environment and return initial observation."""
106+
pass
107+
108+
@abstractmethod
109+
def step(
110+
self,
111+
action: Action,
112+
timeout_s: Optional[float] = None,
113+
**kwargs: Any,
114+
) -> Observation:
115+
"""Take a step in the environment."""
116+
pass
117+
118+
@property
119+
@abstractmethod
120+
def state(self) -> State:
121+
"""Get the current environment state."""
122+
pass
123+
124+
def _apply_transform(self, observation: Observation) -> Observation:
125+
"""Apply transform if one is provided."""
126+
if self.transform is not None:
127+
return self.transform(observation)
128+
return observation

0 commit comments

Comments
 (0)