From ae7456f17c2fa656d75c9e8976f2b2d021bffa21 Mon Sep 17 00:00:00 2001 From: Victor Talpaert Date: Fri, 17 Apr 2026 17:44:42 +0200 Subject: [PATCH 1/2] add tests failing for road_traffic map type 2 --- tests/test_scenarios/test_road_traffic.py | 69 +++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/test_scenarios/test_road_traffic.py diff --git a/tests/test_scenarios/test_road_traffic.py b/tests/test_scenarios/test_road_traffic.py new file mode 100644 index 00000000..c50a87eb --- /dev/null +++ b/tests/test_scenarios/test_road_traffic.py @@ -0,0 +1,69 @@ +# Copyright (c) ProrokLab. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from vmas import make_env + + +class TestRoadTraffic: + def setup_env(self, n_envs, device="cpu", **kwargs) -> None: + self.env = make_env( + scenario="road_traffic", + num_envs=n_envs, + device=device, + continuous_actions=True, + **kwargs, + ) + self.env.seed(0) + + def _seed_buffer(self, device): + """Seed initial_state_buffer with a real state and force it to always be used.""" + scenario = self.env.scenario + buf = scenario.initial_state_buffer + buf.add(scenario.state_buffer.get_latest(n=1)[0]) + buf.probability_use_recording = torch.tensor(1.0, device=device) + + @pytest.mark.parametrize("map_type", ["1", "2"]) + def test_map_type_runs(self, map_type, n_envs=4, n_steps=10): + self.setup_env(n_envs=n_envs, map_type=map_type) + self.env.reset() + for _ in range(n_steps): + actions = [ + torch.zeros(n_envs, agent.action.action_size) + for agent in self.env.agents + ] + obs, rews, dones, _ = self.env.step(actions) + if dones.any(): + for env_index, done in enumerate(dones): + if done: + self.env.reset_at(env_index) + + def test_map_type_2_reset_uses_buffer(self, n_envs=4): + self.setup_env(n_envs=n_envs, map_type="2") + self.env.reset() + actions = [ + torch.zeros(n_envs, agent.action.action_size) + for agent in self.env.agents + ] + self.env.step(actions) + self._seed_buffer(device="cpu") + self.env.reset_at(0) + + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="GPU required to reproduce road_traffic map_type=2 device bugs", + ) + def test_gpu_map_type_2_rand_device(self, n_envs=4): + self.setup_env(n_envs=n_envs, device="cuda", map_type="2") + self.env.reset() + actions = [ + torch.zeros(n_envs, agent.action.action_size, device="cuda") + for agent in self.env.agents + ] + self.env.step(actions) + self._seed_buffer(device="cuda") + self.env.reset_at(0) From a07fa17135367e0f2cf5faaf1d72a596aec136aa Mon Sep 17 00:00:00 2001 From: Victor Talpaert Date: Fri, 17 Apr 2026 17:45:04 +0200 Subject: [PATCH 2/2] fix failing tests for road_traffic map type 2 --- vmas/scenarios/road_traffic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vmas/scenarios/road_traffic.py b/vmas/scenarios/road_traffic.py index 92a06d9c..d9b62067 100644 --- a/vmas/scenarios/road_traffic.py +++ b/vmas/scenarios/road_traffic.py @@ -948,7 +948,8 @@ def reset_world_at(self, env_index: int = None, agent_index: int = None): if ( (self.parameters.map_type == "2") and ( - torch.rand(1) < self.initial_state_buffer.probability_use_recording + torch.rand(1, device=self.world.device) + < self.initial_state_buffer.probability_use_recording ) and (self.initial_state_buffer.valid_size >= 1) ): @@ -1113,6 +1114,7 @@ def reset_init_state( agents[i_agent].set_pos(initial_state[i_agent, 0:2], batch_index=env_i) agents[i_agent].set_rot(initial_state[i_agent, 2], batch_index=env_i) agents[i_agent].set_vel(initial_state[i_agent, 3:5], batch_index=env_i) + return ref_path, path_id else: is_feasible_initial_position_found = False # Ramdomly generate initial states for each agent @@ -2300,7 +2302,7 @@ def done(self): is_collision_with_lanelets = self.collisions.with_lanelets.any(dim=-1) if self.parameters.map_type == "2": # Record into the initial state buffer - if torch.rand(1) > ( + if torch.rand(1, device=self.world.device) > ( 1 - self.initial_state_buffer.probability_record ): # Only a certain probability to record for env_collide in torch.where(is_collision_with_agents)[0]: