Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions src/tether/safety/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class SafetyLimits:
effort_max: list[float] = field(default_factory=list)
workspace_min: list[float] = field(default_factory=lambda: [-1.0, -1.0, 0.0])
workspace_max: list[float] = field(default_factory=lambda: [1.0, 1.0, 1.5])
workspace_indices: list[int] = field(default_factory=list)

@classmethod
def from_urdf(cls, urdf_path: str | Path) -> SafetyLimits:
Expand Down Expand Up @@ -223,8 +224,22 @@ def from_embodiment_config(cls, cfg: Any, **kwargs) -> ActionGuard:
limits = SafetyLimits.from_embodiment_config(cfg)
return cls(limits=limits, **kwargs)

def check_single(self, action: np.ndarray) -> SafetyCheckResult:
"""Check a single action vector against safety limits."""
@staticmethod
def _clamp_value(value: float, lower: float, upper: float) -> float:
return float(min(max(value, lower), upper))

def check_single(
self,
action: np.ndarray,
*,
previous_action: np.ndarray | None = None,
) -> SafetyCheckResult:
"""Check a single action vector against safety limits.

Position, effort, and explicit workspace bounds are single-action
checks. Velocity is a chunk-level delta check and only runs when the
caller provides ``previous_action``.
"""
start = time.perf_counter()
violations = []
clamped = False
Expand All @@ -233,17 +248,81 @@ def check_single(self, action: np.ndarray) -> SafetyCheckResult:

for i in range(num_joints):
# Position bounds
if action[i] < self.limits.position_min[i]:
violations.append(f"joint_{i} below min: {action[i]:.3f} < {self.limits.position_min[i]:.3f}")
if safe_action[i] < self.limits.position_min[i]:
violations.append(
f"joint_{i} below min: "
f"{safe_action[i]:.3f} < {self.limits.position_min[i]:.3f}"
)
if self.mode == "clamp":
safe_action[i] = self.limits.position_min[i]
clamped = True
elif action[i] > self.limits.position_max[i]:
violations.append(f"joint_{i} above max: {action[i]:.3f} > {self.limits.position_max[i]:.3f}")
elif safe_action[i] > self.limits.position_max[i]:
violations.append(
f"joint_{i} above max: "
f"{safe_action[i]:.3f} > {self.limits.position_max[i]:.3f}"
)
if self.mode == "clamp":
safe_action[i] = self.limits.position_max[i]
clamped = True

if i < len(self.limits.effort_max):
effort_limit = self.limits.effort_max[i]
if effort_limit > 0 and abs(float(safe_action[i])) > effort_limit:
violations.append(
f"joint_{i} effort limit: "
f"|{safe_action[i]:.3f}| > {effort_limit:.3f}"
)
if self.mode == "clamp":
safe_action[i] = self._clamp_value(
float(safe_action[i]), -effort_limit, effort_limit
)
clamped = True

if (
previous_action is not None
and not any("velocity limit" not in v for v in violations)
and i < len(self.limits.velocity_max)
):
velocity_limit = self.limits.velocity_max[i]
delta = float(safe_action[i] - previous_action[i])
if velocity_limit > 0 and abs(delta) > velocity_limit:
violations.append(
f"joint_{i} velocity limit: "
f"|delta {delta:.3f}| > {velocity_limit:.3f}"
)
if self.mode == "clamp":
safe_action[i] = float(previous_action[i]) + self._clamp_value(
delta, -velocity_limit, velocity_limit
)
clamped = True

for workspace_axis, action_idx in enumerate(self.limits.workspace_indices):
if action_idx < 0 or action_idx >= len(safe_action):
continue
if (
workspace_axis >= len(self.limits.workspace_min)
or workspace_axis >= len(self.limits.workspace_max)
):
continue
lower = self.limits.workspace_min[workspace_axis]
upper = self.limits.workspace_max[workspace_axis]
if safe_action[action_idx] < lower:
violations.append(
f"workspace_axis_{workspace_axis} below min: "
f"action[{action_idx}]={safe_action[action_idx]:.3f} < {lower:.3f}"
)
if self.mode == "clamp":
safe_action[action_idx] = lower
clamped = True
elif safe_action[action_idx] > upper:
violations.append(
f"workspace_axis_{workspace_axis} above max: "
f"action[{action_idx}]={safe_action[action_idx]:.3f} > {upper:.3f}"
)
if self.mode == "clamp":
safe_action[action_idx] = upper
clamped = True

if self.mode == "reject" and violations:
safe_action = np.zeros_like(action)

