Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
583 changes: 583 additions & 0 deletions alignit/cameras/realsense.py

Large diffs are not rendered by default.

35 changes: 29 additions & 6 deletions alignit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dataclasses import dataclass, field
from typing import Optional, List

import numpy as np


Expand All @@ -10,7 +11,7 @@ class DatasetConfig:
"""Configuration for dataset paths and loading."""

path: str = field(
default="./data/duck", metadata={"help": "Path to the dataset directory"}
default="./data/default", metadata={"help": "Path to the dataset directory"}
)


Expand Down Expand Up @@ -46,6 +47,12 @@ class ModelConfig:
default="alignnet_model.pth",
metadata={"help": "Path to save/load trained model"},
)
use_depth_input: bool = field(
default=True, metadata={"help": "Whether to use depth input for the model"}
)
depth_hidden_dim: int = field(
default=128, metadata={"help": "Output dimension of depth CNN"}
)


@dataclass
Expand Down Expand Up @@ -98,6 +105,13 @@ class RecordConfig:
ang_tol_trajectory: float = field(
default=0.05, metadata={"help": "Angular tolerance for trajectory servo"}
)
manual_height: float = field(
default=-0.05, metadata={"help": "Height above surface for manual movement"}
)
world_z_offset: float = field(
default=-0.02,
metadata={"help": "World frame Z offset after manual positioning"},
)


@dataclass
Expand All @@ -106,7 +120,7 @@ class TrainConfig:

dataset: DatasetConfig = field(default_factory=DatasetConfig)
model: ModelConfig = field(default_factory=ModelConfig)
batch_size: int = field(default=8, metadata={"help": "Training batch size"})
batch_size: int = field(default=4, metadata={"help": "Training batch size"})
learning_rate: float = field(
default=1e-4, metadata={"help": "Learning rate for optimizer"}
)
Expand All @@ -133,22 +147,31 @@ class InferConfig:
metadata={"help": "Starting pose RPY angles"},
)
lin_tolerance: float = field(
default=2e-3, metadata={"help": "Linear tolerance for convergence (meters)"}
default=5e-3, metadata={"help": "Linear tolerance for convergence (meters)"}
)
ang_tolerance: float = field(
default=2, metadata={"help": "Angular tolerance for convergence (degrees)"}
default=5, metadata={"help": "Angular tolerance for convergence (degrees)"}
)
max_iterations: Optional[int] = field(
default=None,
default=20,
metadata={"help": "Maximum iterations before stopping (None = infinite)"},
)
debug_output: bool = field(
default=True, metadata={"help": "Print debug information during inference"}
)
debouncing_count: int = field(
default=5,
default=20,
metadata={"help": "Number of iterations within tolerance before stopping"},
)
rotation_matrix_multiplier: int = field(
default=3,
metadata={
"help": "Number of times to multiply the rotation matrix of relative action in order to speed up convergence"
},
)
manual_height: float = field(
default=0.08, metadata={"help": "Height above surface for manual movement"}
)


@dataclass
Expand Down
78 changes: 51 additions & 27 deletions alignit/infere.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import transforms3d as t3d
import numpy as np
import time
import draccus
from alignit.config import InferConfig

import torch
import transforms3d as t3d
import numpy as np
import draccus

from alignit.config import InferConfig
from alignit.models.alignnet import AlignNet
from alignit.utils.zhou import sixd_se3
from alignit.utils.tfs import print_pose, are_tfs_close
Expand All @@ -19,7 +19,6 @@ def main(cfg: InferConfig):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# load model from file
net = AlignNet(
backbone_name=cfg.model.backbone,
backbone_weights=cfg.model.backbone_weights,
Expand All @@ -28,75 +27,100 @@ def main(cfg: InferConfig):
vector_hidden_dim=cfg.model.vector_hidden_dim,
output_dim=cfg.model.output_dim,
feature_agg=cfg.model.feature_agg,
use_depth_input=cfg.model.use_depth_input,
)
net.load_state_dict(torch.load(cfg.model.path, map_location=device))
net.to(device)
net.eval()

robot = XarmSim()

# Set initial pose from config
start_pose = t3d.affines.compose(
[0.23, 0, 0.25], t3d.euler.euler2mat(np.pi, 0, 0), [1, 1, 1]
)
robot.servo_to_pose(start_pose, lin_tol=1e-2)

iteration = 0
iterations_within_tolerance = 0
ang_tol_rad = np.deg2rad(cfg.ang_tolerance)

try:
while True:
observation = robot.get_observation()
images = [observation["camera.rgb"].astype(np.float32) / 255.0]

