diff --git a/embodichain/lab/sim/planners/base_planner.py b/embodichain/lab/sim/planners/base_planner.py index 5ace6c6d..ed70a3f9 100644 --- a/embodichain/lab/sim/planners/base_planner.py +++ b/embodichain/lab/sim/planners/base_planner.py @@ -17,10 +17,12 @@ import numpy as np from abc import ABC, abstractmethod from typing import Dict, List, Tuple, Union +import torch import matplotlib.pyplot as plt from embodichain.lab.sim.planners.utils import TrajectorySampleMethod from embodichain.utils import logger +from embodichain.lab.sim.planners.utils import PlanState class BasePlanner(ABC): @@ -34,22 +36,23 @@ class BasePlanner(ABC): max_constraints: Dictionary containing 'velocity' and 'acceleration' constraints """ - def __init__(self, dofs: int, max_constraints: Dict[str, List[float]]): - self.dofs = dofs - self.max_constraints = max_constraints + def __init__(self, **kwargs): + self.dofs = kwargs.get("dofs", None) + self.max_constraints = kwargs.get("max_constraints", None) + self.device = kwargs.get("device", torch.device("cpu")) @abstractmethod def plan( self, - current_state: Dict, - target_states: List[Dict], + current_state: PlanState, + target_states: list[PlanState], **kwargs, ) -> Tuple[ bool, - np.ndarray | None, - np.ndarray | None, - np.ndarray | None, - np.ndarray | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, float, ]: r"""Execute trajectory planning. @@ -64,10 +67,10 @@ def plan( Returns: Tuple of (success, positions, velocities, accelerations, times, duration): - success: bool, whether planning succeeded - - positions: np.ndarray (N, DOF), joint positions along trajectory - - velocities: np.ndarray (N, DOF), joint velocities along trajectory - - accelerations: np.ndarray (N, DOF), joint accelerations along trajectory - - times: np.ndarray (N,), time stamps for each point + - positions: torch.Tensor (N, DOF), joint positions along trajectory + - velocities: torch.Tensor (N, DOF), joint velocities along trajectory + - accelerations: torch.Tensor (N, DOF), joint accelerations along trajectory + - times: torch.Tensor (N,), time stamps for each point - duration: float, total trajectory duration """ logger.log_error("Subclasses must implement plan() method", NotImplementedError) diff --git a/embodichain/lab/sim/planners/motion_generator.py b/embodichain/lab/sim/planners/motion_generator.py index 156792cf..09668938 100644 --- a/embodichain/lab/sim/planners/motion_generator.py +++ b/embodichain/lab/sim/planners/motion_generator.py @@ -24,13 +24,7 @@ from embodichain.lab.sim.planners.utils import TrajectorySampleMethod from embodichain.lab.sim.objects.robot import Robot from embodichain.utils import logger - - -class PlannerType(Enum): - r"""Enumeration for different planner types.""" - - TOPPRA = "toppra" - """TOPPRA planner for time-optimal trajectory planning.""" +from embodichain.lab.sim.planners.utils import PlanState, MoveType, MovePart class MotionGenerator: @@ -51,12 +45,23 @@ class MotionGenerator: **kwargs: Additional arguments passed to planner initialization """ + _support_planner_dict = { + "toppra": ToppraPlanner, + } + + @classmethod + def register_planner_type(cls, name: str, planner_class): + """ + Register a new planner type. + """ + cls._support_planner_dict[name] = planner_class + def __init__( self, robot: Robot, uid: str, sim=None, - planner_type: Union[str, PlannerType] = "toppra", + planner_type: str = "toppra", default_velocity: float = 0.2, default_acceleration: float = 0.5, collision_margin: float = 0.01, @@ -65,45 +70,16 @@ def __init__( self.robot = robot self.sim = sim self.collision_margin = collision_margin - self.uid = uid + self.uid = uid # control part # Get robot DOF using get_joint_ids for specified control part (None for whole body) self.dof = len(robot.get_joint_ids(uid)) # Create planner based on planner_type - self.planner_type = self._parse_planner_type(planner_type) self.planner = self._create_planner( - self.planner_type, default_velocity, default_acceleration, **kwargs + planner_type, default_velocity, default_acceleration, **kwargs ) - def _parse_planner_type(self, planner_type: Union[str, PlannerType]) -> str: - r"""Parse planner type from string or enum. - - Args: - planner_type: Planner type as string or PlannerType enum - - Returns: - Planner type as string - """ - if isinstance(planner_type, PlannerType): - return planner_type.value - elif isinstance(planner_type, str): - planner_type_lower = planner_type.lower() - # Validate planner type - valid_types = [e.value for e in PlannerType] - if planner_type_lower not in valid_types: - logger.log_warning( - f"Unknown planner type '{planner_type}', using 'toppra'. " - f"Valid types: {valid_types}" - ) - return "toppra" - return planner_type_lower - else: - logger.log_error( - f"planner_type must be str or PlannerType, got {type(planner_type)}", - TypeError, - ) - def _create_planner( self, planner_type: str, @@ -123,18 +99,20 @@ def _create_planner( Planner instance """ # Get constraints from robot or use defaults - max_constraints = self._get_constraints( - default_velocity, default_acceleration, **kwargs - ) - - if planner_type == "toppra": - return ToppraPlanner(self.dof, max_constraints) - else: + planner_class = self._support_planner_dict.get(planner_type, None) + if planner_class is None: logger.log_error( - f"Unknown planner type '{planner_type}'. " - f"Supported types: {[e.value for e in PlannerType]}", + f"Unsupported planner type '{planner_type}'. " + f"Supported types: {[e for e in self._support_planner_dict.keys()]}", ValueError, ) + cfg = kwargs.copy() + cfg["dofs"] = self.dof + cfg["max_constraints"] = self._get_constraints( + default_velocity, default_acceleration, **kwargs + ) + cfg["robot"] = self.robot + return planner_class(**cfg) def _get_constraints( self, default_velocity: float, default_acceleration: float, **kwargs @@ -200,17 +178,17 @@ def _create_state_dict( def plan( self, - current_state: Dict, - target_states: List[Dict], + current_state: PlanState, + target_states: List[PlanState], sample_method: TrajectorySampleMethod = TrajectorySampleMethod.TIME, sample_interval: Union[float, int] = 0.01, **kwargs, ) -> Tuple[ bool, - np.ndarray | None, - np.ndarray | None, - np.ndarray | None, - np.ndarray | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, float, ]: r"""Plan trajectory without collision checking. @@ -219,11 +197,8 @@ def plan( velocity and acceleration constraints, but does not check for collisions. Args: - current_state: Dictionary containing current state: - - "position": Current joint positions (required) - - "velocity": Current joint velocities (optional, defaults to zeros) - - "acceleration": Current joint accelerations (optional, defaults to zeros) - target_states: List of target state dictionaries, each with same format as current_state + current_state: PlanState + target_states: List of PlanState sample_method: Sampling method (TIME or QUANTITY) sample_interval: Sampling interval (time in seconds for TIME method, or number of points for QUANTITY) **kwargs: Additional arguments @@ -231,28 +206,12 @@ def plan( Returns: Tuple of (success, positions, velocities, accelerations, times, duration): - success: bool, whether planning succeeded - - positions: np.ndarray (N, DOF), joint positions along trajectory - - velocities: np.ndarray (N, DOF), joint velocities along trajectory - - accelerations: np.ndarray (N, DOF), joint accelerations along trajectory - - times: np.ndarray (N,), time stamps for each point + - positions: torch.Tensor (N, DOF), joint positions along trajectory + - velocities: torch.Tensor (N, DOF), joint velocities along trajectory + - accelerations: torch.Tensor (N, DOF), joint accelerations along trajectory + - times: torch.Tensor (N,), time stamps for each point - duration: float, total trajectory duration """ - # Validate inputs - if len(current_state["position"]) != self.dof: - logger.log_warning( - f"Current state position dimension {len(current_state['position'])} " - f"does not match robot DOF {self.dof}" - ) - return False, None, None, None, None, 0.0 - - for i, target in enumerate(target_states): - if len(target["position"]) != self.dof: - logger.log_warning( - f"Target state {i} position dimension {len(target['position'])} " - f"does not match robot DOF {self.dof}" - ) - return False, None, None, None, None, 0.0 - # Plan trajectory using selected planner ( success, @@ -487,10 +446,24 @@ def calculate_point_allocations( self._create_state_dict(pos) for pos in interpolate_qpos_list[1:] ] + init_plan_state = PlanState( + move_type=MoveType.JOINT_MOVE, + move_part=MovePart.ALL, + qpos=current_state["position"], + qvel=current_state["velocity"], + qacc=current_state["acceleration"], + ) + target_plan_states = [] + for state in target_states: + plan_state = PlanState( + move_type=MoveType.JOINT_MOVE, qpos=state["position"] + ) + target_plan_states.append(plan_state) + # Plan trajectory using internal plan method success, positions, velocities, accelerations, times, duration = self.plan( - current_state=current_state, - target_states=target_states, + current_state=init_plan_state, + target_states=target_plan_states, sample_method=sample_method, sample_interval=sample_num, **kwargs, @@ -500,10 +473,7 @@ def calculate_point_allocations( logger.log_error("Failed to plan trajectory") # Convert positions to list - out_qpos_list = ( - positions.tolist() if isinstance(positions, np.ndarray) else positions - ) - + out_qpos_list = positions.to("cpu").numpy().tolist() out_qpos_list = ( torch.tensor(out_qpos_list) if not isinstance(out_qpos_list, torch.Tensor) diff --git a/embodichain/lab/sim/planners/toppra_planner.py b/embodichain/lab/sim/planners/toppra_planner.py index 5f2d0c09..47f3ecf1 100644 --- a/embodichain/lab/sim/planners/toppra_planner.py +++ b/embodichain/lab/sim/planners/toppra_planner.py @@ -18,6 +18,8 @@ from embodichain.utils import logger from embodichain.lab.sim.planners.utils import TrajectorySampleMethod from embodichain.lab.sim.planners.base_planner import BasePlanner +from embodichain.lab.sim.planners.utils import PlanState +import torch from typing import TYPE_CHECKING, Union, Tuple @@ -33,24 +35,25 @@ class ToppraPlanner(BasePlanner): - def __init__(self, dofs, max_constraints): + def __init__(self, **kwargs): r"""Initialize the TOPPRA trajectory planner. Args: dofs: Number of degrees of freedom max_constraints: Dictionary containing 'velocity' and 'acceleration' constraints """ - super().__init__(dofs, max_constraints) + super().__init__(**kwargs) # Create TOPPRA-specific constraint arrays (symmetric format) # This format is required by TOPPRA library + max_constraints = kwargs.get("max_constraints", None) self.vlims = np.array([[-v, v] for v in max_constraints["velocity"]]) self.alims = np.array([[-a, a] for a in max_constraints["acceleration"]]) def plan( self, - current_state: dict, - target_states: list[dict], + current_state: PlanState, + target_states: list[PlanState], **kwargs, ): r"""Execute trajectory planning. @@ -75,28 +78,25 @@ def plan( logger.log_error("At least 2 sample points required", ValueError) # Check waypoints - if len(current_state["position"]) != self.dofs: + if len(current_state.qpos) != self.dofs: logger.log_info("Current wayponit does not align") return False, None, None, None, None, None for target in target_states: - if len(target["position"]) != self.dofs: + if len(target.qpos) != self.dofs: logger.log_info("Target Wayponits does not align") return False, None, None, None, None, None if ( len(target_states) == 1 and np.sum( - np.abs( - np.array(target_states[0]["position"]) - - np.array(current_state["position"]) - ) + np.abs(np.array(target_states[0].qpos) - np.array(current_state.qpos)) ) < 1e-3 ): logger.log_info("Only two same waypoints, do not plan") return ( True, - np.array([current_state["position"], target_states[0]["position"]]), + np.array([current_state.qpos, target_states[0].qpos]), np.array([[0.0] * self.dofs, [0.0] * self.dofs]), np.array([[0.0] * self.dofs, [0.0] * self.dofs]), 0, @@ -104,11 +104,10 @@ def plan( ) # Build waypoints - waypoints = [np.array(current_state["position"])] + waypoints = [np.array(current_state.qpos)] for target in target_states: - waypoints.append(np.array(target["position"])) + waypoints.append(np.array(target.qpos)) waypoints = np.array(waypoints) - # Create spline interpolation # NOTE: Suitable for dense waypoints ss = np.linspace(0, 1, len(waypoints)) @@ -164,9 +163,11 @@ def plan( return ( True, - np.array(positions), - np.array(velocities), - np.array(accelerations), - ts, + torch.tensor(np.array(positions), dtype=torch.float32, device=self.device), + torch.tensor(np.array(velocities), dtype=torch.float32, device=self.device), + torch.tensor( + np.array(accelerations), dtype=torch.float32, device=self.device + ), + torch.tensor(ts, dtype=torch.float32, device=self.device), duration, ) diff --git a/embodichain/lab/sim/planners/utils.py b/embodichain/lab/sim/planners/utils.py index 9b31685f..2b5cd77a 100644 --- a/embodichain/lab/sim/planners/utils.py +++ b/embodichain/lab/sim/planners/utils.py @@ -17,6 +17,9 @@ from enum import Enum from typing import Union from embodichain.utils import logger +import torch +from enum import Enum +from dataclasses import dataclass class TrajectorySampleMethod(Enum): @@ -53,3 +56,32 @@ def from_str( def __str__(self): """Override string representation for better readability.""" return self.value.capitalize() + + +class MovePart(Enum): + LEFT = 0 + RIGHT = 1 + BOTH = 2 + TORSO = 3 + ALL = 4 + + +class MoveType(Enum): + TOOL = 0 + TCP_MOVE = 1 + JOINT_MOVE = 2 + SYNC = 3 + PAUSE = 4 + + +@dataclass +class PlanState: + move_type: MoveType = MoveType.PAUSE + move_part: MovePart = MovePart.LEFT + xpos: torch.Tensor = None + qpos: torch.Tensor = None + qacc: torch.Tensor = None + qvel: torch.Tensor = None + is_open: bool = True + is_world_coordinate: bool = True + pause_seconds: float = 0.0