Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions solo/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typer
from typing import Optional


app = typer.Typer()

# Lazy-loaded commands to improve CLI startup performance
Expand Down Expand Up @@ -183,6 +184,17 @@ def setup_usb_cmd(
from solo.commands.setup_usb import setup_usb
setup_usb(auto_confirm=yes)

@app.command()
def audit(
file_path: str = typer.Argument(..., help="Path to the .parquet dataset file"),
threshold: float = typer.Option(0.5, help="Jump magnitude limit for jerky movement"),
plot: bool = typer.Option(False, "--plot", "-p", help="Visualize the movement jumps"),
):
"""
Audit a robotics dataset for kinematic glitches and stability.
"""
from .commands.audit import data as _audit
_audit(file_path=file_path, threshold=threshold, plot=plot)

if __name__ == "__main__":
app()
58 changes: 58 additions & 0 deletions solo/commands/audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pandas as pd
import numpy as np
from pathlib import Path
from rich.console import Console
import matplotlib.pyplot as plt

console = Console()

def data(file_path: str, threshold: float = 0.5, plot: bool = False):
file_ptr = Path(file_path)

if not file_ptr.exists():
console.print(f"[red]❌ Error: File {file_path} not found![/red]")
return

console.print(f"[blue]🔍 Auditing {file_ptr.name}...[/blue]")

try:
df = pd.read_parquet(file_ptr)
# Assuming 'action' is the column with movement data
actions = np.stack(df['action'].values)
jumps = np.abs(np.diff(actions, axis=0))
# Take the maximum jump across all dimensions (x, y, z, etc.)
max_jumps_per_frame = jumps.max(axis=1)
absolute_max = max_jumps_per_frame.max()

console.print(f"✅ Loaded {len(df)} frames.")
console.print(f"📊 Sharpest movement: [bold]{absolute_max:.2f}[/bold]")

glitch_frame = np.argmax(max_jumps_per_frame)

if absolute_max > threshold:
console.print(f"[red]🚨 ALERT: Significant glitch detected at Frame {glitch_frame}![/red]")
else:
console.print("[green]✨ Data Quality: Smooth movement detected.[/green]")

# --- NEW PLOTTING LOGIC ---
if plot:
console.print("[yellow]📈 Generating visualization...[/yellow]")
plt.figure(figsize=(10, 5))
plt.plot(max_jumps_per_frame, label='Inter-frame Jump Magnitude', color='royalblue')
plt.axhline(y=threshold, color='red', linestyle='--', label='Threshold')

if absolute_max > threshold:
plt.annotate(f'GLITCH @ {glitch_frame}',
xy=(glitch_frame, absolute_max),
xytext=(glitch_frame + 500, absolute_max),
arrowprops=dict(facecolor='black', shrink=0.05))

plt.title(f"Kinematic Audit: {file_ptr.name}")
plt.xlabel("Frame Index")
plt.ylabel("Jump Magnitude")
plt.legend()
plt.grid(alpha=0.3)
plt.show()

except Exception as e:
console.print(f"[red]❌ Logic Error: {e}[/red]")