Expand Down Expand Up @@ -296,10 +375,21 @@ def check(self, actions: np.ndarray) -> tuple[np.ndarray, list[SafetyCheckResult
chunk_clamped = True
else:
safe_actions = actions.copy()
previous_safe: np.ndarray | None = None
for i in range(len(actions)):
result = self.check_single(actions[i])
result = self.check_single(
actions[i],
previous_action=previous_safe,
)
results.append(result)
safe_actions[i] = np.array(result.safe_action)
if result.safe or (
result.violations
and all("velocity limit" in v for v in result.violations)
):
previous_safe = safe_actions[i]
else:
previous_safe = None
all_violations = [v for r in results for v in r.violations]
chunk_clamped = any(r.clamped for r in results)

Expand Down
87 changes: 83 additions & 4 deletions tests/test_guard.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
"""Tests for safety guardrails."""

import json
import tempfile
from pathlib import Path

import numpy as np
import pytest

from tether.safety.guard import ActionGuard, SafetyLimits, SafetyCheckResult
from tether.safety.guard import ActionGuard, SafetyLimits


class TestSafetyLimits:
Expand All @@ -32,9 +29,11 @@ def test_custom_limits(self):
position_max=[1.0, 2.0],
velocity_max=[1.5, 1.5],
effort_max=[50.0, 50.0],
workspace_indices=[0, 1],
)
assert limits.position_min[0] == -1.0
assert limits.position_max[1] == 2.0
assert limits.workspace_indices == [0, 1]


class TestActionGuard:
Expand Down Expand Up @@ -76,6 +75,86 @@ def test_reject_mode_zeros(self):
assert not result.safe
assert result.safe_action[0] == 0.0

def test_effort_limit_clamps_single_action(self):
limits = SafetyLimits(
joint_names=["j1", "j2"],
position_min=[-10.0, -10.0],
position_max=[10.0, 10.0],
velocity_max=[10.0, 10.0],
effort_max=[2.0, 4.0],
)
guard = ActionGuard(limits=limits, mode="clamp")

result = guard.check_single(np.array([3.0, -6.0]))

assert not result.safe
assert result.clamped
assert result.safe_action == [2.0, -4.0]
assert any("effort limit" in v for v in result.violations)

def test_velocity_limit_clamps_between_chunk_actions(self):
limits = SafetyLimits(
joint_names=["j1", "j2"],
position_min=[-10.0, -10.0],
position_max=[10.0, 10.0],
velocity_max=[1.0, 0.5],
effort_max=[50.0, 50.0],
)
guard = ActionGuard(limits=limits, mode="clamp")
actions = np.array([
[0.0, 0.0],
[5.0, -5.0],
[2.0, -1.0],
])

safe_actions, results = guard.check(actions)

np.testing.assert_allclose(safe_actions[0], [0.0, 0.0])
np.testing.assert_allclose(safe_actions[1], [1.0, -0.5])
np.testing.assert_allclose(safe_actions[2], [2.0, -1.0])
assert any("velocity limit" in v for v in results[1].violations)
assert results[1].clamped
assert results[2].safe

def test_workspace_limit_clamps_explicit_indices(self):
limits = SafetyLimits(
joint_names=["x", "unused", "z"],
position_min=[-10.0, -10.0, -10.0],
position_max=[10.0, 10.0, 10.0],
velocity_max=[10.0, 10.0, 10.0],
effort_max=[50.0, 50.0, 50.0],
workspace_min=[-1.0, 0.0],
workspace_max=[1.0, 2.0],
workspace_indices=[0, 2],
)
guard = ActionGuard(limits=limits, mode="clamp")

result = guard.check_single(np.array([3.0, 9.0, -5.0]))

assert not result.safe
assert result.clamped
assert result.safe_action == [1.0, 9.0, 0.0]
assert any("workspace_axis_0 above max" in v for v in result.violations)
assert any("workspace_axis_1 below min" in v for v in result.violations)

def test_workspace_limits_are_opt_in_for_joint_actions(self):
limits = SafetyLimits(
joint_names=["j1", "j2", "j3"],
position_min=[-10.0, -10.0, -10.0],
position_max=[10.0, 10.0, 10.0],
velocity_max=[10.0, 10.0, 10.0],
effort_max=[50.0, 50.0, 50.0],
workspace_min=[-1.0, -1.0, 0.0],
workspace_max=[1.0, 1.0, 1.5],
workspace_indices=[],
)
guard = ActionGuard(limits=limits, mode="clamp")

result = guard.check_single(np.array([3.0, -3.0, -3.0]))

assert result.safe
assert result.safe_action == [3.0, -3.0, -3.0]

def test_batch_check(self):
guard = ActionGuard.default(num_joints=3)
actions = np.array([
Expand Down
Loading