# Convert images to tensor and reshape from HWC to CHW format
images_tensor = (
torch.from_numpy(np.array(images))
.permute(0, 3, 1, 2)
rgb_image = observation["rgb"].astype(np.float32) / 255.0
depth_image = observation["depth"].astype(np.float32)
print(
"Min/Max depth,mean (raw):",
observation["depth"].min(),
observation["depth"].max(),
observation["depth"].mean(),
)
print(
"Min/Max depth,mean (scaled):",
depth_image.min(),
depth_image.max(),
depth_image.mean(),
)
rgb_image_tensor = (
torch.from_numpy(np.array(rgb_image))
.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
.unsqueeze(0)
.to(device)
)

if cfg.debug_output:
print(f"Max pixel value: {torch.max(images_tensor)}")
depth_image_tensor = (
torch.from_numpy(np.array(depth_image))
.unsqueeze(0) # Add channel dimension: (1, H, W)
.unsqueeze(0) # Add batch dimension: (1, 1, H, W)
.to(device)
)
rgb_images_batch = rgb_image_tensor.unsqueeze(1)
depth_images_batch = depth_image_tensor.unsqueeze(1)

start = time.time()
with torch.no_grad():
relative_action = net(images_tensor)
relative_action = net(rgb_images_batch, depth_images=depth_images_batch)
relative_action = relative_action.squeeze(0).cpu().numpy()
relative_action = sixd_se3(relative_action)

if cfg.debug_output:
print_pose(relative_action)

# Check convergence
relative_action[:3, :3] = np.linalg.matrix_power(
relative_action[:3, :3], cfg.rotation_matrix_multiplier
)
if are_tfs_close(
relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad
):
iterations_within_tolerance += 1
else:
iterations_within_tolerance = 0

if iterations_within_tolerance >= cfg.debouncing_count:
print("Alignment achieved - stopping.")
break

print(relative_action)
target_pose = robot.pose() @ relative_action
iteration += 1
action = {
"pose": target_pose,
"gripper.pos": 1.0,
}
robot.send_action(action)

# Check max iterations
if cfg.max_iterations and iteration >= cfg.max_iterations:
if iterations_within_tolerance >= cfg.max_iterations:
print(f"Reached maximum iterations ({cfg.max_iterations}) - stopping.")
print("Moving robot to final pose.")
current_pose = robot.pose()
gripper_z_offset = np.array(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, cfg.manual_height],
[0, 0, 0, 1],
]
)
offset_pose = current_pose @ gripper_z_offset
robot.servo_to_pose(pose=offset_pose)
robot.close_gripper()
robot.gripper_off()

break

time.sleep(10.0)

except KeyboardInterrupt:
print("\nExiting...")

Expand Down
43 changes: 34 additions & 9 deletions alignit/models/alignnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ def __init__(
backbone_name="efficientnet_b0",
backbone_weights="DEFAULT",
use_vector_input=True,
use_depth_input=True,
fc_layers=[256, 128],
vector_hidden_dim=64,
depth_hidden_dim=128,
output_dim=7,
feature_agg="mean",
):
Expand All @@ -23,27 +25,39 @@ def __init__(
:param vector_hidden_dim: output dim of the vector MLP
:param output_dim: final output vector size
:param feature_agg: 'mean' or 'max' across image views
:param use_depth_input: whether to accept depth input
:param depth_hidden_dim: output dim of the depth MLP
"""
super().__init__()
self.use_vector_input = use_vector_input
self.use_depth_input = use_depth_input
self.feature_agg = feature_agg

# CNN backbone
self.backbone, self.image_feature_dim = self._build_backbone(
backbone_name, backbone_weights
)

# Linear projection of image features
self.image_fc = nn.Sequential(
nn.Linear(self.image_feature_dim, fc_layers[0]), nn.ReLU()
)

if use_depth_input:
self.depth_cnn = nn.Sequential(
nn.Conv2d(1, 8, 3, padding=1),
nn.ReLU(),
nn.Conv2d(8, 16, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
)
self.depth_fc = nn.Sequential(nn.Linear(16, depth_hidden_dim), nn.ReLU())
input_dim = fc_layers[0] + depth_hidden_dim
else:
input_dim = fc_layers[0]

# Optional vector input processing
if use_vector_input:
self.vector_fc = nn.Sequential(nn.Linear(1, vector_hidden_dim), nn.ReLU())
input_dim = fc_layers[0] + vector_hidden_dim
else:
input_dim = fc_layers[0]
input_dim += vector_hidden_dim

# Fully connected layers
layers = []
Expand Down Expand Up @@ -81,10 +95,11 @@ def aggregate_image_features(self, feats):
else:
raise ValueError("Invalid aggregation type")

def forward(self, rgb_images, vector_inputs=None):
def forward(self, rgb_images, vector_inputs=None, depth_images=None):
"""
:param rgb_images: Tensor of shape (B, N, 3, H, W)
:param vector_inputs: List of tensors of shape (L_i,) or None
:param depth_images: Tensor of shape (B, N, 1, H, W) or None
:return: Tensor of shape (B, output_dim)
"""
B, N, C, H, W = rgb_images.shape
Expand All @@ -93,15 +108,25 @@ def forward(self, rgb_images, vector_inputs=None):
image_feats = self.aggregate_image_features(feats)
image_feats = self.image_fc(image_feats)

features = [image_feats]

if self.use_depth_input and depth_images is not None:
depth = depth_images.view(B * N, 1, H, W)
depth_feats = self.depth_cnn(depth).view(B, N, -1)
depth_feats = self.aggregate_image_features(depth_feats)
depth_feats = self.depth_fc(depth_feats)
features.append(depth_feats)

if self.use_vector_input and vector_inputs is not None:
vec_feats = []
for vec in vector_inputs:
vec = vec.unsqueeze(1) # (L, 1)
pooled = self.vector_fc(vec).mean(dim=0) # (D,)
vec_feats.append(pooled)
vec_feats = torch.stack(vec_feats, dim=0)
fused = torch.cat([image_feats, vec_feats], dim=1)
else:
fused = image_feats
features.append(vec_feats)

fused = torch.cat(features, dim=1)
print("Fused shape:", fused.shape)

return self.head(fused) # (B, output_dim)
Loading