diff --git a/crazyflow/sim/data.py b/crazyflow/sim/data.py index c6b07b6..f778840 100644 --- a/crazyflow/sim/data.py +++ b/crazyflow/sim/data.py @@ -262,3 +262,5 @@ class SimData: """Drone parameters.""" core: SimCore """Core parameters of the simulation.""" + plugins: dict[str, Array] = field(default_factory=dict) + """Arbitrary data for plugins to store state in the simulation.""" diff --git a/crazyflow/sim/visualize.py b/crazyflow/sim/visualize.py index 00c5d15..a0a6ee0 100644 --- a/crazyflow/sim/visualize.py +++ b/crazyflow/sim/visualize.py @@ -57,6 +57,9 @@ def draw_points(sim: Sim, points: NDArray, rgba: NDArray | None = None, size: fl return if sim.max_visual_geom < points.shape[0]: raise RuntimeError("Attempted to draw too many points. Try to increase Sim.max_visual_geom") + points = np.atleast_2d(points) + assert points.ndim == 2, f"Expected array of [N, 3] points, got Array of shape {points.shape}" + assert points.shape[-1] == 3, f"Points must be 3D, are {points.shape[-1]}" viewer = sim.viewer.viewer if rgba is None: rgba = np.array([1.0, 0, 0, 1]) diff --git a/examples/plugins.py b/examples/plugins.py new file mode 100644 index 0000000..5056d51 --- /dev/null +++ b/examples/plugins.py @@ -0,0 +1,83 @@ +"""Example of how to extend the simulation data with custom plugins. + +Here, we implement an action delay of 0.03s in the attitude control loop. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import jax.numpy as jnp +import numpy as np + +from crazyflow import Sim +from crazyflow.sim.visualize import draw_points + +if TYPE_CHECKING: + from crazyflow.sim.data import SimData + + +def control(t: float) -> np.ndarray: + cmd = np.zeros((1, 1, 13)) + cmd[..., :3] = [np.cos(t) - 1, np.sin(t), 0.2 * t] + return cmd + + +def action_delay(data: SimData) -> SimData: + """Delay state control actions.""" + queued_actions = data.plugins["queued_actions"] + next_action = queued_actions[0] + queued_actions = jnp.roll(queued_actions, shift=-1, axis=0) + queued_actions = queued_actions.at[-1].set(data.controls.attitude.staged_cmd) + data = data.replace( + controls=data.controls.replace( + attitude=data.controls.attitude.replace(staged_cmd=next_action) + ), + plugins=data.plugins | {"queued_actions": queued_actions}, + ) + return data + + +def main(): + sim = Sim(control="state") + states = [] + steps = 500 + for i in range(steps): + cmd = control(i / sim.control_freq) + sim.state_control(cmd) + sim.step(sim.freq // sim.control_freq) + sim.render(camera="track_cam:0") + states.append(sim.data.states.pos[0, 0]) + + # Delay settings + delay: float = 0.03 # seconds + delay_steps = int(delay * sim.data.controls.attitude.freq) + sim.reset() + + # Now we add our action delay into the simulation. We first insert the data we need into the + # plugins dict, then add our plugin function into the step pipeline, and finally rebuild the sim + # default data and step function to make sure our plugin is included and data persists across + # resets. + custom_data = {"queued_actions": jnp.zeros((delay_steps, 1, 1, 4))} + sim.data = sim.data.replace(plugins=sim.data.plugins | custom_data) + sim.step_pipeline = (action_delay,) + sim.step_pipeline + sim.build_default_data() + sim.build_step_fn() + + # Run the simulation again, this time with an action delay. The states should be different + delayed_states = [] + for i in range(steps): + cmd = control(i / sim.control_freq) + sim.state_control(cmd) + sim.step(sim.freq // sim.control_freq) + draw_points(sim, states[i], size=0.02) + sim.render() + delayed_states.append(sim.data.states.pos[0, 0]) + + position_differences = jnp.linalg.norm(jnp.array(states) - jnp.array(delayed_states), axis=-1) + print(f"Mean position difference: {jnp.mean(position_differences) * 100:.2f}cm") + sim.close() + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_plugins.py b/tests/unit/test_plugins.py new file mode 100644 index 0000000..ec38429 --- /dev/null +++ b/tests/unit/test_plugins.py @@ -0,0 +1,124 @@ +"""Unit tests for the simulation plugin system. + +Plugins are callables of the form ``fn(data: SimData) -> SimData`` that are inserted into +``sim.step_pipeline``. They can store arbitrary state in ``sim.data.plugins`` (a dict of JAX +arrays). Tests use a simple per-world step counter as a canonical plugin. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import jax.numpy as jnp +import pytest + +from crazyflow.sim import Sim + +if TYPE_CHECKING: + from crazyflow.sim.data import SimData + + +def counter_plugin(data: SimData) -> SimData: + """Increment plugins["counter"] by 1 on every simulation step.""" + return data.replace(plugins=data.plugins | {"counter": data.plugins["counter"] + 1}) + + +def accumulater_plugin(data: SimData) -> SimData: + """Accumulate plugins["counter"] on every simulation step.""" + return data.replace( + plugins=data.plugins + | {"accumulated": data.plugins["accumulated"] + data.plugins["counter"]} + ) + + +@pytest.mark.unit +def test_empty_by_default(): + """SimData.plugins is an empty dict when no plugins are registered.""" + sim = Sim() + assert sim.data.plugins == {}, f"Expected empty plugins dict, got {sim.data.plugins}" + sim.close() + + +@pytest.mark.unit +def test_builds(): + """Building patterns should not break with plugin data.""" + sim = Sim() + sim.data = sim.data.replace(plugins={"sentinel": jnp.array([42])}) + sim.build_default_data() + assert "sentinel" in sim.default_data.plugins, "Plugin data should be in default_data" + sim.build_reset_fn() + sim.build_step_fn() + + +@pytest.mark.unit +@pytest.mark.parametrize("n_worlds", [1, 3]) +def test_plugin_data_changes(n_worlds: int): + """Plugin data changes across simulation steps.""" + sim = Sim(n_worlds=n_worlds) + sim.data = sim.data.replace(plugins={"counter": jnp.zeros((n_worlds, 1), dtype=jnp.int32)}) + sim.step_pipeline = sim.step_pipeline + (counter_plugin,) + sim.build_step_fn() + n_steps = 7 + sim.step(n_steps) + assert jnp.all(sim.data.plugins["counter"] == n_steps), ( + f"Expected counter={n_steps}, got {sim.data.plugins['counter']}" + ) + sim.close() + + +@pytest.mark.unit +def test_plugin_resets(): + """sim.reset() restores the counter to its default value (0).""" + sim = Sim() + sim.data = sim.data.replace(plugins={"counter": jnp.zeros((1, 1), dtype=jnp.int32)}) + sim.step_pipeline = sim.step_pipeline + (counter_plugin,) + sim.build_default_data() + sim.build_step_fn() + sim.step(10) + assert jnp.all(sim.data.plugins["counter"] == 10), "Precondition: counter should be 10" + sim.reset() + assert jnp.all(sim.data.plugins["counter"] == 0), "Counter should be 0 after full reset" + sim.close() + + +@pytest.mark.unit +def test_plugin_masked_reset(): + """Masked reset only resets the plugin data for selected worlds.""" + n_worlds = 2 + sim = Sim(n_worlds=n_worlds) + sim.data = sim.data.replace(plugins={"counter": jnp.zeros((n_worlds, 1), dtype=jnp.int32)}) + sim.step_pipeline = sim.step_pipeline + (counter_plugin,) + sim.build_default_data() + sim.build_step_fn() + sim.step(10) + assert jnp.all(sim.data.plugins["counter"] == 10), "Precondition: both counters should be 10" + + mask = jnp.array([True, False]) # reset world 0 only + sim.reset(mask) + assert jnp.all(sim.data.plugins["counter"][0] == 0), "World 0 counter must be reset to 0" + assert jnp.all(sim.data.plugins["counter"][1] == 10), "World 1 counter must remain at 10" + sim.close() + + +@pytest.mark.unit +def test_chained_plugins(): + """Two chained plugins produce different results depending on their order in the pipeline.""" + sim = Sim() + sim.data = sim.data.replace( + plugins={ + "counter": jnp.zeros((1, 1), dtype=jnp.int32), + "accumulated": jnp.zeros((1, 1), dtype=jnp.int32), + } + ) + sim.step_pipeline = sim.step_pipeline + (counter_plugin, accumulater_plugin) + sim.build_default_data() + sim.build_step_fn() + n_steps = 3 + sim.step(n_steps) + assert int(sim.data.plugins["counter"][0, 0]) == n_steps, ( + f"Expected counter={n_steps}, got {sim.data.plugins['counter']}" + ) + assert int(sim.data.plugins["accumulated"][0, 0]) == sum(range(1, n_steps + 1)), ( + f"Expected accumulated={sum(range(1, n_steps + 1))}, got {sim.data.plugins['accumulated']}" + ) + sim.close() diff --git a/tests/unit/test_sim.py b/tests/unit/test_sim.py index a965e07..a548ddc 100644 --- a/tests/unit/test_sim.py +++ b/tests/unit/test_sim.py @@ -424,9 +424,9 @@ def assert_committed(obj0: Array | Any, path: str = "data"): assert_committed(item0, f"{path}[{i}]") elif isinstance(obj0, type(sim.data.core.device)): # Device objects pass # Devices themselves don't have committed attribute - elif isinstance(obj0, dict): # Handle dictionaries + elif isinstance(obj0, dict): for key, value0 in obj0.items(): - assert_committed(value0, f"{path}[{repr(key)}]") + assert_committed(value0, f"{path}[{key}]") else: raise TypeError(f"Could not handle type {type(obj0)} at {path}")