Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/test_scenarios/test_road_traffic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) ProrokLab.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from vmas import make_env


class TestRoadTraffic:
def setup_env(self, n_envs, device="cpu", **kwargs) -> None:
self.env = make_env(
scenario="road_traffic",
num_envs=n_envs,
device=device,
continuous_actions=True,
**kwargs,
)
self.env.seed(0)

def _seed_buffer(self, device):
"""Seed initial_state_buffer with a real state and force it to always be used."""
scenario = self.env.scenario
buf = scenario.initial_state_buffer
buf.add(scenario.state_buffer.get_latest(n=1)[0])
buf.probability_use_recording = torch.tensor(1.0, device=device)

@pytest.mark.parametrize("map_type", ["1", "2"])
def test_map_type_runs(self, map_type, n_envs=4, n_steps=10):
self.setup_env(n_envs=n_envs, map_type=map_type)
self.env.reset()
for _ in range(n_steps):
actions = [
torch.zeros(n_envs, agent.action.action_size)
for agent in self.env.agents
]
obs, rews, dones, _ = self.env.step(actions)
if dones.any():
for env_index, done in enumerate(dones):
if done:
self.env.reset_at(env_index)

def test_map_type_2_reset_uses_buffer(self, n_envs=4):
self.setup_env(n_envs=n_envs, map_type="2")
self.env.reset()
actions = [
torch.zeros(n_envs, agent.action.action_size)
for agent in self.env.agents
]
self.env.step(actions)
self._seed_buffer(device="cpu")
self.env.reset_at(0)

@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="GPU required to reproduce road_traffic map_type=2 device bugs",
)
def test_gpu_map_type_2_rand_device(self, n_envs=4):
self.setup_env(n_envs=n_envs, device="cuda", map_type="2")
self.env.reset()
actions = [
torch.zeros(n_envs, agent.action.action_size, device="cuda")
for agent in self.env.agents
]
self.env.step(actions)
self._seed_buffer(device="cuda")
self.env.reset_at(0)
6 changes: 4 additions & 2 deletions vmas/scenarios/road_traffic.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,8 @@ def reset_world_at(self, env_index: int = None, agent_index: int = None):
if (
(self.parameters.map_type == "2")
and (
torch.rand(1) < self.initial_state_buffer.probability_use_recording
torch.rand(1, device=self.world.device)
< self.initial_state_buffer.probability_use_recording
)
and (self.initial_state_buffer.valid_size >= 1)
):
Expand Down Expand Up @@ -1113,6 +1114,7 @@ def reset_init_state(
agents[i_agent].set_pos(initial_state[i_agent, 0:2], batch_index=env_i)
agents[i_agent].set_rot(initial_state[i_agent, 2], batch_index=env_i)
agents[i_agent].set_vel(initial_state[i_agent, 3:5], batch_index=env_i)
return ref_path, path_id
else:
is_feasible_initial_position_found = False
# Ramdomly generate initial states for each agent
Expand Down Expand Up @@ -2300,7 +2302,7 @@ def done(self):
is_collision_with_lanelets = self.collisions.with_lanelets.any(dim=-1)

if self.parameters.map_type == "2": # Record into the initial state buffer
if torch.rand(1) > (
if torch.rand(1, device=self.world.device) > (
1 - self.initial_state_buffer.probability_record
): # Only a certain probability to record
for env_collide in torch.where(is_collision_with_agents)[0]:
Expand Down