diff --git a/alignit/cameras/realsense.py b/alignit/cameras/realsense.py new file mode 100644 index 0000000..8039384 --- /dev/null +++ b/alignit/cameras/realsense.py @@ -0,0 +1,583 @@ +# Reference to https://github.com/huggingface/lerobot/tree/main +# Changed async_read now returns both RGB and depth images + +import logging +import time +from threading import Event, Lock, Thread +from typing import Any + +import cv2 +import numpy as np + +try: + import pyrealsense2 as rs +except Exception as e: + logging.info(f"Could not import realsense: {e}") + +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from lerobot.cameras.camera import Camera +from lerobot.configs import ColorMode +from lerobot.utils import get_cv2_rotation +from lerobot.cameras.realsense import RealSenseCameraConfig + +logger = logging.getLogger(__name__) + + +class RealSenseCamera(Camera): + """ + Manages interactions with Intel RealSense cameras for frame and depth recording. + + This class provides an interface similar to `OpenCVCamera` but tailored for + RealSense devices, leveraging the `pyrealsense2` library. It uses the camera's + unique serial number for identification, offering more stability than device + indices, especially on Linux. It also supports capturing depth maps alongside + color frames. + + Use the provided utility script to find available camera indices and default profiles: + ```bash + python -m lerobot.find_cameras realsense + ``` + + A `RealSenseCamera` instance requires a configuration object specifying the + camera's serial number or a unique device name. If using the name, ensure only + one camera with that name is connected. + + The camera's default settings (FPS, resolution, color mode) from the stream + profile are used unless overridden in the configuration. + + Example: + ```python + from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig + from lerobot.cameras import ColorMode, Cv2Rotation + + # Basic usage with serial number + config = RealSenseCameraConfig(serial_number_or_name="0123456789") # Replace with actual SN + camera = RealSenseCamera(config) + camera.connect() + + # Read 1 frame synchronously + color_image = camera.read() + print(color_image.shape) + + # Read 1 frame asynchronously + async_image = camera.async_read() + + # When done, properly disconnect the camera using + camera.disconnect() + + # Example with depth capture and custom settings + custom_config = RealSenseCameraConfig( + serial_number_or_name="0123456789", # Replace with actual SN + fps=30, + width=1280, + height=720, + color_mode=ColorMode.BGR, # Request BGR output + rotation=Cv2Rotation.NO_ROTATION, + use_depth=True + ) + depth_camera = RealSenseCamera(custom_config) + depth_camera.connect() + + # Read 1 depth frame + depth_map = depth_camera.read_depth() + + # Example using a unique camera name + name_config = RealSenseCameraConfig(serial_number_or_name="Intel RealSense D435") # If unique + name_camera = RealSenseCamera(name_config) + # ... connect, read, disconnect ... + ``` + """ + + def __init__(self, config: RealSenseCameraConfig): + """ + Initializes the RealSenseCamera instance. + + Args: + config: The configuration settings for the camera. + """ + + super().__init__(config) + + self.config = config + + if config.serial_number_or_name.isdigit(): + self.serial_number = config.serial_number_or_name + else: + self.serial_number = self._find_serial_number_from_name( + config.serial_number_or_name + ) + + self.fps = config.fps + self.color_mode = config.color_mode + self.use_depth = config.use_depth + self.warmup_s = config.warmup_s + self.latest_frame_acquisition_time = None + self.rs_pipeline: rs.pipeline | None = None + self.rs_profile: rs.pipeline_profile | None = None + self.align = rs.align(rs.stream.color) + + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_lock: Lock = Lock() + self.latest_frame: np.ndarray | None = None + self.latest_frame_depth: np.ndarray | None = None + self.new_frame_event: Event = Event() + + self.rotation: int | None = get_cv2_rotation(config.rotation) + + if self.height and self.width: + self.capture_width, self.capture_height = self.width, self.height + if self.rotation in [ + cv2.ROTATE_90_CLOCKWISE, + cv2.ROTATE_90_COUNTERCLOCKWISE, + ]: + self.capture_width, self.capture_height = self.height, self.width + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.serial_number})" + + @property + def is_connected(self) -> bool: + """Checks if the camera pipeline is started and streams are active.""" + return self.rs_pipeline is not None and self.rs_profile is not None + + def connect(self, warmup: bool = True): + """ + Connects to the RealSense camera specified in the configuration. + + Initializes the RealSense pipeline, configures the required streams (color + and optionally depth), starts the pipeline, and validates the actual stream settings. + + Raises: + DeviceAlreadyConnectedError: If the camera is already connected. + ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique). + ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all. + RuntimeError: If the pipeline starts but fails to apply requested settings. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} is already connected.") + + self.rs_pipeline = rs.pipeline() + rs_config = rs.config() + self._configure_rs_pipeline_config(rs_config) + + try: + self.rs_profile = self.rs_pipeline.start(rs_config) + except RuntimeError as e: + self.rs_profile = None + self.rs_pipeline = None + raise ConnectionError( + f"Failed to open {self}." + "Run `python -m lerobot.find_cameras realsense` to find available cameras." + ) from e + + self._configure_capture_settings() + + if warmup: + time.sleep( + 1 + ) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise. + start_time = time.time() + while time.time() - start_time < self.warmup_s: + self.read() + time.sleep(0.1) + + logger.info(f"{self} connected.") + + @staticmethod + def find_cameras() -> list[dict[str, Any]]: + """ + Detects available Intel RealSense cameras connected to the system. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, + where each dictionary contains 'type', 'id' (serial number), 'name', + firmware version, USB type, and other available specs, and the default profile properties (width, height, fps, format). + + Raises: + OSError: If pyrealsense2 is not installed. + ImportError: If pyrealsense2 is not installed. + """ + found_cameras_info = [] + context = rs.context() + devices = context.query_devices() + + for device in devices: + camera_info = { + "name": device.get_info(rs.camera_info.name), + "type": "RealSense", + "id": device.get_info(rs.camera_info.serial_number), + "firmware_version": device.get_info(rs.camera_info.firmware_version), + "usb_type_descriptor": device.get_info( + rs.camera_info.usb_type_descriptor + ), + "physical_port": device.get_info(rs.camera_info.physical_port), + "product_id": device.get_info(rs.camera_info.product_id), + "product_line": device.get_info(rs.camera_info.product_line), + } + + # Get stream profiles for each sensor + sensors = device.query_sensors() + for sensor in sensors: + profiles = sensor.get_stream_profiles() + + for profile in profiles: + if profile.is_video_stream_profile() and profile.is_default(): + vprofile = profile.as_video_stream_profile() + stream_info = { + "stream_type": vprofile.stream_name(), + "format": vprofile.format().name, + "width": vprofile.width(), + "height": vprofile.height(), + "fps": vprofile.fps(), + } + camera_info["default_stream_profile"] = stream_info + + found_cameras_info.append(camera_info) + + return found_cameras_info + + def _find_serial_number_from_name(self, name: str) -> str: + """Finds the serial number for a given unique camera name.""" + camera_infos = self.find_cameras() + found_devices = [cam for cam in camera_infos if str(cam["name"]) == name] + + if not found_devices: + available_names = [cam["name"] for cam in camera_infos] + raise ValueError( + f"No RealSense camera found with name '{name}'. Available camera names: {available_names}" + ) + + if len(found_devices) > 1: + serial_numbers = [dev["serial_number"] for dev in found_devices] + raise ValueError( + f"Multiple RealSense cameras found with name '{name}'. " + f"Please use a unique serial number instead. Found SNs: {serial_numbers}" + ) + + serial_number = str(found_devices[0]["serial_number"]) + return serial_number + + def _configure_rs_pipeline_config(self, rs_config): + """Creates and configures the RealSense pipeline configuration object.""" + rs.config.enable_device(rs_config, self.serial_number) + + if self.width and self.height and self.fps: + rs_config.enable_stream( + rs.stream.color, + self.capture_width, + self.capture_height, + rs.format.rgb8, + self.fps, + ) + if self.use_depth: + rs_config.enable_stream( + rs.stream.depth, + self.capture_width, + self.capture_height, + rs.format.z16, + self.fps, + ) + else: + rs_config.enable_stream(rs.stream.color) + if self.use_depth: + rs_config.enable_stream(rs.stream.depth) + + def _configure_capture_settings(self) -> None: + """Sets fps, width, and height from device stream if not already configured. + + Uses the color stream profile to update unset attributes. Handles rotation by + swapping width/height when needed. Original capture dimensions are always stored. + + Raises: + DeviceNotConnectedError: If device is not connected. + """ + if not self.is_connected: + raise DeviceNotConnectedError( + f"Cannot validate settings for {self} as it is not connected." + ) + + stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile() + + if self.fps is None: + self.fps = stream.fps() + + if self.width is None or self.height is None: + actual_width = int(round(stream.width())) + actual_height = int(round(stream.height())) + if self.rotation in [ + cv2.ROTATE_90_CLOCKWISE, + cv2.ROTATE_90_COUNTERCLOCKWISE, + ]: + self.width, self.height = actual_height, actual_width + self.capture_width, self.capture_height = actual_width, actual_height + else: + self.width, self.height = actual_width, actual_height + self.capture_width, self.capture_height = actual_width, actual_height + + def read_depth(self, timeout_ms: int = 200) -> np.ndarray: + """ + Reads a single frame (depth) synchronously from the camera. + + This is a blocking call. It waits for a coherent set of frames (depth) + from the camera hardware via the RealSense pipeline. + + Args: + timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. + + Returns: + np.ndarray: The depth map as a NumPy array (height, width) + of type `np.uint16` (raw depth values in millimeters) and rotation. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading frames from the pipeline fails or frames are invalid. + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if not self.use_depth: + raise RuntimeError( + f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." + ) + + start_time = time.perf_counter() + + ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) + + if not ret or frame is None: + raise RuntimeError(f"{self} read_depth failed (status={ret}).") + + depth_frame = frame.get_depth_frame() + depth_map = np.asanyarray(depth_frame.get_data()) + + depth_map_processed = self._postprocess_image(depth_map, depth_frame=True) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return depth_map_processed + + def read( + self, color_mode: ColorMode | None = None, timeout_ms: int = 200 + ) -> np.ndarray: + """ + Reads a single frame (color) synchronously from the camera. + + This is a blocking call. It waits for a coherent set of frames (color) + from the camera hardware via the RealSense pipeline. + + Args: + timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. + + Returns: + np.ndarray: The captured color frame as a NumPy array + (height, width, channels), processed according to `color_mode` and rotation. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading frames from the pipeline fails or frames are invalid. + ValueError: If an invalid `color_mode` is requested. + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + start_time = time.perf_counter() + + ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) + + if not ret or frame is None: + raise RuntimeError(f"{self} read failed (status={ret}).") + + color_frame = frame.get_color_frame() + color_image_raw = np.asanyarray(color_frame.get_data()) + + color_image_processed = self._postprocess_image(color_image_raw, color_mode) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return color_image_processed + + def _postprocess_image( + self, + image: np.ndarray, + color_mode: ColorMode | None = None, + depth_frame: bool = False, + ) -> np.ndarray: + """ + Applies color conversion, dimension validation, and rotation to a raw color frame. + + Args: + image (np.ndarray): The raw image frame (expected RGB format from RealSense). + color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, + uses the instance's default `self.color_mode`. + + Returns: + np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`. + + Raises: + ValueError: If the requested `color_mode` is invalid. + RuntimeError: If the raw frame dimensions do not match the configured + `width` and `height`. + """ + + if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + ) + + if depth_frame: + h, w = image.shape + else: + h, w, c = image.shape + + if c != 3: + raise RuntimeError( + f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR)." + ) + + if h != self.capture_height or w != self.capture_width: + raise RuntimeError( + f"{self} frame width={w} or height={h} do not match configured width={self.capture_width} or height={self.capture_height}." + ) + + processed_image = image + if self.color_mode == ColorMode.BGR: + processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + if self.rotation in [ + cv2.ROTATE_90_CLOCKWISE, + cv2.ROTATE_90_COUNTERCLOCKWISE, + cv2.ROTATE_180, + ]: + processed_image = cv2.rotate(processed_image, self.rotation) + + return processed_image + + def _read_loop(self): + """ + Internal loop run by the background thread for asynchronous reading. + + On each iteration: + 1. Reads a color frame with 500ms timeout + 2. Stores result in latest_frame (thread-safe) + 3. Sets new_frame_event to notify listeners + + Stops on DeviceNotConnectedError, logs other errors and continues. + """ + while not self.stop_event.is_set(): + try: + color_image = self.read(timeout_ms=500) + depth_image = self.read_depth(timeout_ms=500) + + with self.frame_lock: + self.latest_frame = color_image, depth_image + self.new_frame_event.set() + + except DeviceNotConnectedError: + break + except Exception as e: + logger.warning( + f"Error reading frame in background thread for {self}: {e}" + ) + + def _start_read_thread(self) -> None: + """Starts or restarts the background read thread if it's not running.""" + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=0.1) + if self.stop_event is not None: + self.stop_event.set() + + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") + self.thread.daemon = True + self.thread.start() + + def _stop_read_thread(self): + """Signals the background read thread to stop and waits for it to join.""" + if self.stop_event is not None: + self.stop_event.set() + + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + self.thread = None + self.stop_event = None + + # NOTE(Steven): Missing implementation for depth for now + def async_read(self, timeout_ms: float = 200) -> np.ndarray: + """ + Reads the latest available frame data (color) asynchronously. + + This method retrieves the most recent color frame captured by the background + read thread. It does not block waiting for the camera hardware directly, + but may wait up to timeout_ms for the background thread to provide a frame. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available. Defaults to 200ms (0.2 seconds). + + Returns: + np.ndarray: + The latest captured frame data (color image), processed according to configuration. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame data becomes available within the specified timeout. + RuntimeError: If the background thread died unexpectedly or another error occurs. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + self._start_read_thread() + + if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): + thread_alive = self.thread is not None and self.thread.is_alive() + raise TimeoutError( + f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " + f"Read thread alive: {thread_alive}." + ) + + with self.frame_lock: + frame = self.latest_frame + depth = self.latest_frame_depth + acquisition_time = ( + self.latest_frame_acquisition_time + ) # Retrieve the timestamp + self.new_frame_event.clear() + + if frame is None: + raise RuntimeError( + f"Internal error: Event set but no frame available for {self}." + ) + + return frame, depth, acquisition_time + + def disconnect(self): + """ + Disconnects from the camera, stops the pipeline, and cleans up resources. + + Stops the background read thread (if running) and stops the RealSense pipeline. + + Raises: + DeviceNotConnectedError: If the camera is already disconnected (pipeline not running). + """ + + if not self.is_connected and self.thread is None: + raise DeviceNotConnectedError( + f"Attempted to disconnect {self}, but it appears already disconnected." + ) + + if self.thread is not None: + self._stop_read_thread() + + if self.rs_pipeline is not None: + self.rs_pipeline.stop() + self.rs_pipeline = None + self.rs_profile = None + + logger.info(f"{self} disconnected.") diff --git a/alignit/config.py b/alignit/config.py index df3b71c..15b8ff2 100644 --- a/alignit/config.py +++ b/alignit/config.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from typing import Optional, List + import numpy as np @@ -10,7 +11,7 @@ class DatasetConfig: """Configuration for dataset paths and loading.""" path: str = field( - default="./data/duck", metadata={"help": "Path to the dataset directory"} + default="./data/default", metadata={"help": "Path to the dataset directory"} ) @@ -46,6 +47,12 @@ class ModelConfig: default="alignnet_model.pth", metadata={"help": "Path to save/load trained model"}, ) + use_depth_input: bool = field( + default=True, metadata={"help": "Whether to use depth input for the model"} + ) + depth_hidden_dim: int = field( + default=128, metadata={"help": "Output dimension of depth CNN"} + ) @dataclass @@ -98,6 +105,13 @@ class RecordConfig: ang_tol_trajectory: float = field( default=0.05, metadata={"help": "Angular tolerance for trajectory servo"} ) + manual_height: float = field( + default=-0.05, metadata={"help": "Height above surface for manual movement"} + ) + world_z_offset: float = field( + default=-0.02, + metadata={"help": "World frame Z offset after manual positioning"}, + ) @dataclass @@ -106,7 +120,7 @@ class TrainConfig: dataset: DatasetConfig = field(default_factory=DatasetConfig) model: ModelConfig = field(default_factory=ModelConfig) - batch_size: int = field(default=8, metadata={"help": "Training batch size"}) + batch_size: int = field(default=4, metadata={"help": "Training batch size"}) learning_rate: float = field( default=1e-4, metadata={"help": "Learning rate for optimizer"} ) @@ -133,22 +147,31 @@ class InferConfig: metadata={"help": "Starting pose RPY angles"}, ) lin_tolerance: float = field( - default=2e-3, metadata={"help": "Linear tolerance for convergence (meters)"} + default=5e-3, metadata={"help": "Linear tolerance for convergence (meters)"} ) ang_tolerance: float = field( - default=2, metadata={"help": "Angular tolerance for convergence (degrees)"} + default=5, metadata={"help": "Angular tolerance for convergence (degrees)"} ) max_iterations: Optional[int] = field( - default=None, + default=20, metadata={"help": "Maximum iterations before stopping (None = infinite)"}, ) debug_output: bool = field( default=True, metadata={"help": "Print debug information during inference"} ) debouncing_count: int = field( - default=5, + default=20, metadata={"help": "Number of iterations within tolerance before stopping"}, ) + rotation_matrix_multiplier: int = field( + default=3, + metadata={ + "help": "Number of times to multiply the rotation matrix of relative action in order to speed up convergence" + }, + ) + manual_height: float = field( + default=0.08, metadata={"help": "Height above surface for manual movement"} + ) @dataclass diff --git a/alignit/infere.py b/alignit/infere.py index eb5810a..e533162 100644 --- a/alignit/infere.py +++ b/alignit/infere.py @@ -1,11 +1,11 @@ -import transforms3d as t3d -import numpy as np import time -import draccus -from alignit.config import InferConfig import torch +import transforms3d as t3d +import numpy as np +import draccus +from alignit.config import InferConfig from alignit.models.alignnet import AlignNet from alignit.utils.zhou import sixd_se3 from alignit.utils.tfs import print_pose, are_tfs_close @@ -19,7 +19,6 @@ def main(cfg: InferConfig): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") - # load model from file net = AlignNet( backbone_name=cfg.model.backbone, backbone_weights=cfg.model.backbone_weights, @@ -28,6 +27,7 @@ def main(cfg: InferConfig): vector_hidden_dim=cfg.model.vector_hidden_dim, output_dim=cfg.model.output_dim, feature_agg=cfg.model.feature_agg, + use_depth_input=cfg.model.use_depth_input, ) net.load_state_dict(torch.load(cfg.model.path, map_location=device)) net.to(device) @@ -35,42 +35,57 @@ def main(cfg: InferConfig): robot = XarmSim() - # Set initial pose from config start_pose = t3d.affines.compose( [0.23, 0, 0.25], t3d.euler.euler2mat(np.pi, 0, 0), [1, 1, 1] ) robot.servo_to_pose(start_pose, lin_tol=1e-2) - iteration = 0 iterations_within_tolerance = 0 ang_tol_rad = np.deg2rad(cfg.ang_tolerance) - try: while True: observation = robot.get_observation() - images = [observation["camera.rgb"].astype(np.float32) / 255.0] - - # Convert images to tensor and reshape from HWC to CHW format - images_tensor = ( - torch.from_numpy(np.array(images)) - .permute(0, 3, 1, 2) + rgb_image = observation["rgb"].astype(np.float32) / 255.0 + depth_image = observation["depth"].astype(np.float32) + print( + "Min/Max depth,mean (raw):", + observation["depth"].min(), + observation["depth"].max(), + observation["depth"].mean(), + ) + print( + "Min/Max depth,mean (scaled):", + depth_image.min(), + depth_image.max(), + depth_image.mean(), + ) + rgb_image_tensor = ( + torch.from_numpy(np.array(rgb_image)) + .permute(2, 0, 1) # (H, W, C) -> (C, H, W) .unsqueeze(0) .to(device) ) - if cfg.debug_output: - print(f"Max pixel value: {torch.max(images_tensor)}") + depth_image_tensor = ( + torch.from_numpy(np.array(depth_image)) + .unsqueeze(0) # Add channel dimension: (1, H, W) + .unsqueeze(0) # Add batch dimension: (1, 1, H, W) + .to(device) + ) + rgb_images_batch = rgb_image_tensor.unsqueeze(1) + depth_images_batch = depth_image_tensor.unsqueeze(1) - start = time.time() with torch.no_grad(): - relative_action = net(images_tensor) + relative_action = net(rgb_images_batch, depth_images=depth_images_batch) relative_action = relative_action.squeeze(0).cpu().numpy() relative_action = sixd_se3(relative_action) if cfg.debug_output: print_pose(relative_action) - # Check convergence + relative_action[:3, :3] = np.linalg.matrix_power( + relative_action[:3, :3], cfg.rotation_matrix_multiplier + ) if are_tfs_close( relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad ): @@ -78,10 +93,7 @@ def main(cfg: InferConfig): else: iterations_within_tolerance = 0 - if iterations_within_tolerance >= cfg.debouncing_count: - print("Alignment achieved - stopping.") - break - + print(relative_action) target_pose = robot.pose() @ relative_action iteration += 1 action = { @@ -89,14 +101,26 @@ def main(cfg: InferConfig): "gripper.pos": 1.0, } robot.send_action(action) - - # Check max iterations - if cfg.max_iterations and iteration >= cfg.max_iterations: + if iterations_within_tolerance >= cfg.max_iterations: print(f"Reached maximum iterations ({cfg.max_iterations}) - stopping.") + print("Moving robot to final pose.") + current_pose = robot.pose() + gripper_z_offset = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, cfg.manual_height], + [0, 0, 0, 1], + ] + ) + offset_pose = current_pose @ gripper_z_offset + robot.servo_to_pose(pose=offset_pose) + robot.close_gripper() + robot.gripper_off() + break time.sleep(10.0) - except KeyboardInterrupt: print("\nExiting...") diff --git a/alignit/models/alignnet.py b/alignit/models/alignnet.py index 6da726f..8f652bf 100644 --- a/alignit/models/alignnet.py +++ b/alignit/models/alignnet.py @@ -10,8 +10,10 @@ def __init__( backbone_name="efficientnet_b0", backbone_weights="DEFAULT", use_vector_input=True, + use_depth_input=True, fc_layers=[256, 128], vector_hidden_dim=64, + depth_hidden_dim=128, output_dim=7, feature_agg="mean", ): @@ -23,27 +25,39 @@ def __init__( :param vector_hidden_dim: output dim of the vector MLP :param output_dim: final output vector size :param feature_agg: 'mean' or 'max' across image views + :param use_depth_input: whether to accept depth input + :param depth_hidden_dim: output dim of the depth MLP """ super().__init__() self.use_vector_input = use_vector_input + self.use_depth_input = use_depth_input self.feature_agg = feature_agg - # CNN backbone self.backbone, self.image_feature_dim = self._build_backbone( backbone_name, backbone_weights ) - # Linear projection of image features self.image_fc = nn.Sequential( nn.Linear(self.image_feature_dim, fc_layers[0]), nn.ReLU() ) + if use_depth_input: + self.depth_cnn = nn.Sequential( + nn.Conv2d(1, 8, 3, padding=1), + nn.ReLU(), + nn.Conv2d(8, 16, 3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d(1), + ) + self.depth_fc = nn.Sequential(nn.Linear(16, depth_hidden_dim), nn.ReLU()) + input_dim = fc_layers[0] + depth_hidden_dim + else: + input_dim = fc_layers[0] + # Optional vector input processing if use_vector_input: self.vector_fc = nn.Sequential(nn.Linear(1, vector_hidden_dim), nn.ReLU()) - input_dim = fc_layers[0] + vector_hidden_dim - else: - input_dim = fc_layers[0] + input_dim += vector_hidden_dim # Fully connected layers layers = [] @@ -81,10 +95,11 @@ def aggregate_image_features(self, feats): else: raise ValueError("Invalid aggregation type") - def forward(self, rgb_images, vector_inputs=None): + def forward(self, rgb_images, vector_inputs=None, depth_images=None): """ :param rgb_images: Tensor of shape (B, N, 3, H, W) :param vector_inputs: List of tensors of shape (L_i,) or None + :param depth_images: Tensor of shape (B, N, 1, H, W) or None :return: Tensor of shape (B, output_dim) """ B, N, C, H, W = rgb_images.shape @@ -93,6 +108,15 @@ def forward(self, rgb_images, vector_inputs=None): image_feats = self.aggregate_image_features(feats) image_feats = self.image_fc(image_feats) + features = [image_feats] + + if self.use_depth_input and depth_images is not None: + depth = depth_images.view(B * N, 1, H, W) + depth_feats = self.depth_cnn(depth).view(B, N, -1) + depth_feats = self.aggregate_image_features(depth_feats) + depth_feats = self.depth_fc(depth_feats) + features.append(depth_feats) + if self.use_vector_input and vector_inputs is not None: vec_feats = [] for vec in vector_inputs: @@ -100,8 +124,9 @@ def forward(self, rgb_images, vector_inputs=None): pooled = self.vector_fc(vec).mean(dim=0) # (D,) vec_feats.append(pooled) vec_feats = torch.stack(vec_feats, dim=0) - fused = torch.cat([image_feats, vec_feats], dim=1) - else: - fused = image_feats + features.append(vec_feats) + + fused = torch.cat(features, dim=1) + print("Fused shape:", fused.shape) return self.head(fused) # (B, output_dim) diff --git a/alignit/record.py b/alignit/record.py index 93fc7d5..00e897b 100644 --- a/alignit/record.py +++ b/alignit/record.py @@ -82,22 +82,17 @@ def main(cfg: RecordConfig): """Record alignment dataset using configuration parameters.""" robot = XarmSim() features = Features( - {"images": Sequence(Image()), "action": Sequence(Value("float32"))} + { + "images": Sequence(Image()), + "action": Sequence(Value("float32")), + "depth": Sequence(Image()), + } ) for episode in range(cfg.episodes): pose_start, pose_alignment_target = robot.reset() - - robot.servo_to_pose(pose_alignment_target, lin_tol=0.015, ang_tol=0.015) - - robot.servo_to_pose( - pose_alignment_target, - lin_tol=cfg.lin_tol_alignment, - ang_tol=cfg.ang_tol_alignment, - ) - trajectory = generate_spiral_trajectory(pose_start, cfg.trajectory) - + pose = robot.pose() frames = [] for pose in trajectory: robot.servo_to_pose( @@ -109,28 +104,33 @@ def main(cfg: RecordConfig): action_sixd = se3_sixd(action_pose) observation = robot.get_observation() + print(observation.keys()) frame = { - "images": [observation["camera.rgb"].copy()], + "images": [observation["rgb"].copy()], "action": action_sixd, + "depth": [observation["depth"].copy()], } frames.append(frame) - print(f"Episode {episode+1} completed with {len(frames)} frames.") episode_dataset = Dataset.from_list(frames, features=features) - if episode == 0: - combined_dataset = episode_dataset + + # 2. Load existing dataset if available + if os.path.exists(cfg.dataset.path): + existing_dataset = load_from_disk(cfg.dataset.path) + existing_dataset = existing_dataset.cast(features) + combined_dataset = concatenate_datasets([existing_dataset, episode_dataset]) else: - previous_dataset = load_from_disk(cfg.dataset.path) - previous_dataset = previous_dataset.cast(features) - combined_dataset = concatenate_datasets([previous_dataset, episode_dataset]) - del previous_dataset + combined_dataset = episode_dataset + # 3. Save to TEMPORARY location first (avoid self-overwrite) temp_path = f"{cfg.dataset.path}_temp" combined_dataset.save_to_disk(temp_path) + + # 4. Atomic replacement (only after successful save) if os.path.exists(cfg.dataset.path): - shutil.rmtree(cfg.dataset.path) - shutil.move(temp_path, cfg.dataset.path) + shutil.rmtree(cfg.dataset.path) # Remove old version + shutil.move(temp_path, cfg.dataset.path) # Move new version into place robot.disconnect() diff --git a/alignit/robots/xarm.py b/alignit/robots/xarm.py index 075c2cc..39c69fd 100644 --- a/alignit/robots/xarm.py +++ b/alignit/robots/xarm.py @@ -1,13 +1,14 @@ import time +import numpy as np +import transforms3d as t3d from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig from lerobot_xarm.xarm import Xarm as LeXarm from lerobot_xarm.config import XarmConfig -import numpy as np -import transforms3d as t3d from alignit.robots.robot import Robot from alignit.utils.tfs import are_tfs_close +from alignit.config import RecordConfig class Xarm(Robot): @@ -32,11 +33,17 @@ def _connect(self): def send_action(self, action): self.robot.send_action(action) + def get_intrinsics(self): + return self.camera.get_intrinsics() + def get_observation(self): - rgb_image = self.camera.read() + rgb_image, depth_image, acquisition_time = self.camera.async_read() + depth_array_clipped = np.clip(np.array(depth_image), a_min=0, a_max=1000) + depth_image = np.array(depth_array_clipped) / 1000.0 return { - "camera.rgb": rgb_image, + "rgb": rgb_image, + "depth": depth_image, } def disconnect(self): @@ -46,12 +53,33 @@ def servo_to_pose(self, pose, lin_tol=1e-3, ang_tol=1e-2): while not are_tfs_close(self.pose(), pose, lin_tol, ang_tol): action = { "pose": pose, - "gripper.pos": 1.0, # Optional: set gripper state (0.0=closed, 1.0=open) + "gripper.pos": 1.0, } self.send_action(action) - time.sleep(1.0 / 60.0) # Adjust frequency as needed + time.sleep(1.0 / 60.0) + + def close_gripper(self): + action = { + "pose": self.pose(), + "gripper.pos": 0.0, + } + self.send_action(action) - def reset(self): + def open_gripper(self): + action = { + "pose": self.pose(), + "gripper.pos": 1.0, + } + self.send_action(action) + + def gripper_off(self): + action = { + "pose": self.pose(), + "gripper.pos": 0.5, + } + self.send_action(action) + + def reset(self, cfg: RecordConfig): """ Reset routine: 1. Allows manual movement of the arm @@ -64,30 +92,26 @@ def reset(self): manual_height: Height above surface to maintain during manual movement (meters) world_z_offset: Additional Z offset in world frame after manual positioning (meters) """ - manual_height = -0.05 - world_z_offset = -0.02 self.robot.disconnect() input("Press Enter after positioning the arm...") self.robot.connect() current_pose = self.pose() gripper_z_offset = np.array( - [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, manual_height], [0, 0, 0, 1]] + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cfg.manual_height], [0, 0, 0, 1]] ) offset_pose = current_pose @ gripper_z_offset self.servo_to_pose(pose=offset_pose) world_z_offset_mat = np.array( - [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, world_z_offset], [0, 0, 0, 1]] + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, cfg.world_z_offset], [0, 0, 0, 1]] ) final_pose = offset_pose @ world_z_offset_mat self.servo_to_pose(pose=final_pose) - + current_pose = self.pose() pose_start = current_pose @ t3d.affines.compose( - [0, 0, -0.090], t3d.euler.euler2mat(0, 0, 0), [1, 1, 1] - ) - pose_alignment_target = current_pose @ t3d.affines.compose( - [0, 0, -0.1], t3d.euler.euler2mat(0, 0, 0), [1, 1, 1] + [0, 0, -0.01], t3d.euler.euler2mat(0, 0, 0), [1, 1, 1] ) + pose_alignment_target = current_pose _, (position, _, _) = self.robot._arm.get_joint_states() for i in range(6): diff --git a/alignit/robots/xarmsim/__init__.py b/alignit/robots/xarmsim/__init__.py index c64e631..fbad295 100644 --- a/alignit/robots/xarmsim/__init__.py +++ b/alignit/robots/xarmsim/__init__.py @@ -125,10 +125,10 @@ def _set_object_pose(self, object_name, pose_matrix): self.data.qvel[qvel_adr : qvel_adr + 6] = 0 mj.mj_forward(self.model, self.data) - def _gripper_close(self): + def close_gripper(self): self._set_gripper_position(self.gripper_close_pos) - def _gripper_open(self): + def open_gripper(self): self._set_gripper_position(self.gripper_open_pos) def _set_gripper_position(self, pos, tolerance=1e-3, max_sim_steps=2000): @@ -191,7 +191,14 @@ def get_observation(self): name = mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_CAMERA, i) self.renderer.update_scene(self.data, camera=name) image = self.renderer.render() - obs["camera." + name] = image[:, :, ::-1] + self.renderer.enable_depth_rendering() + self.renderer.update_scene(self.data, camera=name) + image_depth = self.renderer.render() + self.renderer.disable_depth_rendering() + + # TODO: Handle multiple cameras + obs["rgb"] = image[:, :, ::-1] + obs["depth"] = image_depth return obs diff --git a/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_narrow.xml b/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_narrow.xml index 2281230..0e37df0 100644 --- a/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_narrow.xml +++ b/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_narrow.xml @@ -104,7 +104,7 @@ diaginertia="0.00016117 0.000118 0.00014455" /> - + diff --git a/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_wide.xml b/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_wide.xml deleted file mode 100644 index d25e408..0000000 --- a/alignit/robots/xarmsim/ufactory_lite6/lite6_gripper_wide.xml +++ /dev/null @@ -1,152 +0,0 @@ - - - - diff --git a/alignit/train.py b/alignit/train.py index 6f05a33..37b4204 100644 --- a/alignit/train.py +++ b/alignit/train.py @@ -6,6 +6,7 @@ from datasets import load_from_disk from torchvision import transforms import draccus +import numpy as np from alignit.config import TrainConfig from alignit.models.alignnet import AlignNet @@ -13,19 +14,20 @@ def collate_fn(batch): images = [item["images"] for item in batch] + depth_images = [item.get("depth", None) for item in batch] actions = [item["action"] for item in batch] - return {"images": images, "action": torch.tensor(actions, dtype=torch.float32)} + return { + "images": images, + "depth_images": depth_images, + "action": torch.tensor(actions, dtype=torch.float32), + } @draccus.wrap() def main(cfg: TrainConfig): """Train AlignNet model using configuration parameters.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # Load the dataset from disk dataset = load_from_disk(cfg.dataset.path) - - # Create model using config parameters net = AlignNet( backbone_name=cfg.model.backbone, backbone_weights=cfg.model.backbone_weights, @@ -34,14 +36,12 @@ def main(cfg: TrainConfig): vector_hidden_dim=cfg.model.vector_hidden_dim, output_dim=cfg.model.output_dim, feature_agg=cfg.model.feature_agg, + use_depth_input=cfg.model.use_depth_input, ).to(device) - # Split dataset train_dataset = dataset.train_test_split( test_size=cfg.test_size, seed=cfg.random_seed ) - - # Create data loader train_loader = DataLoader( train_dataset["train"], batch_size=cfg.batch_size, @@ -56,31 +56,58 @@ def main(cfg: TrainConfig): for epoch in range(cfg.epochs): for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"): images = batch["images"] + depth_images_pil = batch["depth_images"] actions = batch["action"].to(device) - # Convert PIL Images to tensors and stack them properly - # images is a list of lists of PIL Images - batch_images = [] - transform = transforms.Compose([transforms.ToTensor()]) + batch_rgb_tensors = [] + rgb_transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) for image_sequence in images: - tensor_sequence = [ - transform(img.convert("RGB")) for img in image_sequence + tensor_sequence_rgb = [ + rgb_transform(img.convert("RGB")) for img in image_sequence ] - stacked_tensors = torch.stack(tensor_sequence, dim=0) - batch_images.append(stacked_tensors) + stacked_tensors_rgb = torch.stack(tensor_sequence_rgb, dim=0) + batch_rgb_tensors.append(stacked_tensors_rgb) + + batch_rgb_tensors = torch.stack(batch_rgb_tensors, dim=0).to(device) - # Stack all batches to get shape (B, N, 3, H, W) - batch_images = torch.stack(batch_images, dim=0).to(device) + batch_depth_tensors = None + if cfg.model.use_depth_input: + batch_depth_tensors = [] + for depth_sequence in depth_images_pil: + if depth_sequence is None: + raise ValueError( + "Depth images expected but not found when use_depth_input=True" + ) + + depth_sequence_processed = [] + for d_img in depth_sequence: + depth_array = np.array(d_img) + depth_tensor = torch.from_numpy(depth_array).float() + depth_tensor = depth_tensor.unsqueeze(0) + depth_sequence_processed.append(depth_tensor) + + stacked_depth = torch.stack(depth_sequence_processed, dim=0) + batch_depth_tensors.append(stacked_depth) + + batch_depth_tensors = torch.stack(batch_depth_tensors, dim=0).to(device) optimizer.zero_grad() - outputs = net(batch_images) + if cfg.model.use_depth_input: + outputs = net(batch_rgb_tensors, depth_images=batch_depth_tensors) + else: + outputs = net(batch_rgb_tensors) + loss = criterion(outputs, actions) loss.backward() optimizer.step() + tqdm.write(f"Loss: {loss.item():.4f}") - # Save the trained model torch.save(net.state_dict(), cfg.model.path) tqdm.write(f"Model saved as {cfg.model.path}") diff --git a/alignit/visualize.py b/alignit/visualize.py index 86532e0..60d83b8 100644 --- a/alignit/visualize.py +++ b/alignit/visualize.py @@ -1,6 +1,5 @@ import gradio as gr import draccus - from alignit.utils.dataset import load_dataset from alignit.utils.zhou import sixd_se3 from alignit.utils.tfs import get_pose_str @@ -14,15 +13,20 @@ def visualize(cfg: VisualizeConfig): def get_data(index): item = dataset[index] image = item["images"][0] + depth = item["depth"][0] action_sixd = item["action"] action = sixd_se3(action_sixd) label = get_pose_str(action, degrees=True) - return image, label + return image, depth, label gr.Interface( fn=get_data, inputs=gr.Slider(0, len(dataset) - 1, step=1, label="Index", interactive=True), - outputs=[gr.Image(type="pil", label="Image"), gr.Text(label="Label")], + outputs=[ + gr.Image(type="pil", label="Image"), + gr.Image(type="pil", label="Depth Image"), + gr.Text(label="Label"), + ], title="Dataset Image Viewer", live=True, ).launch(share=cfg.share, server_name=cfg.server_name, server_port=cfg.server_port) diff --git a/benchmark/apriltag_benchmark.py b/benchmark/apriltag_benchmark.py new file mode 100644 index 0000000..9010310 --- /dev/null +++ b/benchmark/apriltag_benchmark.py @@ -0,0 +1,113 @@ +import cv2 +import numpy as np +from pupil_apriltags import Detector +from scipy.spatial.transform import Rotation as R + +from alignit.robots.xarm import Xarm + + +class AprilTagBenchmark: + def __init__(self, tag_size=0.06): + camera_params = robot.get_intrinsics() + self.detector = Detector(families="tag36h11") + self.tag_size = tag_size + self.camera_params = camera_params + + def detect_pose(self, image): + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + tags = self.detector.detect( + gray, + estimate_tag_pose=True, + camera_params=self.camera_params, + tag_size=self.tag_size, + ) + + if len(tags) == 0: + print("No tags detected") + return None + + tag = tags[0] + pose = np.eye(4) + pose[:3, :3] = tag.pose_R + pose[:3, 3] = tag.pose_t.flatten() + return pose + + def pose_difference(self, T_ideal, T_current, max_pos_distance=1.0): + + try: + # Convert and validate inputs + T1 = np.array(T_ideal, dtype=np.float32) + T2 = np.array(T_current, dtype=np.float32) + + if T1.shape != (4, 4) or T2.shape != (4, 4): + raise ValueError("Input matrices must be 4x4") + + # Extract components + p1 = T1[:3, 3] * 1000 # Ideal position + p2 = T2[:3, 3] * 1000 # Current position + R1 = T1[:3, :3] # Ideal rotation + R2 = T2[:3, :3] # Current rotation + + # Position difference (Euclidean distance) + pos_distance = np.linalg.norm(p2 - p1) # Convert to millimeters + pos_diff_pct = min(pos_distance / max_pos_distance, 1.0) * 100 + + # Rotation difference (angle in radians) + rel_rot = R1.T @ R2 + angle_rad = np.arccos(np.clip((np.trace(rel_rot) - 1) / 2, -1, 1)) + rot_diff_pct = (angle_rad / np.pi) * 100 + + # Euler angle differences + rpy1 = R.from_matrix(R1).as_euler("xyz") + rpy2 = R.from_matrix(R2).as_euler("xyz") + rpy_diff = np.abs(rpy2 - rpy1) + + return { + "position_diff_milimeters": pos_distance, + "position_diff%": pos_diff_pct, + "rotation_diff_rad": angle_rad, + "rotation_diff%": rot_diff_pct, + "rpy_diff_rad": rpy_diff, + "xyz_diff_milimeters": (p2 - p1).tolist(), + "combined_diff%": 0.5 * pos_diff_pct + 0.5 * rot_diff_pct, + } + + except Exception as e: + print(f"Pose comparison failed: {str(e)}") + return None + + +def pose_in_tag_frame(tag_pose_world, robot_pose_world): + """Convert robot pose from world frame to AprilTag's frame.""" + tag_pose_inv = np.linalg.inv(tag_pose_world) + return tag_pose_inv @ robot_pose_world + + +if __name__ == "__main__": + robot = Xarm() + detector = AprilTagBenchmark() + while True: + observation = robot.get_observation() + curr_pose = robot.pose() + rgb_image = observation["rgb"] + apriltag_pose = detector.detect_pose(rgb_image) + T_ideal_tag = np.array( + [ + [0.92529235, 0.37684695, 0.04266663, 0.25978898], + [0.37511194, -0.89279238, -0.24942494, 0.06284318], + [-0.05590259, 0.24679575, -0.96745375, -0.25993736], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, + ) + + T_after_inference = pose_in_tag_frame(apriltag_pose, curr_pose) + result = detector.pose_difference(T_ideal_tag, T_after_inference) + if result: + print("xyz_diff_milimeters:", result["xyz_diff_milimeters"]) + print("position_diff_milimeters:", result["position_diff_milimeters"]) + print("rotation_diff_rad:", result["rotation_diff_rad"]) + print("rpy_diff_rad:", result["rpy_diff_rad"]) + print("combined_diff%:", result["combined_diff%"]) + else: + print("Error: Failed to get current pose estimate") diff --git a/tests/test_alignnet.py b/tests/test_alignnet.py index 77bcddb..94c4274 100644 --- a/tests/test_alignnet.py +++ b/tests/test_alignnet.py @@ -9,6 +9,7 @@ def test_alignnet_forward_shapes_cpu(): backbone_weights=None, use_vector_input=False, output_dim=7, + use_depth_input=False, ) model.eval() x = torch.randn(2, 3, 3, 64, 64) # B=2, N=3 views @@ -23,6 +24,7 @@ def test_alignnet_with_vector_input(): backbone_weights=None, use_vector_input=True, output_dim=7, + use_depth_input=False, ) model.eval() x = torch.randn(1, 2, 3, 64, 64) @@ -38,6 +40,7 @@ def test_alignnet_performance(): backbone_weights=None, use_vector_input=True, output_dim=7, + use_depth_input=False, ) model.eval() x = torch.randn(1, 3, 3, 224, 224) # B=1, N=3 views @@ -53,3 +56,37 @@ def test_alignnet_performance(): print(f"Forward pass took {elapsed_time_ms:.2f} ms") assert elapsed_time < 0.5 + + +def test_alignnet_with_depth_input(): + model = AlignNet( + backbone_name="resnet18", + backbone_weights=None, + use_vector_input=False, + use_depth_input=True, + output_dim=7, + ) + model.eval() + x = torch.randn(2, 3, 3, 64, 64) # RGB images + depth = torch.randn(2, 3, 1, 64, 64) # Depth images + with torch.no_grad(): + y = model(x, depth_images=depth) + assert y.shape == (2, 7) + + +def test_alignnet_with_all_inputs(): + # New test for combined inputs + model = AlignNet( + backbone_name="efficientnet_b0", + backbone_weights=None, + use_vector_input=True, + use_depth_input=True, + output_dim=7, + ) + model.eval() + x = torch.randn(1, 3, 3, 224, 224) # RGB images + depth = torch.randn(1, 3, 1, 224, 224) # Depth images + vecs = [torch.randn(5)] # Vector inputs + with torch.no_grad(): + y = model(x, vecs, depth) + assert y.shape == (1, 7)