Skip to content

Commit 04eb97b

Browse files
committed
feat: extra fields in reset and step request models for custom params
1 parent 82acaf2 commit 04eb97b

File tree

3 files changed

+250
-213
lines changed

3 files changed

+250
-213
lines changed

src/core/env_server/http_server.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,8 @@ async def reset(
9696
) -> ResetResponse:
9797
"""Reset endpoint - returns initial observation."""
9898
# Handle optional parameters
99-
kwargs = {}
100-
if request.seed is not None:
101-
kwargs["seed"] = request.seed
102-
if request.episode_id is not None:
103-
kwargs["episode_id"] = request.episode_id
99+
# Start with all fields from the request, including extra ones
100+
kwargs = request.model_dump(exclude_unset=True)
104101

105102
# Pass arguments only if environment accepts them
106103
sig = inspect.signature(self.env.reset)
@@ -132,9 +129,8 @@ async def step(request: StepRequest) -> StepResponse:
132129
)
133130

134131
# Handle optional parameters
135-
kwargs = {}
136-
if request.timeout_s is not None:
137-
kwargs["timeout_s"] = request.timeout_s
132+
# Start with all fields from the request, including extra ones, but exclude 'action'
133+
kwargs = request.model_dump(exclude_unset=True, exclude={'action'})
138134

139135
# Pass arguments only if environment accepts them
140136
sig = inspect.signature(self.env.step)

src/core/env_server/types.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class ResetRequest(BaseModel):
5656
"""Request model for environment reset."""
5757

5858
model_config = ConfigDict(
59-
extra="forbid",
59+
extra="allow", # Allow extra fields for custom reset parameters
6060
json_schema_extra={"examples": [{"seed": 42, "episode_id": "episode-001"}, {}]},
6161
)
6262

@@ -87,7 +87,15 @@ class ResetResponse(BaseModel):
8787
class StepRequest(BaseModel):
8888
"""Request model for environment step."""
8989

90-
model_config = ConfigDict(extra="forbid")
90+
model_config = ConfigDict(
91+
extra="allow", # Allow extra fields for custom step parameters
92+
json_schema_extra={
93+
"examples": [
94+
{"action": {"value": 1}, "timeout_s": 30.0},
95+
{"action": {"value": 1}, "render": True, "verbose": False},
96+
]
97+
},
98+
)
9199

92100
action: Dict[str, Any] = Field(
93101
...,

0 commit comments

Comments
 (0)