Skip to content

Commit ae6d12f

Browse files
authored
Merge pull request #211 from kashif/browsergym
[browsergym] fixes
2 parents 9e4ff2d + d9dbbc8 commit ae6d12f

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

src/core/env_server/http_server.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
from __future__ import annotations
1515

16+
import asyncio
1617
import os
18+
from concurrent.futures import ThreadPoolExecutor
1719
from dataclasses import asdict
1820
from typing import Any, Dict, Type
1921

@@ -62,6 +64,9 @@ def __init__(
6264
self.env = env
6365
self.action_cls = action_cls
6466
self.observation_cls = observation_cls
67+
# Create thread pool for running sync code in async context
68+
# This is needed for environments using sync libraries (e.g., Playwright sync API)
69+
self._executor = ThreadPoolExecutor(max_workers=1)
6570

6671
def register_routes(self, app: Any) -> None:
6772
"""
@@ -78,20 +83,26 @@ def register_routes(self, app: Any) -> None:
7883
async def reset(request: Dict[str, Any] = Body(default={})) -> Dict[str, Any]:
7984
"""Reset endpoint - returns initial observation."""
8085
# TODO: Handle seed, episode_id from request if provided
81-
observation = self.env.reset()
86+
# Run sync environment code in thread pool to avoid blocking asyncio loop
87+
loop = asyncio.get_event_loop()
88+
observation = await loop.run_in_executor(self._executor, self.env.reset)
8289
return self._serialize_observation(observation)
8390

8491
@app.post("/step")
8592
async def step(request: Dict[str, Any]) -> Dict[str, Any]:
8693
"""Step endpoint - executes action and returns observation."""
87-
action_data = request.get("action", {})
94+
# Support both {"action": {...}} and direct action fields
95+
action_data = request.get("action", request)
8896
# TODO: Handle timeout_s, request_id, episode_id from request if provided
8997

9098
# Deserialize action
9199
action = self._deserialize_action(action_data)
92100

93-
# Execute step
94-
observation = self.env.step(action)
101+
# Execute step in thread pool to avoid blocking asyncio loop
102+
loop = asyncio.get_event_loop()
103+
observation = await loop.run_in_executor(
104+
self._executor, self.env.step, action
105+
)
95106

96107
# Return serialized observation
97108
return self._serialize_observation(observation)
@@ -147,6 +158,19 @@ def _serialize_observation(self, observation: Observation) -> Dict[str, Any]:
147158
"""
148159
obs_dict = asdict(observation)
149160

161+
# Convert numpy arrays to lists for JSON serialization
162+
def _convert_numpy(obj):
163+
"""Recursively convert numpy arrays to lists."""
164+
if hasattr(obj, '__array__'): # numpy array
165+
return obj.tolist()
166+
elif isinstance(obj, dict):
167+
return {k: _convert_numpy(v) for k, v in obj.items()}
168+
elif isinstance(obj, (list, tuple)):
169+
return type(obj)(_convert_numpy(item) for item in obj)
170+
return obj
171+
172+
obs_dict = _convert_numpy(obs_dict)
173+
150174
# Extract reward and done (these are part of StepResult on client side)
151175
reward = obs_dict.pop("reward", None)
152176
done = obs_dict.pop("done", False)

src/envs/browsergym_env/server/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ browsergym-webarena>=0.2.0
55
gymnasium>=0.29.0
66
playwright>=1.40.0
77
Pillow>=10.0.0
8+
fastapi>=0.104.0
9+
uvicorn>=0.24.0

src/envs/browsergym_env/server/start.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ cleanup() {
2525

2626
trap cleanup EXIT INT TERM
2727

28-
exec uvicorn envs.browsergym_env.server.app:app --host 0.0.0.0 --port "${BROWSERGYM_PORT}"
28+
exec python -m uvicorn envs.browsergym_env.server.app:app --host 0.0.0.0 --port "${BROWSERGYM_PORT}"
2929

0 commit comments

Comments
 (0)