1313
1414from __future__ import annotations
1515
16+ import asyncio
1617import os
18+ from concurrent .futures import ThreadPoolExecutor
1719from dataclasses import asdict
1820from typing import Any , Dict , Type
1921
@@ -62,6 +64,9 @@ def __init__(
6264 self .env = env
6365 self .action_cls = action_cls
6466 self .observation_cls = observation_cls
67+ # Create thread pool for running sync code in async context
68+ # This is needed for environments using sync libraries (e.g., Playwright sync API)
69+ self ._executor = ThreadPoolExecutor (max_workers = 1 )
6570
6671 def register_routes (self , app : Any ) -> None :
6772 """
@@ -78,20 +83,26 @@ def register_routes(self, app: Any) -> None:
7883 async def reset (request : Dict [str , Any ] = Body (default = {})) -> Dict [str , Any ]:
7984 """Reset endpoint - returns initial observation."""
8085 # TODO: Handle seed, episode_id from request if provided
81- observation = self .env .reset ()
86+ # Run sync environment code in thread pool to avoid blocking asyncio loop
87+ loop = asyncio .get_event_loop ()
88+ observation = await loop .run_in_executor (self ._executor , self .env .reset )
8289 return self ._serialize_observation (observation )
8390
8491 @app .post ("/step" )
8592 async def step (request : Dict [str , Any ]) -> Dict [str , Any ]:
8693 """Step endpoint - executes action and returns observation."""
87- action_data = request .get ("action" , {})
94+ # Support both {"action": {...}} and direct action fields
95+ action_data = request .get ("action" , request )
8896 # TODO: Handle timeout_s, request_id, episode_id from request if provided
8997
9098 # Deserialize action
9199 action = self ._deserialize_action (action_data )
92100
93- # Execute step
94- observation = self .env .step (action )
101+ # Execute step in thread pool to avoid blocking asyncio loop
102+ loop = asyncio .get_event_loop ()
103+ observation = await loop .run_in_executor (
104+ self ._executor , self .env .step , action
105+ )
95106
96107 # Return serialized observation
97108 return self ._serialize_observation (observation )
@@ -147,6 +158,19 @@ def _serialize_observation(self, observation: Observation) -> Dict[str, Any]:
147158 """
148159 obs_dict = asdict (observation )
149160
161+ # Convert numpy arrays to lists for JSON serialization
162+ def _convert_numpy (obj ):
163+ """Recursively convert numpy arrays to lists."""
164+ if hasattr (obj , '__array__' ): # numpy array
165+ return obj .tolist ()
166+ elif isinstance (obj , dict ):
167+ return {k : _convert_numpy (v ) for k , v in obj .items ()}
168+ elif isinstance (obj , (list , tuple )):
169+ return type (obj )(_convert_numpy (item ) for item in obj )
170+ return obj
171+
172+ obs_dict = _convert_numpy (obs_dict )
173+
150174 # Extract reward and done (these are part of StepResult on client side)
151175 reward = obs_dict .pop ("reward" , None )
152176 done = obs_dict .pop ("done" , False )
0 commit comments