diff --git a/alignit/cameras/realsense.py b/alignit/cameras/realsense.py
new file mode 100644
index 0000000..8039384
--- /dev/null
+++ b/alignit/cameras/realsense.py
@@ -0,0 +1,583 @@
+# Reference to https://github.com/huggingface/lerobot/tree/main
+# Changed async_read now returns both RGB and depth images
+
+import logging
+import time
+from threading import Event, Lock, Thread
+from typing import Any
+
+import cv2
+import numpy as np
+
+try:
+ import pyrealsense2 as rs
+except Exception as e:
+ logging.info(f"Could not import realsense: {e}")
+
+from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+
+from lerobot.cameras.camera import Camera
+from lerobot.configs import ColorMode
+from lerobot.utils import get_cv2_rotation
+from lerobot.cameras.realsense import RealSenseCameraConfig
+
+logger = logging.getLogger(__name__)
+
+
+class RealSenseCamera(Camera):
+ """
+ Manages interactions with Intel RealSense cameras for frame and depth recording.
+
+ This class provides an interface similar to `OpenCVCamera` but tailored for
+ RealSense devices, leveraging the `pyrealsense2` library. It uses the camera's
+ unique serial number for identification, offering more stability than device
+ indices, especially on Linux. It also supports capturing depth maps alongside
+ color frames.
+
+ Use the provided utility script to find available camera indices and default profiles:
+ ```bash
+ python -m lerobot.find_cameras realsense
+ ```
+
+ A `RealSenseCamera` instance requires a configuration object specifying the
+ camera's serial number or a unique device name. If using the name, ensure only
+ one camera with that name is connected.
+
+ The camera's default settings (FPS, resolution, color mode) from the stream
+ profile are used unless overridden in the configuration.
+
+ Example:
+ ```python
+ from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
+ from lerobot.cameras import ColorMode, Cv2Rotation
+
+ # Basic usage with serial number
+ config = RealSenseCameraConfig(serial_number_or_name="0123456789") # Replace with actual SN
+ camera = RealSenseCamera(config)
+ camera.connect()
+
+ # Read 1 frame synchronously
+ color_image = camera.read()
+ print(color_image.shape)
+
+ # Read 1 frame asynchronously
+ async_image = camera.async_read()
+
+ # When done, properly disconnect the camera using
+ camera.disconnect()
+
+ # Example with depth capture and custom settings
+ custom_config = RealSenseCameraConfig(
+ serial_number_or_name="0123456789", # Replace with actual SN
+ fps=30,
+ width=1280,
+ height=720,
+ color_mode=ColorMode.BGR, # Request BGR output
+ rotation=Cv2Rotation.NO_ROTATION,
+ use_depth=True
+ )
+ depth_camera = RealSenseCamera(custom_config)
+ depth_camera.connect()
+
+ # Read 1 depth frame
+ depth_map = depth_camera.read_depth()
+
+ # Example using a unique camera name
+ name_config = RealSenseCameraConfig(serial_number_or_name="Intel RealSense D435") # If unique
+ name_camera = RealSenseCamera(name_config)
+ # ... connect, read, disconnect ...
+ ```
+ """
+
+ def __init__(self, config: RealSenseCameraConfig):
+ """
+ Initializes the RealSenseCamera instance.
+
+ Args:
+ config: The configuration settings for the camera.
+ """
+
+ super().__init__(config)
+
+ self.config = config
+
+ if config.serial_number_or_name.isdigit():
+ self.serial_number = config.serial_number_or_name
+ else:
+ self.serial_number = self._find_serial_number_from_name(
+ config.serial_number_or_name
+ )
+
+ self.fps = config.fps
+ self.color_mode = config.color_mode
+ self.use_depth = config.use_depth
+ self.warmup_s = config.warmup_s
+ self.latest_frame_acquisition_time = None
+ self.rs_pipeline: rs.pipeline | None = None
+ self.rs_profile: rs.pipeline_profile | None = None
+ self.align = rs.align(rs.stream.color)
+
+ self.thread: Thread | None = None
+ self.stop_event: Event | None = None
+ self.frame_lock: Lock = Lock()
+ self.latest_frame: np.ndarray | None = None
+ self.latest_frame_depth: np.ndarray | None = None
+ self.new_frame_event: Event = Event()
+
+ self.rotation: int | None = get_cv2_rotation(config.rotation)
+
+ if self.height and self.width:
+ self.capture_width, self.capture_height = self.width, self.height
+ if self.rotation in [
+ cv2.ROTATE_90_CLOCKWISE,
+ cv2.ROTATE_90_COUNTERCLOCKWISE,
+ ]:
+ self.capture_width, self.capture_height = self.height, self.width
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}({self.serial_number})"
+
+ @property
+ def is_connected(self) -> bool:
+ """Checks if the camera pipeline is started and streams are active."""
+ return self.rs_pipeline is not None and self.rs_profile is not None
+
+ def connect(self, warmup: bool = True):
+ """
+ Connects to the RealSense camera specified in the configuration.
+
+ Initializes the RealSense pipeline, configures the required streams (color
+ and optionally depth), starts the pipeline, and validates the actual stream settings.
+
+ Raises:
+ DeviceAlreadyConnectedError: If the camera is already connected.
+ ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique).
+ ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all.
+ RuntimeError: If the pipeline starts but fails to apply requested settings.
+ """
+ if self.is_connected:
+ raise DeviceAlreadyConnectedError(f"{self} is already connected.")
+
+ self.rs_pipeline = rs.pipeline()
+ rs_config = rs.config()
+ self._configure_rs_pipeline_config(rs_config)
+
+ try:
+ self.rs_profile = self.rs_pipeline.start(rs_config)
+ except RuntimeError as e:
+ self.rs_profile = None
+ self.rs_pipeline = None
+ raise ConnectionError(
+ f"Failed to open {self}."
+ "Run `python -m lerobot.find_cameras realsense` to find available cameras."
+ ) from e
+
+ self._configure_capture_settings()
+
+ if warmup:
+ time.sleep(
+ 1
+ ) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise.
+ start_time = time.time()
+ while time.time() - start_time < self.warmup_s:
+ self.read()
+ time.sleep(0.1)
+
+ logger.info(f"{self} connected.")
+
+ @staticmethod
+ def find_cameras() -> list[dict[str, Any]]:
+ """
+ Detects available Intel RealSense cameras connected to the system.
+
+ Returns:
+ List[Dict[str, Any]]: A list of dictionaries,
+ where each dictionary contains 'type', 'id' (serial number), 'name',
+ firmware version, USB type, and other available specs, and the default profile properties (width, height, fps, format).
+
+ Raises:
+ OSError: If pyrealsense2 is not installed.
+ ImportError: If pyrealsense2 is not installed.
+ """
+ found_cameras_info = []
+ context = rs.context()
+ devices = context.query_devices()
+
+ for device in devices:
+ camera_info = {
+ "name": device.get_info(rs.camera_info.name),
+ "type": "RealSense",
+ "id": device.get_info(rs.camera_info.serial_number),
+ "firmware_version": device.get_info(rs.camera_info.firmware_version),
+ "usb_type_descriptor": device.get_info(
+ rs.camera_info.usb_type_descriptor
+ ),
+ "physical_port": device.get_info(rs.camera_info.physical_port),
+ "product_id": device.get_info(rs.camera_info.product_id),
+ "product_line": device.get_info(rs.camera_info.product_line),
+ }
+
+ # Get stream profiles for each sensor
+ sensors = device.query_sensors()
+ for sensor in sensors:
+ profiles = sensor.get_stream_profiles()
+
+ for profile in profiles:
+ if profile.is_video_stream_profile() and profile.is_default():
+ vprofile = profile.as_video_stream_profile()
+ stream_info = {
+ "stream_type": vprofile.stream_name(),
+ "format": vprofile.format().name,
+ "width": vprofile.width(),
+ "height": vprofile.height(),
+ "fps": vprofile.fps(),
+ }
+ camera_info["default_stream_profile"] = stream_info
+
+ found_cameras_info.append(camera_info)
+
+ return found_cameras_info
+
+ def _find_serial_number_from_name(self, name: str) -> str:
+ """Finds the serial number for a given unique camera name."""
+ camera_infos = self.find_cameras()
+ found_devices = [cam for cam in camera_infos if str(cam["name"]) == name]
+
+ if not found_devices:
+ available_names = [cam["name"] for cam in camera_infos]
+ raise ValueError(
+ f"No RealSense camera found with name '{name}'. Available camera names: {available_names}"
+ )
+
+ if len(found_devices) > 1:
+ serial_numbers = [dev["serial_number"] for dev in found_devices]
+ raise ValueError(
+ f"Multiple RealSense cameras found with name '{name}'. "
+ f"Please use a unique serial number instead. Found SNs: {serial_numbers}"
+ )
+
+ serial_number = str(found_devices[0]["serial_number"])
+ return serial_number
+
+ def _configure_rs_pipeline_config(self, rs_config):
+ """Creates and configures the RealSense pipeline configuration object."""
+ rs.config.enable_device(rs_config, self.serial_number)
+
+ if self.width and self.height and self.fps:
+ rs_config.enable_stream(
+ rs.stream.color,
+ self.capture_width,
+ self.capture_height,
+ rs.format.rgb8,
+ self.fps,
+ )
+ if self.use_depth:
+ rs_config.enable_stream(
+ rs.stream.depth,
+ self.capture_width,
+ self.capture_height,
+ rs.format.z16,
+ self.fps,
+ )
+ else:
+ rs_config.enable_stream(rs.stream.color)
+ if self.use_depth:
+ rs_config.enable_stream(rs.stream.depth)
+
+ def _configure_capture_settings(self) -> None:
+ """Sets fps, width, and height from device stream if not already configured.
+
+ Uses the color stream profile to update unset attributes. Handles rotation by
+ swapping width/height when needed. Original capture dimensions are always stored.
+
+ Raises:
+ DeviceNotConnectedError: If device is not connected.
+ """
+ if not self.is_connected:
+ raise DeviceNotConnectedError(
+ f"Cannot validate settings for {self} as it is not connected."
+ )
+
+ stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
+
+ if self.fps is None:
+ self.fps = stream.fps()
+
+ if self.width is None or self.height is None:
+ actual_width = int(round(stream.width()))
+ actual_height = int(round(stream.height()))
+ if self.rotation in [
+ cv2.ROTATE_90_CLOCKWISE,
+ cv2.ROTATE_90_COUNTERCLOCKWISE,
+ ]:
+ self.width, self.height = actual_height, actual_width
+ self.capture_width, self.capture_height = actual_width, actual_height
+ else:
+ self.width, self.height = actual_width, actual_height
+ self.capture_width, self.capture_height = actual_width, actual_height
+
+ def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
+ """
+ Reads a single frame (depth) synchronously from the camera.
+
+ This is a blocking call. It waits for a coherent set of frames (depth)
+ from the camera hardware via the RealSense pipeline.
+
+ Args:
+ timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
+
+ Returns:
+ np.ndarray: The depth map as a NumPy array (height, width)
+ of type `np.uint16` (raw depth values in millimeters) and rotation.
+
+ Raises:
+ DeviceNotConnectedError: If the camera is not connected.
+ RuntimeError: If reading frames from the pipeline fails or frames are invalid.
+ """
+
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+ if not self.use_depth:
+ raise RuntimeError(
+ f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}."
+ )
+
+ start_time = time.perf_counter()
+
+ ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
+
+ if not ret or frame is None:
+ raise RuntimeError(f"{self} read_depth failed (status={ret}).")
+
+ depth_frame = frame.get_depth_frame()
+ depth_map = np.asanyarray(depth_frame.get_data())
+
+ depth_map_processed = self._postprocess_image(depth_map, depth_frame=True)
+
+ read_duration_ms = (time.perf_counter() - start_time) * 1e3
+ logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
+
+ return depth_map_processed
+
+ def read(
+ self, color_mode: ColorMode | None = None, timeout_ms: int = 200
+ ) -> np.ndarray:
+ """
+ Reads a single frame (color) synchronously from the camera.
+
+ This is a blocking call. It waits for a coherent set of frames (color)
+ from the camera hardware via the RealSense pipeline.
+
+ Args:
+ timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms.
+
+ Returns:
+ np.ndarray: The captured color frame as a NumPy array
+ (height, width, channels), processed according to `color_mode` and rotation.
+
+ Raises:
+ DeviceNotConnectedError: If the camera is not connected.
+ RuntimeError: If reading frames from the pipeline fails or frames are invalid.
+ ValueError: If an invalid `color_mode` is requested.
+ """
+
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ start_time = time.perf_counter()
+
+ ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
+
+ if not ret or frame is None:
+ raise RuntimeError(f"{self} read failed (status={ret}).")
+
+ color_frame = frame.get_color_frame()
+ color_image_raw = np.asanyarray(color_frame.get_data())
+
+ color_image_processed = self._postprocess_image(color_image_raw, color_mode)
+
+ read_duration_ms = (time.perf_counter() - start_time) * 1e3
+ logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
+
+ return color_image_processed
+
+ def _postprocess_image(
+ self,
+ image: np.ndarray,
+ color_mode: ColorMode | None = None,
+ depth_frame: bool = False,
+ ) -> np.ndarray:
+ """
+ Applies color conversion, dimension validation, and rotation to a raw color frame.
+
+ Args:
+ image (np.ndarray): The raw image frame (expected RGB format from RealSense).
+ color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None,
+ uses the instance's default `self.color_mode`.
+
+ Returns:
+ np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`.
+
+ Raises:
+ ValueError: If the requested `color_mode` is invalid.
+ RuntimeError: If the raw frame dimensions do not match the configured
+ `width` and `height`.
+ """
+
+ if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR):
+ raise ValueError(
+ f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
+ )
+
+ if depth_frame:
+ h, w = image.shape
+ else:
+ h, w, c = image.shape
+
+ if c != 3:
+ raise RuntimeError(
+ f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR)."
+ )
+
+ if h != self.capture_height or w != self.capture_width:
+ raise RuntimeError(
+ f"{self} frame width={w} or height={h} do not match configured width={self.capture_width} or height={self.capture_height}."
+ )
+
+ processed_image = image
+ if self.color_mode == ColorMode.BGR:
+ processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+
+ if self.rotation in [
+ cv2.ROTATE_90_CLOCKWISE,
+ cv2.ROTATE_90_COUNTERCLOCKWISE,
+ cv2.ROTATE_180,
+ ]:
+ processed_image = cv2.rotate(processed_image, self.rotation)
+
+ return processed_image
+
+ def _read_loop(self):
+ """
+ Internal loop run by the background thread for asynchronous reading.
+
+ On each iteration:
+ 1. Reads a color frame with 500ms timeout
+ 2. Stores result in latest_frame (thread-safe)
+ 3. Sets new_frame_event to notify listeners
+
+ Stops on DeviceNotConnectedError, logs other errors and continues.
+ """
+ while not self.stop_event.is_set():
+ try:
+ color_image = self.read(timeout_ms=500)
+ depth_image = self.read_depth(timeout_ms=500)
+
+ with self.frame_lock:
+ self.latest_frame = color_image, depth_image
+ self.new_frame_event.set()
+
+ except DeviceNotConnectedError:
+ break
+ except Exception as e:
+ logger.warning(
+ f"Error reading frame in background thread for {self}: {e}"
+ )
+
+ def _start_read_thread(self) -> None:
+ """Starts or restarts the background read thread if it's not running."""
+ if self.thread is not None and self.thread.is_alive():
+ self.thread.join(timeout=0.1)
+ if self.stop_event is not None:
+ self.stop_event.set()
+
+ self.stop_event = Event()
+ self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
+ self.thread.daemon = True
+ self.thread.start()
+
+ def _stop_read_thread(self):
+ """Signals the background read thread to stop and waits for it to join."""
+ if self.stop_event is not None:
+ self.stop_event.set()
+
+ if self.thread is not None and self.thread.is_alive():
+ self.thread.join(timeout=2.0)
+
+ self.thread = None
+ self.stop_event = None
+
+ # NOTE(Steven): Missing implementation for depth for now
+ def async_read(self, timeout_ms: float = 200) -> np.ndarray:
+ """
+ Reads the latest available frame data (color) asynchronously.
+
+ This method retrieves the most recent color frame captured by the background
+ read thread. It does not block waiting for the camera hardware directly,
+ but may wait up to timeout_ms for the background thread to provide a frame.
+
+ Args:
+ timeout_ms (float): Maximum time in milliseconds to wait for a frame
+ to become available. Defaults to 200ms (0.2 seconds).
+
+ Returns:
+ np.ndarray:
+ The latest captured frame data (color image), processed according to configuration.
+
+ Raises:
+ DeviceNotConnectedError: If the camera is not connected.
+ TimeoutError: If no frame data becomes available within the specified timeout.
+ RuntimeError: If the background thread died unexpectedly or another error occurs.
+ """
+ if not self.is_connected:
+ raise DeviceNotConnectedError(f"{self} is not connected.")
+
+ if self.thread is None or not self.thread.is_alive():
+ self._start_read_thread()
+
+ if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
+ thread_alive = self.thread is not None and self.thread.is_alive()
+ raise TimeoutError(
+ f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. "
+ f"Read thread alive: {thread_alive}."
+ )
+
+ with self.frame_lock:
+ frame = self.latest_frame
+ depth = self.latest_frame_depth
+ acquisition_time = (
+ self.latest_frame_acquisition_time
+ ) # Retrieve the timestamp
+ self.new_frame_event.clear()
+
+ if frame is None:
+ raise RuntimeError(
+ f"Internal error: Event set but no frame available for {self}."
+ )
+
+ return frame, depth, acquisition_time
+
+ def disconnect(self):
+ """
+ Disconnects from the camera, stops the pipeline, and cleans up resources.
+
+ Stops the background read thread (if running) and stops the RealSense pipeline.
+
+ Raises:
+ DeviceNotConnectedError: If the camera is already disconnected (pipeline not running).
+ """
+
+ if not self.is_connected and self.thread is None:
+ raise DeviceNotConnectedError(
+ f"Attempted to disconnect {self}, but it appears already disconnected."
+ )
+
+ if self.thread is not None:
+ self._stop_read_thread()
+
+ if self.rs_pipeline is not None:
+ self.rs_pipeline.stop()
+ self.rs_pipeline = None
+ self.rs_profile = None
+
+ logger.info(f"{self} disconnected.")
diff --git a/alignit/config.py b/alignit/config.py
index df3b71c..15b8ff2 100644
--- a/alignit/config.py
+++ b/alignit/config.py
@@ -2,6 +2,7 @@
from dataclasses import dataclass, field
from typing import Optional, List
+
import numpy as np
@@ -10,7 +11,7 @@ class DatasetConfig:
"""Configuration for dataset paths and loading."""
path: str = field(
- default="./data/duck", metadata={"help": "Path to the dataset directory"}
+ default="./data/default", metadata={"help": "Path to the dataset directory"}
)
@@ -46,6 +47,12 @@ class ModelConfig:
default="alignnet_model.pth",
metadata={"help": "Path to save/load trained model"},
)
+ use_depth_input: bool = field(
+ default=True, metadata={"help": "Whether to use depth input for the model"}
+ )
+ depth_hidden_dim: int = field(
+ default=128, metadata={"help": "Output dimension of depth CNN"}
+ )
@dataclass
@@ -98,6 +105,13 @@ class RecordConfig:
ang_tol_trajectory: float = field(
default=0.05, metadata={"help": "Angular tolerance for trajectory servo"}
)
+ manual_height: float = field(
+ default=-0.05, metadata={"help": "Height above surface for manual movement"}
+ )
+ world_z_offset: float = field(
+ default=-0.02,
+ metadata={"help": "World frame Z offset after manual positioning"},
+ )
@dataclass
@@ -106,7 +120,7 @@ class TrainConfig:
dataset: DatasetConfig = field(default_factory=DatasetConfig)
model: ModelConfig = field(default_factory=ModelConfig)
- batch_size: int = field(default=8, metadata={"help": "Training batch size"})
+ batch_size: int = field(default=4, metadata={"help": "Training batch size"})
learning_rate: float = field(
default=1e-4, metadata={"help": "Learning rate for optimizer"}
)
@@ -133,22 +147,31 @@ class InferConfig:
metadata={"help": "Starting pose RPY angles"},
)
lin_tolerance: float = field(
- default=2e-3, metadata={"help": "Linear tolerance for convergence (meters)"}
+ default=5e-3, metadata={"help": "Linear tolerance for convergence (meters)"}
)
ang_tolerance: float = field(
- default=2, metadata={"help": "Angular tolerance for convergence (degrees)"}
+ default=5, metadata={"help": "Angular tolerance for convergence (degrees)"}
)
max_iterations: Optional[int] = field(
- default=None,
+ default=20,
metadata={"help": "Maximum iterations before stopping (None = infinite)"},
)
debug_output: bool = field(
default=True, metadata={"help": "Print debug information during inference"}
)
debouncing_count: int = field(
- default=5,
+ default=20,
metadata={"help": "Number of iterations within tolerance before stopping"},
)
+ rotation_matrix_multiplier: int = field(
+ default=3,
+ metadata={
+ "help": "Number of times to multiply the rotation matrix of relative action in order to speed up convergence"
+ },
+ )
+ manual_height: float = field(
+ default=0.08, metadata={"help": "Height above surface for manual movement"}
+ )
@dataclass
diff --git a/alignit/infere.py b/alignit/infere.py
index eb5810a..e533162 100644
--- a/alignit/infere.py
+++ b/alignit/infere.py
@@ -1,11 +1,11 @@
-import transforms3d as t3d
-import numpy as np
import time
-import draccus
-from alignit.config import InferConfig
import torch
+import transforms3d as t3d
+import numpy as np
+import draccus
+from alignit.config import InferConfig
from alignit.models.alignnet import AlignNet
from alignit.utils.zhou import sixd_se3
from alignit.utils.tfs import print_pose, are_tfs_close
@@ -19,7 +19,6 @@ def main(cfg: InferConfig):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
- # load model from file
net = AlignNet(
backbone_name=cfg.model.backbone,
backbone_weights=cfg.model.backbone_weights,
@@ -28,6 +27,7 @@ def main(cfg: InferConfig):
vector_hidden_dim=cfg.model.vector_hidden_dim,
output_dim=cfg.model.output_dim,
feature_agg=cfg.model.feature_agg,
+ use_depth_input=cfg.model.use_depth_input,
)
net.load_state_dict(torch.load(cfg.model.path, map_location=device))
net.to(device)
@@ -35,42 +35,57 @@ def main(cfg: InferConfig):
robot = XarmSim()
- # Set initial pose from config
start_pose = t3d.affines.compose(
[0.23, 0, 0.25], t3d.euler.euler2mat(np.pi, 0, 0), [1, 1, 1]
)
robot.servo_to_pose(start_pose, lin_tol=1e-2)
-
iteration = 0
iterations_within_tolerance = 0
ang_tol_rad = np.deg2rad(cfg.ang_tolerance)
-
try:
while True:
observation = robot.get_observation()
- images = [observation["camera.rgb"].astype(np.float32) / 255.0]
-
- # Convert images to tensor and reshape from HWC to CHW format
- images_tensor = (
- torch.from_numpy(np.array(images))
- .permute(0, 3, 1, 2)
+ rgb_image = observation["rgb"].astype(np.float32) / 255.0
+ depth_image = observation["depth"].astype(np.float32)
+ print(
+ "Min/Max depth,mean (raw):",
+ observation["depth"].min(),
+ observation["depth"].max(),
+ observation["depth"].mean(),
+ )
+ print(
+ "Min/Max depth,mean (scaled):",
+ depth_image.min(),
+ depth_image.max(),
+ depth_image.mean(),
+ )
+ rgb_image_tensor = (
+ torch.from_numpy(np.array(rgb_image))
+ .permute(2, 0, 1) # (H, W, C) -> (C, H, W)
.unsqueeze(0)
.to(device)
)
- if cfg.debug_output:
- print(f"Max pixel value: {torch.max(images_tensor)}")
+ depth_image_tensor = (
+ torch.from_numpy(np.array(depth_image))
+ .unsqueeze(0) # Add channel dimension: (1, H, W)
+ .unsqueeze(0) # Add batch dimension: (1, 1, H, W)
+ .to(device)
+ )
+ rgb_images_batch = rgb_image_tensor.unsqueeze(1)
+ depth_images_batch = depth_image_tensor.unsqueeze(1)
- start = time.time()
with torch.no_grad():
- relative_action = net(images_tensor)
+ relative_action = net(rgb_images_batch, depth_images=depth_images_batch)
relative_action = relative_action.squeeze(0).cpu().numpy()
relative_action = sixd_se3(relative_action)
if cfg.debug_output:
print_pose(relative_action)
- # Check convergence
+ relative_action[:3, :3] = np.linalg.matrix_power(
+ relative_action[:3, :3], cfg.rotation_matrix_multiplier
+ )
if are_tfs_close(
relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad
):
@@ -78,10 +93,7 @@ def main(cfg: InferConfig):
else:
iterations_within_tolerance = 0
- if iterations_within_tolerance >= cfg.debouncing_count:
- print("Alignment achieved - stopping.")
- break
-
+ print(relative_action)
target_pose = robot.pose() @ relative_action
iteration += 1
action = {
@@ -89,14 +101,26 @@ def main(cfg: InferConfig):
"gripper.pos": 1.0,
}
robot.send_action(action)
-
- # Check max iterations
- if cfg.max_iterations and iteration >= cfg.max_iterations:
+ if iterations_within_tolerance >= cfg.max_iterations:
print(f"Reached maximum iterations ({cfg.max_iterations}) - stopping.")
+ print("Moving robot to final pose.")
+ current_pose = robot.pose()
+ gripper_z_offset = np.array(
+ [
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, cfg.manual_height],
+ [0, 0, 0, 1],
+ ]
+ )
+ offset_pose = current_pose @ gripper_z_offset
+ robot.servo_to_pose(pose=offset_pose)
+ robot.close_gripper()
+ robot.gripper_off()
+
break
time.sleep(10.0)
-
except KeyboardInterrupt:
print("\nExiting...")
diff --git a/alignit/models/alignnet.py b/alignit/models/alignnet.py
index 6da726f..8f652bf 100644
--- a/alignit/models/alignnet.py
+++ b/alignit/models/alignnet.py
@@ -10,8 +10,10 @@ def __init__(
backbone_name="efficientnet_b0",
backbone_weights="DEFAULT",
use_vector_input=True,
+ use_depth_input=True,
fc_layers=[256, 128],
vector_hidden_dim=64,
+ depth_hidden_dim=128,
output_dim=7,
feature_agg="mean",
):
@@ -23,27 +25,39 @@ def __init__(
:param vector_hidden_dim: output dim of the vector MLP
:param output_dim: final output vector size
:param feature_agg: 'mean' or 'max' across image views
+ :param use_depth_input: whether to accept depth input
+ :param depth_hidden_dim: output dim of the depth MLP
"""
super().__init__()
self.use_vector_input = use_vector_input
+ self.use_depth_input = use_depth_input
self.feature_agg = feature_agg
- # CNN backbone
self.backbone, self.image_feature_dim = self._build_backbone(
backbone_name, backbone_weights
)
- # Linear projection of image features
self.image_fc = nn.Sequential(
nn.Linear(self.image_feature_dim, fc_layers[0]), nn.ReLU()
)
+ if use_depth_input:
+ self.depth_cnn = nn.Sequential(
+ nn.Conv2d(1, 8, 3, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(8, 16, 3, padding=1),
+ nn.ReLU(),
+ nn.AdaptiveAvgPool2d(1),
+ )
+ self.depth_fc = nn.Sequential(nn.Linear(16, depth_hidden_dim), nn.ReLU())
+ input_dim = fc_layers[0] + depth_hidden_dim
+ else:
+ input_dim = fc_layers[0]
+
# Optional vector input processing
if use_vector_input:
self.vector_fc = nn.Sequential(nn.Linear(1, vector_hidden_dim), nn.ReLU())
- input_dim = fc_layers[0] + vector_hidden_dim
- else:
- input_dim = fc_layers[0]
+ input_dim += vector_hidden_dim
# Fully connected layers
layers = []
@@ -81,10 +95,11 @@ def aggregate_image_features(self, feats):
else:
raise ValueError("Invalid aggregation type")
- def forward(self, rgb_images, vector_inputs=None):
+ def forward(self, rgb_images, vector_inputs=None, depth_images=None):
"""
:param rgb_images: Tensor of shape (B, N, 3, H, W)
:param vector_inputs: List of tensors of shape (L_i,) or None
+ :param depth_images: Tensor of shape (B, N, 1, H, W) or None
:return: Tensor of shape (B, output_dim)
"""
B, N, C, H, W = rgb_images.shape
@@ -93,6 +108,15 @@ def forward(self, rgb_images, vector_inputs=None):
image_feats = self.aggregate_image_features(feats)
image_feats = self.image_fc(image_feats)
+ features = [image_feats]
+
+ if self.use_depth_input and depth_images is not None:
+ depth = depth_images.view(B * N, 1, H, W)
+ depth_feats = self.depth_cnn(depth).view(B, N, -1)
+ depth_feats = self.aggregate_image_features(depth_feats)
+ depth_feats = self.depth_fc(depth_feats)
+ features.append(depth_feats)
+
if self.use_vector_input and vector_inputs is not None:
vec_feats = []
for vec in vector_inputs:
@@ -100,8 +124,9 @@ def forward(self, rgb_images, vector_inputs=None):
pooled = self.vector_fc(vec).mean(dim=0) # (D,)
vec_feats.append(pooled)
vec_feats = torch.stack(vec_feats, dim=0)
- fused = torch.cat([image_feats, vec_feats], dim=1)
- else:
- fused = image_feats
+ features.append(vec_feats)
+
+ fused = torch.cat(features, dim=1)
+ print("Fused shape:", fused.shape)
return self.head(fused) # (B, output_dim)
diff --git a/alignit/record.py b/alignit/record.py
index 93fc7d5..00e897b 100644
--- a/alignit/record.py
+++ b/alignit/record.py
@@ -82,22 +82,17 @@ def main(cfg: RecordConfig):
"""Record alignment dataset using configuration parameters."""
robot = XarmSim()
features = Features(
- {"images": Sequence(Image()), "action": Sequence(Value("float32"))}
+ {
+ "images": Sequence(Image()),
+ "action": Sequence(Value("float32")),
+ "depth": Sequence(Image()),
+ }
)
for episode in range(cfg.episodes):
pose_start, pose_alignment_target = robot.reset()
-
- robot.servo_to_pose(pose_alignment_target, lin_tol=0.015, ang_tol=0.015)
-
- robot.servo_to_pose(
- pose_alignment_target,
- lin_tol=cfg.lin_tol_alignment,
- ang_tol=cfg.ang_tol_alignment,
- )
-
trajectory = generate_spiral_trajectory(pose_start, cfg.trajectory)
-
+ pose = robot.pose()
frames = []
for pose in trajectory:
robot.servo_to_pose(
@@ -109,28 +104,33 @@ def main(cfg: RecordConfig):
action_sixd = se3_sixd(action_pose)
observation = robot.get_observation()
+ print(observation.keys())
frame = {
- "images": [observation["camera.rgb"].copy()],
+ "images": [observation["rgb"].copy()],
"action": action_sixd,
+ "depth": [observation["depth"].copy()],
}
frames.append(frame)
-
print(f"Episode {episode+1} completed with {len(frames)} frames.")
episode_dataset = Dataset.from_list(frames, features=features)
- if episode == 0:
- combined_dataset = episode_dataset
+
+ # 2. Load existing dataset if available
+ if os.path.exists(cfg.dataset.path):
+ existing_dataset = load_from_disk(cfg.dataset.path)
+ existing_dataset = existing_dataset.cast(features)
+ combined_dataset = concatenate_datasets([existing_dataset, episode_dataset])
else:
- previous_dataset = load_from_disk(cfg.dataset.path)
- previous_dataset = previous_dataset.cast(features)
- combined_dataset = concatenate_datasets([previous_dataset, episode_dataset])
- del previous_dataset
+ combined_dataset = episode_dataset
+ # 3. Save to TEMPORARY location first (avoid self-overwrite)
temp_path = f"{cfg.dataset.path}_temp"
combined_dataset.save_to_disk(temp_path)
+
+ # 4. Atomic replacement (only after successful save)
if os.path.exists(cfg.dataset.path):
- shutil.rmtree(cfg.dataset.path)
- shutil.move(temp_path, cfg.dataset.path)
+ shutil.rmtree(cfg.dataset.path) # Remove old version
+ shutil.move(temp_path, cfg.dataset.path) # Move new version into place
robot.disconnect()
diff --git a/alignit/robots/xarm.py b/alignit/robots/xarm.py
index 075c2cc..39c69fd 100644
--- a/alignit/robots/xarm.py
+++ b/alignit/robots/xarm.py
@@ -1,13 +1,14 @@
import time
+import numpy as np
+import transforms3d as t3d
from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
from lerobot_xarm.xarm import Xarm as LeXarm
from lerobot_xarm.config import XarmConfig
-import numpy as np
-import transforms3d as t3d
from alignit.robots.robot import Robot
from alignit.utils.tfs import are_tfs_close
+from alignit.config import RecordConfig
class Xarm(Robot):
@@ -32,11 +33,17 @@ def _connect(self):
def send_action(self, action):
self.robot.send_action(action)
+ def get_intrinsics(self):
+ return self.camera.get_intrinsics()
+
def get_observation(self):
- rgb_image = self.camera.read()
+ rgb_image, depth_image, acquisition_time = self.camera.async_read()
+ depth_array_clipped = np.clip(np.array(depth_image), a_min=0, a_max=1000)
+ depth_image = np.array(depth_array_clipped) / 1000.0
return {
- "camera.rgb": rgb_image,
+ "rgb": rgb_image,
+ "depth": depth_image,
}
def disconnect(self):
@@ -46,12 +53,33 @@ def servo_to_pose(self, pose, lin_tol=1e-3, ang_tol=1e-2):
while not are_tfs_close(self.pose(), pose, lin_tol, ang_tol):
action = {
"pose": pose,
- "gripper.pos": 1.0, # Optional: set gripper state (0.0=closed, 1.0=open)
+ "gripper.pos": 1.0,
}
self.send_action(action)
- time.sleep(1.0 / 60.0) # Adjust frequency as needed
+ time.sleep(1.0 / 60.0)
+
+ def close_gripper(self):
+ action = {
+ "pose": self.pose(),
+ "gripper.pos": 0.0,
+ }
+ self.send_action(action)
- def reset(self):
+ def open_gripper(self):
+ action = {
+ "pose": self.pose(),
+ "gripper.pos": 1.0,
+ }
+ self.send_action(action)
+
+ def gripper_off(self):
+ action = {
+ "pose": self.pose(),
+ "gripper.pos": 0.5,
+ }
+ self.send_action(action)
+
+ def reset(self, cfg: RecordConfig):
"""
Reset routine:
1. Allows manual movement of the arm
@@ -64,30 +92,26 @@ def reset(self):
manual_height: Height above surface to maintain during manual movement (meters)
world_z_offset: Additional Z offset in world frame after manual positioning (meters)
"""
- manual_height = -0.05
- world_z_offset = -0.02
self.robot.disconnect()
input("Press Enter after positioning the arm...")
self.robot.connect()
current_pose = self.pose()
gripper_z_offset = np.array(
- [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, manual_height], [0, 0, 0, 1]]
+ [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cfg.manual_height], [0, 0, 0, 1]]
)
offset_pose = current_pose @ gripper_z_offset
self.servo_to_pose(pose=offset_pose)
world_z_offset_mat = np.array(
- [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, world_z_offset], [0, 0, 0, 1]]
+ [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cfg.world_z_offset], [0, 0, 0, 1]]
)
final_pose = offset_pose @ world_z_offset_mat
self.servo_to_pose(pose=final_pose)
-
+ current_pose = self.pose()
pose_start = current_pose @ t3d.affines.compose(
- [0, 0, -0.090], t3d.euler.euler2mat(0, 0, 0), [1, 1, 1]
- )
- pose_alignment_target = current_pose @ t3d.affines.compose(
- [0, 0, -0.1], t3d.euler.euler2mat(0, 0, 0), [1, 1, 1]
+ [0, 0, -0.01], t3d.euler.euler2mat(0, 0, 0), [1, 1, 1]
)
+ pose_alignment_target = current_pose
_, (position, _, _) = self.robot._arm.get_joint_states()
for i in range(6):
diff --git a/alignit/robots/xarmsim/__init__.py b/alignit/robots/xarmsim/__init__.py
index c64e631..fbad295 100644
--- a/alignit/robots/xarmsim/__init__.py
+++ b/alignit/robots/xarmsim/__init__.py
@@ -125,10 +125,10 @@ def _set_object_pose(self, object_name, pose_matrix):
self.data.qvel[qvel_adr : qvel_adr + 6] = 0
mj.mj_forward(self.model, self.data)
- def _gripper_close(self):
+ def close_gripper(self):
self._set_gripper_position(self.gripper_close_pos)
- def _gripper_open(self):
+ def open_gripper(self):
self._set_gripper_position(self.gripper_open_pos)
def _set_gripper_position(self, pos, tolerance=1e-3, max_sim_steps=2000):
@@ -191,7 +191,14 @@ def get_observation(self):
name = mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_CAMERA, i)
self.renderer.update_scene(self.data, camera=name)
image = self.renderer.render()
- obs["camera." + name] = image[:, :, ::-1]
+ self.renderer.enable_depth_rendering()
+ self.renderer.update_scene(self.data, camera=name)
+ image_depth = self.renderer.render()
+ self.renderer.disable_depth_rendering()
+
+ # TODO: Handle multiple cameras
+ obs["rgb"] = image[:, :, ::-1]
+ obs["depth"] = image_depth
return obs
diff --git a/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_narrow.xml b/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_narrow.xml
index 2281230..0e37df0 100644
--- a/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_narrow.xml
+++ b/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_narrow.xml
@@ -104,7 +104,7 @@
diaginertia="0.00016117 0.000118 0.00014455" />