Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crazyflow/sim/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,5 @@ class SimData:
"""Drone parameters."""
core: SimCore
"""Core parameters of the simulation."""
plugins: dict[str, Array] = field(default_factory=dict)
"""Arbitrary data for plugins to store state in the simulation."""
3 changes: 3 additions & 0 deletions crazyflow/sim/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def draw_points(sim: Sim, points: NDArray, rgba: NDArray | None = None, size: fl
return
if sim.max_visual_geom < points.shape[0]:
raise RuntimeError("Attempted to draw too many points. Try to increase Sim.max_visual_geom")
points = np.atleast_2d(points)
assert points.ndim == 2, f"Expected array of [N, 3] points, got Array of shape {points.shape}"
assert points.shape[-1] == 3, f"Points must be 3D, are {points.shape[-1]}"
viewer = sim.viewer.viewer
if rgba is None:
rgba = np.array([1.0, 0, 0, 1])
Expand Down
83 changes: 83 additions & 0 deletions examples/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Example of how to extend the simulation data with custom plugins.

Here, we implement an action delay of 0.03s in the attitude control loop.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import jax.numpy as jnp
import numpy as np

from crazyflow import Sim
from crazyflow.sim.visualize import draw_points

if TYPE_CHECKING:
from crazyflow.sim.data import SimData


def control(t: float) -> np.ndarray:
cmd = np.zeros((1, 1, 13))
cmd[..., :3] = [np.cos(t) - 1, np.sin(t), 0.2 * t]
return cmd


def action_delay(data: SimData) -> SimData:
"""Delay state control actions."""
queued_actions = data.plugins["queued_actions"]
next_action = queued_actions[0]
queued_actions = jnp.roll(queued_actions, shift=-1, axis=0)
queued_actions = queued_actions.at[-1].set(data.controls.attitude.staged_cmd)
data = data.replace(
controls=data.controls.replace(
attitude=data.controls.attitude.replace(staged_cmd=next_action)
),
plugins=data.plugins | {"queued_actions": queued_actions},
)
return data


def main():
sim = Sim(control="state")
states = []
steps = 500
for i in range(steps):
cmd = control(i / sim.control_freq)
sim.state_control(cmd)
sim.step(sim.freq // sim.control_freq)
sim.render(camera="track_cam:0")
states.append(sim.data.states.pos[0, 0])

# Delay settings
delay: float = 0.03 # seconds
delay_steps = int(delay * sim.data.controls.attitude.freq)
sim.reset()

# Now we add our action delay into the simulation. We first insert the data we need into the
# plugins dict, then add our plugin function into the step pipeline, and finally rebuild the sim
# default data and step function to make sure our plugin is included and data persists across
# resets.
custom_data = {"queued_actions": jnp.zeros((delay_steps, 1, 1, 4))}
sim.data = sim.data.replace(plugins=sim.data.plugins | custom_data)
sim.step_pipeline = (action_delay,) + sim.step_pipeline
sim.build_default_data()
sim.build_step_fn()

# Run the simulation again, this time with an action delay. The states should be different
delayed_states = []
for i in range(steps):
cmd = control(i / sim.control_freq)
sim.state_control(cmd)
sim.step(sim.freq // sim.control_freq)
draw_points(sim, states[i], size=0.02)
sim.render()
delayed_states.append(sim.data.states.pos[0, 0])

position_differences = jnp.linalg.norm(jnp.array(states) - jnp.array(delayed_states), axis=-1)
print(f"Mean position difference: {jnp.mean(position_differences) * 100:.2f}cm")
sim.close()


if __name__ == "__main__":
main()
124 changes: 124 additions & 0 deletions tests/unit/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Unit tests for the simulation plugin system.

Plugins are callables of the form ``fn(data: SimData) -> SimData`` that are inserted into
``sim.step_pipeline``. They can store arbitrary state in ``sim.data.plugins`` (a dict of JAX
arrays). Tests use a simple per-world step counter as a canonical plugin.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import jax.numpy as jnp
import pytest

from crazyflow.sim import Sim

if TYPE_CHECKING:
from crazyflow.sim.data import SimData


def counter_plugin(data: SimData) -> SimData:
"""Increment plugins["counter"] by 1 on every simulation step."""
return data.replace(plugins=data.plugins | {"counter": data.plugins["counter"] + 1})


def accumulater_plugin(data: SimData) -> SimData:
"""Accumulate plugins["counter"] on every simulation step."""
return data.replace(
plugins=data.plugins
| {"accumulated": data.plugins["accumulated"] + data.plugins["counter"]}
)


@pytest.mark.unit
def test_empty_by_default():
"""SimData.plugins is an empty dict when no plugins are registered."""
sim = Sim()
assert sim.data.plugins == {}, f"Expected empty plugins dict, got {sim.data.plugins}"
sim.close()


@pytest.mark.unit
def test_builds():
"""Building patterns should not break with plugin data."""
sim = Sim()
sim.data = sim.data.replace(plugins={"sentinel": jnp.array([42])})
sim.build_default_data()
assert "sentinel" in sim.default_data.plugins, "Plugin data should be in default_data"
sim.build_reset_fn()
sim.build_step_fn()


@pytest.mark.unit
@pytest.mark.parametrize("n_worlds", [1, 3])
def test_plugin_data_changes(n_worlds: int):
"""Plugin data changes across simulation steps."""
sim = Sim(n_worlds=n_worlds)
sim.data = sim.data.replace(plugins={"counter": jnp.zeros((n_worlds, 1), dtype=jnp.int32)})
sim.step_pipeline = sim.step_pipeline + (counter_plugin,)
sim.build_step_fn()
n_steps = 7
sim.step(n_steps)
assert jnp.all(sim.data.plugins["counter"] == n_steps), (
f"Expected counter={n_steps}, got {sim.data.plugins['counter']}"
)
sim.close()


@pytest.mark.unit
def test_plugin_resets():
"""sim.reset() restores the counter to its default value (0)."""
sim = Sim()
sim.data = sim.data.replace(plugins={"counter": jnp.zeros((1, 1), dtype=jnp.int32)})
sim.step_pipeline = sim.step_pipeline + (counter_plugin,)
sim.build_default_data()
sim.build_step_fn()
sim.step(10)
assert jnp.all(sim.data.plugins["counter"] == 10), "Precondition: counter should be 10"
sim.reset()
assert jnp.all(sim.data.plugins["counter"] == 0), "Counter should be 0 after full reset"
sim.close()


@pytest.mark.unit
def test_plugin_masked_reset():
"""Masked reset only resets the plugin data for selected worlds."""
n_worlds = 2
sim = Sim(n_worlds=n_worlds)
sim.data = sim.data.replace(plugins={"counter": jnp.zeros((n_worlds, 1), dtype=jnp.int32)})
sim.step_pipeline = sim.step_pipeline + (counter_plugin,)
sim.build_default_data()
sim.build_step_fn()
sim.step(10)
assert jnp.all(sim.data.plugins["counter"] == 10), "Precondition: both counters should be 10"

mask = jnp.array([True, False]) # reset world 0 only
sim.reset(mask)
assert jnp.all(sim.data.plugins["counter"][0] == 0), "World 0 counter must be reset to 0"
assert jnp.all(sim.data.plugins["counter"][1] == 10), "World 1 counter must remain at 10"
sim.close()


@pytest.mark.unit
def test_chained_plugins():
"""Two chained plugins produce different results depending on their order in the pipeline."""
sim = Sim()
sim.data = sim.data.replace(
plugins={
"counter": jnp.zeros((1, 1), dtype=jnp.int32),
"accumulated": jnp.zeros((1, 1), dtype=jnp.int32),
}
)
sim.step_pipeline = sim.step_pipeline + (counter_plugin, accumulater_plugin)
sim.build_default_data()
sim.build_step_fn()
n_steps = 3
sim.step(n_steps)
assert int(sim.data.plugins["counter"][0, 0]) == n_steps, (
f"Expected counter={n_steps}, got {sim.data.plugins['counter']}"
)
assert int(sim.data.plugins["accumulated"][0, 0]) == sum(range(1, n_steps + 1)), (
f"Expected accumulated={sum(range(1, n_steps + 1))}, got {sim.data.plugins['accumulated']}"
)
sim.close()
4 changes: 2 additions & 2 deletions tests/unit/test_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ def assert_committed(obj0: Array | Any, path: str = "data"):
assert_committed(item0, f"{path}[{i}]")
elif isinstance(obj0, type(sim.data.core.device)): # Device objects
pass # Devices themselves don't have committed attribute
elif isinstance(obj0, dict): # Handle dictionaries
elif isinstance(obj0, dict):
for key, value0 in obj0.items():
assert_committed(value0, f"{path}[{repr(key)}]")
assert_committed(value0, f"{path}[{key}]")
else:
raise TypeError(f"Could not handle type {type(obj0)} at {path}")

Expand Down
Loading