From 1a16fbb2384917482cbea6a5448130f334e17392 Mon Sep 17 00:00:00 2001 From: yuanhaonan Date: Fri, 6 Feb 2026 12:09:36 +0800 Subject: [PATCH 01/14] async collect buffer for VLA RL --- .../rl/basic/cart_pole/train_config.json | 2 +- configs/agents/rl/push_cube/train_config.json | 5 +- docs/source/tutorial/rl.rst | 2 +- embodichain/agents/rl/algo/base.py | 24 +- embodichain/agents/rl/algo/ppo.py | 311 +++++++++++++----- embodichain/agents/rl/buffer/__init__.py | 15 +- .../agents/rl/buffer/rollout_buffer.py | 227 ++++++++----- .../agents/rl/buffer/standard_buffer.py | 116 +++++++ embodichain/agents/rl/models/__init__.py | 27 +- embodichain/agents/rl/models/actor_critic.py | 156 +++++++-- embodichain/agents/rl/models/policy.py | 54 +-- embodichain/agents/rl/train.py | 50 ++- embodichain/agents/rl/utils/__init__.py | 10 +- .../agents/rl/utils/async_collector.py | 289 ++++++++++++++++ embodichain/agents/rl/utils/helper.py | 172 ++++++++-- embodichain/agents/rl/utils/trainer.py | 299 +++++++++++++---- pyproject.toml | 1 + tests/agents/test_rl.py | 2 +- 18 files changed, 1403 insertions(+), 359 deletions(-) create mode 100644 embodichain/agents/rl/buffer/standard_buffer.py create mode 100644 embodichain/agents/rl/utils/async_collector.py diff --git a/configs/agents/rl/basic/cart_pole/train_config.json b/configs/agents/rl/basic/cart_pole/train_config.json index 8412fe36..dddabe41 100644 --- a/configs/agents/rl/basic/cart_pole/train_config.json +++ b/configs/agents/rl/basic/cart_pole/train_config.json @@ -9,7 +9,7 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "eval_freq": 2, "save_freq": 200, "use_wandb": false, diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json index ae5026b2..07e6c961 100644 --- a/configs/agents/rl/push_cube/train_config.json +++ b/configs/agents/rl/push_cube/train_config.json @@ -9,11 +9,11 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "enable_eval": true, "num_eval_envs": 16, "num_eval_episodes": 3, - "eval_freq": 2, + "eval_freq": 200, "save_freq": 200, "use_wandb": false, "wandb_project_name": "embodychain-push_cube", @@ -38,6 +38,7 @@ }, "policy": { "name": "actor_critic", + "action_dim": 8, "actor": { "type": "mlp", "network_cfg": { diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index cbc011b2..b0123c96 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -64,7 +64,7 @@ The ``runtime`` section controls experiment setup: - **cuda**: Whether to use GPU (default: true) - **headless**: Whether to run simulation in headless mode - **iterations**: Number of training iterations -- **rollout_steps**: Steps per rollout (e.g., 1024) +- **buffer_size**: Steps per rollout (e.g., 1024) - **eval_freq**: Frequency of evaluation (in steps) - **save_freq**: Frequency of checkpoint saving (in steps) - **use_wandb**: Whether to enable Weights & Biases logging (set in JSON config) diff --git a/embodichain/agents/rl/algo/base.py b/embodichain/agents/rl/algo/base.py index 8d74a918..501ef2e7 100644 --- a/embodichain/agents/rl/algo/base.py +++ b/embodichain/agents/rl/algo/base.py @@ -18,35 +18,29 @@ from typing import Dict, Any, Callable import torch +from tensordict import TensorDict class BaseAlgorithm: - """Base class for RL algorithms. + """Base class for RL algorithms following TorchRL conventions. - Algorithms must implement buffer initialization, rollout collection, and - policy update. Trainer depends only on this interface to remain - algorithm-agnostic. + Algorithms implement rollout collection and policy update using TensorDict. + No custom buffer classes - use TensorDict operations directly. """ device: torch.device - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ) -> None: - """Initialize internal buffer(s) required by the algorithm.""" - raise NotImplementedError - def collect_rollout( self, env, policy, - obs: torch.Tensor, - num_steps: int, + tensordict: TensorDict, + buffer_size: int, on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect trajectories and return logging info (e.g., reward components).""" + ) -> TensorDict: + """Collect rollout and return TensorDict with batch_size=[T, N].""" raise NotImplementedError - def update(self) -> Dict[str, float]: + def update(self, rollout: TensorDict) -> Dict[str, float]: """Update policy using collected data and return training losses.""" raise NotImplementedError diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index f11fbe37..853e9868 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -15,14 +15,54 @@ # ---------------------------------------------------------------------------- import torch -from typing import Dict, Any, Tuple, Callable +from typing import Dict, Any, Callable -from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation -from embodichain.agents.rl.buffer import RolloutBuffer +from tensordict import TensorDict +from embodichain.agents.rl.utils import AlgorithmCfg, compute_gae, dict_to_tensordict from embodichain.utils import configclass from .base import BaseAlgorithm +def _print_tensordict_tree(td, prefix="", is_last=True, name="TensorDict"): + """Recursively print TensorDict structure in tree format.""" + connector = "└── " if is_last else "├── " + + # Print current node + batch_info = ( + f"batch_size={list(td.batch_size)}" if hasattr(td, "batch_size") else "" + ) + device_info = f"device={td.device}" if hasattr(td, "device") else "" + meta_info = ", ".join(filter(None, [batch_info, device_info])) + print(f"{prefix}{connector}{name}: TensorDict ({meta_info})") + + # Prepare prefix for children + extension = " " if is_last else "│ " + new_prefix = prefix + extension + + # Get all keys + keys = sorted(td.keys()) if hasattr(td, "keys") else [] + + for i, key in enumerate(keys): + is_last_child = i == len(keys) - 1 + value = td[key] + + if isinstance(value, TensorDict): + # Recursively print nested TensorDict + _print_tensordict_tree(value, new_prefix, is_last_child, name=key) + elif isinstance(value, torch.Tensor): + # Print tensor info + child_connector = "└── " if is_last_child else "├── " + shape_str = "x".join(map(str, value.shape)) + dtype_str = str(value.dtype).replace("torch.", "") + print( + f"{new_prefix}{child_connector}{key}: Tensor([{shape_str}], {dtype_str})" + ) + else: + # Print other types + child_connector = "└── " if is_last_child else "├── " + print(f"{new_prefix}{child_connector}{key}: {type(value).__name__}") + + @configclass class PPOCfg(AlgorithmCfg): """Configuration for the PPO algorithm.""" @@ -34,126 +74,208 @@ class PPOCfg(AlgorithmCfg): class PPO(BaseAlgorithm): - """PPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design).""" + """PPO algorithm using TensorDict for all data flow. + + Following TorchRL conventions: no custom buffer class, just TensorDict operations. + All data I/O uses TensorDict - no tensor fallback. + """ def __init__(self, cfg: PPOCfg, policy): self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) - self.buffer: RolloutBuffer | None = None - # no per-rollout aggregation for dense logging - - def _compute_gae( - self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Internal method to compute GAE. Only called by collect_rollout.""" - T, N = rewards.shape - advantages = torch.zeros_like(rewards, device=self.device) - last_adv = torch.zeros(N, device=self.device) - for t in reversed(range(T)): - next_value = values[t + 1] if t < T - 1 else torch.zeros_like(values[0]) - not_done = (~dones[t]).float() - delta = rewards[t] + self.cfg.gamma * next_value * not_done - values[t] - last_adv = ( - delta + self.cfg.gamma * self.cfg.gae_lambda * not_done * last_adv - ) - advantages[t] = last_adv - returns = advantages + values - return advantages, returns - - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ): - """Initialize the rollout buffer. Called by trainer before first rollout.""" - self.buffer = RolloutBuffer( - num_steps, num_envs, obs_dim, action_dim, self.device - ) def collect_rollout( self, env, policy, - obs: torch.Tensor, - num_steps: int, + tensordict: TensorDict, + buffer_size: int, on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect a rollout. Algorithm controls the data collection process.""" - if self.buffer is None: - raise RuntimeError( - "Buffer not initialized. Call initialize_buffer() first." - ) + ) -> TensorDict: + """Collect a rollout using TensorDict data flow. + Args: + env: Environment to collect from + policy: Policy to use for action selection + tensordict: Initial TensorDict with "observation" key + buffer_size: Number of steps to collect + on_step_callback: Optional callback called after each step + + Returns: + TensorDict with batch_size=[T, N] containing full rollout data + """ policy.train() - self.buffer.step = 0 - current_obs = obs + current_td = tensordict + rollout_list = [] - for t in range(num_steps): - # Get action from policy - actions, log_prob, value = policy.get_action( - current_obs, deterministic=False - ) + for t in range(buffer_size): + # Policy forward: adds "action", "sample_log_prob", "value" to tensordict + policy.forward(current_td) - # Wrap action as dict for env processing + # Extract action for environment step + action = current_td["action"] action_type = getattr(env, "action_type", "delta_qpos") - action_dict = {action_type: actions} + action_dict = {action_type: action} - # Step environment - result = env.step(action_dict) - next_obs, reward, terminated, truncated, env_info = result + # Environment step - returns tuple (env returns dict, not TensorDict) + next_obs, reward, terminated, truncated, env_info = env.step(action_dict) + + # Convert env dict observation to TensorDict at boundary + next_obs_td = dict_to_tensordict(next_obs, self.device) + + # Build "next" TensorDict done = terminated | truncated - # Light dtype normalization - reward = reward.float() - done = done.bool() + next_obs_for_td = next_obs_td["observation"] + + # Ensure batch_size consistency - use next_obs_td's batch_size + batch_size = next_obs_td.batch_size[0] + + next_td = TensorDict( + { + "observation": next_obs_for_td, + "reward": ( + reward.float().unsqueeze(-1) + if reward.dim() == 1 + else reward.float() + ), + "done": ( + done.bool().unsqueeze(-1) if done.dim() == 1 else done.bool() + ), + "terminated": ( + terminated.bool().unsqueeze(-1) + if terminated.dim() == 1 + else terminated.bool() + ), + "truncated": ( + truncated.bool().unsqueeze(-1) + if truncated.dim() == 1 + else truncated.bool() + ), + }, + batch_size=torch.Size([batch_size]), + device=self.device, + ) + + # Compute next value for GAE (bootstrap value) + with torch.no_grad(): + next_value_td = TensorDict( + {"observation": next_obs_for_td}, + batch_size=next_td.batch_size, + device=self.device, + ) + policy.get_value(next_value_td) + next_td["value"] = next_value_td["value"] - # Flatten dict observation from ObservationManager if needed - if isinstance(next_obs, dict): - next_obs = flatten_dict_observation(next_obs) + # Add "next" to current tensordict + current_td["next"] = next_td - # Add to buffer - self.buffer.add(current_obs, actions, reward, done, value, log_prob) + # Store complete transition + rollout_list.append(current_td.clone()) - # Dense logging is handled in Trainer.on_step via info; no aggregation here - # Call callback for statistics and logging + # Debug: Print TensorDict structure on first step + if len(rollout_list) == 1: + print("\n" + "=" * 80) + print("[DEBUG] Step 0 TensorDict Structure (Tree View)") + print("=" * 80) + _print_tensordict_tree(current_td, prefix="", is_last=True) + print("=" * 80 + "\n") + + # Callback for statistics and logging if on_step_callback is not None: - on_step_callback(current_obs, actions, reward, done, env_info, next_obs) + on_step_callback(current_td, env_info) + + # Prepare next iteration - use the converted TensorDict + current_td = next_obs_td - current_obs = next_obs + # Stack into [T, N, ...] TensorDict + rollout = torch.stack(rollout_list, dim=0) - # Compute advantages/returns and attach to buffer extras - adv, ret = self._compute_gae( - self.buffer.rewards, self.buffer.values, self.buffer.dones + print("\n" + "=" * 80) + print( + f"[DEBUG] Stacked Rollout (T={rollout.batch_size[0]}, N={rollout.batch_size[1]})" ) - self.buffer.set_extras({"advantages": adv, "returns": ret}) + print("=" * 80) + _print_tensordict_tree(rollout, prefix="", is_last=True) + print("=" * 80 + "\n") + + # Compute GAE advantages and returns + rollout = compute_gae( + rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda + ) + + return rollout + + def update(self, rollout: TensorDict) -> dict: + """Update the policy using collected rollout TensorDict (TorchRL style). - # No aggregated logging results; Trainer performs dense per-step logging - return {} + Args: + rollout: TensorDict with batch_size=[T, N] from collect_rollout() - def update(self) -> dict: - """Update the policy using the collected rollout buffer.""" - if self.buffer is None: - raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") + Returns: + Dictionary of training metrics + """ + # Flatten to [T*N, ...] for training + flat_data = rollout.reshape(-1) + total_samples = flat_data.batch_size[0] - # Normalize advantages (optional, common default) - adv = self.buffer._extras.get("advantages") - adv = (adv - adv.mean()) / (adv.std() + 1e-8) + # Normalize advantages globally + advantages = flat_data["advantage"] + advantages_normalized = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) + flat_data["advantage"] = advantages_normalized total_actor_loss = 0.0 total_value_loss = 0.0 total_entropy = 0.0 total_steps = 0 + total_clip_fraction = 0.0 + total_approx_kl = 0.0 + + for epoch in range(self.cfg.n_epochs): + # Shuffle data each epoch + indices = torch.randperm(total_samples, device=self.device) + shuffled_data = flat_data[indices] + + # Iterate over minibatches + num_minibatches = total_samples // self.cfg.batch_size + for i in range(num_minibatches): + start_idx = i * self.cfg.batch_size + end_idx = start_idx + self.cfg.batch_size + batch_td = shuffled_data[start_idx:end_idx] + + # Extract data from TensorDict batch + old_logprobs = batch_td["sample_log_prob"] + returns = batch_td["value_target"] + advantages = batch_td[ + "advantage" + ] # Note: advantages are already normalized globally before shuffling + + # Evaluate actions with current policy + self.policy.evaluate_actions(batch_td) - for _ in range(self.cfg.n_epochs): - for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): - obs = batch["obs"] - actions = batch["actions"] - old_logprobs = batch["logprobs"] - returns = batch["returns"] - advantages = ( - (batch["advantages"] - adv.mean()) / (adv.std() + 1e-8) - ).detach() - - logprobs, entropy, values = self.policy.evaluate_actions(obs, actions) + # Get updated values + logprobs = batch_td["sample_log_prob"] + entropy = batch_td["entropy"] + values = batch_td["value"] + + # Ensure shapes match (squeeze if needed) + if old_logprobs.dim() > 1: + old_logprobs = old_logprobs.squeeze(-1) + if logprobs.dim() > 1: + logprobs = logprobs.squeeze(-1) + if values.dim() > 1: + values = values.squeeze(-1) + if returns.dim() > 1: + returns = returns.squeeze(-1) + if advantages.dim() > 1: + advantages = advantages.squeeze(-1) + if entropy.dim() > 1: + entropy = entropy.squeeze(-1) + + # PPO loss computation ratio = (logprobs - old_logprobs).exp() surr1 = ratio * advantages surr2 = ( @@ -166,6 +288,13 @@ def update(self) -> dict: value_loss = torch.nn.functional.mse_loss(values, returns) entropy_loss = -entropy.mean() + # Diagnostics + with torch.no_grad(): + clip_fraction = ( + ((ratio - 1.0).abs() > self.cfg.clip_coef).float().mean() + ) + approx_kl = ((ratio - 1.0) - (logprobs - old_logprobs)).mean() + loss = ( actor_loss + self.cfg.vf_coef * value_loss @@ -179,14 +308,18 @@ def update(self) -> dict: ) self.optimizer.step() - bs = obs.shape[0] + bs = batch_td.batch_size[0] total_actor_loss += actor_loss.item() * bs total_value_loss += value_loss.item() * bs total_entropy += (-entropy_loss.item()) * bs + total_clip_fraction += clip_fraction.item() * bs + total_approx_kl += approx_kl.item() * bs total_steps += bs return { "actor_loss": total_actor_loss / max(1, total_steps), "value_loss": total_value_loss / max(1, total_steps), "entropy": total_entropy / max(1, total_steps), + "clip_fraction": total_clip_fraction / max(1, total_steps), + "approx_kl": total_approx_kl / max(1, total_steps), } diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py index 8e6f6392..17d3b4be 100644 --- a/embodichain/agents/rl/buffer/__init__.py +++ b/embodichain/agents/rl/buffer/__init__.py @@ -14,6 +14,17 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .rollout_buffer import RolloutBuffer +""" +Buffer module for RL training. -__all__ = ["RolloutBuffer"] +Provides two buffer implementations: +- RolloutBuffer: Standard PPO buffer (single rollout, use and discard) +- VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference) +""" + +from .rollout_buffer import VLABuffer +from .standard_buffer import RolloutBuffer + +__all__ = ["RolloutBuffer", "VLABuffer"] + +__all__ = ["TensorDictRolloutBuffer"] diff --git a/embodichain/agents/rl/buffer/rollout_buffer.py b/embodichain/agents/rl/buffer/rollout_buffer.py index d99a8966..d08252ff 100644 --- a/embodichain/agents/rl/buffer/rollout_buffer.py +++ b/embodichain/agents/rl/buffer/rollout_buffer.py @@ -16,91 +16,160 @@ from __future__ import annotations -from typing import Dict, Iterator - import torch +from tensordict import TensorDict +from typing import Optional + +class VLABuffer: + """FIFO rollout buffer for VLA RL with pre-allocated TensorDict storage. -class RolloutBuffer: - """On-device rollout buffer for on-policy algorithms. + Uses a single pre-allocated TensorDict with circular indexing for efficient + high-frequency transition writes. Designed for async VLA scenarios where + model inference is slow but training is fast. - Stores (obs, actions, rewards, dones, values, logprobs) over time. - After finalize(), exposes advantages/returns and minibatch iteration. + Key characteristics: + - Pre-allocated memory: Zero-copy writes via direct indexing + - FIFO eviction: Circular buffer automatically overwrites oldest data + - Transition-level storage: Each step is a separate entry + - High-frequency writes: Optimized for async collection (no TensorDict creation overhead) + + Storage layout: Single TensorDict with shape [buffer_size, ...] """ - def __init__( - self, - num_steps: int, - num_envs: int, - obs_dim: int, - action_dim: int, - device: torch.device, - ): - self.num_steps = num_steps - self.num_envs = num_envs - self.obs_dim = obs_dim - self.action_dim = action_dim + def __init__(self, buffer_size: int, device: torch.device): + """Initialize VLA buffer with lazy allocation. + + Args: + buffer_size: Maximum number of transitions to store + device: Device to store tensors on + """ + self.buffer_size = buffer_size self.device = device + self.buffer: Optional[TensorDict] = None # Lazy init on first add + self.write_pos = 0 # Current write position (circular) + self.size = 0 # Current valid data count + self._total_added = 0 + self._initialized = False + + def _initialize_buffer(self, template: TensorDict) -> None: + """Initialize buffer structure from first transition template. + + Args: + template: First transition TensorDict to infer structure from + """ + if self._initialized: + return - T, N = num_steps, num_envs - self.obs = torch.zeros(T, N, obs_dim, dtype=torch.float32, device=device) - self.actions = torch.zeros(T, N, action_dim, dtype=torch.float32, device=device) - self.rewards = torch.zeros(T, N, dtype=torch.float32, device=device) - self.dones = torch.zeros(T, N, dtype=torch.bool, device=device) - self.values = torch.zeros(T, N, dtype=torch.float32, device=device) - self.logprobs = torch.zeros(T, N, dtype=torch.float32, device=device) - - self.step = 0 - # Container for algorithm-specific extra fields (e.g., advantages, returns) - self._extras: dict[str, torch.Tensor] = {} - - def add( - self, - obs: torch.Tensor, - action: torch.Tensor, - reward: torch.Tensor, - done: torch.Tensor, - value: torch.Tensor, - logprob: torch.Tensor, - ) -> None: - t = self.step - self.obs[t].copy_(obs) - self.actions[t].copy_(action) - self.rewards[t].copy_(reward) - self.dones[t].copy_(done) - self.values[t].copy_(value) - self.logprobs[t].copy_(logprob) - self.step += 1 - - def set_extras(self, extras: dict[str, torch.Tensor]) -> None: - """Attach algorithm-specific tensors (shape [T, N, ...]) for batching. - - Examples: - {"advantages": adv, "returns": ret} + # Pre-allocate buffer with buffer_size + # Template should be a single transition [key: shape] + self.buffer = template.expand(self.buffer_size).clone() + self._initialized = True + + def add(self, transition: TensorDict) -> None: + """Add a single transition to buffer (high-frequency async writes). + + Args: + transition: Single transition TensorDict (no batch dimension) """ - self._extras = extras or {} - - def iterate_minibatches(self, batch_size: int) -> Iterator[Dict[str, torch.Tensor]]: - T, N = self.num_steps, self.num_envs - total = T * N - indices = torch.randperm(total, device=self.device) - for start in range(0, total, batch_size): - idx = indices[start : start + batch_size] - t_idx = idx // N - n_idx = idx % N - batch = { - "obs": self.obs[t_idx, n_idx], - "actions": self.actions[t_idx, n_idx], - "rewards": self.rewards[t_idx, n_idx], - "dones": self.dones[t_idx, n_idx], - "values": self.values[t_idx, n_idx], - "logprobs": self.logprobs[t_idx, n_idx], - } - # Slice extras if present and shape aligned to [T, N, ...] - for name, tensor in self._extras.items(): - try: - batch[name] = tensor[t_idx, n_idx] - except Exception: - # Skip misaligned extras silently - continue - yield batch + # Lazy initialization on first add + if not self._initialized: + self._initialize_buffer(transition.to(self.device)) + + # Ensure transition is on correct device + transition = transition.to(self.device) + + # Direct index assignment (zero-copy write) + self.buffer[self.write_pos] = transition + + # Update circular index + self.write_pos = (self.write_pos + 1) % self.buffer_size + + # Update size (saturates at buffer_size) + self.size = min(self.size + 1, self.buffer_size) + self._total_added += 1 + + def add_batch(self, transitions: TensorDict) -> None: + """Add multiple transitions at once (batch write). + + Args: + transitions: Batch of transitions with shape [batch_size, ...] + """ + batch_size = transitions.batch_size[0] + + # Lazy initialization + if not self._initialized: + self._initialize_buffer(transitions[0].to(self.device)) + + transitions = transitions.to(self.device) + + # Handle circular write + for i in range(batch_size): + self.buffer[self.write_pos] = transitions[i] + self.write_pos = (self.write_pos + 1) % self.buffer_size + self.size = min(self.size + 1, self.buffer_size) + self._total_added += 1 + + def get(self, flatten: bool = True) -> TensorDict: + """Get valid data from buffer. + + Args: + flatten: If True, return flattened [size, ...]. Currently only supports True. + + Returns: + TensorDict with batch_size=[size, ...] containing valid data + """ + if not self._initialized or self.size == 0: + raise ValueError("Buffer is empty") + + if not flatten: + raise NotImplementedError("Only flatten=True is supported for VLABuffer") + + # Return first 'size' elements (valid data) + # Note: Data is in insertion order up to write_pos, then wraps + if self.size < self.buffer_size: + # Buffer not yet full, data is [0:size] + return self.buffer[: self.size] + else: + # Buffer full, need to rearrange to maintain temporal order + # Oldest data is at write_pos, newest at write_pos-1 + indices = ( + torch.arange( + self.write_pos, + self.write_pos + self.buffer_size, + device=self.device, + ) + % self.buffer_size + ) + return self.buffer[indices] + + def clear(self) -> None: + """Clear buffer (reset pointers, keep pre-allocated memory).""" + self.write_pos = 0 + self.size = 0 + # Keep buffer allocated for reuse + + def __len__(self) -> int: + """Return current number of valid transitions.""" + return self.size + + def is_full(self) -> bool: + """Check if buffer is at full buffer_size.""" + return self.size >= self.buffer_size + + def get_num_rollouts(self) -> int: + """Return 1 (buffer stores transitions, not rollouts).""" + return 1 if self.size > 0 else 0 + + def get_stats(self) -> dict: + """Get buffer statistics for logging.""" + return { + "buffer_size": self.size, + "buffer_capacity": self.buffer_size, + "total_transitions": self.size, + "total_added": self._total_added, + "buffer_usage": ( + self.size / self.buffer_size if self.buffer_size > 0 else 0.0 + ), + "write_pos": self.write_pos, + } diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py new file mode 100644 index 00000000..200cc176 --- /dev/null +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -0,0 +1,116 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict +from typing import Optional + + +class RolloutBuffer: + """Standard on-policy rollout buffer for PPO (matches mainstream implementations). + + Unlike VLA buffer which accumulates multiple rollouts with FIFO eviction, + this buffer follows standard PPO pattern: + - Stores exactly ONE rollout at a time + - After training, buffer is cleared (on-policy: use once and discard) + - Simple and efficient for normal-sized models + + Interface compatible with VLABuffer for easy switching. + """ + + def __init__(self, buffer_size: int, device: torch.device): + """Initialize standard rollout buffer. + + Args: + buffer_size: Not used (kept for interface compatibility) + device: Device to store tensors on + """ + self.device = device + self._rollout: Optional[TensorDict] = None + + def add(self, rollout: TensorDict) -> None: + """Add a rollout to buffer, replacing any existing rollout. + + Args: + rollout: TensorDict with batch_size=[T, N, ...] + """ + # Standard PPO: replace existing rollout (not accumulate) + self._rollout = rollout.to(self.device) + + def get(self, flatten: bool = True) -> TensorDict: + """Get rollout from buffer and clear it (standard PPO behavior). + + Args: + flatten: If True, flatten to [batch_size, ...]. + If False, return as [T, N, ...]. + + Returns: + TensorDict with rollout data + """ + if self._rollout is None: + raise ValueError("Buffer is empty") + + rollout = self._rollout + + # Clear after retrieval (on-policy: use once) + self._rollout = None + + if flatten: + # Flatten [T, N, ...] -> [T*N, ...] + return rollout.reshape(-1) + else: + return rollout + + def clear(self) -> None: + """Clear buffer.""" + self._rollout = None + + def is_full(self) -> bool: + """Check if buffer has a rollout ready for training. + + Returns: + True if buffer contains a rollout + """ + return self._rollout is not None + + def __len__(self) -> int: + """Return 1 if buffer has data, 0 otherwise.""" + return 1 if self._rollout is not None else 0 + + def get_num_rollouts(self) -> int: + """Return current number of rollouts in buffer (0 or 1).""" + return 1 if self._rollout is not None else 0 + + def get_num_transitions(self) -> int: + """Return total number of transitions stored.""" + if self._rollout is None: + return 0 + return self._rollout.batch_size[0] * self._rollout.batch_size[1] + + def get_stats(self) -> dict: + """Get buffer statistics for logging. + + Returns: + Dict with buffer stats + """ + return { + "buffer_size": 1 if self._rollout is not None else 0, + "buffer_capacity": 1, + "total_transitions": self.get_num_transitions(), + "buffer_usage": 1.0 if self._rollout is not None else 0.0, + } diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 669e2b33..1c5e70a6 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -18,7 +18,6 @@ from typing import Dict, Type import torch -from gymnasium import spaces from .actor_critic import ActorCritic from .policy import Policy @@ -44,13 +43,26 @@ def get_policy_class(name: str) -> Type[Policy] | None: def build_policy( policy_block: dict, - obs_space: spaces.Space, - action_space: spaces.Space, + action_dim: int, device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, ) -> Policy: - """Build policy strictly from json-like block: { name: ..., cfg: {...} }""" + """Build policy from json-like block. + + With TensorDict architecture, we only need action_dim. + Observations are handled via TensorDict structure. + + Args: + policy_block: Config dict with 'name' key + action_dim: Dimension of action space + device: Device to place policy on + actor: Actor network (required for actor_critic) + critic: Critic network (required for actor_critic) + + Returns: + Initialized Policy instance + """ name = policy_block["name"].lower() if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) @@ -63,9 +75,12 @@ def build_policy( raise ValueError( "ActorCritic policy requires external 'actor' and 'critic' modules." ) - return policy_cls(obs_space, action_space, device, actor=actor, critic=critic) + return policy_cls( + action_dim=action_dim, device=device, actor=actor, critic=critic + ) else: - return policy_cls(obs_space, action_space, device) + # Other policies should also use action_dim signature + return policy_cls(action_dim=action_dim, device=device) def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 1c40043a..faf305cb 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -16,11 +16,12 @@ from __future__ import annotations -from typing import Dict, Any, Tuple +from typing import Dict, Any import torch import torch.nn as nn from torch.distributions.normal import Normal +from tensordict import TensorDict from .mlp import MLP from .policy import Policy @@ -28,31 +29,31 @@ class ActorCritic(Policy): """Actor-Critic with learnable log_std for Gaussian policy. - This is a placeholder implementation of the Policy interface that: - - Encapsulates MLP networks (actor + critic) that need to be trained by RL algorithms + Uses TensorDict for all data I/O following TorchRL conventions. + This implementation: + - Encapsulates MLP networks (actor + critic) trained by RL algorithms - Handles internal computation: MLP output → mean + learnable log_std → Normal distribution - - Provides a uniform interface for RL algorithms (PPO, SAC, etc.) + - Provides a uniform TensorDict-based interface for RL algorithms (PPO, SAC, etc.) This allows seamless swapping with other policy implementations (e.g., VLAPolicy) without modifying RL algorithm code. Implements: - - get_action(obs, deterministic=False) -> (action, log_prob, value) - - get_value(obs) - - evaluate_actions(obs, actions) -> (log_prob, entropy, value) + - forward(tensordict) -> tensordict (adds action, sample_log_prob, value) + - get_value(tensordict) -> tensordict (adds value) + - evaluate_actions(tensordict) -> tensordict (adds sample_log_prob, entropy, value) """ def __init__( self, - obs_space, - action_space, + action_dim: int, device: torch.device, actor: nn.Module, critic: nn.Module, ): super().__init__() - self.obs_dim = obs_space.shape[-1] - self.action_dim = action_space.shape[-1] + # Observation handling done via TensorDict - no need for obs_space + self.action_dim = action_dim self.device = device # Require external injection of actor and critic @@ -66,31 +67,128 @@ def __init__( self.log_std_min = -5.0 self.log_std_max = 2.0 + def _extract_obs_tensor(self, tensordict: TensorDict) -> torch.Tensor: + """Extract observation as flat tensor from TensorDict. + + For nested TensorDict observations, flattens all leaf tensors. + For plain tensor observations, returns as is. + + Args: + tensordict: Input TensorDict with "observation" key + + Returns: + Flattened observation tensor + """ + obs = tensordict["observation"] + + # Handle nested TensorDict by collecting all leaf tensors + obs_list = [] + + def _collect(item): + # Duck typing: if it has keys(), treat as TensorDict + if hasattr(item, "keys"): + for key in sorted(item.keys()): + _collect(item[key]) + else: + # Leaf tensor + obs_list.append(item.flatten(start_dim=1)) + + _collect(obs) + + if len(obs_list) == 0: + raise ValueError("No tensors found in observation") + elif len(obs_list) == 1: + return obs_list[0] + else: + return torch.cat(obs_list, dim=-1) + @torch.no_grad() - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) + def forward(self, tensordict: TensorDict) -> TensorDict: + """Forward pass: sample action and compute value (in-place modification). + + Args: + tensordict: Must contain "observation" key + + Returns: + Same tensordict with added keys: + - "action": Sampled action + - "sample_log_prob": Log probability of sampled action + - "value": Value estimate + - "loc": Distribution mean + - "scale": Distribution std + """ + obs_tensor = self._extract_obs_tensor(tensordict) + + # Actor forward + mean = self.actor(obs_tensor) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) dist = Normal(mean, std) - action = mean if deterministic else dist.sample() - log_prob = dist.log_prob(action).sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return action, log_prob, value + + # Sample action (or use mean if deterministic mode set elsewhere) + # For now, always sample during forward; deterministic handled by setting std=0 externally if needed + action = dist.sample() + log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) + + # Critic forward - keep shape [N, 1] for consistency with reward/done + value = self.critic(obs_tensor) + + # Add to tensordict (in-place) + tensordict["action"] = action + tensordict["sample_log_prob"] = log_prob + tensordict["value"] = value + tensordict["loc"] = mean + tensordict["scale"] = std + + return tensordict @torch.no_grad() - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - return self.critic(obs).squeeze(-1) + def get_value(self, tensordict: TensorDict) -> TensorDict: + """Get value estimate for observations (in-place modification). + + Args: + tensordict: Must contain "observation" key + + Returns: + Same tensordict with added key: + - "value": Value estimate + """ + obs_tensor = self._extract_obs_tensor(tensordict) + value = self.critic(obs_tensor) # Keep shape [N, 1] + tensordict["value"] = value + return tensordict + + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + """Evaluate actions for policy gradient computation (in-place modification). - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) + Args: + tensordict: Must contain "observation" and "action" keys + + Returns: + Same tensordict with added keys: + - "sample_log_prob": Log probability of actions + - "entropy": Entropy of action distribution + - "value": Value estimate + """ + obs_tensor = self._extract_obs_tensor(tensordict) + actions = tensordict["action"] + + # Actor forward + mean = self.actor(obs_tensor) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) dist = Normal(mean, std) - log_prob = dist.log_prob(actions).sum(dim=-1) - entropy = dist.entropy().sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return log_prob, entropy, value + + # Evaluate given actions - keep shape [N, 1] for consistency + log_prob = dist.log_prob(actions).sum(dim=-1, keepdim=True) + entropy = dist.entropy().sum(dim=-1, keepdim=True) + + # Critic forward - keep shape [N, 1] + value = self.critic(obs_tensor) + + # Add to tensordict (in-place) + tensordict["sample_log_prob"] = log_prob + tensordict["entropy"] = entropy + tensordict["value"] = value + + return tensordict diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index cd21d0f7..68b909d1 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -19,13 +19,15 @@ This module defines an abstract Policy base class that all RL policies must inherit from. A Policy encapsulates the neural networks and exposes a uniform interface for RL algorithms (e.g., PPO, SAC) to interact with. + +All data I/O now uses TensorDict for structured, extensible data flow. """ from __future__ import annotations -from typing import Tuple from abc import ABC, abstractmethod import torch.nn as nn +from tensordict import TensorDict import torch @@ -37,6 +39,7 @@ class Policy(nn.Module, ABC): - Encapsulates neural networks that are trained by RL algorithms - Handles internal computations (e.g., network output → distribution) - Provides a uniform interface for algorithms (PPO, SAC, etc.) + - Uses TensorDict for all inputs and outputs (no tensor fallback) """ device: torch.device @@ -46,49 +49,54 @@ def __init__(self) -> None: super().__init__() @abstractmethod - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sample an action from the policy. + def forward(self, tensordict: TensorDict) -> TensorDict: + """Forward pass that adds action to the input tensordict (in-place). + + This is the main inference method following TorchRL conventions. Args: - obs: Observation tensor of shape (batch_size, obs_dim) - deterministic: If True, return the mean action; otherwise sample + tensordict: Input TensorDict containing at minimum: + - "observation": Observation tensor or nested TensorDict Returns: - Tuple of (action, log_prob, value): - - action: Sampled action tensor of shape (batch_size, action_dim) - - log_prob: Log probability of the action, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + The same TensorDict (modified in-place) with added fields: + - "action": Sampled action tensor + - "sample_log_prob": Log probability of the sampled action + - "value": Value estimate (optional, for actor-critic) + - "loc": Distribution mean (optional, for continuous actions) + - "scale": Distribution std (optional, for continuous actions) """ raise NotImplementedError @abstractmethod - def get_value(self, obs: torch.Tensor) -> torch.Tensor: + def get_value(self, tensordict: TensorDict) -> TensorDict: """Get value estimate for given observations. Args: - obs: Observation tensor of shape (batch_size, obs_dim) + tensordict: Input TensorDict containing: + - "observation": Observation data Returns: - Value estimate tensor of shape (batch_size,) + TensorDict with added field: + - "value": Value estimate tensor """ raise NotImplementedError @abstractmethod - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: """Evaluate actions and compute log probabilities, entropy, and values. + Used during policy updates to recompute action probabilities. + Args: - obs: Observation tensor of shape (batch_size, obs_dim) - actions: Action tensor of shape (batch_size, action_dim) + tensordict: Input TensorDict containing: + - "observation": Observation data + - "action": Actions to evaluate Returns: - Tuple of (log_prob, entropy, value): - - log_prob: Log probability of actions, shape (batch_size,) - - entropy: Entropy of the action distribution, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + TensorDict with added fields: + - "sample_log_prob": Log probability of actions + - "entropy": Entropy of the action distribution + - "value": Value estimate """ raise NotImplementedError diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index 0f766954..4da25442 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -64,7 +64,7 @@ def train_from_config(config_path: str): seed = int(trainer_cfg.get("seed", 1)) device_str = trainer_cfg.get("device", "cpu") iterations = int(trainer_cfg.get("iterations", 250)) - rollout_steps = int(trainer_cfg.get("rollout_steps", 2048)) + buffer_size = int(trainer_cfg.get("buffer_size", 2048)) enable_eval = bool(trainer_cfg.get("enable_eval", False)) eval_freq = int(trainer_cfg.get("eval_freq", 10000)) save_freq = int(trainer_cfg.get("save_freq", 50000)) @@ -175,13 +175,36 @@ def train_from_config(config_path: str): # Build Policy via registry policy_name = policy_block["name"] - # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic) - if policy_name.lower() == "actor_critic": - # Get observation dimension from flattened observation space - # flattened_observation_space returns Box space for RL training - obs_dim = env.flattened_observation_space.shape[-1] - action_dim = env.action_space.shape[-1] + # Get action_dim from config (required) + action_dim = policy_block.get("action_dim") + if action_dim is None: + raise ValueError( + "Missing 'action_dim' in policy config. " + "With TensorDict architecture, action dimension must be explicitly specified in config. " + 'Example: {"policy": {"name": "actor_critic", "action_dim": 7, ...}}' + ) + + # Infer obs_dim from environment sampling (no gym space dependency) + # Env returns dict, we process it to infer dimensions + sample_obs, _ = env.reset() + + # Get obs_dim by flattening observation structure (env returns dict) + obs_list = [] + + def _collect(item): + """Recursively collect tensors from dict or direct tensor.""" + if hasattr(item, "keys"): # It's a dict + for key in sorted(item.keys()): + _collect(item[key]) + else: # It's a Tensor + obs_list.append(item.flatten(start_dim=1)) + + _collect(sample_obs) + obs_dim = sum(t.shape[-1] for t in obs_list) + + # Build policy based on type + if policy_name.lower() == "actor_critic": actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: @@ -194,16 +217,13 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.flattened_observation_space, - env.action_space, - device, + action_dim=action_dim, + device=device, actor=actor, critic=critic, ) else: - policy = build_policy( - policy_block, env.flattened_observation_space, env.action_space, device - ) + policy = build_policy(policy_block, action_dim=action_dim, device=device) # Build Algorithm via factory algo_name = algo_block["name"].lower() @@ -254,7 +274,7 @@ def train_from_config(config_path: str): policy=policy, env=env, algorithm=algo, - num_steps=rollout_steps, + buffer_size=buffer_size, batch_size=algo_cfg["batch_size"], writer=writer, eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled @@ -277,7 +297,7 @@ def train_from_config(config_path: str): f"Algorithm: {algo_name} (available: {get_registered_algo_names()})" ) - total_steps = int(iterations * rollout_steps * env.num_envs) + total_steps = int(iterations * buffer_size * env.num_envs) logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})") try: diff --git a/embodichain/agents/rl/utils/__init__.py b/embodichain/agents/rl/utils/__init__.py index e6f9e57a..f1b94eb1 100644 --- a/embodichain/agents/rl/utils/__init__.py +++ b/embodichain/agents/rl/utils/__init__.py @@ -15,9 +15,15 @@ # ---------------------------------------------------------------------------- from .config import AlgorithmCfg -from .helper import flatten_dict_observation +from .helper import dict_to_tensordict, mean_scalar, pack_log_dict, compute_gae +from .async_collector import AsyncCollector, AsyncCollectorStats __all__ = [ "AlgorithmCfg", - "flatten_dict_observation", + "dict_to_tensordict", + "mean_scalar", + "pack_log_dict", + "compute_gae", + "AsyncCollector", + "AsyncCollectorStats", ] diff --git a/embodichain/agents/rl/utils/async_collector.py b/embodichain/agents/rl/utils/async_collector.py new file mode 100644 index 00000000..ac9b4b4d --- /dev/null +++ b/embodichain/agents/rl/utils/async_collector.py @@ -0,0 +1,289 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import threading +import time +from typing import Callable, Optional +import torch +from tensordict import TensorDict +from collections import deque + +from .helper import dict_to_tensordict + + +class AsyncCollector: + """Asynchronous data collector for VLA RL scenarios. + + Runs in a background thread to continuously collect transitions while + the main thread performs model updates. Designed for scenarios where + model inference is slow (e.g., VLA) but training is fast. + + Key features: + - Background thread: Continuous data collection + - Thread-safe buffer: Lock-protected writes + - Step-level collection: Individual transitions added to buffer + - Episode statistics tracking: Rewards and lengths + + Usage: + collector = AsyncCollector(env, policy, buffer, device, ...) + collector.start() # Begin background collection + # ... main thread does training ... + collector.stop() # Stop collection + """ + + def __init__( + self, + env, + policy, + buffer, + device: torch.device, + on_step_callback: Optional[Callable] = None, + ): + """Initialize async collector. + + Args: + env: Environment to collect from + policy: Policy for action selection + buffer: VLABuffer instance (shared with Trainer) + device: Device for tensor operations + on_step_callback: Optional callback(transition, env_info) called after each step + """ + self.env = env + self.policy = policy + self.buffer = buffer + self.device = device + self.on_step_callback = on_step_callback + + # Thread control + self._running = False + self._thread: Optional[threading.Thread] = None + self._lock = threading.Lock() + + # Episode statistics + self._episode_count = 0 + self._step_count = 0 + + # Initialize observation + obs_dict, _ = self.env.reset() + self.obs_tensordict = dict_to_tensordict(obs_dict, self.device) + + def start(self): + """Start background collection thread.""" + if self._running: + raise RuntimeError("Collector is already running") + + self._running = True + self._thread = threading.Thread(target=self._collect_loop, daemon=True) + self._thread.start() + print("[AsyncCollector] Background collection started") + + def stop(self): + """Stop background collection thread.""" + if not self._running: + return + + self._running = False + if self._thread is not None: + self._thread.join(timeout=5.0) + if self._thread.is_alive(): + print("[AsyncCollector] Warning: Thread did not stop cleanly") + + print( + f"[AsyncCollector] Stopped (collected {self._step_count} steps, {self._episode_count} episodes)" + ) + + def is_running(self) -> bool: + """Check if collector is currently running.""" + return self._running + + def get_stats(self) -> dict: + """Get collection statistics.""" + with self._lock: + return { + "steps_collected": self._step_count, + "episodes_collected": self._episode_count, + } + + def _collect_loop(self): + """Background thread main loop: continuously collect transitions. + + This method runs in a separate thread and continuously: + 1. Gets action from policy + 2. Steps environment + 3. Constructs transition TensorDict + 4. Adds to buffer (thread-safe) + 5. Updates statistics + """ + current_td = self.obs_tensordict + + while self._running: + try: + # Policy forward (no_grad for inference) + with torch.no_grad(): + self.policy.train() # Use stochastic policy + self.policy.forward(current_td) + + # Extract action + action = current_td["action"] + action_type = getattr(self.env, "action_type", "delta_qpos") + action_dict = {action_type: action} + + # Environment step + next_obs_dict, reward, terminated, truncated, env_info = self.env.step( + action_dict + ) + + # Convert observation to TensorDict + next_obs_td = dict_to_tensordict(next_obs_dict, self.device) + done = terminated | truncated + next_obs_for_td = next_obs_td["observation"] + batch_size = next_obs_td.batch_size[0] + + # Build "next" TensorDict + next_td = TensorDict( + { + "observation": next_obs_for_td, + "reward": ( + reward.float().unsqueeze(-1) + if reward.dim() == 1 + else reward.float() + ), + "done": ( + done.bool().unsqueeze(-1) + if done.dim() == 1 + else done.bool() + ), + "terminated": ( + terminated.bool().unsqueeze(-1) + if terminated.dim() == 1 + else terminated.bool() + ), + "truncated": ( + truncated.bool().unsqueeze(-1) + if truncated.dim() == 1 + else truncated.bool() + ), + }, + batch_size=torch.Size([batch_size]), + device=self.device, + ) + + # Compute next value for bootstrapping (GAE computation) + with torch.no_grad(): + next_value_td = TensorDict( + {"observation": next_obs_for_td}, + batch_size=next_td.batch_size, + device=self.device, + ) + self.policy.get_value(next_value_td) + next_td["value"] = next_value_td["value"] + + # Add "next" to current transition + current_td["next"] = next_td + + # Flatten transition for buffer (remove batch dimension for single-step storage) + # Current buffer expects transitions without batch dimension + # We need to add each parallel env's transition separately + for env_idx in range(batch_size): + transition = current_td[env_idx] # Extract single env's transition + + # Thread-safe buffer write + with self._lock: + self.buffer.add(transition) + self._step_count += 1 + + # Callback for statistics + if self.on_step_callback is not None: + self.on_step_callback(current_td, env_info) + + # Handle episode termination + if done.any(): + with self._lock: + self._episode_count += done.sum().item() + + # Prepare next observation + current_td = next_obs_td + + except Exception as e: + print(f"[AsyncCollector] Error in collection loop: {e}") + import traceback + + traceback.print_exc() + break + + print("[AsyncCollector] Collection loop exited") + + +class AsyncCollectorStats: + """Helper class to track async collection statistics safely.""" + + def __init__(self, num_envs: int, device: torch.device): + self.device = device + self.num_envs = num_envs + + # Episode tracking (on device) + self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=device) + self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=device) + + # Completed episodes (CPU) + self.ret_window = deque(maxlen=100) + self.len_window = deque(maxlen=100) + self._lock = threading.Lock() + + def update(self, reward: torch.Tensor, done: torch.Tensor): + """Update episode statistics (thread-safe). + + Args: + reward: Reward tensor [N] or [N, 1] + done: Done tensor [N] or [N, 1] + """ + # Ensure correct shape + if reward.dim() > 1: + reward = reward.squeeze(-1) + if done.dim() > 1: + done = done.squeeze(-1) + + with self._lock: + # Update cumulative stats + self.curr_ret += reward + self.curr_len += 1 + + # Handle completed episodes + done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) + if done_idx.numel() > 0: + finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() + finished_len = self.curr_len[done_idx].detach().cpu().tolist() + self.ret_window.extend(finished_ret) + self.len_window.extend(finished_len) + + # Reset for finished episodes + self.curr_ret[done_idx] = 0 + self.curr_len[done_idx] = 0 + + def get_avg_stats(self) -> tuple[float, float]: + """Get average episode return and length (thread-safe). + + Returns: + (avg_return, avg_length) or (nan, nan) if no episodes completed + """ + with self._lock: + if len(self.ret_window) == 0: + return float("nan"), float("nan") + return float(sum(self.ret_window) / len(self.ret_window)), float( + sum(self.len_window) / len(self.len_window) + ) diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index 3021a31f..17919144 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -14,39 +14,167 @@ # limitations under the License. # ---------------------------------------------------------------------------- -import torch +"""Helper utilities for RL training. +This module provides utility functions for RL algorithms. +""" -def flatten_dict_observation(input_dict: dict) -> torch.Tensor: - """ - Flatten hierarchical dict observations from ObservationManager. +import torch +import numpy as np +from tensordict import TensorDict - Recursively traverse nested dicts, collect all tensor values, - flatten each to (num_envs, -1), and concatenate in sorted key order. + +def dict_to_tensordict(obs_dict: dict, device: torch.device) -> TensorDict: + """Convert nested dict observation to TensorDict recursively. Args: - input_dict: Nested dict structure, e.g. {"robot": {"qpos": tensor, "ee_pos": tensor}, "object": {...}} + obs_dict: Nested observation dictionary + device: Device to place tensors on Returns: - Concatenated flat tensor of shape (num_envs, total_dim) + TensorDict with nested structure preserved and "observation" key """ - obs_list = [] - def _collect_tensors(d, prefix=""): - """Recursively collect tensors from nested dicts in sorted order.""" - for key in sorted(d.keys()): - full_key = f"{prefix}/{key}" if prefix else key - value = d[key] + def _recursive_convert(d): + """Recursively convert dict to TensorDict-compatible structure.""" + result = {} + for key, value in d.items(): if isinstance(value, dict): - _collect_tensors(value, full_key) + # Recursively convert nested dicts + result[key] = _recursive_convert(value) elif isinstance(value, torch.Tensor): - # Flatten tensor to (num_envs, -1) shape - obs_list.append(value.flatten(start_dim=1)) + result[key] = value.to(device) + else: + result[key] = torch.tensor(value, device=device) + return result + + # Convert the observation dict structure + converted = _recursive_convert(obs_dict) + + # Infer batch_size from first tensor we find + def _get_first_tensor_batch_size(d): + """Find first tensor and get its batch dimension.""" + for value in d.values(): + if isinstance(value, torch.Tensor): + return value.shape[0] + elif isinstance(value, dict): + bs = _get_first_tensor_batch_size(value) + if bs is not None: + return bs + return None + + batch_size = _get_first_tensor_batch_size(converted) + if batch_size is None: + batch_size = 1 # Default if no tensors found + + # Wrap in TensorDict with explicit batch_size + obs_td = TensorDict(converted, batch_size=[batch_size], device=device) + + # Wrap observation in outer TensorDict with "observation" key + return TensorDict({"observation": obs_td}, batch_size=[batch_size], device=device) + + +def mean_scalar(x) -> float: + """Convert tensor or array to scalar float (mean if needed). + + Args: + x: Tensor, array, or scalar value + + Returns: + Float scalar value + """ + if hasattr(x, "detach"): + x = x.detach().cpu().numpy() + else: + x = np.asarray(x) + return float(np.mean(x)) + + +def pack_log_dict(prefix: str, data: dict) -> dict: + """Pack data dict into logging dict with prefix. + + Args: + prefix: Prefix for keys (e.g., "train", "eval") + data: Dictionary of values to pack + + Returns: + Dictionary with prefixed keys and scalar values + """ + if not isinstance(data, dict): + return {} + out = {} + for k, v in data.items(): + try: + out[f"{prefix}/{k}"] = mean_scalar(v) + except Exception: + continue + return out + + +def compute_gae( + rollout: TensorDict, + gamma: float, + gae_lambda: float, +) -> TensorDict: + """Compute Generalized Advantage Estimation (GAE) on rollout TensorDict. + + This follows the TorchRL convention where rollout has shape [T, N, ...]. + Computes advantage and value_target in-place and returns the modified TensorDict. + + Args: + rollout: TensorDict with batch_size=[T, N] containing: + - "value": Tensor[T, N, 1] - state values + - "next": TensorDict with: + - "reward": Tensor[T, N, 1] + - "done": Tensor[T, N, 1] + - "value": Tensor[T, N, 1] - next state values (bootstrapped) + gamma: Discount factor + gae_lambda: GAE lambda parameter + + Returns: + TensorDict with added keys: + - "advantage": Tensor[T, N, 1] + - "value_target": Tensor[T, N, 1] + """ + T, N = rollout.batch_size[:2] + device = rollout.device + + # Extract tensors - shape [T, N, 1] + values = rollout["value"] + rewards = rollout["next"]["reward"] + dones = rollout["next"]["done"].float() + + # Bootstrap values: use next state value from rollout["next"]["value"] + # This is computed during collection by evaluating policy on next_obs + if "value" in rollout["next"]: + bootstrap_values = rollout["next"]["value"] + else: + # If not provided, assume 0 (terminal state) + bootstrap_values = torch.zeros_like(values) + + # Compute GAE advantages using backward iteration + # advantage[t] = delta[t] + (gamma * gae_lambda) * (1 - done[t]) * advantage[t+1] + # where delta[t] = reward[t] + gamma * (1 - done[t]) * V(s_{t+1}) - V(s_t) + # V(s_{t+1}) comes from bootstrap_values[t] which was computed on next_obs[t] + + advantages = torch.zeros_like(values) + gae = torch.zeros(N, 1, device=device) + + # Iterate backwards through time + for t in reversed(range(T)): + # Compute TD error (delta) + # bootstrap_values[t] is V(s_{t+1}), the value of the next state after action at t + delta = rewards[t] + gamma * bootstrap_values[t] * (1.0 - dones[t]) - values[t] + + # Compute GAE recursively + gae = delta + gamma * gae_lambda * (1.0 - dones[t]) * gae + advantages[t] = gae - _collect_tensors(input_dict) + # Compute value targets (for value function loss) + value_targets = advantages + values - if not obs_list: - raise ValueError("No tensors found in observation dict") + # Add to rollout TensorDict (in-place) + rollout["advantage"] = advantages + rollout["value_target"] = value_targets - result = torch.cat(obs_list, dim=-1) - return result + return rollout diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 40df6d74..a67d34c3 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -23,9 +23,10 @@ from torch.utils.tensorboard import SummaryWriter from collections import deque import wandb +from tensordict import TensorDict from embodichain.lab.gym.envs.managers.event_manager import EventManager -from .helper import flatten_dict_observation +from .helper import dict_to_tensordict, mean_scalar, pack_log_dict class Trainer: @@ -36,7 +37,7 @@ def __init__( policy, env, algorithm, - num_steps: int, + buffer_size: int, batch_size: int, writer: SummaryWriter | None, eval_freq: int, @@ -48,12 +49,14 @@ def __init__( event_cfg=None, eval_event_cfg=None, num_eval_episodes: int = 5, + # Buffer config: "standard" (default) or "vla" + buffer_type: str = "standard", ): self.policy = policy self.env = env self.eval_env = eval_env self.algorithm = algorithm - self.num_steps = num_steps + self.buffer_size = buffer_size self.batch_size = batch_size self.writer = writer self.eval_freq = eval_freq @@ -63,6 +66,29 @@ def __init__( self.use_wandb = use_wandb self.num_eval_episodes = num_eval_episodes + # Buffer setup + self.buffer_type = buffer_type + device = ( + algorithm.device + if hasattr(algorithm, "device") + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + if buffer_type == "vla": + # VLA buffer: accumulate multiple rollouts with FIFO + from embodichain.agents.rl.buffer import VLABuffer + + self.buffer = VLABuffer(buffer_size=buffer_size, device=device) + elif buffer_type == "standard": + # Standard PPO buffer: single rollout, use and discard + from embodichain.agents.rl.buffer import RolloutBuffer + + self.buffer = RolloutBuffer(buffer_size=1, device=device) + else: + raise ValueError( + f"Unknown buffer_type: {buffer_type}. Use 'standard' or 'vla'." + ) + if event_cfg is not None: self.event_manager = EventManager(event_cfg, env=self.env) if eval_event_cfg is not None: @@ -75,85 +101,199 @@ def __init__( self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) - # initial obs (assume env returns torch tensors already on target device) + # Get initial observation from env (dict) and convert to TensorDict at boundary obs, _ = self.env.reset() - - # Initialize algorithm's buffer - # Flatten dict observations from ObservationManager to tensor for RL algorithms - if isinstance(obs, dict): - obs_tensor = flatten_dict_observation(obs) - obs_dim = obs_tensor.shape[-1] - num_envs = obs_tensor.shape[0] - # Store flattened observation for RL training - self.obs = obs_tensor - - action_space = getattr(self.env, "action_space", None) - action_dim = action_space.shape[-1] if action_space else None - if action_dim is None: - raise RuntimeError( - "Env must expose action_space with shape for buffer initialization." - ) - - # Algorithm manages its own buffer - self.algorithm.initialize_buffer(num_steps, num_envs, obs_dim, action_dim) + self.obs_tensordict = dict_to_tensordict(obs, self.device) + num_envs = self.obs_tensordict.batch_size[0] # episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device) # ---- lightweight helpers for dense logging ---- - @staticmethod - def _mean_scalar(x) -> float: - if hasattr(x, "detach"): - x = x.detach().cpu().numpy() - else: - x = np.asarray(x) - return float(np.mean(x)) - def _log_scalar_dict(self, prefix: str, data: dict): if not self.writer or not isinstance(data, dict): return for k, v in data.items(): try: self.writer.add_scalar( - f"{prefix}/{k}", self._mean_scalar(v), self.global_step + f"{prefix}/{k}", mean_scalar(v), self.global_step ) except Exception: continue - def _pack_log_dict(self, prefix: str, data: dict) -> dict: - if not isinstance(data, dict): - return {} - out = {} - for k, v in data.items(): - try: - out[f"{prefix}/{k}"] = self._mean_scalar(v) - except Exception: - continue - return out - def train(self, total_timesteps: int): print(f"Start training, total steps: {total_timesteps}") + print(f"Using {self.buffer_type} buffer") + + if self.buffer_type == "vla": + # VLA mode: Use async collector + self._train_async(total_timesteps) + else: + # Standard mode: Use synchronous collection + self._train_sync(total_timesteps) + + def _train_sync(self, total_timesteps: int): + """Synchronous training loop (standard PPO).""" while self.global_step < total_timesteps: - self._collect_rollout() - losses = self.algorithm.update() - self._log_train(losses) + rollout = self._collect_rollout() + + # Add rollout to buffer + self.buffer.add(rollout) + + # Train when buffer is full + if self.buffer.is_full(): + data = self.buffer.get(flatten=True) + losses = self.algorithm.update(data) + self._log_train(losses) + + # Evaluation if ( self.eval_freq > 0 and self.eval_env is not None and self.global_step % self.eval_freq == 0 ): self._eval_once(num_episodes=self.num_eval_episodes) + + # Checkpoint if self.global_step % self.save_freq == 0: self.save_checkpoint() + def _train_async(self, total_timesteps: int): + """Asynchronous training loop (VLA mode).""" + from .async_collector import AsyncCollector, AsyncCollectorStats + + # Create statistics tracker + num_envs = self.obs_tensordict.batch_size[0] + async_stats = AsyncCollectorStats(num_envs, self.device) + + # Create callback for async collector + def on_async_step(tensordict: TensorDict, env_info: dict): + """Callback for async collection statistics.""" + # Extract reward and done + reward = tensordict["next"]["reward"] + done = tensordict["next"]["done"] + + # Update statistics + async_stats.update(reward, done) + + # Update global step + num_envs = tensordict.batch_size[0] + self.global_step += num_envs + + # Log environment metrics + if isinstance(env_info, dict): + rewards_dict = env_info.get("rewards") + metrics_dict = env_info.get("metrics") + self._log_scalar_dict("rewards", rewards_dict) + self._log_scalar_dict("metrics", metrics_dict) + log_dict = {} + log_dict.update(pack_log_dict("rewards", rewards_dict)) + log_dict.update(pack_log_dict("metrics", metrics_dict)) + if log_dict and self.use_wandb: + wandb.log(log_dict, step=self.global_step) + + # Create and start async collector + collector = AsyncCollector( + env=self.env, + policy=self.policy, + buffer=self.buffer, + device=self.device, + on_step_callback=on_async_step, + ) + + print("[Trainer] Starting async collector...") + collector.start() + + # Training loop: wait for buffer to fill, then train + last_eval_step = 0 + last_save_step = 0 + update_count = 0 + + try: + while self.global_step < total_timesteps: + # Wait for buffer to fill + while not self.buffer.is_full(): + time.sleep(0.1) # Check every 100ms + if not collector.is_running(): + raise RuntimeError("Async collector stopped unexpectedly") + + # Get data and train + data = self.buffer.get(flatten=True) + losses = self.algorithm.update(data) + + # Update episode statistics from async tracker + avg_ret, avg_len = async_stats.get_avg_stats() + if not np.isnan(avg_ret): + self.ret_window.append(avg_ret) + if not np.isnan(avg_len): + self.len_window.append(avg_len) + + # Log training + self._log_train(losses) + update_count += 1 + + # Clear buffer for next collection (optional, depends on policy staleness tolerance) + # For VLA, we might keep some data for stability + # self.buffer.clear() + + # Evaluation + if ( + self.eval_freq > 0 + and self.eval_env is not None + and self.global_step - last_eval_step >= self.eval_freq + ): + # Temporarily pause collection during eval + print("[Trainer] Pausing collection for evaluation...") + collector.stop() + self._eval_once(num_episodes=self.num_eval_episodes) + collector.start() + print("[Trainer] Resuming collection...") + last_eval_step = self.global_step + + # Checkpoint + if self.global_step - last_save_step >= self.save_freq: + self.save_checkpoint() + last_save_step = self.global_step + + # Log buffer and collector stats + buffer_stats = self.buffer.get_stats() + collector_stats = collector.get_stats() + print(f"[Trainer] Buffer: {buffer_stats}") + print(f"[Trainer] Collector: {collector_stats}") + + finally: + # Always stop collector when training ends + print("[Trainer] Stopping async collector...") + collector.stop() + print(f"[Trainer] Training completed ({update_count} updates)") + @torch.no_grad() - def _collect_rollout(self): - """Collect a rollout. Algorithm controls the data collection process.""" + def _collect_rollout(self) -> TensorDict: + """Collect a rollout. Algorithm controls the data collection process. + + Returns: + TensorDict with batch_size=[T, N] containing full rollout + """ + + # Callback function for statistics and logging (uses TensorDict) + def on_step(tensordict: TensorDict, env_info: dict): + """Callback called at each step during rollout collection. + + Args: + tensordict: Complete transition TensorDict with "next" key + env_info: Environment info dict + """ + # Extract reward and done from next subdictionary + reward = tensordict["next"]["reward"] + done = tensordict["next"]["done"] + + # Squeeze if needed + if reward.dim() > 1: + reward = reward.squeeze(-1) + if done.dim() > 1: + done = done.squeeze(-1) - # Callback function for statistics and logging - def on_step(obs, actions, reward, done, info, next_obs): - """Callback called at each step during rollout collection.""" # Episode stats (stay on device; convert only when episode ends) self.curr_ret += reward self.curr_len += 1 @@ -166,31 +306,40 @@ def on_step(obs, actions, reward, done, info, next_obs): self.curr_ret[done_idx] = 0 self.curr_len[done_idx] = 0 - # Update global step and observation - # next_obs is already flattened in algorithm's collect_rollout - self.obs = next_obs - self.global_step += next_obs.shape[0] + # Update global step and observation TensorDict (already TensorDict from PPO) + next_obs = tensordict["next"]["observation"] + num_envs = next_obs.batch_size[0] + + # Prepare next tensordict + self.obs_tensordict = TensorDict( + {"observation": next_obs}, + batch_size=torch.Size([num_envs]), + device=self.device, + ) + self.global_step += num_envs - if isinstance(info, dict): - rewards_dict = info.get("rewards") - metrics_dict = info.get("metrics") + if isinstance(env_info, dict): + rewards_dict = env_info.get("rewards") + metrics_dict = env_info.get("metrics") self._log_scalar_dict("rewards", rewards_dict) self._log_scalar_dict("metrics", metrics_dict) log_dict = {} - log_dict.update(self._pack_log_dict("rewards", rewards_dict)) - log_dict.update(self._pack_log_dict("metrics", metrics_dict)) + log_dict.update(pack_log_dict("rewards", rewards_dict)) + log_dict.update(pack_log_dict("metrics", metrics_dict)) if log_dict and self.use_wandb: wandb.log(log_dict, step=self.global_step) - # Algorithm controls data collection - result = self.algorithm.collect_rollout( + # Algorithm controls data collection and returns TensorDict rollout + rollout = self.algorithm.collect_rollout( env=self.env, policy=self.policy, - obs=self.obs, - num_steps=self.num_steps, + tensordict=self.obs_tensordict, + buffer_size=self.buffer_size, on_step_callback=on_step, ) + return rollout + def _log_train(self, losses: Dict[str, float]): if self.writer: for k, v in losses.items(): @@ -243,10 +392,10 @@ def _eval_once(self, num_episodes: int = 5): episode_lengths = [] for _ in range(num_episodes): - # Reset and initialize episode tracking + # Reset and initialize episode tracking - env returns dict, convert at boundary obs, _ = self.eval_env.reset() - obs = flatten_dict_observation(obs) - num_envs = obs.shape[0] if obs.ndim == 2 else 1 + obs = dict_to_tensordict(obs, self.device) + num_envs = obs.batch_size[0] done_mask = torch.zeros(num_envs, dtype=torch.bool, device=self.device) cumulative_reward = torch.zeros( @@ -256,16 +405,22 @@ def _eval_once(self, num_episodes: int = 5): # Run episode until all environments complete while not done_mask.all(): - # Get deterministic actions from policy - actions, _, _ = self.policy.get_action(obs, deterministic=True) + # Get deterministic actions from policy (policy.forward modifies in-place) + # For deterministic eval, we can set a flag or use mean directly + # For now, use forward and extract action + obs_copy = obs.clone() + self.policy.forward(obs_copy) + actions = obs_copy["action"] + action_type = getattr(self.eval_env, "action_type", "delta_qpos") action_dict = {action_type: actions} - # Environment step - obs, reward, terminated, truncated, info = self.eval_env.step( + # Environment step - env returns dict, convert to TensorDict at boundary + next_obs, reward, terminated, truncated, info = self.eval_env.step( action_dict ) - obs = flatten_dict_observation(obs) if isinstance(obs, dict) else obs + next_obs = dict_to_tensordict(next_obs, self.device) + obs = next_obs # Update statistics only for still-running environments done = terminated | truncated diff --git a/pyproject.toml b/pyproject.toml index 0b4624d7..84328fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", + "tensordict>=0.5.0", # For TensorDict-based RL data structures ] [project.optional-dependencies] diff --git a/tests/agents/test_rl.py b/tests/agents/test_rl.py index d12cc10f..a4951fae 100644 --- a/tests/agents/test_rl.py +++ b/tests/agents/test_rl.py @@ -70,7 +70,7 @@ def setup_method(self): test_train_config = train_config.copy() test_train_config["trainer"]["gym_config"] = self.temp_gym_config_path test_train_config["trainer"]["iterations"] = 2 - test_train_config["trainer"]["rollout_steps"] = 32 + test_train_config["trainer"]["buffer_size"] = 32 test_train_config["trainer"]["eval_freq"] = 1000000 # Disable eval test_train_config["trainer"]["save_freq"] = 1000000 # Disable save test_train_config["trainer"]["headless"] = True From 26246a5348e60c9ef04fefee4c7c7a0b5d84bade Mon Sep 17 00:00:00 2001 From: yuanhaonan Date: Fri, 6 Feb 2026 14:43:47 +0800 Subject: [PATCH 02/14] vla policy wrapper --- configs/agents/rl/push_cube/train_config.json | 1 + .../agents/rl/vla_example/train_config.json | 70 +++++ docs/RL_TRAINING_FRAMEWORK.md | 270 ++++++++++++++++++ docs/VLA_INTEGRATION_GUIDE.md | 213 ++++++++++++++ embodichain/agents/rl/buffer/__init__.py | 4 +- .../agents/rl/buffer/standard_buffer.py | 5 +- .../{rollout_buffer.py => vla_buffer.py} | 0 embodichain/agents/rl/models/__init__.py | 7 + embodichain/agents/rl/models/vla_policy.py | 235 +++++++++++++++ embodichain/agents/rl/train.py | 8 + embodichain/agents/rl/utils/trainer.py | 24 +- 11 files changed, 820 insertions(+), 17 deletions(-) create mode 100644 configs/agents/rl/vla_example/train_config.json create mode 100644 docs/RL_TRAINING_FRAMEWORK.md create mode 100644 docs/VLA_INTEGRATION_GUIDE.md rename embodichain/agents/rl/buffer/{rollout_buffer.py => vla_buffer.py} (100%) create mode 100644 embodichain/agents/rl/models/vla_policy.py diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json index 07e6c961..b7639c06 100644 --- a/configs/agents/rl/push_cube/train_config.json +++ b/configs/agents/rl/push_cube/train_config.json @@ -17,6 +17,7 @@ "save_freq": 200, "use_wandb": false, "wandb_project_name": "embodychain-push_cube", + "model_type": "standard", "events": { "eval": { "record_camera": { diff --git a/configs/agents/rl/vla_example/train_config.json b/configs/agents/rl/vla_example/train_config.json new file mode 100644 index 00000000..bc48b9f5 --- /dev/null +++ b/configs/agents/rl/vla_example/train_config.json @@ -0,0 +1,70 @@ +{ + "trainer": { + "exp_name": "vla_fine_tuning_ppo", + "gym_config": "configs/agents/rl/push_cube/gym_config.json", + "seed": 42, + "device": "cuda:0", + "headless": true, + "enable_rt": false, + "gpu_id": 0, + "num_envs": 32, + "iterations": 500, + "buffer_size": 2048, + "buffer_type": "vla", + "enable_eval": true, + "num_eval_envs": 8, + "num_eval_episodes": 3, + "eval_freq": 100, + "save_freq": 100, + "use_wandb": true, + "wandb_project_name": "embodychain-vla-training", + "model_type": "vla", + "events": { + "eval": { + "record_camera": { + "func": "record_camera_data_async", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "main_cam", + "resolution": [640, 480], + "eye": [-1.4, 1.4, 2.0], + "target": [0, 0, 0], + "up": [0, 0, 1], + "intrinsics": [600, 600, 320, 240], + "save_path": "./outputs/videos/vla_eval" + } + } + } + } + }, + "policy": { + "name": "vla", + "action_dim": 7, + "vla_config": { + "model_path": "checkpoints/pretrained_vla_model.pth", + "model_class": "vla_models.GPTVLAModel", + "model_config": { + "vision_encoder": "resnet50", + "language_model": "gpt2-medium", + "action_head_hidden_size": 512, + "freeze_vision_encoder": false, + "freeze_language_model": false + } + } + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 1e-5, + "n_epochs": 4, + "batch_size": 2048, + "gamma": 0.99, + "gae_lambda": 0.95, + "clip_coef": 0.2, + "ent_coef": 0.001, + "vf_coef": 0.5, + "max_grad_norm": 1.0 + } + } +} diff --git a/docs/RL_TRAINING_FRAMEWORK.md b/docs/RL_TRAINING_FRAMEWORK.md new file mode 100644 index 00000000..800d62aa --- /dev/null +++ b/docs/RL_TRAINING_FRAMEWORK.md @@ -0,0 +1,270 @@ +# RL Training Framework + +## Overview + +Modern **TensorDict-based** RL training framework supporting standard PPO, asynchronous VLA training, and pretrained VLA model fine-tuning. + +**Key Features**: +- Pure TensorDict data flow +- Dual modes: Standard synchronous / VLA asynchronous +- Efficient buffers: Single-use / Pre-allocated circular +- VLA model integration: Load and fine-tune pretrained VLA models + +--- + +## Quick Start + +### 1. Configuration + +```json +{ + "trainer": { + "buffer_size": 2048, + "buffer_type": "standard", // "standard" or "vla" + "iterations": 500 + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 3e-4, + "gamma": 0.99, + "n_epochs": 10, + "batch_size": 64 + } + } +} +``` + +### 2. Run Training + +```bash +python embodichain/agents/rl/train.py --config configs/agents/rl/my_config.json +``` + +--- + +## Training Modes + +### Standard Mode (Default) + +**Use Case**: Regular model training + +``` +Collect data (2048 steps) → Train model → Clear buffer → Repeat +``` + +**Configuration**: +```json +{"trainer": {"buffer_type": "standard"}} +``` + +**Characteristics**: Simple, stable, low memory usage + +--- + +### VLA Async Mode + +**Use Case**: Large models with slow inference (e.g., VLA models, >1 sec/step) + +``` +Background Thread: Continuously collect data → Write to buffer +Main Thread: Wait for buffer full → Train model → Repeat +``` + +**Configuration**: +```json +{"trainer": {"buffer_type": "vla"}} +``` + +**Characteristics**: +- ✅ Parallel collection & training, 2-3x speedup +- ✅ Pre-allocated memory, optimized for high-frequency writes +- ⚠️ Slightly stale data (acceptable for on-policy algorithms) + +--- + +## Buffer Explanation + +### RolloutBuffer (Standard) + +- **Storage**: One complete rollout [T, N, ...] +- **Behavior**: Add → Train once → Clear +- **Usage**: Standard PPO + +### VLABuffer (Async) + +- **Storage**: Circular buffer [buffer_size, ...] +- **Behavior**: Incremental add → Train when full → Old data overwritten +- **Usage**: VLA async collection + +**Circular Overwrite Example** (capacity=4): +``` +[T0, _, _, _] → [T0,T1, _, _] → [T0,T1,T2, _] → [T0,T1,T2,T3] (full) +→ [T4,T1,T2,T3] (T0 overwritten) → [T4,T5,T2,T3] (T1 overwritten) +``` + +--- + +## Core API + +### Trainer + +```python +from embodichain.agents.rl.utils import Trainer + +trainer = Trainer( + policy, env, algorithm, + buffer_size=2048, + buffer_type="standard", # or "vla" + batch_size=64, + ... +) +trainer.train(total_timesteps=1000000) +``` + +### Buffer Interface + +```python +# Add data +buffer.add(rollout) # Standard mode: complete rollout +buffer.add(transition) # VLA mode: single transition + +# Get data +data = buffer.get(flatten=True) # Returns [batch, ...] + +# Check status +if buffer.is_full(): + train() +``` + +--- + +## FAQ + +### When to use VLA mode? + +Use VLA mode when inference time > 100ms/step and GPU training is fast. + +### How to set buffer capacity? + +- Standard mode: `buffer_size` = steps per rollout (typically 2048) +- VLA mode: `buffer_size` = circular buffer capacity (recommended 2048-4096) + +### Will data be stale in async mode? + +Yes, slightly stale (up to buffer_size steps), but acceptable for PPO and other on-policy algorithms. Performance gain far outweighs staleness cost. + +--- + +## VLA Model Integration + +### Overview + +The framework supports loading and fine-tuning pretrained Vision-Language-Action (VLA) models. VLA models are loaded from checkpoints and wrapped in `VLAPolicy` to conform to the standard Policy interface. + +### VLA Model Requirements + +VLA model developers should implement a model class with the following interface: + +```python +class MyVLAModel(nn.Module): + def forward(self, observations: TensorDict) -> torch.Tensor: + """Generate actions from observations. + + Args: + observations: TensorDict with keys like "rgb", "depth", "proprio", "language" + Returns: + Action tensor [B, action_dim] + """ + + def get_value(self, observations: TensorDict) -> torch.Tensor: + """Get value estimate. + + Returns: + Value tensor [B, 1] + """ + + def evaluate_actions( + self, + observations: TensorDict, + actions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Evaluate log probability and entropy. + + Returns: + (log_prob [B,], entropy [B,]) + """ +``` + +See [vla_policy.py](../embodichain/agents/rl/models/vla_policy.py) for detailed interface documentation (`VLAModelInterface`). + +### Configuration Example + +```json +{ + "trainer": { + "buffer_type": "vla", + "buffer_size": 2048, + ... + }, + "policy": { + "name": "vla", + "action_dim": 7, + "vla_config": { + "model_path": "checkpoints/pretrained_vla_model.pth", + "model_class": "vla_models.GPTVLAModel", + "model_config": { + "vision_encoder": "resnet50", + "language_model": "gpt2-medium", + "freeze_vision_encoder": false + } + } + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 1e-5, + ... + } + } +} +``` + +See [vla_example/train_config.json](../configs/agents/rl/vla_example/train_config.json) for complete example. + +### Implementation Guide for VLA Team + +1. **Implement VLA Model Class**: Create a model class conforming to `VLAModelInterface` +2. **Implement Checkpoint Loading**: Implement `load_vla_model()` function in [vla_policy.py](../embodichain/agents/rl/models/vla_policy.py) +3. **Test Integration**: Use example config to verify model loads and trains correctly + +The `load_vla_model()` function is currently a placeholder that raises `NotImplementedError` - VLA team should implement actual loading logic. + +--- + +## File Structure + +``` +embodichain/agents/rl/ +├── train.py # Entry point +├── algo/ppo.py # PPO algorithm +├── buffer/ +│ ├── standard_buffer.py # RolloutBuffer +│ └── vla_buffer.py # VLABuffer +├── models/ # Policy definitions +│ ├── policy.py # Policy base class +│ ├── actor_critic.py # Standard ActorCritic (from scratch) +│ ├── vla_policy.py # VLA model wrapper (pretrained) +│ └── ... +└── utils/ + ├── trainer.py # Training coordinator + └── async_collector.py # Async data collector +``` + +--- + +## References + +- [TensorDict Documentation](https://pytorch.org/tensordict/) +- [VLA Policy Interface](../embodichain/agents/rl/models/vla_policy.py) +- Example configs: `configs/agents/rl/` diff --git a/docs/VLA_INTEGRATION_GUIDE.md b/docs/VLA_INTEGRATION_GUIDE.md new file mode 100644 index 00000000..28dd063b --- /dev/null +++ b/docs/VLA_INTEGRATION_GUIDE.md @@ -0,0 +1,213 @@ +# VLA Model Integration Guide + +This guide explains how to integrate a VLA (Vision-Language-Action) model with the EmbodiChain RL training framework. + +## For VLA Model Developers + +### 1. Model Interface Requirements + +Your VLA model class must implement the following interface: + +```python +class YourVLAModel(nn.Module): + def __init__(self, **config): + """Initialize VLA model with configuration.""" + super().__init__() + # Your initialization code + + def forward(self, observations: TensorDict) -> torch.Tensor: + """Generate actions from observations. + + Args: + observations: TensorDict containing observation data + Expected keys may include: + - "rgb": RGB images [B, H, W, C] or [B, C, H, W] + - "depth": Depth images [B, H, W] + - "proprio": Proprioceptive state [B, proprio_dim] + - "language": Language tokens [B, seq_len] or raw strings + + Returns: + Action tensor [B, action_dim] + """ + # Your action generation code + pass + + def get_value(self, observations: TensorDict) -> torch.Tensor: + """Get value estimate for observations. + + Args: + observations: TensorDict containing observation data + + Returns: + Value tensor [B, 1] + """ + # Your value estimation code + pass + + def evaluate_actions( + self, + observations: TensorDict, + actions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Evaluate log probability and entropy for observation-action pairs. + + Args: + observations: TensorDict containing observation data + actions: Action tensor [B, action_dim] + + Returns: + Tuple of (log_prob [B,], entropy [B,]) + """ + # Your action evaluation code + pass +``` + +**Important Notes**: +- All methods must accept `TensorDict` for observations (not plain tensors) +- Handle missing observation keys gracefully (not all tasks provide all modalities) +- Your model should manage its own tokenization, preprocessing, and internal state +- Value head is required for PPO training (can be a simple MLP on top of your embeddings) + +### 2. Implement Checkpoint Loading + +Edit `embodichain/agents/rl/models/vla_policy.py` and implement the `load_vla_model()` function: + +```python +def load_vla_model( + model_path: str, + model_class: Optional[str] = None, + model_config: Optional[dict] = None, + device: torch.device = torch.device("cpu"), +) -> nn.Module: + """Load VLA model from checkpoint.""" + import importlib + + # Parse model class path + module_name, class_name = model_class.rsplit(".", 1) + module = importlib.import_module(module_name) + ModelClass = getattr(module, class_name) + + # Initialize model + model = ModelClass(**model_config) + + # Load checkpoint + checkpoint = torch.load(model_path, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + + # Move to device + model.to(device) + model.eval() # Start in eval mode (trainer will set to train) + + return model +``` + +Adapt this to your checkpoint format (may use different keys, compression, etc.). + +### 3. Configuration Format + +Create a training config JSON: + +```json +{ + "trainer": { + "buffer_type": "vla", + "buffer_size": 2048, + ... + }, + "policy": { + "name": "vla", + "action_dim": 7, + "vla_config": { + "model_path": "path/to/your/checkpoint.pth", + "model_class": "your_package.YourVLAModel", + "model_config": { + "vision_encoder": "resnet50", + "language_model": "gpt2", + "freeze_vision_encoder": false, + ... // your model-specific config + } + } + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 1e-5, // Lower LR for fine-tuning + ... + } + } +} +``` + +### 4. Testing Your Integration + +```bash +# Run training with your VLA model +python embodichain/agents/rl/train.py --config configs/agents/rl/your_vla_config.json +``` + +Expected workflow: +1. `load_vla_model()` loads your pretrained checkpoint +2. `VLAPolicy` wraps your model and adapts it to Policy interface +3. RL trainer fine-tunes your model using PPO (or other algorithms) + +### 5. Tips for VLA Models + +**Value Head**: +- PPO requires value estimates +- Add a simple value head (e.g., linear layer) on your embeddings +- Can initialize randomly or pretrain with imitation learning + +**Action Distribution**: +- `evaluate_actions()` needs to compute log_prob and entropy +- For continuous actions: use Gaussian distribution (mean from your model, learnable std) +- For discrete actions: use Categorical distribution + +**Gradient Flow**: +- Set `freeze_vision_encoder: true` to only fine-tune action/value heads +- Set `freeze_language_model: true` if using large LMs (reduce memory) +- Or fine-tune entire model with lower learning rate + +**Memory Optimization**: +- VLA models are large - use `buffer_type: "vla"` with async collection +- Reduce `num_envs` if running out of memory +- Consider gradient checkpointing for very large models + +## For RL Framework Users + +If VLA model is already integrated, simply configure and run: + +```bash +python embodichain/agents/rl/train.py --config configs/agents/rl/vla_example/train_config.json +``` + +See [RL_TRAINING_FRAMEWORK.md](RL_TRAINING_FRAMEWORK.md) for general usage. + +## Architecture Diagram + +``` +┌─────────────────────────────────────────────────────────────┐ +│ RL Training Framework │ +│ │ +│ ┌────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ Trainer │─────▶│ VLAPolicy │────▶│ VLA Model │ │ +│ │ │ │ (wrapper) │ │ (your code) │ │ +│ └────────────┘ └──────────────┘ └──────────────┘ │ +│ │ │ │ │ +│ │ │ │ │ +│ TensorDict TensorDict TensorDict │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ [obs, action, [forward, [vision, │ +│ reward, done] get_value, language, │ +│ evaluate_actions] action] │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Reference Implementation + +See `embodichain/agents/rl/models/vla_policy.py` for: +- `VLAModelInterface`: Protocol defining required methods +- `VLAPolicy`: Wrapper that adapts VLA model to Policy interface +- `load_vla_model()`: Checkpoint loading function (to be implemented) + +Example config: `configs/agents/rl/vla_example/train_config.json` diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py index 17d3b4be..b68a7f49 100644 --- a/embodichain/agents/rl/buffer/__init__.py +++ b/embodichain/agents/rl/buffer/__init__.py @@ -22,9 +22,7 @@ - VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference) """ -from .rollout_buffer import VLABuffer +from .vla_buffer import VLABuffer from .standard_buffer import RolloutBuffer __all__ = ["RolloutBuffer", "VLABuffer"] - -__all__ = ["TensorDictRolloutBuffer"] diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index 200cc176..eea45bf9 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -37,9 +37,10 @@ def __init__(self, buffer_size: int, device: torch.device): """Initialize standard rollout buffer. Args: - buffer_size: Not used (kept for interface compatibility) + buffer_size: Buffer size from config (for interface compatibility with VLABuffer) device: Device to store tensors on """ + self.buffer_size = buffer_size self.device = device self._rollout: Optional[TensorDict] = None @@ -110,7 +111,7 @@ def get_stats(self) -> dict: """ return { "buffer_size": 1 if self._rollout is not None else 0, - "buffer_capacity": 1, + "buffer_capacity": self.buffer_size, "total_transitions": self.get_num_transitions(), "buffer_usage": 1.0 if self._rollout is not None else 0.0, } diff --git a/embodichain/agents/rl/buffer/rollout_buffer.py b/embodichain/agents/rl/buffer/vla_buffer.py similarity index 100% rename from embodichain/agents/rl/buffer/rollout_buffer.py rename to embodichain/agents/rl/buffer/vla_buffer.py diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 1c5e70a6..e996d2d8 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -22,6 +22,7 @@ from .actor_critic import ActorCritic from .policy import Policy from .mlp import MLP +from .vla_policy import VLAPolicy, build_vla_policy, load_vla_model # In-module policy registry _POLICY_REGISTRY: Dict[str, Type[Policy]] = {} @@ -78,6 +79,8 @@ def build_policy( return policy_cls( action_dim=action_dim, device=device, actor=actor, critic=critic ) + elif name == "vla": + return build_vla_policy(policy_block, action_dim, device) else: # Other policies should also use action_dim signature return policy_cls(action_dim=action_dim, device=device) @@ -103,12 +106,16 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: # default registrations register_policy("actor_critic", ActorCritic) +register_policy("vla", VLAPolicy) __all__ = [ "ActorCritic", + "VLAPolicy", "register_policy", "get_registered_policy_names", "build_policy", + "build_vla_policy", + "load_vla_model", "build_mlp_from_cfg", "get_policy_class", "Policy", diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py new file mode 100644 index 00000000..36e9754c --- /dev/null +++ b/embodichain/agents/rl/models/vla_policy.py @@ -0,0 +1,235 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""VLA Policy for RL training with pretrained models. + +This module provides VLAPolicy that inherits from Policy base class, +just like ActorCritic. VLAPolicy loads pretrained VLA model components +and exposes the same interface as other policies. +""" + +from __future__ import annotations + +from typing import Optional +import torch +import torch.nn as nn +from tensordict import TensorDict + +from .policy import Policy + + +class VLAPolicy(Policy): + """VLA Policy that loads pretrained vision-language-action models. + + Similar to ActorCritic, this class inherits from Policy and implements + the required methods. The difference is that VLAPolicy loads pretrained + model components instead of training from scratch. + + VLA model components are loaded by the VLA team's implementation and + should provide the necessary interfaces for action generation and value + estimation. + """ + + def __init__( + self, + action_dim: int, + device: torch.device, + vla_model: nn.Module, + ): + """Initialize VLA policy with pretrained model. + + Args: + action_dim: Dimension of action space + device: Device to place policy on + vla_model: Pretrained VLA model (vision encoder, language model, + action head, value head, etc.) + """ + super().__init__() + self.action_dim = action_dim + self.device = device + + # Store VLA model + self.vla_model = vla_model + self.vla_model.to(self.device) + + @torch.no_grad() + def forward(self, tensordict: TensorDict) -> TensorDict: + """Forward pass: generate action and value from VLA model. + + Args: + tensordict: Must contain "observation" key with observation data + + Returns: + Same tensordict with added keys: + - "action": Sampled action + - "sample_log_prob": Log probability of action + - "value": Value estimate + """ + # VLA team should implement forward logic here + # This is a template - actual implementation depends on VLA model structure + obs = tensordict["observation"] + + # Example: VLA model generates action and value + action, log_prob, value = self.vla_model(obs) + + tensordict["action"] = action + tensordict["sample_log_prob"] = log_prob + tensordict["value"] = value.squeeze(-1) + + return tensordict + + @torch.no_grad() + def get_value(self, tensordict: TensorDict) -> TensorDict: + """Get value estimate from VLA model. + + Args: + tensordict: Must contain "observation" key + + Returns: + Same tensordict with added "value" key + """ + obs = tensordict["observation"] + + # VLA team implements value computation + value = self.vla_model.get_value(obs) + + tensordict["value"] = value.squeeze(-1) + return tensordict + + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + """Evaluate actions using VLA model. + + Args: + tensordict: Must contain: + - "observation": Observation data + - "action": Actions to evaluate + + Returns: + Same tensordict with added keys: + - "sample_log_prob": Log probability of actions + - "entropy": Entropy of action distribution + - "value": Value estimate + """ + obs = tensordict["observation"] + actions = tensordict["action"] + + # VLA team implements action evaluation + log_prob, entropy, value = self.vla_model.evaluate_actions(obs, actions) + + tensordict["sample_log_prob"] = log_prob + tensordict["entropy"] = entropy + tensordict["value"] = value.squeeze(-1) + + return tensordict + + +def load_vla_model( + model_path: str, + model_class: Optional[str] = None, + model_config: Optional[dict] = None, + device: torch.device = torch.device("cpu"), +) -> nn.Module: + """Load VLA model from checkpoint. + + This function should be implemented by the VLA team to load their + pretrained VLA model (vision encoder, language model, action head, etc.). + + The returned module should have methods: + - forward(obs) -> (action, log_prob, value) + - get_value(obs) -> value + - evaluate_actions(obs, actions) -> (log_prob, entropy, value) + + Args: + model_path: Path to checkpoint file + model_class: Fully qualified class name for VLA model + model_config: Configuration dict for model initialization + device: Device to load model on + + Returns: + Initialized VLA model module + + Example implementation by VLA team: + ```python + def load_vla_model(model_path, model_class, model_config, device): + import importlib + + # Import VLA model class + module_name, class_name = model_class.rsplit(".", 1) + module = importlib.import_module(module_name) + ModelClass = getattr(module, class_name) + + # Initialize model + model = ModelClass(**model_config) + + # Load checkpoint + checkpoint = torch.load(model_path, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + + model.to(device) + model.eval() + + return model + ``` + """ + raise NotImplementedError( + "load_vla_model() must be implemented. " + f"Model path: {model_path}, class: {model_class}, config: {model_config}" + ) + + +def build_vla_policy( + policy_block: dict, + action_dim: int, + device: torch.device, +) -> VLAPolicy: + """Build VLA policy from configuration. + + Args: + policy_block: Configuration dict + action_dim: Dimension of action space + device: Device to place policy on + + Returns: + Initialized VLAPolicy instance + """ + vla_config = policy_block.get("vla_config") + if vla_config is None: + raise ValueError("VLA policy requires 'vla_config' in policy block") + + model_path = vla_config.get("model_path") + if model_path is None: + raise ValueError("VLA config requires 'model_path'") + + model_class = vla_config.get("model_class") + model_config = vla_config.get("model_config", {}) + model_config["action_dim"] = action_dim + + # Load VLA model + vla_model = load_vla_model( + model_path=model_path, + model_class=model_class, + model_config=model_config, + device=device, + ) + + # Create VLAPolicy instance + policy = VLAPolicy( + action_dim=action_dim, + device=device, + vla_model=vla_model, + ) + + return policy diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index 4da25442..a634f359 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -65,6 +65,7 @@ def train_from_config(config_path: str): device_str = trainer_cfg.get("device", "cpu") iterations = int(trainer_cfg.get("iterations", 250)) buffer_size = int(trainer_cfg.get("buffer_size", 2048)) + model_type = trainer_cfg.get("model_type", "standard") enable_eval = bool(trainer_cfg.get("enable_eval", False)) eval_freq = int(trainer_cfg.get("eval_freq", 10000)) save_freq = int(trainer_cfg.get("save_freq", 50000)) @@ -222,6 +223,12 @@ def _collect(item): actor=actor, critic=critic, ) + elif policy_name.lower() == "vla": + # VLA policy loads pretrained model from checkpoint + logger.info( + f"Loading VLA model from config: {policy_block.get('vla_config', {})}" + ) + policy = build_policy(policy_block, action_dim=action_dim, device=device) else: policy = build_policy(policy_block, action_dim=action_dim, device=device) @@ -286,6 +293,7 @@ def _collect(item): event_cfg=train_event_cfg, eval_event_cfg=eval_event_cfg if enable_eval else {}, num_eval_episodes=num_eval_episodes, + model_type=model_type, ) logger.log_info("Generic training initialized") diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index a67d34c3..f4f44641 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -49,8 +49,8 @@ def __init__( event_cfg=None, eval_event_cfg=None, num_eval_episodes: int = 5, - # Buffer config: "standard" (default) or "vla" - buffer_type: str = "standard", + # Model type: "standard" (default PPO) or "vla" + model_type: str = "standard", ): self.policy = policy self.env = env @@ -66,27 +66,27 @@ def __init__( self.use_wandb = use_wandb self.num_eval_episodes = num_eval_episodes - # Buffer setup - self.buffer_type = buffer_type + # Buffer setup (depends on model_type) + self.model_type = model_type device = ( algorithm.device if hasattr(algorithm, "device") else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) - if buffer_type == "vla": - # VLA buffer: accumulate multiple rollouts with FIFO + if model_type == "vla": + # VLA model: accumulate multiple rollouts with FIFO buffer from embodichain.agents.rl.buffer import VLABuffer self.buffer = VLABuffer(buffer_size=buffer_size, device=device) - elif buffer_type == "standard": - # Standard PPO buffer: single rollout, use and discard + elif model_type == "standard": + # Standard PPO model: single rollout, use and discard from embodichain.agents.rl.buffer import RolloutBuffer - self.buffer = RolloutBuffer(buffer_size=1, device=device) + self.buffer = RolloutBuffer(buffer_size=buffer_size, device=device) else: raise ValueError( - f"Unknown buffer_type: {buffer_type}. Use 'standard' or 'vla'." + f"Unknown model_type: {model_type}. Use 'standard' or 'vla'." ) if event_cfg is not None: @@ -124,9 +124,9 @@ def _log_scalar_dict(self, prefix: str, data: dict): def train(self, total_timesteps: int): print(f"Start training, total steps: {total_timesteps}") - print(f"Using {self.buffer_type} buffer") + print(f"Model type: {self.model_type}") - if self.buffer_type == "vla": + if self.model_type == "vla": # VLA mode: Use async collector self._train_async(total_timesteps) else: From fd298f8da547e4122f8589d1c1e689efe3ff9964 Mon Sep 17 00:00:00 2001 From: yuanhaonan Date: Fri, 6 Feb 2026 15:51:02 +0800 Subject: [PATCH 03/14] update --- docs/RL_TRAINING_FRAMEWORK.md | 270 ---------------- docs/VLA_INTEGRATION_GUIDE.md | 213 ------------- docs/rl_training_guide.md | 292 ++++++++++++++++++ embodichain/agents/rl/ARCHITECTURE.md | 216 +++++++++++++ embodichain/agents/rl/algo/base.py | 26 +- embodichain/agents/rl/algo/ppo.py | 135 +------- embodichain/agents/rl/collector/__init__.py | 26 ++ .../{utils => collector}/async_collector.py | 29 +- embodichain/agents/rl/collector/base.py | 64 ++++ .../agents/rl/collector/sync_collector.py | 130 ++++++++ embodichain/agents/rl/models/actor_critic.py | 21 +- embodichain/agents/rl/models/vla_policy.py | 9 +- embodichain/agents/rl/utils/__init__.py | 3 - embodichain/agents/rl/utils/trainer.py | 248 +++++---------- 14 files changed, 873 insertions(+), 809 deletions(-) delete mode 100644 docs/RL_TRAINING_FRAMEWORK.md delete mode 100644 docs/VLA_INTEGRATION_GUIDE.md create mode 100644 docs/rl_training_guide.md create mode 100644 embodichain/agents/rl/ARCHITECTURE.md create mode 100644 embodichain/agents/rl/collector/__init__.py rename embodichain/agents/rl/{utils => collector}/async_collector.py (93%) create mode 100644 embodichain/agents/rl/collector/base.py create mode 100644 embodichain/agents/rl/collector/sync_collector.py diff --git a/docs/RL_TRAINING_FRAMEWORK.md b/docs/RL_TRAINING_FRAMEWORK.md deleted file mode 100644 index 800d62aa..00000000 --- a/docs/RL_TRAINING_FRAMEWORK.md +++ /dev/null @@ -1,270 +0,0 @@ -# RL Training Framework - -## Overview - -Modern **TensorDict-based** RL training framework supporting standard PPO, asynchronous VLA training, and pretrained VLA model fine-tuning. - -**Key Features**: -- Pure TensorDict data flow -- Dual modes: Standard synchronous / VLA asynchronous -- Efficient buffers: Single-use / Pre-allocated circular -- VLA model integration: Load and fine-tune pretrained VLA models - ---- - -## Quick Start - -### 1. Configuration - -```json -{ - "trainer": { - "buffer_size": 2048, - "buffer_type": "standard", // "standard" or "vla" - "iterations": 500 - }, - "algorithm": { - "name": "ppo", - "cfg": { - "learning_rate": 3e-4, - "gamma": 0.99, - "n_epochs": 10, - "batch_size": 64 - } - } -} -``` - -### 2. Run Training - -```bash -python embodichain/agents/rl/train.py --config configs/agents/rl/my_config.json -``` - ---- - -## Training Modes - -### Standard Mode (Default) - -**Use Case**: Regular model training - -``` -Collect data (2048 steps) → Train model → Clear buffer → Repeat -``` - -**Configuration**: -```json -{"trainer": {"buffer_type": "standard"}} -``` - -**Characteristics**: Simple, stable, low memory usage - ---- - -### VLA Async Mode - -**Use Case**: Large models with slow inference (e.g., VLA models, >1 sec/step) - -``` -Background Thread: Continuously collect data → Write to buffer -Main Thread: Wait for buffer full → Train model → Repeat -``` - -**Configuration**: -```json -{"trainer": {"buffer_type": "vla"}} -``` - -**Characteristics**: -- ✅ Parallel collection & training, 2-3x speedup -- ✅ Pre-allocated memory, optimized for high-frequency writes -- ⚠️ Slightly stale data (acceptable for on-policy algorithms) - ---- - -## Buffer Explanation - -### RolloutBuffer (Standard) - -- **Storage**: One complete rollout [T, N, ...] -- **Behavior**: Add → Train once → Clear -- **Usage**: Standard PPO - -### VLABuffer (Async) - -- **Storage**: Circular buffer [buffer_size, ...] -- **Behavior**: Incremental add → Train when full → Old data overwritten -- **Usage**: VLA async collection - -**Circular Overwrite Example** (capacity=4): -``` -[T0, _, _, _] → [T0,T1, _, _] → [T0,T1,T2, _] → [T0,T1,T2,T3] (full) -→ [T4,T1,T2,T3] (T0 overwritten) → [T4,T5,T2,T3] (T1 overwritten) -``` - ---- - -## Core API - -### Trainer - -```python -from embodichain.agents.rl.utils import Trainer - -trainer = Trainer( - policy, env, algorithm, - buffer_size=2048, - buffer_type="standard", # or "vla" - batch_size=64, - ... -) -trainer.train(total_timesteps=1000000) -``` - -### Buffer Interface - -```python -# Add data -buffer.add(rollout) # Standard mode: complete rollout -buffer.add(transition) # VLA mode: single transition - -# Get data -data = buffer.get(flatten=True) # Returns [batch, ...] - -# Check status -if buffer.is_full(): - train() -``` - ---- - -## FAQ - -### When to use VLA mode? - -Use VLA mode when inference time > 100ms/step and GPU training is fast. - -### How to set buffer capacity? - -- Standard mode: `buffer_size` = steps per rollout (typically 2048) -- VLA mode: `buffer_size` = circular buffer capacity (recommended 2048-4096) - -### Will data be stale in async mode? - -Yes, slightly stale (up to buffer_size steps), but acceptable for PPO and other on-policy algorithms. Performance gain far outweighs staleness cost. - ---- - -## VLA Model Integration - -### Overview - -The framework supports loading and fine-tuning pretrained Vision-Language-Action (VLA) models. VLA models are loaded from checkpoints and wrapped in `VLAPolicy` to conform to the standard Policy interface. - -### VLA Model Requirements - -VLA model developers should implement a model class with the following interface: - -```python -class MyVLAModel(nn.Module): - def forward(self, observations: TensorDict) -> torch.Tensor: - """Generate actions from observations. - - Args: - observations: TensorDict with keys like "rgb", "depth", "proprio", "language" - Returns: - Action tensor [B, action_dim] - """ - - def get_value(self, observations: TensorDict) -> torch.Tensor: - """Get value estimate. - - Returns: - Value tensor [B, 1] - """ - - def evaluate_actions( - self, - observations: TensorDict, - actions: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """Evaluate log probability and entropy. - - Returns: - (log_prob [B,], entropy [B,]) - """ -``` - -See [vla_policy.py](../embodichain/agents/rl/models/vla_policy.py) for detailed interface documentation (`VLAModelInterface`). - -### Configuration Example - -```json -{ - "trainer": { - "buffer_type": "vla", - "buffer_size": 2048, - ... - }, - "policy": { - "name": "vla", - "action_dim": 7, - "vla_config": { - "model_path": "checkpoints/pretrained_vla_model.pth", - "model_class": "vla_models.GPTVLAModel", - "model_config": { - "vision_encoder": "resnet50", - "language_model": "gpt2-medium", - "freeze_vision_encoder": false - } - } - }, - "algorithm": { - "name": "ppo", - "cfg": { - "learning_rate": 1e-5, - ... - } - } -} -``` - -See [vla_example/train_config.json](../configs/agents/rl/vla_example/train_config.json) for complete example. - -### Implementation Guide for VLA Team - -1. **Implement VLA Model Class**: Create a model class conforming to `VLAModelInterface` -2. **Implement Checkpoint Loading**: Implement `load_vla_model()` function in [vla_policy.py](../embodichain/agents/rl/models/vla_policy.py) -3. **Test Integration**: Use example config to verify model loads and trains correctly - -The `load_vla_model()` function is currently a placeholder that raises `NotImplementedError` - VLA team should implement actual loading logic. - ---- - -## File Structure - -``` -embodichain/agents/rl/ -├── train.py # Entry point -├── algo/ppo.py # PPO algorithm -├── buffer/ -│ ├── standard_buffer.py # RolloutBuffer -│ └── vla_buffer.py # VLABuffer -├── models/ # Policy definitions -│ ├── policy.py # Policy base class -│ ├── actor_critic.py # Standard ActorCritic (from scratch) -│ ├── vla_policy.py # VLA model wrapper (pretrained) -│ └── ... -└── utils/ - ├── trainer.py # Training coordinator - └── async_collector.py # Async data collector -``` - ---- - -## References - -- [TensorDict Documentation](https://pytorch.org/tensordict/) -- [VLA Policy Interface](../embodichain/agents/rl/models/vla_policy.py) -- Example configs: `configs/agents/rl/` diff --git a/docs/VLA_INTEGRATION_GUIDE.md b/docs/VLA_INTEGRATION_GUIDE.md deleted file mode 100644 index 28dd063b..00000000 --- a/docs/VLA_INTEGRATION_GUIDE.md +++ /dev/null @@ -1,213 +0,0 @@ -# VLA Model Integration Guide - -This guide explains how to integrate a VLA (Vision-Language-Action) model with the EmbodiChain RL training framework. - -## For VLA Model Developers - -### 1. Model Interface Requirements - -Your VLA model class must implement the following interface: - -```python -class YourVLAModel(nn.Module): - def __init__(self, **config): - """Initialize VLA model with configuration.""" - super().__init__() - # Your initialization code - - def forward(self, observations: TensorDict) -> torch.Tensor: - """Generate actions from observations. - - Args: - observations: TensorDict containing observation data - Expected keys may include: - - "rgb": RGB images [B, H, W, C] or [B, C, H, W] - - "depth": Depth images [B, H, W] - - "proprio": Proprioceptive state [B, proprio_dim] - - "language": Language tokens [B, seq_len] or raw strings - - Returns: - Action tensor [B, action_dim] - """ - # Your action generation code - pass - - def get_value(self, observations: TensorDict) -> torch.Tensor: - """Get value estimate for observations. - - Args: - observations: TensorDict containing observation data - - Returns: - Value tensor [B, 1] - """ - # Your value estimation code - pass - - def evaluate_actions( - self, - observations: TensorDict, - actions: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """Evaluate log probability and entropy for observation-action pairs. - - Args: - observations: TensorDict containing observation data - actions: Action tensor [B, action_dim] - - Returns: - Tuple of (log_prob [B,], entropy [B,]) - """ - # Your action evaluation code - pass -``` - -**Important Notes**: -- All methods must accept `TensorDict` for observations (not plain tensors) -- Handle missing observation keys gracefully (not all tasks provide all modalities) -- Your model should manage its own tokenization, preprocessing, and internal state -- Value head is required for PPO training (can be a simple MLP on top of your embeddings) - -### 2. Implement Checkpoint Loading - -Edit `embodichain/agents/rl/models/vla_policy.py` and implement the `load_vla_model()` function: - -```python -def load_vla_model( - model_path: str, - model_class: Optional[str] = None, - model_config: Optional[dict] = None, - device: torch.device = torch.device("cpu"), -) -> nn.Module: - """Load VLA model from checkpoint.""" - import importlib - - # Parse model class path - module_name, class_name = model_class.rsplit(".", 1) - module = importlib.import_module(module_name) - ModelClass = getattr(module, class_name) - - # Initialize model - model = ModelClass(**model_config) - - # Load checkpoint - checkpoint = torch.load(model_path, map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) - - # Move to device - model.to(device) - model.eval() # Start in eval mode (trainer will set to train) - - return model -``` - -Adapt this to your checkpoint format (may use different keys, compression, etc.). - -### 3. Configuration Format - -Create a training config JSON: - -```json -{ - "trainer": { - "buffer_type": "vla", - "buffer_size": 2048, - ... - }, - "policy": { - "name": "vla", - "action_dim": 7, - "vla_config": { - "model_path": "path/to/your/checkpoint.pth", - "model_class": "your_package.YourVLAModel", - "model_config": { - "vision_encoder": "resnet50", - "language_model": "gpt2", - "freeze_vision_encoder": false, - ... // your model-specific config - } - } - }, - "algorithm": { - "name": "ppo", - "cfg": { - "learning_rate": 1e-5, // Lower LR for fine-tuning - ... - } - } -} -``` - -### 4. Testing Your Integration - -```bash -# Run training with your VLA model -python embodichain/agents/rl/train.py --config configs/agents/rl/your_vla_config.json -``` - -Expected workflow: -1. `load_vla_model()` loads your pretrained checkpoint -2. `VLAPolicy` wraps your model and adapts it to Policy interface -3. RL trainer fine-tunes your model using PPO (or other algorithms) - -### 5. Tips for VLA Models - -**Value Head**: -- PPO requires value estimates -- Add a simple value head (e.g., linear layer) on your embeddings -- Can initialize randomly or pretrain with imitation learning - -**Action Distribution**: -- `evaluate_actions()` needs to compute log_prob and entropy -- For continuous actions: use Gaussian distribution (mean from your model, learnable std) -- For discrete actions: use Categorical distribution - -**Gradient Flow**: -- Set `freeze_vision_encoder: true` to only fine-tune action/value heads -- Set `freeze_language_model: true` if using large LMs (reduce memory) -- Or fine-tune entire model with lower learning rate - -**Memory Optimization**: -- VLA models are large - use `buffer_type: "vla"` with async collection -- Reduce `num_envs` if running out of memory -- Consider gradient checkpointing for very large models - -## For RL Framework Users - -If VLA model is already integrated, simply configure and run: - -```bash -python embodichain/agents/rl/train.py --config configs/agents/rl/vla_example/train_config.json -``` - -See [RL_TRAINING_FRAMEWORK.md](RL_TRAINING_FRAMEWORK.md) for general usage. - -## Architecture Diagram - -``` -┌─────────────────────────────────────────────────────────────┐ -│ RL Training Framework │ -│ │ -│ ┌────────────┐ ┌──────────────┐ ┌──────────────┐ │ -│ │ Trainer │─────▶│ VLAPolicy │────▶│ VLA Model │ │ -│ │ │ │ (wrapper) │ │ (your code) │ │ -│ └────────────┘ └──────────────┘ └──────────────┘ │ -│ │ │ │ │ -│ │ │ │ │ -│ TensorDict TensorDict TensorDict │ -│ │ │ │ │ -│ ▼ ▼ ▼ │ -│ [obs, action, [forward, [vision, │ -│ reward, done] get_value, language, │ -│ evaluate_actions] action] │ -└─────────────────────────────────────────────────────────────┘ -``` - -## Reference Implementation - -See `embodichain/agents/rl/models/vla_policy.py` for: -- `VLAModelInterface`: Protocol defining required methods -- `VLAPolicy`: Wrapper that adapts VLA model to Policy interface -- `load_vla_model()`: Checkpoint loading function (to be implemented) - -Example config: `configs/agents/rl/vla_example/train_config.json` diff --git a/docs/rl_training_guide.md b/docs/rl_training_guide.md new file mode 100644 index 00000000..3db3d072 --- /dev/null +++ b/docs/rl_training_guide.md @@ -0,0 +1,292 @@ +# RL Training Framework Guide + +TensorDict-based RL framework supporting standard PPO and asynchronous VLA training. + +--- + +## Quick Start + +### Configuration + +```json +{ + "trainer": { + "buffer_size": 2048, + "model_type": "standard" // or "vla" + }, + "policy": {"name": "actor_critic"}, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 3e-4, + "gamma": 0.99, + "n_epochs": 10, + "batch_size": 64 + } + } +} +``` + +### Run Training + +```bash +python embodichain/agents/rl/train.py --config configs/agents/rl/my_config.json +``` + +--- + +## Architecture + +``` +Trainer → Collector (sync/async) → Buffer (standard/vla) → Algorithm (PPO) +``` + +**Components**: +- **Collector**: Gather data from environment (SyncCollector / AsyncCollector) +- **Buffer**: Store transitions (RolloutBuffer / VLABuffer) +- **Algorithm**: Update policy (PPO) +- **Trainer**: Coordinate training loop + +--- + +## Training Modes + +### Standard Mode (Default) + +**For**: Normal models (<100ms inference/step) + +``` +SyncCollector → Collect 2048 steps → Train → Clear buffer → Repeat +``` + +**Config**: `{"trainer": {"model_type": "standard"}}` + +**Pros**: Simple, stable, low memory, no staleness + +### VLA Async Mode + +**For**: Large models (>1 sec inference/step) + +``` +Background: AsyncCollector → Continuously collect → VLABuffer +Main: Wait for buffer full → Train → Repeat +``` + +**Config**: `{"trainer": {"model_type": "vla"}}` + +**Pros**: 2-3x speedup via parallel collection +**Cons**: Data staleness, higher memory + +--- + +## Collectors + +### SyncCollector + +Collects complete rollout synchronously: + +```python +from embodichain.agents.rl.collector import SyncCollector + +collector = SyncCollector(env, policy, device, callback) +rollout = collector.collect(num_steps=2048) # [T, N, ...] +``` + +### AsyncCollector + +Runs in background thread: + +```python +from embodichain.agents.rl.collector import AsyncCollector + +collector = AsyncCollector(env, policy, buffer, device, callback) +collector.start() # Begin background collection +# ... buffer fills automatically ... +collector.stop() # Stop collection +``` + +--- + +## Buffers + +### RolloutBuffer (Standard) + +Single-use buffer: + +```python +from embodichain.agents.rl.buffer import RolloutBuffer + +buffer = RolloutBuffer(buffer_size=2048, device=device) +buffer.add(rollout) # [T, N, ...] +data = buffer.get(flatten=True) # [T*N, ...], auto-clears +``` + +### VLABuffer (Async) + +Circular FIFO buffer: + +```python +from embodichain.agents.rl.buffer import VLABuffer + +buffer = VLABuffer(buffer_size=4096, device=device) +buffer.add(transition) # Single step +data = buffer.get(flatten=True) # [buffer_size, ...] when full +``` + +**Circular behavior**: `[T0,T1,T2,T3]` → add T4 → `[T4,T1,T2,T3]` (T0 overwritten) + +--- + +## VLA Integration + +### 1. Implement Model + +```python +class MyVLAModel(nn.Module): + def forward(self, obs: TensorDict) -> TensorDict: + # Add 'action', 'sample_log_prob', 'value' + ... + def get_value(self, obs: TensorDict) -> TensorDict: + # Add 'value' + ... + def evaluate_actions(self, obs: TensorDict) -> TensorDict: + # Add 'sample_log_prob', 'entropy', 'value' + ... +``` + +### 2. Implement Loading + +Edit `embodichain/agents/rl/models/vla_policy.py`: + +```python +def load_vla_model(model_path, model_class, model_config, device): + model = MyVLAModel(**model_config) + model.load_state_dict(torch.load(model_path)) + return model.to(device) +``` + +### 3. Configure + +```json +{ + "trainer": {"model_type": "vla"}, + "policy": { + "name": "vla", + "vla_config": { + "model_path": "checkpoints/vla.pt", + "model_class": "MyVLAModel", + "model_config": {} + } + } +} +``` + +--- + +## Common APIs + +### Trainer + +```python +from embodichain.agents.rl.utils import Trainer + +trainer = Trainer( + policy, env, algorithm, + buffer_size=2048, + model_type="standard", # or "vla" + ... +) +trainer.train(total_timesteps=1000000) +``` + +### Buffer Methods + +```python +buffer.add(data) # Add data +data = buffer.get(flatten=True) # Retrieve data +buffer.is_full() # Check ready status +buffer.clear() # Clear buffer +buffer.get_stats() # Statistics +``` + +### Algorithm + +```python +from embodichain.agents.rl.algo import PPO, PPOCfg + +algorithm = PPO(PPOCfg(...), policy) +losses = algorithm.update(rollout) # Returns loss dict +``` + +--- + +## FAQ + +**Q: When use VLA mode?** +A: Inference >100ms/step AND GPU training fast + +**Q: Buffer size?** +A: Standard: 2048-4096 (rollout size). VLA: 2048-4096 (buffer capacity) + +**Q: Data staleness impact?** +A: Minor. PPO robust to staleness. 2-3x speedup >> small penalty + +**Q: Debug data flow?** +A: `buffer.get_stats()` or `_print_tensordict_tree(rollout)` in ppo.py + +--- + +## Workflows + +### Standard + +```python +collector = SyncCollector(env, policy, device, callback) +while step < total: + rollout = collector.collect(num_steps=2048) + buffer.add(rollout) + data = buffer.get(flatten=True) + losses = algorithm.update(data) +``` + +### VLA + +```python +collector = AsyncCollector(env, policy, buffer, device, callback) +collector.start() +while step < total: + while not buffer.is_full(): + time.sleep(0.1) + data = buffer.get(flatten=True) + losses = algorithm.update(data) +collector.stop() +``` + +--- + +## File Structure + +``` +embodichain/agents/rl/ +├── train.py # Entry point +├── algo/ppo.py # PPO algorithm +├── buffer/ +│ ├── standard_buffer.py # RolloutBuffer +│ └── vla_buffer.py # VLABuffer +├── collector/ +│ ├── base.py # BaseCollector +│ ├── sync_collector.py # SyncCollector +│ └── async_collector.py # AsyncCollector +├── models/ +│ ├── actor_critic.py # Standard policy +│ └── vla_policy.py # VLA wrapper +└── utils/trainer.py # Training coordinator +``` + +--- + +## References + +- [TensorDict Docs](https://pytorch.org/tensordict/) +- [PPO Paper](https://arxiv.org/abs/1707.06347) +- Example configs: `configs/agents/rl/` diff --git a/embodichain/agents/rl/ARCHITECTURE.md b/embodichain/agents/rl/ARCHITECTURE.md new file mode 100644 index 00000000..c83e2ff1 --- /dev/null +++ b/embodichain/agents/rl/ARCHITECTURE.md @@ -0,0 +1,216 @@ +# RL训练框架架构 + +## 总体流程 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Trainer │ +│ (训练总协调者) │ +│ │ +│ ┌─────────────────┐ ┌──────────────────┐ │ +│ │ 初始化阶段 │ │ 训练循环 │ │ +│ │ │ │ │ │ +│ │ 1. 创建Policy │───────▶│ while epoch: │ │ +│ │ 2. 创建Algo │ │ ├─ 收集数据 │ │ +│ │ 3. 创建Collector│ │ ├─ 更新策略 │ │ +│ │ 4. 创建Env │ │ └─ 评估性能 │ │ +│ └─────────────────┘ └──────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────┘ + │ + ┌──────────────┼──────────────┐ + │ │ │ + ▼ ▼ ▼ + ┌──────────┐ ┌──────────┐ ┌──────────┐ + │ Collector│ │Algorithm │ │ Policy │ + └──────────┘ └──────────┘ └──────────┘ +``` + +## 核心组件 + +### 1. Trainer(训练器) +**职责**:总协调者,串联所有组件 +``` +训练循环: + for epoch in range(n_epochs): + ├─ rollout = collector.collect(n_steps) # 收集数据 + ├─ metrics = algorithm.update(rollout) # 更新策略 + └─ eval_reward = evaluate(policy) # 评估性能 +``` + +### 2. Collector(数据收集器) +**职责**:与环境交互,收集经验数据 + +``` +┌─────────────────────────────────────────────┐ +│ Collector 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────────┐ ┌─────────────────┐ │ +│ │ SyncCollector │ │ AsyncCollector │ │ +│ │ (同步收集) │ │ (异步收集) │ │ +│ │ │ │ │ │ +│ │ 用于标准RL算法 │ │ 用于VLA模型 │ │ +│ │ - PPO │ │ - 后台持续收集 │ │ +│ │ - SAC │ │ - 独立线程 │ │ +│ └──────────────────┘ └─────────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +工作流程: + obs = env.reset() + for step in range(n_steps): + ├─ policy.forward(obs, deterministic=False) # 采样动作 + ├─ next_obs, reward, done = env.step(action) + └─ 存储到 TensorDict: (obs, action, reward, done, value) + return rollout_tensordict # [T, N] 格式 +``` + +### 3. Algorithm(算法) +**职责**:策略更新逻辑 + +``` +┌─────────────────────────────────────────────┐ +│ Algorithm 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ PPO │ │ SAC │ ... │ +│ │ │ │ │ │ +│ │ - GAE计算 │ │ - Q学习 │ │ +│ │ - Clip损失 │ │ - Soft更新 │ │ +│ │ - 价值损失 │ │ - 熵正则化 │ │ +│ └──────────────┘ └──────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +工作流程: + def update(rollout: TensorDict) -> dict: + ├─ 计算优势函数 (GAE) + ├─ 多轮优化循环 + │ ├─ policy.evaluate_actions(batch) # 重新计算log_prob + │ ├─ 计算loss (clip + value + entropy) + │ └─ optimizer.step() + └─ return metrics +``` + +### 4. Policy(策略) +**职责**:神经网络,输出动作和价值 + +``` +┌─────────────────────────────────────────────┐ +│ Policy 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ ActorCritic │ │ VLAPolicy │ │ +│ │ │ │ │ │ +│ │ - MLP网络 │ │ - 视觉语言 │ │ +│ │ - 高斯策略 │ │ - 预训练模型 │ │ +│ └──────────────┘ └──────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +接口方法: + 1. forward(obs, deterministic=False) + ├─ 训练时:采样动作 (deterministic=False) + ├─ 评估时:确定性动作 (deterministic=True) + └─ 返回:action, log_prob, value + + 2. evaluate_actions(obs, action) + └─ 重新计算给定动作的log_prob和entropy + + 3. get_value(obs) + └─ 仅返回价值估计 +``` + +## 数据流动(TensorDict) + +``` +Environment ──▶ Collector ──▶ Algorithm ──▶ Policy + │ │ │ │ + │ TensorDict TensorDict Parameters + │ [T, N] [batch] Update + │ │ │ │ + └───────────────┴──────────────┴────────────┘ + +TensorDict 结构: +{ + "observation": Tensor or nested TensorDict, + "action": Tensor[T, N, action_dim], + "reward": Tensor[T, N, 1], + "done": Tensor[T, N, 1], + "value": Tensor[T, N, 1], + "sample_log_prob": Tensor[T, N, 1], + "advantage": Tensor[T, N, 1], # GAE计算后添加 + "return": Tensor[T, N, 1], # GAE计算后添加 +} +``` + +## 完整训练流程示例 + +```python +# 1. 初始化组件 +trainer = Trainer( + env=env, + policy=ActorCritic(...), + algorithm=PPO(...), +) + +# 2. 创建Collector +collector = SyncCollector( + env=env, + policy=policy, + device=device, +) + +# 3. 训练循环 +for epoch in range(n_epochs): + + # 3.1 收集数据 + rollout = collector.collect( + n_steps=2048, + reset=True, + ) + # rollout: TensorDict[T=2048, N=num_envs] + + # 3.2 更新策略 + metrics = algorithm.update(rollout) + # metrics: {"loss": ..., "clip_frac": ..., ...} + + # 3.3 评估性能 + eval_reward = trainer.evaluate( + n_episodes=10, + deterministic=True, # 评估时使用确定性动作 + ) + + # 3.4 日志记录 + print(f"Epoch {epoch}: reward={eval_reward}, loss={metrics['loss']}") +``` + +## 关键设计原则 + +### 1. 职责分离 +- **Trainer**: 协调者,不涉及具体实现 +- **Collector**: 只负责数据收集,不做策略更新 +- **Algorithm**: 只负责策略更新,不做数据收集 +- **Policy**: 只负责网络前向,不涉及训练逻辑 + +### 2. 统一接口 +- 所有组件使用 **TensorDict** 进行数据传递 +- Policy暴露统一接口:`forward()`, `evaluate_actions()`, `get_value()` +- 易于切换不同实现(ActorCritic ↔ VLAPolicy) + +### 3. 灵活扩展 +- 添加新算法:继承 `BaseAlgorithm`,实现 `update()` +- 添加新策略:继承 `Policy`,实现三个抽象方法 +- 添加新收集器:继承 `BaseCollector`,实现 `collect()` + +### 4. 确定性评估 +```python +# 训练时(随机采样,探索) +policy.forward(obs, deterministic=False) # 使用 dist.sample() + +# 评估时(确定性,稳定) +policy.forward(obs, deterministic=True) # 使用 dist.mean +``` diff --git a/embodichain/agents/rl/algo/base.py b/embodichain/agents/rl/algo/base.py index 501ef2e7..06058f46 100644 --- a/embodichain/agents/rl/algo/base.py +++ b/embodichain/agents/rl/algo/base.py @@ -24,23 +24,21 @@ class BaseAlgorithm: """Base class for RL algorithms following TorchRL conventions. - Algorithms implement rollout collection and policy update using TensorDict. - No custom buffer classes - use TensorDict operations directly. + Algorithms implement policy updates using TensorDict. + Data collection is handled separately by Collector classes (SyncCollector/AsyncCollector). """ device: torch.device - def collect_rollout( - self, - env, - policy, - tensordict: TensorDict, - buffer_size: int, - on_step_callback: Callable | None = None, - ) -> TensorDict: - """Collect rollout and return TensorDict with batch_size=[T, N].""" - raise NotImplementedError - def update(self, rollout: TensorDict) -> Dict[str, float]: - """Update policy using collected data and return training losses.""" + """Update policy using collected rollout data. + + Args: + rollout: TensorDict containing collected rollout data from Collector + Expected batch_size format: [T, N] for on-policy algorithms + where T is trajectory length and N is number of environments + + Returns: + Dictionary of training metrics (losses, learning stats, etc.) + """ raise NotImplementedError diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 853e9868..99671fa6 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -15,10 +15,8 @@ # ---------------------------------------------------------------------------- import torch -from typing import Dict, Any, Callable - from tensordict import TensorDict -from embodichain.agents.rl.utils import AlgorithmCfg, compute_gae, dict_to_tensordict +from embodichain.agents.rl.utils import AlgorithmCfg, compute_gae from embodichain.utils import configclass from .base import BaseAlgorithm @@ -75,9 +73,7 @@ class PPOCfg(AlgorithmCfg): class PPO(BaseAlgorithm): """PPO algorithm using TensorDict for all data flow. - - Following TorchRL conventions: no custom buffer class, just TensorDict operations. - All data I/O uses TensorDict - no tensor fallback. + Data collection is handled by Collector classes (SyncCollector/AsyncCollector). """ def __init__(self, cfg: PPOCfg, policy): @@ -86,136 +82,25 @@ def __init__(self, cfg: PPOCfg, policy): self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) - def collect_rollout( - self, - env, - policy, - tensordict: TensorDict, - buffer_size: int, - on_step_callback: Callable | None = None, - ) -> TensorDict: - """Collect a rollout using TensorDict data flow. + def update(self, rollout: TensorDict) -> dict: + """Update the policy using collected rollout TensorDict (TorchRL style). Args: - env: Environment to collect from - policy: Policy to use for action selection - tensordict: Initial TensorDict with "observation" key - buffer_size: Number of steps to collect - on_step_callback: Optional callback called after each step + rollout: TensorDict with batch_size=[T, N] from collect_rollout() + OR [size] from VLA buffer Returns: - TensorDict with batch_size=[T, N] containing full rollout data + Dictionary of training metrics """ - policy.train() - current_td = tensordict - rollout_list = [] - - for t in range(buffer_size): - # Policy forward: adds "action", "sample_log_prob", "value" to tensordict - policy.forward(current_td) - - # Extract action for environment step - action = current_td["action"] - action_type = getattr(env, "action_type", "delta_qpos") - action_dict = {action_type: action} - - # Environment step - returns tuple (env returns dict, not TensorDict) - next_obs, reward, terminated, truncated, env_info = env.step(action_dict) - - # Convert env dict observation to TensorDict at boundary - next_obs_td = dict_to_tensordict(next_obs, self.device) - - # Build "next" TensorDict - done = terminated | truncated - next_obs_for_td = next_obs_td["observation"] - - # Ensure batch_size consistency - use next_obs_td's batch_size - batch_size = next_obs_td.batch_size[0] - - next_td = TensorDict( - { - "observation": next_obs_for_td, - "reward": ( - reward.float().unsqueeze(-1) - if reward.dim() == 1 - else reward.float() - ), - "done": ( - done.bool().unsqueeze(-1) if done.dim() == 1 else done.bool() - ), - "terminated": ( - terminated.bool().unsqueeze(-1) - if terminated.dim() == 1 - else terminated.bool() - ), - "truncated": ( - truncated.bool().unsqueeze(-1) - if truncated.dim() == 1 - else truncated.bool() - ), - }, - batch_size=torch.Size([batch_size]), - device=self.device, - ) - - # Compute next value for GAE (bootstrap value) - with torch.no_grad(): - next_value_td = TensorDict( - {"observation": next_obs_for_td}, - batch_size=next_td.batch_size, - device=self.device, - ) - policy.get_value(next_value_td) - next_td["value"] = next_value_td["value"] - - # Add "next" to current tensordict - current_td["next"] = next_td - - # Store complete transition - rollout_list.append(current_td.clone()) - - # Debug: Print TensorDict structure on first step - if len(rollout_list) == 1: - print("\n" + "=" * 80) - print("[DEBUG] Step 0 TensorDict Structure (Tree View)") - print("=" * 80) - _print_tensordict_tree(current_td, prefix="", is_last=True) - print("=" * 80 + "\n") - - # Callback for statistics and logging - if on_step_callback is not None: - on_step_callback(current_td, env_info) - - # Prepare next iteration - use the converted TensorDict - current_td = next_obs_td - - # Stack into [T, N, ...] TensorDict - rollout = torch.stack(rollout_list, dim=0) - - print("\n" + "=" * 80) - print( - f"[DEBUG] Stacked Rollout (T={rollout.batch_size[0]}, N={rollout.batch_size[1]})" - ) - print("=" * 80) - _print_tensordict_tree(rollout, prefix="", is_last=True) - print("=" * 80 + "\n") + # Ensure 2D format [T, N] for GAE computation + if len(rollout.batch_size) == 1: + rollout = rollout.unsqueeze(1) # [size] -> [size, 1] # Compute GAE advantages and returns rollout = compute_gae( rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda ) - return rollout - - def update(self, rollout: TensorDict) -> dict: - """Update the policy using collected rollout TensorDict (TorchRL style). - - Args: - rollout: TensorDict with batch_size=[T, N] from collect_rollout() - - Returns: - Dictionary of training metrics - """ # Flatten to [T*N, ...] for training flat_data = rollout.reshape(-1) total_samples = flat_data.batch_size[0] diff --git a/embodichain/agents/rl/collector/__init__.py b/embodichain/agents/rl/collector/__init__.py new file mode 100644 index 00000000..eede4937 --- /dev/null +++ b/embodichain/agents/rl/collector/__init__.py @@ -0,0 +1,26 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .base import BaseCollector +from .sync_collector import SyncCollector +from .async_collector import AsyncCollector, AsyncCollectorStats + +__all__ = [ + "BaseCollector", + "SyncCollector", + "AsyncCollector", + "AsyncCollectorStats", +] diff --git a/embodichain/agents/rl/utils/async_collector.py b/embodichain/agents/rl/collector/async_collector.py similarity index 93% rename from embodichain/agents/rl/utils/async_collector.py rename to embodichain/agents/rl/collector/async_collector.py index ac9b4b4d..063c33a3 100644 --- a/embodichain/agents/rl/utils/async_collector.py +++ b/embodichain/agents/rl/collector/async_collector.py @@ -17,16 +17,16 @@ from __future__ import annotations import threading -import time from typing import Callable, Optional import torch from tensordict import TensorDict from collections import deque -from .helper import dict_to_tensordict +from ..utils.helper import dict_to_tensordict +from .base import BaseCollector -class AsyncCollector: +class AsyncCollector(BaseCollector): """Asynchronous data collector for VLA RL scenarios. Runs in a background thread to continuously collect transitions while @@ -63,11 +63,8 @@ def __init__( device: Device for tensor operations on_step_callback: Optional callback(transition, env_info) called after each step """ - self.env = env - self.policy = policy + super().__init__(env, policy, device, on_step_callback) self.buffer = buffer - self.device = device - self.on_step_callback = on_step_callback # Thread control self._running = False @@ -78,10 +75,6 @@ def __init__( self._episode_count = 0 self._step_count = 0 - # Initialize observation - obs_dict, _ = self.env.reset() - self.obs_tensordict = dict_to_tensordict(obs_dict, self.device) - def start(self): """Start background collection thread.""" if self._running: @@ -92,6 +85,20 @@ def start(self): self._thread.start() print("[AsyncCollector] Background collection started") + def collect(self, **kwargs) -> TensorDict: + """For AsyncCollector, data is collected continuously in background. + + This method is just for interface compatibility with BaseCollector. + Actual data retrieval happens through buffer.get(). + + Returns: + Empty TensorDict (not used in async mode) + """ + raise NotImplementedError( + "AsyncCollector collects data in background thread. " + "Use buffer.get() to retrieve data instead." + ) + def stop(self): """Stop background collection thread.""" if not self._running: diff --git a/embodichain/agents/rl/collector/base.py b/embodichain/agents/rl/collector/base.py new file mode 100644 index 00000000..3f49d1e0 --- /dev/null +++ b/embodichain/agents/rl/collector/base.py @@ -0,0 +1,64 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable, Optional +import torch +from tensordict import TensorDict + +from ..utils.helper import dict_to_tensordict + + +class BaseCollector(ABC): + """Abstract base class for data collectors. + + Defines the interface that all collectors must implement. + """ + + def __init__( + self, + env, + policy, + device: torch.device, + on_step_callback: Optional[Callable] = None, + ): + """Initialize base collector. + + Args: + env: Environment to collect from + policy: Policy for action selection + device: Device for tensor operations + on_step_callback: Optional callback(tensordict, env_info) called after each step + """ + self.env = env + self.policy = policy + self.device = device + self.on_step_callback = on_step_callback + + # Initialize observation + obs_dict, _ = self.env.reset() + self.obs_tensordict = dict_to_tensordict(obs_dict, self.device) + + @abstractmethod + def collect(self, **kwargs) -> TensorDict: + """Collect data from environment. + + Returns: + TensorDict with collected data + """ + pass diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py new file mode 100644 index 00000000..4136096f --- /dev/null +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -0,0 +1,130 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +from ..utils.helper import dict_to_tensordict +from .base import BaseCollector + + +class SyncCollector(BaseCollector): + """Synchronous data collector for standard RL training. + + Collects a complete rollout of specified length, then returns it. + Used with RolloutBuffer for standard PPO training. + + Usage: + collector = SyncCollector(env, policy, device) + rollout = collector.collect(num_steps=2048) + buffer.add(rollout) + """ + + def collect(self, num_steps: int) -> TensorDict: + """Collect a synchronous rollout. + + Args: + num_steps: Number of steps to collect + + Returns: + TensorDict with batch_size=[T, N] containing full rollout + """ + self.policy.train() + current_td = self.obs_tensordict + rollout_list = [] + + for t in range(num_steps): + # Policy forward: adds "action", "sample_log_prob", "value" to tensordict + self.policy.forward(current_td) + + # Extract action for environment step + action = current_td["action"] + action_type = getattr(self.env, "action_type", "delta_qpos") + action_dict = {action_type: action} + + # Environment step - returns tuple (env returns dict, not TensorDict) + next_obs, reward, terminated, truncated, env_info = self.env.step( + action_dict + ) + + # Convert env dict observation to TensorDict at boundary + next_obs_td = dict_to_tensordict(next_obs, self.device) + + # Build "next" TensorDict + done = terminated | truncated + next_obs_for_td = next_obs_td["observation"] + + # Ensure batch_size consistency - use next_obs_td's batch_size + batch_size = next_obs_td.batch_size[0] + + next_td = TensorDict( + { + "observation": next_obs_for_td, + "reward": ( + reward.float().unsqueeze(-1) + if reward.dim() == 1 + else reward.float() + ), + "done": ( + done.bool().unsqueeze(-1) if done.dim() == 1 else done.bool() + ), + "terminated": ( + terminated.bool().unsqueeze(-1) + if terminated.dim() == 1 + else terminated.bool() + ), + "truncated": ( + truncated.bool().unsqueeze(-1) + if truncated.dim() == 1 + else truncated.bool() + ), + }, + batch_size=torch.Size([batch_size]), + device=self.device, + ) + + # Compute next value for GAE (bootstrap value) + with torch.no_grad(): + next_value_td = TensorDict( + {"observation": next_obs_for_td}, + batch_size=next_td.batch_size, + device=self.device, + ) + self.policy.get_value(next_value_td) + next_td["value"] = next_value_td["value"] + + # Add "next" to current tensordict + current_td["next"] = next_td + + # Store complete transition + rollout_list.append(current_td.clone()) + + # Callback for statistics and logging + if self.on_step_callback is not None: + self.on_step_callback(current_td, env_info) + + # Prepare next iteration - use the converted TensorDict + current_td = next_obs_td + + # Update observation for next collection + self.obs_tensordict = current_td + + # Stack into [T, N, ...] TensorDict + rollout = torch.stack(rollout_list, dim=0) + + return rollout diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index faf305cb..f404d41e 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -103,16 +103,19 @@ def _collect(item): return torch.cat(obs_list, dim=-1) @torch.no_grad() - def forward(self, tensordict: TensorDict) -> TensorDict: + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: """Forward pass: sample action and compute value (in-place modification). Args: tensordict: Must contain "observation" key + deterministic: If True, use mean instead of sampling Returns: Same tensordict with added keys: - - "action": Sampled action - - "sample_log_prob": Log probability of sampled action + - "action": Sampled or deterministic action + - "sample_log_prob": Log probability of action - "value": Value estimate - "loc": Distribution mean - "scale": Distribution std @@ -123,11 +126,17 @@ def forward(self, tensordict: TensorDict) -> TensorDict: mean = self.actor(obs_tensor) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) + dist = Normal(mean, std) - # Sample action (or use mean if deterministic mode set elsewhere) - # For now, always sample during forward; deterministic handled by setting std=0 externally if needed - action = dist.sample() + # Sample action or use mean + if deterministic: + action = mean + else: + dist = Normal(mean, std) + action = dist.sample() + + # Compute log probability log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) # Critic forward - keep shape [N, 1] for consistency with reward/done diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index 36e9754c..63bbeeab 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -66,15 +66,18 @@ def __init__( self.vla_model.to(self.device) @torch.no_grad() - def forward(self, tensordict: TensorDict) -> TensorDict: + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: """Forward pass: generate action and value from VLA model. Args: tensordict: Must contain "observation" key with observation data + deterministic: If True, use deterministic actions (passed to VLA model) Returns: Same tensordict with added keys: - - "action": Sampled action + - "action": Sampled or deterministic action - "sample_log_prob": Log probability of action - "value": Value estimate """ @@ -83,7 +86,7 @@ def forward(self, tensordict: TensorDict) -> TensorDict: obs = tensordict["observation"] # Example: VLA model generates action and value - action, log_prob, value = self.vla_model(obs) + action, log_prob, value = self.vla_model(obs, deterministic=deterministic) tensordict["action"] = action tensordict["sample_log_prob"] = log_prob diff --git a/embodichain/agents/rl/utils/__init__.py b/embodichain/agents/rl/utils/__init__.py index f1b94eb1..852cdaf9 100644 --- a/embodichain/agents/rl/utils/__init__.py +++ b/embodichain/agents/rl/utils/__init__.py @@ -16,7 +16,6 @@ from .config import AlgorithmCfg from .helper import dict_to_tensordict, mean_scalar, pack_log_dict, compute_gae -from .async_collector import AsyncCollector, AsyncCollectorStats __all__ = [ "AlgorithmCfg", @@ -24,6 +23,4 @@ "mean_scalar", "pack_log_dict", "compute_gae", - "AsyncCollector", - "AsyncCollectorStats", ] diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index f4f44641..1d8612d1 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -27,6 +27,7 @@ from embodichain.lab.gym.envs.managers.event_manager import EventManager from .helper import dict_to_tensordict, mean_scalar, pack_log_dict +from ..collector import SyncCollector, AsyncCollector class Trainer: @@ -101,12 +102,12 @@ def __init__( self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) - # Get initial observation from env (dict) and convert to TensorDict at boundary + # Initialize observation - will be used by collectors obs, _ = self.env.reset() self.obs_tensordict = dict_to_tensordict(obs, self.device) num_envs = self.obs_tensordict.batch_size[0] - # episode stats tracked on device to avoid repeated CPU round-trips + # Episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device) @@ -122,23 +123,84 @@ def _log_scalar_dict(self, prefix: str, data: dict): except Exception: continue + def _create_step_callback(self) -> Callable: + """Create step callback for collectors. + + Returns: + Callback function compatible with both sync and async collectors + """ + + def on_step(tensordict: TensorDict, env_info: dict): + """Callback called at each step during rollout collection.""" + # Extract reward and done from next subdictionary + reward = tensordict["next"]["reward"] + done = tensordict["next"]["done"] + + # Squeeze if needed + if reward.dim() > 1: + reward = reward.squeeze(-1) + if done.dim() > 1: + done = done.squeeze(-1) + + # Episode stats (stay on device; convert only when episode ends) + self.curr_ret += reward + self.curr_len += 1 + done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) + if done_idx.numel() > 0: + finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() + finished_len = self.curr_len[done_idx].detach().cpu().tolist() + self.ret_window.extend(finished_ret) + self.len_window.extend(finished_len) + self.curr_ret[done_idx] = 0 + self.curr_len[done_idx] = 0 + + # Log environment metrics + if isinstance(env_info, dict): + rewards_dict = env_info.get("rewards") + metrics_dict = env_info.get("metrics") + self._log_scalar_dict("rewards", rewards_dict) + self._log_scalar_dict("metrics", metrics_dict) + log_dict = {} + log_dict.update(pack_log_dict("rewards", rewards_dict)) + log_dict.update(pack_log_dict("metrics", metrics_dict)) + if log_dict and self.use_wandb: + wandb.log(log_dict, step=self.global_step) + + return on_step + def train(self, total_timesteps: int): print(f"Start training, total steps: {total_timesteps}") print(f"Model type: {self.model_type}") if self.model_type == "vla": - # VLA mode: Use async collector - self._train_async(total_timesteps) + collector = AsyncCollector( + env=self.env, + policy=self.policy, + buffer=self.buffer, + device=self.device, + on_step_callback=self._create_step_callback(), + ) + self._train_async(collector, total_timesteps) else: - # Standard mode: Use synchronous collection - self._train_sync(total_timesteps) + collector = SyncCollector( + env=self.env, + policy=self.policy, + device=self.device, + on_step_callback=self._create_step_callback(), + ) + self._train_sync(collector, total_timesteps) - def _train_sync(self, total_timesteps: int): + def _train_sync(self, collector: SyncCollector, total_timesteps: int): """Synchronous training loop (standard PPO).""" while self.global_step < total_timesteps: - rollout = self._collect_rollout() + # Collect rollout + rollout = collector.collect(num_steps=self.buffer_size) + + # Update global step (main thread only) + num_steps = rollout.batch_size[0] # T dimension + num_envs = rollout.batch_size[1] if len(rollout.batch_size) > 1 else 1 + self.global_step += num_steps * num_envs - # Add rollout to buffer self.buffer.add(rollout) # Train when buffer is full @@ -159,186 +221,46 @@ def _train_sync(self, total_timesteps: int): if self.global_step % self.save_freq == 0: self.save_checkpoint() - def _train_async(self, total_timesteps: int): + def _train_async(self, collector: AsyncCollector, total_timesteps: int): """Asynchronous training loop (VLA mode).""" - from .async_collector import AsyncCollector, AsyncCollectorStats - - # Create statistics tracker - num_envs = self.obs_tensordict.batch_size[0] - async_stats = AsyncCollectorStats(num_envs, self.device) - - # Create callback for async collector - def on_async_step(tensordict: TensorDict, env_info: dict): - """Callback for async collection statistics.""" - # Extract reward and done - reward = tensordict["next"]["reward"] - done = tensordict["next"]["done"] - - # Update statistics - async_stats.update(reward, done) - - # Update global step - num_envs = tensordict.batch_size[0] - self.global_step += num_envs - - # Log environment metrics - if isinstance(env_info, dict): - rewards_dict = env_info.get("rewards") - metrics_dict = env_info.get("metrics") - self._log_scalar_dict("rewards", rewards_dict) - self._log_scalar_dict("metrics", metrics_dict) - log_dict = {} - log_dict.update(pack_log_dict("rewards", rewards_dict)) - log_dict.update(pack_log_dict("metrics", metrics_dict)) - if log_dict and self.use_wandb: - wandb.log(log_dict, step=self.global_step) - - # Create and start async collector - collector = AsyncCollector( - env=self.env, - policy=self.policy, - buffer=self.buffer, - device=self.device, - on_step_callback=on_async_step, - ) - - print("[Trainer] Starting async collector...") collector.start() - - # Training loop: wait for buffer to fill, then train - last_eval_step = 0 - last_save_step = 0 - update_count = 0 + print("[Trainer] Async collector started") try: while self.global_step < total_timesteps: # Wait for buffer to fill while not self.buffer.is_full(): - time.sleep(0.1) # Check every 100ms + time.sleep(0.1) if not collector.is_running(): raise RuntimeError("Async collector stopped unexpectedly") # Get data and train data = self.buffer.get(flatten=True) - losses = self.algorithm.update(data) - # Update episode statistics from async tracker - avg_ret, avg_len = async_stats.get_avg_stats() - if not np.isnan(avg_ret): - self.ret_window.append(avg_ret) - if not np.isnan(avg_len): - self.len_window.append(avg_len) + # Update global step based on collected data (main thread only) + batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 + self.global_step += batch_size - # Log training + losses = self.algorithm.update(data) self._log_train(losses) - update_count += 1 - # Clear buffer for next collection (optional, depends on policy staleness tolerance) - # For VLA, we might keep some data for stability - # self.buffer.clear() - - # Evaluation + # Evaluation (pause collector during eval) if ( self.eval_freq > 0 and self.eval_env is not None - and self.global_step - last_eval_step >= self.eval_freq + and self.global_step % self.eval_freq == 0 ): - # Temporarily pause collection during eval - print("[Trainer] Pausing collection for evaluation...") collector.stop() self._eval_once(num_episodes=self.num_eval_episodes) collector.start() - print("[Trainer] Resuming collection...") - last_eval_step = self.global_step # Checkpoint - if self.global_step - last_save_step >= self.save_freq: + if self.global_step % self.save_freq == 0: self.save_checkpoint() - last_save_step = self.global_step - - # Log buffer and collector stats - buffer_stats = self.buffer.get_stats() - collector_stats = collector.get_stats() - print(f"[Trainer] Buffer: {buffer_stats}") - print(f"[Trainer] Collector: {collector_stats}") finally: - # Always stop collector when training ends - print("[Trainer] Stopping async collector...") collector.stop() - print(f"[Trainer] Training completed ({update_count} updates)") - - @torch.no_grad() - def _collect_rollout(self) -> TensorDict: - """Collect a rollout. Algorithm controls the data collection process. - - Returns: - TensorDict with batch_size=[T, N] containing full rollout - """ - - # Callback function for statistics and logging (uses TensorDict) - def on_step(tensordict: TensorDict, env_info: dict): - """Callback called at each step during rollout collection. - - Args: - tensordict: Complete transition TensorDict with "next" key - env_info: Environment info dict - """ - # Extract reward and done from next subdictionary - reward = tensordict["next"]["reward"] - done = tensordict["next"]["done"] - - # Squeeze if needed - if reward.dim() > 1: - reward = reward.squeeze(-1) - if done.dim() > 1: - done = done.squeeze(-1) - - # Episode stats (stay on device; convert only when episode ends) - self.curr_ret += reward - self.curr_len += 1 - done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) - if done_idx.numel() > 0: - finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() - finished_len = self.curr_len[done_idx].detach().cpu().tolist() - self.ret_window.extend(finished_ret) - self.len_window.extend(finished_len) - self.curr_ret[done_idx] = 0 - self.curr_len[done_idx] = 0 - - # Update global step and observation TensorDict (already TensorDict from PPO) - next_obs = tensordict["next"]["observation"] - num_envs = next_obs.batch_size[0] - - # Prepare next tensordict - self.obs_tensordict = TensorDict( - {"observation": next_obs}, - batch_size=torch.Size([num_envs]), - device=self.device, - ) - self.global_step += num_envs - - if isinstance(env_info, dict): - rewards_dict = env_info.get("rewards") - metrics_dict = env_info.get("metrics") - self._log_scalar_dict("rewards", rewards_dict) - self._log_scalar_dict("metrics", metrics_dict) - log_dict = {} - log_dict.update(pack_log_dict("rewards", rewards_dict)) - log_dict.update(pack_log_dict("metrics", metrics_dict)) - if log_dict and self.use_wandb: - wandb.log(log_dict, step=self.global_step) - - # Algorithm controls data collection and returns TensorDict rollout - rollout = self.algorithm.collect_rollout( - env=self.env, - policy=self.policy, - tensordict=self.obs_tensordict, - buffer_size=self.buffer_size, - on_step_callback=on_step, - ) - - return rollout + print("[Trainer] Async collector stopped") def _log_train(self, losses: Dict[str, float]): if self.writer: @@ -405,11 +327,9 @@ def _eval_once(self, num_episodes: int = 5): # Run episode until all environments complete while not done_mask.all(): - # Get deterministic actions from policy (policy.forward modifies in-place) - # For deterministic eval, we can set a flag or use mean directly - # For now, use forward and extract action + # Get deterministic actions for evaluation obs_copy = obs.clone() - self.policy.forward(obs_copy) + self.policy.forward(obs_copy, deterministic=True) actions = obs_copy["action"] action_type = getattr(self.eval_env, "action_type", "delta_qpos") From a4db0893387a8cf1947f2aaf1e07f5721ee17cbd Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 27 Feb 2026 05:01:38 +0000 Subject: [PATCH 04/14] Fix: clear buffer after get --- embodichain/agents/rl/utils/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 1d8612d1..e0104b62 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -236,6 +236,7 @@ def _train_async(self, collector: AsyncCollector, total_timesteps: int): # Get data and train data = self.buffer.get(flatten=True) + self.buffer.clear() # Must clear to avoid retraining on same data # Update global step based on collected data (main thread only) batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 From c8c0919c0c5c606a4be4520730bd26913775c902 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 27 Feb 2026 05:06:05 +0000 Subject: [PATCH 05/14] Fix: add lock for get/is_full --- embodichain/agents/rl/buffer/vla_buffer.py | 136 +++++++++++---------- 1 file changed, 74 insertions(+), 62 deletions(-) diff --git a/embodichain/agents/rl/buffer/vla_buffer.py b/embodichain/agents/rl/buffer/vla_buffer.py index d08252ff..cdab9f56 100644 --- a/embodichain/agents/rl/buffer/vla_buffer.py +++ b/embodichain/agents/rl/buffer/vla_buffer.py @@ -16,6 +16,7 @@ from __future__ import annotations +import threading import torch from tensordict import TensorDict from typing import Optional @@ -51,6 +52,7 @@ def __init__(self, buffer_size: int, device: torch.device): self.size = 0 # Current valid data count self._total_added = 0 self._initialized = False + self._lock = threading.Lock() # Thread-safe: main thread reads, collector writes def _initialize_buffer(self, template: TensorDict) -> None: """Initialize buffer structure from first transition template. @@ -72,22 +74,23 @@ def add(self, transition: TensorDict) -> None: Args: transition: Single transition TensorDict (no batch dimension) """ - # Lazy initialization on first add - if not self._initialized: - self._initialize_buffer(transition.to(self.device)) + with self._lock: + # Lazy initialization on first add + if not self._initialized: + self._initialize_buffer(transition.to(self.device)) - # Ensure transition is on correct device - transition = transition.to(self.device) + # Ensure transition is on correct device + transition = transition.to(self.device) - # Direct index assignment (zero-copy write) - self.buffer[self.write_pos] = transition + # Direct index assignment (zero-copy write) + self.buffer[self.write_pos] = transition - # Update circular index - self.write_pos = (self.write_pos + 1) % self.buffer_size + # Update circular index + self.write_pos = (self.write_pos + 1) % self.buffer_size - # Update size (saturates at buffer_size) - self.size = min(self.size + 1, self.buffer_size) - self._total_added += 1 + # Update size (saturates at buffer_size) + self.size = min(self.size + 1, self.buffer_size) + self._total_added += 1 def add_batch(self, transitions: TensorDict) -> None: """Add multiple transitions at once (batch write). @@ -95,23 +98,24 @@ def add_batch(self, transitions: TensorDict) -> None: Args: transitions: Batch of transitions with shape [batch_size, ...] """ - batch_size = transitions.batch_size[0] + with self._lock: + batch_size = transitions.batch_size[0] - # Lazy initialization - if not self._initialized: - self._initialize_buffer(transitions[0].to(self.device)) + # Lazy initialization + if not self._initialized: + self._initialize_buffer(transitions[0].to(self.device)) - transitions = transitions.to(self.device) + transitions = transitions.to(self.device) - # Handle circular write - for i in range(batch_size): - self.buffer[self.write_pos] = transitions[i] - self.write_pos = (self.write_pos + 1) % self.buffer_size - self.size = min(self.size + 1, self.buffer_size) - self._total_added += 1 + # Handle circular write + for i in range(batch_size): + self.buffer[self.write_pos] = transitions[i] + self.write_pos = (self.write_pos + 1) % self.buffer_size + self.size = min(self.size + 1, self.buffer_size) + self._total_added += 1 def get(self, flatten: bool = True) -> TensorDict: - """Get valid data from buffer. + """Get valid data from buffer (thread-safe). Args: flatten: If True, return flattened [size, ...]. Currently only supports True. @@ -119,57 +123,65 @@ def get(self, flatten: bool = True) -> TensorDict: Returns: TensorDict with batch_size=[size, ...] containing valid data """ - if not self._initialized or self.size == 0: - raise ValueError("Buffer is empty") - - if not flatten: - raise NotImplementedError("Only flatten=True is supported for VLABuffer") - - # Return first 'size' elements (valid data) - # Note: Data is in insertion order up to write_pos, then wraps - if self.size < self.buffer_size: - # Buffer not yet full, data is [0:size] - return self.buffer[: self.size] - else: - # Buffer full, need to rearrange to maintain temporal order - # Oldest data is at write_pos, newest at write_pos-1 - indices = ( - torch.arange( - self.write_pos, - self.write_pos + self.buffer_size, - device=self.device, + with self._lock: + if not self._initialized or self.size == 0: + raise ValueError("Buffer is empty") + + if not flatten: + raise NotImplementedError( + "Only flatten=True is supported for VLABuffer" + ) + + # Return first 'size' elements (valid data) + # Note: Data is in insertion order up to write_pos, then wraps + if self.size < self.buffer_size: + # Buffer not yet full, data is [0:size] + return self.buffer[: self.size].clone() + else: + # Buffer full, need to rearrange to maintain temporal order + # Oldest data is at write_pos, newest at write_pos-1 + indices = ( + torch.arange( + self.write_pos, + self.write_pos + self.buffer_size, + device=self.device, + ) + % self.buffer_size ) - % self.buffer_size - ) - return self.buffer[indices] + return self.buffer[indices].clone() def clear(self) -> None: """Clear buffer (reset pointers, keep pre-allocated memory).""" - self.write_pos = 0 - self.size = 0 - # Keep buffer allocated for reuse + with self._lock: + self.write_pos = 0 + self.size = 0 + # Keep buffer allocated for reuse def __len__(self) -> int: """Return current number of valid transitions.""" - return self.size + with self._lock: + return self.size def is_full(self) -> bool: """Check if buffer is at full buffer_size.""" - return self.size >= self.buffer_size + with self._lock: + return self.size >= self.buffer_size def get_num_rollouts(self) -> int: """Return 1 (buffer stores transitions, not rollouts).""" - return 1 if self.size > 0 else 0 + with self._lock: + return 1 if self.size > 0 else 0 def get_stats(self) -> dict: """Get buffer statistics for logging.""" - return { - "buffer_size": self.size, - "buffer_capacity": self.buffer_size, - "total_transitions": self.size, - "total_added": self._total_added, - "buffer_usage": ( - self.size / self.buffer_size if self.buffer_size > 0 else 0.0 - ), - "write_pos": self.write_pos, - } + with self._lock: + return { + "buffer_size": self.size, + "buffer_capacity": self.buffer_size, + "total_transitions": self.size, + "total_added": self._total_added, + "buffer_usage": ( + self.size / self.buffer_size if self.buffer_size > 0 else 0.0 + ), + "write_pos": self.write_pos, + } From 71a56276135cd0b48a1172e5807672159d0ba095 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 27 Feb 2026 05:08:48 +0000 Subject: [PATCH 06/14] Fix: protect thread safety for on_step callback in Trainer --- embodichain/agents/rl/utils/trainer.py | 49 ++++++++++++++------------ 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index e0104b62..b6ebccb6 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -16,8 +16,9 @@ from __future__ import annotations -from typing import Dict, Any, Tuple, Callable +import threading import time +from typing import Dict, Any, Tuple, Callable import numpy as np import torch from torch.utils.tensorboard import SummaryWriter @@ -101,6 +102,7 @@ def __init__( self.start_time = time.time() self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) + self._stats_lock = threading.Lock() # Protects curr_ret, curr_len, ret_window, len_window (async mode) # Initialize observation - will be used by collectors obs, _ = self.env.reset() @@ -142,17 +144,18 @@ def on_step(tensordict: TensorDict, env_info: dict): if done.dim() > 1: done = done.squeeze(-1) - # Episode stats (stay on device; convert only when episode ends) - self.curr_ret += reward - self.curr_len += 1 - done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) - if done_idx.numel() > 0: - finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() - finished_len = self.curr_len[done_idx].detach().cpu().tolist() - self.ret_window.extend(finished_ret) - self.len_window.extend(finished_len) - self.curr_ret[done_idx] = 0 - self.curr_len[done_idx] = 0 + # Episode stats (thread-safe for async mode: collector writes, main reads) + with self._stats_lock: + self.curr_ret += reward + self.curr_len += 1 + done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) + if done_idx.numel() > 0: + finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() + finished_len = self.curr_len[done_idx].detach().cpu().tolist() + self.ret_window.extend(finished_ret) + self.len_window.extend(finished_len) + self.curr_ret[done_idx] = 0 + self.curr_len[done_idx] = 0 # Log environment metrics if isinstance(env_info, dict): @@ -264,28 +267,30 @@ def _train_async(self, collector: AsyncCollector, total_timesteps: int): print("[Trainer] Async collector stopped") def _log_train(self, losses: Dict[str, float]): + # Snapshot episode stats under lock (async mode: main reads, collector writes) + with self._stats_lock: + ret_list = list(self.ret_window) + len_list = list(self.len_window) + + avgR = float(np.mean(ret_list)) if ret_list else float("nan") + avgL = float(np.mean(len_list)) if len_list else float("nan") + if self.writer: for k, v in losses.items(): self.writer.add_scalar(f"train/{k}", v, self.global_step) elapsed = max(1e-6, time.time() - self.start_time) sps = self.global_step / elapsed self.writer.add_scalar("charts/SPS", sps, self.global_step) - if len(self.ret_window) > 0: + if ret_list: self.writer.add_scalar( - "charts/episode_reward_avg_100", - float(np.mean(self.ret_window)), - self.global_step, + "charts/episode_reward_avg_100", avgR, self.global_step ) - if len(self.len_window) > 0: + if len_list: self.writer.add_scalar( - "charts/episode_length_avg_100", - float(np.mean(self.len_window)), - self.global_step, + "charts/episode_length_avg_100", avgL, self.global_step ) # console sps = self.global_step / max(1e-6, time.time() - self.start_time) - avgR = np.mean(self.ret_window) if len(self.ret_window) > 0 else float("nan") - avgL = np.mean(self.len_window) if len(self.len_window) > 0 else float("nan") print( f"[train] step={self.global_step} sps={sps:.0f} avgReward(100)={avgR:.3f} avgLength(100)={avgL:.1f}" ) From e73a9cc81bad836e6a9caa658ea117989e305802 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 27 Feb 2026 05:10:59 +0000 Subject: [PATCH 07/14] Fix: removed repeated creation of normal distribution in ActorCritic --- embodichain/agents/rl/models/actor_critic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index f404d41e..99c8c0e9 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -133,7 +133,6 @@ def forward( if deterministic: action = mean else: - dist = Normal(mean, std) action = dist.sample() # Compute log probability From 5dc18bf947287c5b45a8940c61f082857d79d9b2 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 27 Feb 2026 05:23:05 +0000 Subject: [PATCH 08/14] Fix: include final smaller batch --- embodichain/agents/rl/algo/ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 99671fa6..44c5a483 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -125,10 +125,10 @@ def update(self, rollout: TensorDict) -> dict: shuffled_data = flat_data[indices] # Iterate over minibatches - num_minibatches = total_samples // self.cfg.batch_size + num_minibatches = (total_samples + self.cfg.batch_size - 1) // self.cfg.batch_size for i in range(num_minibatches): start_idx = i * self.cfg.batch_size - end_idx = start_idx + self.cfg.batch_size + end_idx = min(start_idx + self.cfg.batch_size, total_samples) batch_td = shuffled_data[start_idx:end_idx] # Extract data from TensorDict batch From da2eeeb5a080612e787b593998925db54df8285f Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 27 Feb 2026 05:30:30 +0000 Subject: [PATCH 09/14] Remove buffer_type in train_config --- configs/agents/rl/vla_example/train_config.json | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/agents/rl/vla_example/train_config.json b/configs/agents/rl/vla_example/train_config.json index bc48b9f5..50b9f9c6 100644 --- a/configs/agents/rl/vla_example/train_config.json +++ b/configs/agents/rl/vla_example/train_config.json @@ -10,7 +10,6 @@ "num_envs": 32, "iterations": 500, "buffer_size": 2048, - "buffer_type": "vla", "enable_eval": true, "num_eval_envs": 8, "num_eval_episodes": 3, From b5dfd0529bf7662b40a16caf9267aabd58a76c75 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 27 Feb 2026 06:01:59 +0000 Subject: [PATCH 10/14] Sign unification --- embodichain/agents/rl/collector/base.py | 6 +++++- embodichain/agents/rl/collector/sync_collector.py | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/embodichain/agents/rl/collector/base.py b/embodichain/agents/rl/collector/base.py index 3f49d1e0..07854ddb 100644 --- a/embodichain/agents/rl/collector/base.py +++ b/embodichain/agents/rl/collector/base.py @@ -55,9 +55,13 @@ def __init__( self.obs_tensordict = dict_to_tensordict(obs_dict, self.device) @abstractmethod - def collect(self, **kwargs) -> TensorDict: + def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: """Collect data from environment. + Args: + num_steps: Number of steps to collect (required by SyncCollector, + ignored by AsyncCollector which collects continuously). + Returns: TensorDict with collected data """ diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 4136096f..b1c41bc0 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -35,15 +35,17 @@ class SyncCollector(BaseCollector): buffer.add(rollout) """ - def collect(self, num_steps: int) -> TensorDict: + def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: """Collect a synchronous rollout. Args: - num_steps: Number of steps to collect + num_steps: Number of steps to collect (required) Returns: TensorDict with batch_size=[T, N] containing full rollout """ + if num_steps is None: + raise TypeError("SyncCollector.collect() requires num_steps") self.policy.train() current_td = self.obs_tensordict rollout_list = [] From 8f380834bce07f9c197069bb127ecc528d842d66 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 2 Mar 2026 03:01:11 +0000 Subject: [PATCH 11/14] Fix: correctly compute gae and set batch_size=[N, T] for vla_buffer --- .../agents/rl/vla_example/train_config.json | 1 + embodichain/agents/rl/algo/ppo.py | 10 +- .../agents/rl/buffer/standard_buffer.py | 6 +- embodichain/agents/rl/buffer/vla_buffer.py | 186 +++++++----------- .../agents/rl/collector/async_collector.py | 172 ++++++++-------- .../agents/rl/collector/sync_collector.py | 7 +- embodichain/agents/rl/utils/helper.py | 92 ++++----- embodichain/agents/rl/utils/trainer.py | 42 ++-- 8 files changed, 236 insertions(+), 280 deletions(-) diff --git a/configs/agents/rl/vla_example/train_config.json b/configs/agents/rl/vla_example/train_config.json index 50b9f9c6..87583f38 100644 --- a/configs/agents/rl/vla_example/train_config.json +++ b/configs/agents/rl/vla_example/train_config.json @@ -60,6 +60,7 @@ "batch_size": 2048, "gamma": 0.99, "gae_lambda": 0.95, + "rollout_time_first": false, "clip_coef": 0.2, "ent_coef": 0.001, "vf_coef": 0.5, diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 44c5a483..5aa730ed 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -69,6 +69,7 @@ class PPOCfg(AlgorithmCfg): clip_coef: float = 0.2 ent_coef: float = 0.01 vf_coef: float = 0.5 + rollout_time_first: bool = False class PPO(BaseAlgorithm): @@ -96,9 +97,14 @@ def update(self, rollout: TensorDict) -> dict: if len(rollout.batch_size) == 1: rollout = rollout.unsqueeze(1) # [size] -> [size, 1] - # Compute GAE advantages and returns + # GAE layout: use config (default False = [N, T] batch-first) + time_first = self.cfg.rollout_time_first + rollout = compute_gae( - rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda + rollout, + gamma=self.cfg.gamma, + gae_lambda=self.cfg.gae_lambda, + time_first=time_first, ) # Flatten to [T*N, ...] for training diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index eea45bf9..a838e1b2 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -57,8 +57,8 @@ def get(self, flatten: bool = True) -> TensorDict: """Get rollout from buffer and clear it (standard PPO behavior). Args: - flatten: If True, flatten to [batch_size, ...]. - If False, return as [T, N, ...]. + flatten: If True, flatten to [N*T, ...]. + If False, return as [N, T, ...] (batch-first). Returns: TensorDict with rollout data @@ -72,7 +72,7 @@ def get(self, flatten: bool = True) -> TensorDict: self._rollout = None if flatten: - # Flatten [T, N, ...] -> [T*N, ...] + # Flatten [N, T, ...] -> [N*T, ...] return rollout.reshape(-1) else: return rollout diff --git a/embodichain/agents/rl/buffer/vla_buffer.py b/embodichain/agents/rl/buffer/vla_buffer.py index cdab9f56..7555920a 100644 --- a/embodichain/agents/rl/buffer/vla_buffer.py +++ b/embodichain/agents/rl/buffer/vla_buffer.py @@ -23,165 +23,119 @@ class VLABuffer: - """FIFO rollout buffer for VLA RL with pre-allocated TensorDict storage. + """Rollout buffer for VLA RL with (B, T) batch-first layout. - Uses a single pre-allocated TensorDict with circular indexing for efficient - high-frequency transition writes. Designed for async VLA scenarios where - model inference is slow but training is fast. + Stores complete rollouts to ensure correct GAE computation (GAE requires + sequential timesteps within the same trajectory). Async collector accumulates + T steps per env, then adds the full rollout. Key characteristics: - - Pre-allocated memory: Zero-copy writes via direct indexing - - FIFO eviction: Circular buffer automatically overwrites oldest data - - Transition-level storage: Each step is a separate entry - - High-frequency writes: Optimized for async collection (no TensorDict creation overhead) + - Rollout-level storage: Collect full rollout [T, N] before adding + - Batch-first layout: Stores and returns [N, T, ...] for VLA training + - Thread-safe: Async collector writes, main thread reads + - Single rollout: When full, one rollout ready for training - Storage layout: Single TensorDict with shape [buffer_size, ...] + Storage layout: [N, T, ...] - batch (env) first, time second. """ - def __init__(self, buffer_size: int, device: torch.device): - """Initialize VLA buffer with lazy allocation. + def __init__( + self, + buffer_size: int, + device: torch.device, + num_envs: int, + ): + """Initialize VLA buffer. Args: - buffer_size: Maximum number of transitions to store - device: Device to store tensors on + buffer_size: Total transitions per rollout (T * N) + device: Device for tensors + num_envs: Number of parallel environments (N) """ self.buffer_size = buffer_size self.device = device - self.buffer: Optional[TensorDict] = None # Lazy init on first add - self.write_pos = 0 # Current write position (circular) - self.size = 0 # Current valid data count - self._total_added = 0 - self._initialized = False - self._lock = threading.Lock() # Thread-safe: main thread reads, collector writes + self.num_envs = num_envs + self.rollout_length = buffer_size // num_envs # T + if self.rollout_length * num_envs != buffer_size: + raise ValueError( + f"buffer_size ({buffer_size}) must be divisible by num_envs ({num_envs})" + ) - def _initialize_buffer(self, template: TensorDict) -> None: - """Initialize buffer structure from first transition template. + self._rollout: Optional[TensorDict] = None # [N, T, ...] + self._lock = threading.Lock() - Args: - template: First transition TensorDict to infer structure from - """ - if self._initialized: - return - - # Pre-allocate buffer with buffer_size - # Template should be a single transition [key: shape] - self.buffer = template.expand(self.buffer_size).clone() - self._initialized = True + def add_rollout(self, rollout: TensorDict) -> None: + """Add a complete rollout. Fixed layout: [N, T] (batch-first). - def add(self, transition: TensorDict) -> None: - """Add a single transition to buffer (high-frequency async writes). + GAE requires same-trajectory timesteps; we only accept full rollouts. Args: - transition: Single transition TensorDict (no batch dimension) + rollout: TensorDict with batch_size=[N, T, ...] """ with self._lock: - # Lazy initialization on first add - if not self._initialized: - self._initialize_buffer(transition.to(self.device)) - - # Ensure transition is on correct device - transition = transition.to(self.device) - - # Direct index assignment (zero-copy write) - self.buffer[self.write_pos] = transition - - # Update circular index - self.write_pos = (self.write_pos + 1) % self.buffer_size - - # Update size (saturates at buffer_size) - self.size = min(self.size + 1, self.buffer_size) - self._total_added += 1 + if rollout.batch_size[0] != self.num_envs or rollout.batch_size[1] != self.rollout_length: + raise ValueError( + f"Rollout shape {rollout.batch_size} does not match " + f"expected (N={self.num_envs}, T={self.rollout_length})" + ) + self._rollout = rollout.to(self.device) def add_batch(self, transitions: TensorDict) -> None: - """Add multiple transitions at once (batch write). - - Args: - transitions: Batch of transitions with shape [batch_size, ...] - """ - with self._lock: - batch_size = transitions.batch_size[0] - - # Lazy initialization - if not self._initialized: - self._initialize_buffer(transitions[0].to(self.device)) - - transitions = transitions.to(self.device) - - # Handle circular write - for i in range(batch_size): - self.buffer[self.write_pos] = transitions[i] - self.write_pos = (self.write_pos + 1) % self.buffer_size - self.size = min(self.size + 1, self.buffer_size) - self._total_added += 1 + """Deprecated: Use add_rollout. Batch must be a complete rollout [N, T].""" + if len(transitions.batch_size) >= 2: + self.add_rollout(transitions) + else: + raise NotImplementedError( + "VLABuffer requires full rollout. Use add_rollout(rollout) with [N, T]." + ) def get(self, flatten: bool = True) -> TensorDict: - """Get valid data from buffer (thread-safe). + """Get rollout from buffer (thread-safe). Args: - flatten: If True, return flattened [size, ...]. Currently only supports True. + flatten: If True, flatten to [N*T, ...] for minibatch sampling. Returns: - TensorDict with batch_size=[size, ...] containing valid data + TensorDict with batch_size=[N, T] or [N*T] when flatten=True """ with self._lock: - if not self._initialized or self.size == 0: + if self._rollout is None: raise ValueError("Buffer is empty") - if not flatten: - raise NotImplementedError( - "Only flatten=True is supported for VLABuffer" - ) + rollout = self._rollout + self._rollout = None - # Return first 'size' elements (valid data) - # Note: Data is in insertion order up to write_pos, then wraps - if self.size < self.buffer_size: - # Buffer not yet full, data is [0:size] - return self.buffer[: self.size].clone() - else: - # Buffer full, need to rearrange to maintain temporal order - # Oldest data is at write_pos, newest at write_pos-1 - indices = ( - torch.arange( - self.write_pos, - self.write_pos + self.buffer_size, - device=self.device, - ) - % self.buffer_size - ) - return self.buffer[indices].clone() + if flatten: + return rollout.reshape(-1) + return rollout def clear(self) -> None: - """Clear buffer (reset pointers, keep pre-allocated memory).""" + """Clear buffer.""" with self._lock: - self.write_pos = 0 - self.size = 0 - # Keep buffer allocated for reuse + self._rollout = None def __len__(self) -> int: - """Return current number of valid transitions.""" + """Return 1 if has rollout, 0 otherwise.""" with self._lock: - return self.size + return 1 if self._rollout is not None else 0 def is_full(self) -> bool: - """Check if buffer is at full buffer_size.""" + """True when one complete rollout is ready.""" with self._lock: - return self.size >= self.buffer_size + return self._rollout is not None def get_num_rollouts(self) -> int: - """Return 1 (buffer stores transitions, not rollouts).""" + """Return 1 if has rollout, 0 otherwise.""" with self._lock: - return 1 if self.size > 0 else 0 + return 1 if self._rollout is not None else 0 def get_stats(self) -> dict: - """Get buffer statistics for logging.""" + """Get buffer statistics.""" with self._lock: - return { - "buffer_size": self.size, - "buffer_capacity": self.buffer_size, - "total_transitions": self.size, - "total_added": self._total_added, - "buffer_usage": ( - self.size / self.buffer_size if self.buffer_size > 0 else 0.0 - ), - "write_pos": self.write_pos, - } + has_data = self._rollout is not None + return { + "buffer_size": self.rollout_length * self.num_envs, + "rollout_length": self.rollout_length, + "num_envs": self.num_envs, + "layout": "batch_first", + "has_rollout": has_data, + } diff --git a/embodichain/agents/rl/collector/async_collector.py b/embodichain/agents/rl/collector/async_collector.py index 063c33a3..f75cc235 100644 --- a/embodichain/agents/rl/collector/async_collector.py +++ b/embodichain/agents/rl/collector/async_collector.py @@ -127,104 +127,94 @@ def get_stats(self) -> dict: } def _collect_loop(self): - """Background thread main loop: continuously collect transitions. - - This method runs in a separate thread and continuously: - 1. Gets action from policy - 2. Steps environment - 3. Constructs transition TensorDict - 4. Adds to buffer (thread-safe) - 5. Updates statistics + """Background thread main loop: collect full rollout, then add to buffer. + + GAE requires sequential timesteps within the same trajectory. We accumulate + T steps (one rollout) locally, then add the complete rollout to buffer. + This ensures correct per-env trajectory ordering for GAE computation. """ + rollout_length = self.buffer.rollout_length current_td = self.obs_tensordict while self._running: try: - # Policy forward (no_grad for inference) - with torch.no_grad(): - self.policy.train() # Use stochastic policy - self.policy.forward(current_td) - - # Extract action - action = current_td["action"] - action_type = getattr(self.env, "action_type", "delta_qpos") - action_dict = {action_type: action} - - # Environment step - next_obs_dict, reward, terminated, truncated, env_info = self.env.step( - action_dict - ) - - # Convert observation to TensorDict - next_obs_td = dict_to_tensordict(next_obs_dict, self.device) - done = terminated | truncated - next_obs_for_td = next_obs_td["observation"] - batch_size = next_obs_td.batch_size[0] - - # Build "next" TensorDict - next_td = TensorDict( - { - "observation": next_obs_for_td, - "reward": ( - reward.float().unsqueeze(-1) - if reward.dim() == 1 - else reward.float() - ), - "done": ( - done.bool().unsqueeze(-1) - if done.dim() == 1 - else done.bool() - ), - "terminated": ( - terminated.bool().unsqueeze(-1) - if terminated.dim() == 1 - else terminated.bool() - ), - "truncated": ( - truncated.bool().unsqueeze(-1) - if truncated.dim() == 1 - else truncated.bool() - ), - }, - batch_size=torch.Size([batch_size]), - device=self.device, - ) - - # Compute next value for bootstrapping (GAE computation) - with torch.no_grad(): - next_value_td = TensorDict( - {"observation": next_obs_for_td}, - batch_size=next_td.batch_size, + rollout_list = [] + + for t in range(rollout_length): + # Policy forward (no_grad for inference) + with torch.no_grad(): + self.policy.train() + self.policy.forward(current_td) + + action = current_td["action"] + action_type = getattr(self.env, "action_type", "delta_qpos") + action_dict = {action_type: action} + + next_obs_dict, reward, terminated, truncated, env_info = self.env.step( + action_dict + ) + + next_obs_td = dict_to_tensordict(next_obs_dict, self.device) + done = terminated | truncated + next_obs_for_td = next_obs_td["observation"] + batch_size = next_obs_td.batch_size[0] + + next_td = TensorDict( + { + "observation": next_obs_for_td, + "reward": ( + reward.float().unsqueeze(-1) + if reward.dim() == 1 + else reward.float() + ), + "done": ( + done.bool().unsqueeze(-1) + if done.dim() == 1 + else done.bool() + ), + "terminated": ( + terminated.bool().unsqueeze(-1) + if terminated.dim() == 1 + else terminated.bool() + ), + "truncated": ( + truncated.bool().unsqueeze(-1) + if truncated.dim() == 1 + else truncated.bool() + ), + }, + batch_size=torch.Size([batch_size]), device=self.device, ) - self.policy.get_value(next_value_td) - next_td["value"] = next_value_td["value"] - - # Add "next" to current transition - current_td["next"] = next_td - - # Flatten transition for buffer (remove batch dimension for single-step storage) - # Current buffer expects transitions without batch dimension - # We need to add each parallel env's transition separately - for env_idx in range(batch_size): - transition = current_td[env_idx] # Extract single env's transition - - # Thread-safe buffer write - with self._lock: - self.buffer.add(transition) - self._step_count += 1 - - # Callback for statistics - if self.on_step_callback is not None: - self.on_step_callback(current_td, env_info) - - # Handle episode termination - if done.any(): - with self._lock: - self._episode_count += done.sum().item() - - # Prepare next observation - current_td = next_obs_td + + with torch.no_grad(): + next_value_td = TensorDict( + {"observation": next_obs_for_td}, + batch_size=next_td.batch_size, + device=self.device, + ) + self.policy.get_value(next_value_td) + next_td["value"] = next_value_td["value"] + + current_td["next"] = next_td + rollout_list.append(current_td.clone()) + + if self.on_step_callback is not None: + self.on_step_callback(current_td, env_info) + + if done.any(): + with self._lock: + self._episode_count += done.sum().item() + + current_td = next_obs_td + + # Stack along dim=1: list of [N,...] -> [N, T, ...] (batch-first) + rollout = torch.stack(rollout_list, dim=1) + self.obs_tensordict = current_td + + with self._lock: + self.buffer.add_rollout(rollout) + self._step_count += rollout.batch_size[0] * rollout.batch_size[1] except Exception as e: print(f"[AsyncCollector] Error in collection loop: {e}") diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index b1c41bc0..e2815d3c 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -42,7 +42,7 @@ def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: num_steps: Number of steps to collect (required) Returns: - TensorDict with batch_size=[T, N] containing full rollout + TensorDict with batch_size=[N, T] (batch-first) containing full rollout """ if num_steps is None: raise TypeError("SyncCollector.collect() requires num_steps") @@ -126,7 +126,6 @@ def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: # Update observation for next collection self.obs_tensordict = current_td - # Stack into [T, N, ...] TensorDict - rollout = torch.stack(rollout_list, dim=0) - + # Stack along dim=1: list of [N,...] -> [N, T, ...] (batch-first) + rollout = torch.stack(rollout_list, dim=1) return rollout diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index 17919144..1f0d46ee 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -115,66 +115,68 @@ def compute_gae( rollout: TensorDict, gamma: float, gae_lambda: float, + time_first: bool = True, ) -> TensorDict: """Compute Generalized Advantage Estimation (GAE) on rollout TensorDict. - This follows the TorchRL convention where rollout has shape [T, N, ...]. - Computes advantage and value_target in-place and returns the modified TensorDict. + Supports two layouts: + - time_first=True (default): [T, N, ...] - TorchRL convention + - time_first=False: [N, T, ...] - batch-first, matches VLA training convention + + GAE requires sequential timesteps within the same trajectory. Both layouts + ensure correct per-env trajectory ordering. Args: - rollout: TensorDict with batch_size=[T, N] containing: - - "value": Tensor[T, N, 1] - state values - - "next": TensorDict with: - - "reward": Tensor[T, N, 1] - - "done": Tensor[T, N, 1] - - "value": Tensor[T, N, 1] - next state values (bootstrapped) + rollout: TensorDict with batch_size=[T, N] or [N, T] containing: + - "value": state values + - "next": TensorDict with "reward", "done", "value" (bootstrapped) gamma: Discount factor gae_lambda: GAE lambda parameter + time_first: If True, rollout is [T, N]; if False, rollout is [N, T] Returns: - TensorDict with added keys: - - "advantage": Tensor[T, N, 1] - - "value_target": Tensor[T, N, 1] + TensorDict with added keys: "advantage", "value_target" """ - T, N = rollout.batch_size[:2] device = rollout.device - # Extract tensors - shape [T, N, 1] - values = rollout["value"] - rewards = rollout["next"]["reward"] - dones = rollout["next"]["done"].float() - - # Bootstrap values: use next state value from rollout["next"]["value"] - # This is computed during collection by evaluating policy on next_obs - if "value" in rollout["next"]: - bootstrap_values = rollout["next"]["value"] + if time_first: + # [T, N, ...] + T, N = rollout.batch_size[:2] + values = rollout["value"] + rewards = rollout["next"]["reward"] + dones = rollout["next"]["done"].float() + if "value" in rollout["next"]: + bootstrap_values = rollout["next"]["value"] + else: + bootstrap_values = torch.zeros_like(values) + + advantages = torch.zeros_like(values) + gae = torch.zeros(N, 1, device=device) + + for t in reversed(range(T)): + delta = rewards[t] + gamma * bootstrap_values[t] * (1.0 - dones[t]) - values[t] + gae = delta + gamma * gae_lambda * (1.0 - dones[t]) * gae + advantages[t] = gae else: - # If not provided, assume 0 (terminal state) - bootstrap_values = torch.zeros_like(values) - - # Compute GAE advantages using backward iteration - # advantage[t] = delta[t] + (gamma * gae_lambda) * (1 - done[t]) * advantage[t+1] - # where delta[t] = reward[t] + gamma * (1 - done[t]) * V(s_{t+1}) - V(s_t) - # V(s_{t+1}) comes from bootstrap_values[t] which was computed on next_obs[t] - - advantages = torch.zeros_like(values) - gae = torch.zeros(N, 1, device=device) + # [N, T, ...] - batch-first + N, T = rollout.batch_size[:2] + values = rollout["value"] + rewards = rollout["next"]["reward"] + dones = rollout["next"]["done"].float() + if "value" in rollout["next"]: + bootstrap_values = rollout["next"]["value"] + else: + bootstrap_values = torch.zeros_like(values) + + advantages = torch.zeros_like(values) + gae = torch.zeros(N, 1, device=device) + + for t in reversed(range(T)): + delta = rewards[:, t] + gamma * bootstrap_values[:, t] * (1.0 - dones[:, t]) - values[:, t] + gae = delta + gamma * gae_lambda * (1.0 - dones[:, t]) * gae + advantages[:, t] = gae - # Iterate backwards through time - for t in reversed(range(T)): - # Compute TD error (delta) - # bootstrap_values[t] is V(s_{t+1}), the value of the next state after action at t - delta = rewards[t] + gamma * bootstrap_values[t] * (1.0 - dones[t]) - values[t] - - # Compute GAE recursively - gae = delta + gamma * gae_lambda * (1.0 - dones[t]) * gae - advantages[t] = gae - - # Compute value targets (for value function loss) value_targets = advantages + values - - # Add to rollout TensorDict (in-place) rollout["advantage"] = advantages rollout["value_target"] = value_targets - return rollout diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index b6ebccb6..4814173f 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -76,11 +76,18 @@ def __init__( else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) + # Initialize observation and get num_envs (needed for VLA buffer) + obs, _ = env.reset() + self.obs_tensordict = dict_to_tensordict(obs, device) + num_envs = self.obs_tensordict.batch_size[0] + if model_type == "vla": - # VLA model: accumulate multiple rollouts with FIFO buffer + # VLA model: rollout-level buffer with (B,T) layout for correct GAE from embodichain.agents.rl.buffer import VLABuffer - self.buffer = VLABuffer(buffer_size=buffer_size, device=device) + self.buffer = VLABuffer( + buffer_size=buffer_size, device=device, num_envs=num_envs + ) elif model_type == "standard": # Standard PPO model: single rollout, use and discard from embodichain.agents.rl.buffer import RolloutBuffer @@ -104,11 +111,6 @@ def __init__( self.len_window = deque(maxlen=100) self._stats_lock = threading.Lock() # Protects curr_ret, curr_len, ret_window, len_window (async mode) - # Initialize observation - will be used by collectors - obs, _ = self.env.reset() - self.obs_tensordict = dict_to_tensordict(obs, self.device) - num_envs = self.obs_tensordict.batch_size[0] - # Episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device) @@ -199,16 +201,16 @@ def _train_sync(self, collector: SyncCollector, total_timesteps: int): # Collect rollout rollout = collector.collect(num_steps=self.buffer_size) - # Update global step (main thread only) - num_steps = rollout.batch_size[0] # T dimension - num_envs = rollout.batch_size[1] if len(rollout.batch_size) > 1 else 1 - self.global_step += num_steps * num_envs + # Update global step (rollout is [N, T]) + num_envs = rollout.batch_size[0] + num_steps = rollout.batch_size[1] if len(rollout.batch_size) > 1 else 1 + self.global_step += num_envs * num_steps self.buffer.add(rollout) - # Train when buffer is full + # Train when buffer is full (pass [N, T] for correct GAE) if self.buffer.is_full(): - data = self.buffer.get(flatten=True) + data = self.buffer.get(flatten=False) losses = self.algorithm.update(data) self._log_train(losses) @@ -237,13 +239,15 @@ def _train_async(self, collector: AsyncCollector, total_timesteps: int): if not collector.is_running(): raise RuntimeError("Async collector stopped unexpectedly") - # Get data and train - data = self.buffer.get(flatten=True) - self.buffer.clear() # Must clear to avoid retraining on same data + # Get data (flatten=False so PPO gets [N, T] for GAE) + data = self.buffer.get(flatten=False) + self.buffer.clear() # get() already clears; clear() is redundant - # Update global step based on collected data (main thread only) - batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 - self.global_step += batch_size + # Update global step (data is [N, T]) + if len(data.batch_size) >= 2: + self.global_step += data.batch_size[0] * data.batch_size[1] + else: + self.global_step += data.batch_size[0] if data.batch_size else 0 losses = self.algorithm.update(data) self._log_train(losses) From c8140cff441461639e9f4099046304e2a186cb81 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 9 Mar 2026 16:05:39 +0000 Subject: [PATCH 12/14] Move policy adaptor from dexechain to embodichain --- .../rl/stack_bowls/train_config_grpo.json | 61 ++ embodichain/agents/rl/buffer/vla_buffer.py | 5 +- .../agents/rl/collector/async_collector.py | 13 +- embodichain/agents/rl/collector/base.py | 14 + .../agents/rl/collector/sync_collector.py | 11 +- embodichain/agents/rl/models/vla_policy.py | 757 ++++++++++++++---- embodichain/agents/rl/train.py | 21 + embodichain/agents/rl/utils/helper.py | 14 +- embodichain/agents/rl/utils/trainer.py | 17 +- 9 files changed, 740 insertions(+), 173 deletions(-) create mode 100644 configs/agents/rl/stack_bowls/train_config_grpo.json diff --git a/configs/agents/rl/stack_bowls/train_config_grpo.json b/configs/agents/rl/stack_bowls/train_config_grpo.json new file mode 100644 index 00000000..78e9121f --- /dev/null +++ b/configs/agents/rl/stack_bowls/train_config_grpo.json @@ -0,0 +1,61 @@ +{ + "trainer": { + "exp_name": "stack_bowls_vla_grpo", + "gym_config": "/root/workspace/research/embodichain/configs/gym/stack_bowls/gym_config.json", + "seed": 42, + "device": "cuda:0", + "headless": true, + "enable_rt": false, + "gpu_id": 0, + "num_envs": 8, + "iterations": 1000, + "buffer_size": 64, + "enable_eval": false, + "eval_freq": 0, + "save_freq": 100, + "use_wandb": false, + "wandb_project_name": "embodychain-stack_bowls", + "import_modules": [ + "dexechain.lab.gym.envs.tasks.tableware.stack_bowls_v1.stack_bowls" + ], + "model_type": "standard" + }, + "policy": { + "name": "vla", + "action_dim": 14, + "vla_config": { + "model_path": "/root/workspace/output/stack_bowls/checkpoint-19000", + "instruction": "Stack the bowls.", + "inference_horizon": 32, + "action_std_init": 0.01, + "robot_type": "CobotMagic", + "gripper_open_value": 0.05, + "gripper_closed_value": 0.0, + "action_key_order": [ + "left_armdelta_qpos", + "left_eefgripper", + "right_armdelta_qpos", + "right_eefgripper" + ], + "model_config": { + "torch_dtype": "float32" + } + } + }, + "algorithm": { + "name": "grpo", + "cfg": { + "learning_rate": 1e-5, + "n_epochs": 4, + "batch_size": 256, + "gamma": 0.99, + "clip_coef": 0.2, + "ent_coef": 0.001, + "kl_coef": 0.02, + "group_size": 4, + "eps": 1e-8, + "max_grad_norm": 1.0, + "truncate_at_first_done": true + } + } +} diff --git a/embodichain/agents/rl/buffer/vla_buffer.py b/embodichain/agents/rl/buffer/vla_buffer.py index 7555920a..4d254ba0 100644 --- a/embodichain/agents/rl/buffer/vla_buffer.py +++ b/embodichain/agents/rl/buffer/vla_buffer.py @@ -72,7 +72,10 @@ def add_rollout(self, rollout: TensorDict) -> None: rollout: TensorDict with batch_size=[N, T, ...] """ with self._lock: - if rollout.batch_size[0] != self.num_envs or rollout.batch_size[1] != self.rollout_length: + if ( + rollout.batch_size[0] != self.num_envs + or rollout.batch_size[1] != self.rollout_length + ): raise ValueError( f"Rollout shape {rollout.batch_size} does not match " f"expected (N={self.num_envs}, T={self.rollout_length})" diff --git a/embodichain/agents/rl/collector/async_collector.py b/embodichain/agents/rl/collector/async_collector.py index f75cc235..8c4b1785 100644 --- a/embodichain/agents/rl/collector/async_collector.py +++ b/embodichain/agents/rl/collector/async_collector.py @@ -146,12 +146,15 @@ def _collect_loop(self): self.policy.train() self.policy.forward(current_td) - action = current_td["action"] - action_type = getattr(self.env, "action_type", "delta_qpos") - action_dict = {action_type: action} + action = ( + current_td["env_action"] + if "env_action" in current_td.keys() + else current_td["action"] + ) + env_action = self._format_env_action(action) - next_obs_dict, reward, terminated, truncated, env_info = self.env.step( - action_dict + next_obs_dict, reward, terminated, truncated, env_info = ( + self.env.step(env_action) ) next_obs_td = dict_to_tensordict(next_obs_dict, self.device) diff --git a/embodichain/agents/rl/collector/base.py b/embodichain/agents/rl/collector/base.py index 07854ddb..9664f2e2 100644 --- a/embodichain/agents/rl/collector/base.py +++ b/embodichain/agents/rl/collector/base.py @@ -49,11 +49,25 @@ def __init__( self.policy = policy self.device = device self.on_step_callback = on_step_callback + if hasattr(self.policy, "bind_env"): + self.policy.bind_env(self.env) # Initialize observation obs_dict, _ = self.env.reset() self.obs_tensordict = dict_to_tensordict(obs_dict, self.device) + def _format_env_action(self, action: torch.Tensor): + """Format policy action for the target environment. + + When an ActionManager is configured, the environment expects a mapping + keyed by the active action term name. Otherwise, the environment expects + a raw tensor that is applied directly as joint-space command. + """ + action_manager = getattr(self.env, "action_manager", None) + if action_manager is not None: + return {action_manager.action_type: action} + return action + @abstractmethod def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: """Collect data from environment. diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index e2815d3c..fca615d8 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -55,13 +55,16 @@ def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: self.policy.forward(current_td) # Extract action for environment step - action = current_td["action"] - action_type = getattr(self.env, "action_type", "delta_qpos") - action_dict = {action_type: action} + action = ( + current_td["env_action"] + if "env_action" in current_td.keys() + else current_td["action"] + ) + env_action = self._format_env_action(action) # Environment step - returns tuple (env returns dict, not TensorDict) next_obs, reward, terminated, truncated, env_info = self.env.step( - action_dict + env_action ) # Convert env dict observation to TensorDict at boundary diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index 63bbeeab..616398ec 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -1,5 +1,5 @@ # ---------------------------------------------------------------------------- -# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,128 +14,615 @@ # limitations under the License. # ---------------------------------------------------------------------------- -"""VLA Policy for RL training with pretrained models. - -This module provides VLAPolicy that inherits from Policy base class, -just like ActorCritic. VLAPolicy loads pretrained VLA model components -and exposes the same interface as other policies. -""" - from __future__ import annotations +import math +import sys +from pathlib import Path from typing import Optional + import torch import torch.nn as nn +import torch.nn.functional as F from tensordict import TensorDict +from torch.distributions.normal import Normal from .policy import Policy +__all__ = ["VLAPolicy", "load_vla_model", "build_vla_policy"] -class VLAPolicy(Policy): - """VLA Policy that loads pretrained vision-language-action models. - - Similar to ActorCritic, this class inherits from Policy and implements - the required methods. The difference is that VLAPolicy loads pretrained - model components instead of training from scratch. - VLA model components are loaded by the VLA team's implementation and - should provide the necessary interfaces for action generation and value - estimation. - """ +class VLAPolicy(Policy): + """Wrap a pretrained DexForceVLA model with the RL Policy interface.""" def __init__( self, action_dim: int, device: torch.device, vla_model: nn.Module, - ): - """Initialize VLA policy with pretrained model. - - Args: - action_dim: Dimension of action space - device: Device to place policy on - vla_model: Pretrained VLA model (vision encoder, language model, - action head, value head, etc.) - """ + instruction: str = "Stack the bowls.", + inference_horizon: int = 32, + action_std_init: float = 0.02, + robot_type: str = "CobotMagic", + gripper_open_value: float = 0.05, + gripper_closed_value: float = 0.0, + action_key_order: Optional[list[str]] = None, + ) -> None: super().__init__() - self.action_dim = action_dim self.device = device - - # Store VLA model - self.vla_model = vla_model - self.vla_model.to(self.device) + self.instruction = instruction + self.inference_horizon = inference_horizon + self.robot_type = robot_type + self.gripper_open_value = gripper_open_value + self.gripper_closed_value = gripper_closed_value + + self.vla_model = vla_model.to(self.device) + self.vla_model.eval() + + self._workspace_root = Path(__file__).resolve().parents[5] + self._dexechain_root = self._workspace_root / "embodichain" + if str(self._workspace_root) not in sys.path: + sys.path.append(str(self._workspace_root)) + if str(self._dexechain_root) not in sys.path: + sys.path.append(str(self._dexechain_root)) + + from dexechain.data.data_engine.indices_unifier import ( # pyright: ignore[reportMissingImports] + ActionIndicesGenerator, + ) + from dexechain.data.enum import ( # pyright: ignore[reportMissingImports] + ActionMode, + ControlParts, + EefNormalizer, + EndEffector, + JointType, + Modality, + ) + from dexechain.data.global_mapping import GlobalMapping # pyright: ignore[reportMissingImports] + from dexechain.lab.gym.utils.gym_utils import ( # pyright: ignore[reportMissingImports] + get_pk_serial_chain_from_robot_type, + ) + from dexechain.lab.gym.utils.misc import ( # pyright: ignore[reportMissingImports] + _data_key_to_control_part, + ) + from dexechain.utils.utility import get_right_name # pyright: ignore[reportMissingImports] + + self.ActionMode = ActionMode + self.ControlParts = ControlParts + self.EefNormalizer = EefNormalizer + self.EndEffector = EndEffector + self.JointType = JointType + self.Modality = Modality + self._data_key_to_control_part = _data_key_to_control_part + self._get_right_name = get_right_name + + self.indices_generator = ActionIndicesGenerator(self.vla_model.arm_dofs) + self.global_mapping = GlobalMapping(self.vla_model.arm_dofs) + self.pk_chain = get_pk_serial_chain_from_robot_type(self.robot_type) + + self.state_history_len = int(self.vla_model.state_history_len) + self.img_history_size = int(self.vla_model.img_history_size) + self.state_token_dim = int(self.vla_model.state_token_dim) + self.camera_used = list(getattr(self.vla_model, "camera_used", [])) + self.action_key_order = self._resolve_action_key_order(action_key_order) + self.action_dim = sum( + len(self.indices_generator.get([key])) for key in self.action_key_order + ) + if action_dim != self.action_dim: + raise ValueError( + f"Configured action_dim={action_dim} does not match decoded VLA " + f"action_dim={self.action_dim} for keys {self.action_key_order}." + ) + self.full_action_indices = self.indices_generator.get(self.vla_model.output) + + self.log_std = nn.Parameter( + torch.full( + (self.action_dim,), + float(math.log(max(action_std_init, 1e-6))), + device=self.device, + ) + ) + self.log_std_min = -5.0 + self.log_std_max = 2.0 + critic_input_dim = self.state_history_len * self.state_token_dim + self.value_head = nn.Sequential( + nn.Linear(critic_input_dim, 256), + nn.ReLU(), + nn.Linear(256, 1), + ).to(self.device) + + self._runtime_env = None + self._runtime_robot = None + self._state_history: torch.Tensor | None = None + self._image_history: torch.Tensor | None = None + + def bind_env(self, env) -> None: + self._runtime_env = env + if env is None: + self._runtime_robot = None + return + try: + self._runtime_robot = env.get_wrapper_attr("robot") + except Exception: + self._runtime_robot = None + + def _resolve_action_key_order( + self, action_key_order: Optional[list[str]] + ) -> list[str]: + output_keys = list(self.vla_model.output) + if action_key_order: + return [key for key in action_key_order if key in output_keys] + + preferred_order = [ + self.ControlParts.LEFT_ARM.value + + self.ActionMode.RELATIVE.value + + self.JointType.QPOS.value, + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value, + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value, + self.ControlParts.RIGHT_ARM.value + + self.ActionMode.RELATIVE.value + + self.JointType.QPOS.value, + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value, + self.ControlParts.RIGHT_EEF.value + self.EndEffector.GRIPPER.value, + ] + resolved = [key for key in preferred_order if key in output_keys] + if not resolved: + raise ValueError(f"No supported VLA outputs found in {output_keys}") + return resolved + + def _fit_state_value( + self, key: str, value: torch.Tensor | object, dtype: torch.dtype + ) -> torch.Tensor: + tensor = ( + value.to(self.device, dtype=dtype) + if isinstance(value, torch.Tensor) + else torch.as_tensor(value, device=self.device, dtype=dtype) + ) + if tensor.dim() == 1: + tensor = tensor.unsqueeze(0) + + target_width = len(self.global_mapping.get_indices([key])) + if tensor.shape[-1] != target_width: + if target_width == 1: + tensor = tensor.mean(dim=-1, keepdim=True) + elif tensor.shape[-1] > target_width: + tensor = tensor[..., :target_width] + else: + raise ValueError( + f"State '{key}' width {tensor.shape[-1]} cannot fit target width {target_width}." + ) + return tensor + + def _normalize_gripper(self, qpos: torch.Tensor, key: str) -> torch.Tensor: + if self._runtime_robot is not None: + normalized = self.EefNormalizer.normalize_cobotmagic_gripper( + qpos, key, is_action=False, robot=self._runtime_robot + ) + return self._fit_state_value(key, normalized, qpos.dtype).clamp(0.0, 1.0) + + if qpos.dim() >= 2 and qpos.shape[-1] > 1: + qpos = qpos.mean(dim=-1, keepdim=True) + denom = max(self.gripper_open_value - self.gripper_closed_value, 1e-6) + normalized = 1.0 - (qpos - self.gripper_closed_value) / denom + return self._fit_state_value(key, normalized.clamp(0.0, 1.0), qpos.dtype) + + def _resolve_camera_image( + self, sensor_obs: TensorDict, camera_name: str + ) -> torch.Tensor | None: + if camera_name in sensor_obs: + return sensor_obs[camera_name]["color"][..., :3].to(self.device) + + for base_camera_name in sensor_obs.keys(): + if ( + self._get_right_name(base_camera_name) == camera_name + and "color_right" in sensor_obs[base_camera_name] + ): + return sensor_obs[base_camera_name]["color_right"][..., :3].to( + self.device + ) + + return None + + def _resize_camera_image(self, image: torch.Tensor) -> torch.Tensor: + target_size = int(getattr(self.vla_model, "img_size", 0) or 0) + if target_size <= 0: + return image + if image.shape[-3:-1] == (target_size, target_size): + return image + + resized = F.interpolate( + image.permute(0, 3, 1, 2).float(), + size=(target_size, target_size), + mode="bilinear", + align_corners=False, + ) + return resized.permute(0, 2, 3, 1).to(dtype=image.dtype) + + def _extract_current_images(self, observation: TensorDict) -> torch.Tensor: + sensor_obs = observation["sensor"] + images = [] + for camera_name in self.camera_used: + image = self._resolve_camera_image(sensor_obs, camera_name) + if image is None: + raise KeyError(f"Camera '{camera_name}' not found in observation.") + images.append(self._resize_camera_image(image)) + return torch.stack(images, dim=1) + + def _split_qpos(self, qpos: torch.Tensor) -> dict[str, torch.Tensor]: + arm_dofs_per_side = self.vla_model.arm_dofs // 2 + eef_dofs_total = qpos.shape[-1] - self.vla_model.arm_dofs + eef_dofs_per_side = max(eef_dofs_total // 2, 0) + + left_arm_end = arm_dofs_per_side + left_eef_end = left_arm_end + eef_dofs_per_side + right_arm_end = left_eef_end + arm_dofs_per_side + + return { + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value: qpos[ + :, :left_arm_end + ], + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value: qpos[ + :, left_arm_end:left_eef_end + ], + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value: qpos[ + :, left_eef_end:right_arm_end + ], + self.ControlParts.RIGHT_EEF.value + self.EndEffector.GRIPPER.value: qpos[ + :, right_arm_end: + ], + } + + def _build_state_vector( + self, observation: TensorDict + ) -> tuple[torch.Tensor, torch.Tensor]: + qpos = observation["robot"][self.JointType.QPOS.value].to(self.device) + qpos_chunks = self._split_qpos(qpos) + state_entries: dict[str, torch.Tensor] = {} + + if self._runtime_env is not None and self._runtime_robot is not None: + control_parts = ( + self._runtime_env.metadata.get("dataset", {}) + .get("robot_meta", {}) + .get("control_parts", []) + ) + if not control_parts: + control_parts = [ + self.ControlParts.LEFT_ARM.value, + self.ControlParts.LEFT_EEF.value, + self.ControlParts.RIGHT_ARM.value, + self.ControlParts.RIGHT_EEF.value, + ] + for key in self.vla_model.state_meta: + part = self._data_key_to_control_part( + robot=self._runtime_robot, + control_parts=control_parts, + data_key=key, + ) + if part is None: + continue + indices = self._runtime_robot.get_joint_ids(part, remove_mimic=True) + qpos_data = qpos[:, indices] + if self.EndEffector.GRIPPER.value in key: + state_entries[key] = self._normalize_gripper(qpos_data, key) + else: + normalized = self.EefNormalizer.normalize_eef( + qpos_data, part, robot=self._runtime_robot + ) + state_entries[key] = self._fit_state_value(key, normalized, qpos.dtype) + else: + state_entries = { + self.ControlParts.LEFT_ARM.value + + self.JointType.QPOS.value: qpos_chunks[ + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value + ], + self.ControlParts.RIGHT_ARM.value + + self.JointType.QPOS.value: qpos_chunks[ + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value + ], + self.ControlParts.LEFT_EEF.value + + self.EndEffector.GRIPPER.value: self._normalize_gripper( + qpos_chunks[ + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value + ], + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value, + ), + self.ControlParts.RIGHT_EEF.value + + self.EndEffector.GRIPPER.value: self._normalize_gripper( + qpos_chunks[ + self.ControlParts.RIGHT_EEF.value + + self.EndEffector.GRIPPER.value + ], + self.ControlParts.RIGHT_EEF.value + self.EndEffector.GRIPPER.value, + ), + } + + if self.pk_chain is not None: + from dexechain.lab.gym.utils.gym_utils import ( # pyright: ignore[reportMissingImports] + map_qpos_to_eef_pose, + ) + + arm_dofs_per_side = self.vla_model.arm_dofs // 2 + arm_qpos = torch.cat( + [ + qpos_chunks[ + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value + ], + qpos_chunks[ + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value + ], + ], + dim=-1, + ) + eef_pose_dict = map_qpos_to_eef_pose( + self.pk_chain, + arm_qpos.to("cpu"), + control_parts=[ + self.ControlParts.LEFT_ARM.value, + self.ControlParts.RIGHT_ARM.value, + ], + control_ids=[ + list(range(0, arm_dofs_per_side)), + list(range(arm_dofs_per_side, arm_dofs_per_side * 2)), + ], + ) + eef_pose_dict = { + key: value.to(self.device, dtype=qpos.dtype) + if isinstance(value, torch.Tensor) + else torch.as_tensor(value, device=self.device, dtype=qpos.dtype) + for key, value in eef_pose_dict.items() + } + state_entries.update(eef_pose_dict) + + state_vector = torch.zeros( + (qpos.shape[0], self.state_token_dim), + device=self.device, + dtype=qpos.dtype, + ) + state_indicator = torch.zeros_like(state_vector) + for key in self.vla_model.state_meta: + if key not in state_entries: + continue + indices = self.global_mapping.get_indices([key]) + state_vector[:, indices] = state_entries[key] + state_indicator[:, indices] = 1 + return state_vector, state_indicator + + def _roll_history( + self, + history: torch.Tensor | None, + current: torch.Tensor, + history_len: int, + ) -> torch.Tensor: + if history is None or history.shape[0] != current.shape[0]: + return current.unsqueeze(1).repeat( + [1, history_len] + [1] * (current.dim() - 1) + ) + if history_len == 1: + return current.unsqueeze(1) + return torch.cat([history[:, 1:], current.unsqueeze(1)], dim=1) + + def _build_policy_context( + self, + observation: TensorDict, + update_history: bool, + cached_context: TensorDict | None = None, + ) -> tuple[dict[str, torch.Tensor | list[str]], torch.Tensor, TensorDict]: + current_state, current_state_indicator = self._build_state_vector(observation) + current_images = self._extract_current_images(observation) + + if cached_context is not None: + state_history = cached_context["state_history"].to(self.device) + image_history = cached_context["image_history"].to(self.device) + else: + state_history = self._roll_history( + self._state_history, current_state, self.state_history_len + ) + image_history = self._roll_history( + self._image_history, current_images, self.img_history_size + ) + if update_history: + self._state_history = state_history.detach().clone() + self._image_history = image_history.detach().clone() + + state_indicator = current_state_indicator.unsqueeze(1).repeat( + 1, state_history.shape[1], 1 + ) + action_indicator = torch.zeros( + ( + current_state.shape[0], + self.inference_horizon, + self.state_token_dim, + ), + device=self.device, + dtype=current_state.dtype, + ) + action_indicator[:, :, self.full_action_indices] = 1 + + batch = { + self.Modality.IMAGES.value: image_history, + self.Modality.STATES.value: state_history, + self.Modality.STATE_INDICATOR.value: state_indicator, + self.Modality.ACTION_INDICATOR.value: action_indicator, + "instruction": [self.instruction] * current_state.shape[0], + } + critic_input = state_history.reshape(state_history.shape[0], -1).float() + context = TensorDict( + { + "state_history": state_history.detach(), + "image_history": image_history.detach(), + }, + batch_size=[current_state.shape[0]], + device=self.device, + ) + return batch, critic_input, context + + def _predict_chunk_actions( + self, batch: dict[str, torch.Tensor | list[str]] + ) -> torch.Tensor: + self.vla_model.eval() + data = self.vla_model.brain_infer( + batch, + action_mask=batch[self.Modality.ACTION_INDICATOR.value], + precomp_lang_embed=True, + use_fix_aug=False, + ) + data = self.vla_model._compute_priviliges(data) + data = self.vla_model._compute_adaptors(data) + data = self.vla_model.cerebellum(data, None) + + from dexechain.agents.dexforce_vla.models.utils import ( # pyright: ignore[reportMissingImports] + post_process, + ) + + data = post_process( + data, + is_training=False, + **self.vla_model.global_collection, + ) + return data[self.Modality.ACTIONS.value] + + def _decode_first_action( + self, trajectory: torch.Tensor, observation: TensorDict + ) -> torch.Tensor: + first_step = trajectory[:, 0] + current_qpos = observation["robot"][self.JointType.QPOS.value].to(self.device) + qpos_chunks = self._split_qpos(current_qpos) + decoded_parts: list[torch.Tensor] = [] + + for key in self.action_key_order: + indices = self.indices_generator.get([key]) + value = first_step[:, indices] + if ( + self.ActionMode.RELATIVE.value in key + and self.JointType.QPOS.value in key + ): + if key.startswith(self.ControlParts.LEFT_ARM.value): + value = ( + qpos_chunks[ + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value + ] + + value + ) + elif key.startswith(self.ControlParts.RIGHT_ARM.value): + value = ( + qpos_chunks[ + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value + ] + + value + ) + elif self.EndEffector.GRIPPER.value in key: + value = self.gripper_closed_value + ( + 1.0 - value + ) * (self.gripper_open_value - self.gripper_closed_value) + decoded_parts.append(value) + + if not decoded_parts: + raise ValueError( + f"No action keys could be decoded from model outputs: {self.vla_model.output}" + ) + return torch.cat(decoded_parts, dim=-1).to(self.device) + + def _expand_env_action( + self, action: torch.Tensor, observation: TensorDict + ) -> torch.Tensor: + expanded_parts: list[torch.Tensor] = [] + offset = 0 + for key in self.action_key_order: + width = len(self.indices_generator.get([key])) + value = action[:, offset : offset + width] + offset += width + + if ( + self._runtime_robot is not None + and self.EndEffector.GRIPPER.value in key + and value.shape[-1] == 1 + ): + value = self.EefNormalizer.denormalize_cobotmagic_gripper( + value, key, robot=self._runtime_robot + ) + value = ( + value.to(self.device, dtype=action.dtype) + if isinstance(value, torch.Tensor) + else torch.as_tensor(value, device=self.device, dtype=action.dtype) + ) + if value.dim() == 1: + value = value.unsqueeze(0) + control_part = key.replace(self.EndEffector.GRIPPER.value, "") + target_dim = len( + self._runtime_robot.get_joint_ids( + control_part, remove_mimic=False + ) + ) + if target_dim > value.shape[-1]: + value = value.repeat(1, target_dim) + + expanded_parts.append(value) + + return torch.cat(expanded_parts, dim=-1).to(self.device) + + def _action_stats( + self, + mean_action: torch.Tensor, + deterministic: bool, + provided_action: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) + std = log_std.exp().expand(mean_action.shape[0], -1) + dist = Normal(mean_action, std) + if provided_action is not None: + action = provided_action + elif deterministic: + action = mean_action + else: + action = dist.rsample() + log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) + entropy = dist.entropy().sum(dim=-1, keepdim=True) + return action, log_prob, entropy @torch.no_grad() def forward( self, tensordict: TensorDict, deterministic: bool = False ) -> TensorDict: - """Forward pass: generate action and value from VLA model. - - Args: - tensordict: Must contain "observation" key with observation data - deterministic: If True, use deterministic actions (passed to VLA model) - - Returns: - Same tensordict with added keys: - - "action": Sampled or deterministic action - - "sample_log_prob": Log probability of action - - "value": Value estimate - """ - # VLA team should implement forward logic here - # This is a template - actual implementation depends on VLA model structure - obs = tensordict["observation"] - - # Example: VLA model generates action and value - action, log_prob, value = self.vla_model(obs, deterministic=deterministic) - + observation = tensordict["observation"] + batch, critic_input, context = self._build_policy_context( + observation, update_history=True + ) + trajectory = self._predict_chunk_actions(batch) + mean_action = self._decode_first_action(trajectory, observation) + action, log_prob, _ = self._action_stats(mean_action, deterministic) tensordict["action"] = action + tensordict["env_action"] = self._expand_env_action(action, observation) tensordict["sample_log_prob"] = log_prob - tensordict["value"] = value.squeeze(-1) - + tensordict["value"] = self.value_head(critic_input) + tensordict["policy_context"] = context + tensordict["loc"] = mean_action + tensordict["scale"] = self.log_std.clamp( + self.log_std_min, self.log_std_max + ).exp().expand_as(mean_action) return tensordict @torch.no_grad() def get_value(self, tensordict: TensorDict) -> TensorDict: - """Get value estimate from VLA model. - - Args: - tensordict: Must contain "observation" key - - Returns: - Same tensordict with added "value" key - """ - obs = tensordict["observation"] - - # VLA team implements value computation - value = self.vla_model.get_value(obs) - - tensordict["value"] = value.squeeze(-1) + _, critic_input, _ = self._build_policy_context( + tensordict["observation"], update_history=False + ) + tensordict["value"] = self.value_head(critic_input) return tensordict def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: - """Evaluate actions using VLA model. - - Args: - tensordict: Must contain: - - "observation": Observation data - - "action": Actions to evaluate - - Returns: - Same tensordict with added keys: - - "sample_log_prob": Log probability of actions - - "entropy": Entropy of action distribution - - "value": Value estimate - """ - obs = tensordict["observation"] + observation = tensordict["observation"] actions = tensordict["action"] - - # VLA team implements action evaluation - log_prob, entropy, value = self.vla_model.evaluate_actions(obs, actions) - + context = tensordict.get("policy_context", None) + batch, critic_input, _ = self._build_policy_context( + observation, update_history=False, cached_context=context + ) + trajectory = self._predict_chunk_actions(batch) + mean_action = self._decode_first_action(trajectory, observation) + _, log_prob, entropy = self._action_stats( + mean_action, deterministic=False, provided_action=actions + ) tensordict["sample_log_prob"] = log_prob tensordict["entropy"] = entropy - tensordict["value"] = value.squeeze(-1) - + tensordict["value"] = self.value_head(critic_input) return tensordict @@ -145,69 +632,37 @@ def load_vla_model( model_config: Optional[dict] = None, device: torch.device = torch.device("cpu"), ) -> nn.Module: - """Load VLA model from checkpoint. - - This function should be implemented by the VLA team to load their - pretrained VLA model (vision encoder, language model, action head, etc.). - - The returned module should have methods: - - forward(obs) -> (action, log_prob, value) - - get_value(obs) -> value - - evaluate_actions(obs, actions) -> (log_prob, entropy, value) - - Args: - model_path: Path to checkpoint file - model_class: Fully qualified class name for VLA model - model_config: Configuration dict for model initialization - device: Device to load model on - - Returns: - Initialized VLA model module - - Example implementation by VLA team: - ```python - def load_vla_model(model_path, model_class, model_config, device): - import importlib - - # Import VLA model class + """Load a pretrained DexForceVLA-compatible model.""" + workspace_root = Path(__file__).resolve().parents[5] + dexechain_root = workspace_root / "embodichain" + if str(workspace_root) not in sys.path: + sys.path.append(str(workspace_root)) + if str(dexechain_root) not in sys.path: + sys.path.append(str(dexechain_root)) + + model_config = {} if model_config is None else dict(model_config) + torch_dtype_name = model_config.pop("torch_dtype", "float32") + weight_dtype = getattr(torch, torch_dtype_name) + + if model_class: module_name, class_name = model_class.rsplit(".", 1) - module = importlib.import_module(module_name) - ModelClass = getattr(module, class_name) + module = __import__(module_name, fromlist=[class_name]) + model_cls = getattr(module, class_name) + return model_cls.from_pretrained(model_path, dtype=weight_dtype).to(device) - # Initialize model - model = ModelClass(**model_config) - - # Load checkpoint - checkpoint = torch.load(model_path, map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) - - model.to(device) - model.eval() - - return model - ``` - """ - raise NotImplementedError( - "load_vla_model() must be implemented. " - f"Model path: {model_path}, class: {model_class}, config: {model_config}" + from dexechain.agents.dexforce_vla.models.dexforcevla_runner import ( # pyright: ignore[reportMissingImports] + DexForceVLA, ) + return DexForceVLA.from_pretrained(model_path, dtype=weight_dtype).to(device) + def build_vla_policy( policy_block: dict, action_dim: int, device: torch.device, ) -> VLAPolicy: - """Build VLA policy from configuration. - - Args: - policy_block: Configuration dict - action_dim: Dimension of action space - device: Device to place policy on - - Returns: - Initialized VLAPolicy instance - """ + """Build a VLAPolicy from configuration.""" vla_config = policy_block.get("vla_config") if vla_config is None: raise ValueError("VLA policy requires 'vla_config' in policy block") @@ -216,23 +671,21 @@ def build_vla_policy( if model_path is None: raise ValueError("VLA config requires 'model_path'") - model_class = vla_config.get("model_class") - model_config = vla_config.get("model_config", {}) - model_config["action_dim"] = action_dim - - # Load VLA model vla_model = load_vla_model( model_path=model_path, - model_class=model_class, - model_config=model_config, + model_class=vla_config.get("model_class"), + model_config=dict(vla_config.get("model_config", {})), device=device, ) - - # Create VLAPolicy instance - policy = VLAPolicy( + return VLAPolicy( action_dim=action_dim, device=device, vla_model=vla_model, + instruction=vla_config.get("instruction", "Stack the bowls."), + inference_horizon=int(vla_config.get("inference_horizon", 32)), + action_std_init=float(vla_config.get("action_std_init", 0.02)), + robot_type=vla_config.get("robot_type", "CobotMagic"), + gripper_open_value=float(vla_config.get("gripper_open_value", 0.05)), + gripper_closed_value=float(vla_config.get("gripper_closed_value", 0.0)), + action_key_order=vla_config.get("action_key_order"), ) - - return policy diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index fda61261..93682940 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -15,7 +15,9 @@ # ---------------------------------------------------------------------------- import argparse +import importlib import os +import sys import time from pathlib import Path @@ -77,6 +79,8 @@ def train_from_config(config_path: str): gpu_id = int(trainer_cfg.get("gpu_id", 0)) num_envs = trainer_cfg.get("num_envs", None) wandb_project_name = trainer_cfg.get("wandb_project_name", "embodychain-generic") + filter_dataset_saving = bool(trainer_cfg.get("filter_dataset_saving", True)) + import_modules = list(trainer_cfg.get("import_modules", [])) # Device if not isinstance(device_str, str): @@ -132,15 +136,30 @@ def train_from_config(config_path: str): if use_wandb: wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json) + workspace_root = Path(__file__).resolve().parents[3] + dexechain_root = workspace_root / "embodichain" + if str(workspace_root) not in sys.path: + sys.path.append(str(workspace_root)) + if str(dexechain_root) not in sys.path: + sys.path.append(str(dexechain_root)) + for module_name in import_modules: + importlib.import_module(module_name) + gym_config_path = Path(trainer_cfg["gym_config"]) logger.log_info(f"Current working directory: {Path.cwd()}") gym_config_data = load_json(str(gym_config_path)) + if filter_dataset_saving: + gym_config_data = deepcopy(gym_config_data) + gym_config_data.get("env", {}).pop("dataset", None) gym_env_cfg = config_to_cfg( gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES ) if num_envs is not None: gym_env_cfg.num_envs = int(num_envs) + gym_env_cfg.filter_dataset_saving = filter_dataset_saving + if filter_dataset_saving: + gym_env_cfg.init_rollout_buffer = False # Ensure sim configuration mirrors runtime overrides if gym_env_cfg.sim_cfg is None: @@ -171,6 +190,8 @@ def train_from_config(config_path: str): eval_gym_env_cfg = deepcopy(gym_env_cfg) eval_gym_env_cfg.num_envs = num_eval_envs eval_gym_env_cfg.sim_cfg.headless = True + eval_gym_env_cfg.filter_dataset_saving = True + eval_gym_env_cfg.init_rollout_buffer = False eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg) logger.log_info( f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)" diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index d88e391e..113792e1 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -23,7 +23,7 @@ from tensordict import TensorDict -def dict_to_tensordict(obs_dict: dict, device: torch.device) -> TensorDict: +def dict_to_tensordict(obs_dict: dict | TensorDict, device: torch.device) -> TensorDict: """Convert a nested observation dict into a TensorDict. Args: @@ -34,10 +34,10 @@ def dict_to_tensordict(obs_dict: dict, device: torch.device) -> TensorDict: TensorDict with an outer ``"observation"`` key. """ - def _recursive_convert(data: dict) -> dict: + def _recursive_convert(data: dict | TensorDict) -> dict: result = {} for key, value in data.items(): - if isinstance(value, dict): + if isinstance(value, (dict, TensorDict)): result[key] = _recursive_convert(value) elif isinstance(value, torch.Tensor): result[key] = value.to(device) @@ -55,6 +55,14 @@ def _get_first_tensor_batch_size(data: dict) -> int | None: return batch_size return None + if isinstance(obs_dict, TensorDict): + obs_td = obs_dict.to(device) + return TensorDict( + {"observation": obs_td}, + batch_size=obs_td.batch_size, + device=device, + ) + converted = _recursive_convert(obs_dict) batch_size = _get_first_tensor_batch_size(converted) if batch_size is None: diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 4b00be88..491040c0 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -322,6 +322,8 @@ def _eval_once(self, num_episodes: int = 5): num_episodes: Number of episodes to evaluate """ self.policy.eval() + if hasattr(self.policy, "bind_env"): + self.policy.bind_env(self.eval_env) episode_returns = [] episode_lengths = [] @@ -342,18 +344,17 @@ def _eval_once(self, num_episodes: int = 5): # Get deterministic actions for evaluation obs_copy = obs.clone() self.policy.forward(obs_copy, deterministic=True) - actions = obs_copy["action"] - am = getattr(self.eval_env, "action_manager", None) - action_type = ( - am.action_type - if am - else getattr(self.eval_env, "action_type", "delta_qpos") + actions = ( + obs_copy["env_action"] + if "env_action" in obs_copy.keys() + else obs_copy["action"] ) - action_dict = {action_type: actions} + am = getattr(self.eval_env, "action_manager", None) + env_action = {am.action_type: actions} if am else actions # Environment step - env returns dict, convert to TensorDict at boundary next_obs, reward, terminated, truncated, info = self.eval_env.step( - action_dict + env_action ) next_obs = dict_to_tensordict(next_obs, self.device) obs = next_obs From d8c30e1669a0fcc2b64ff2d7fbfa514af85f9ef9 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 9 Mar 2026 16:08:27 +0000 Subject: [PATCH 13/14] reformate files --- embodichain/agents/rl/models/vla_policy.py | 64 ++++++++++++---------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index 616398ec..448ec93d 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -77,14 +77,18 @@ def __init__( JointType, Modality, ) - from dexechain.data.global_mapping import GlobalMapping # pyright: ignore[reportMissingImports] + from dexechain.data.global_mapping import ( + GlobalMapping, + ) # pyright: ignore[reportMissingImports] from dexechain.lab.gym.utils.gym_utils import ( # pyright: ignore[reportMissingImports] get_pk_serial_chain_from_robot_type, ) from dexechain.lab.gym.utils.misc import ( # pyright: ignore[reportMissingImports] _data_key_to_control_part, ) - from dexechain.utils.utility import get_right_name # pyright: ignore[reportMissingImports] + from dexechain.utils.utility import ( + get_right_name, + ) # pyright: ignore[reportMissingImports] self.ActionMode = ActionMode self.ControlParts = ControlParts @@ -257,18 +261,14 @@ def _split_qpos(self, qpos: torch.Tensor) -> dict[str, torch.Tensor]: right_arm_end = left_eef_end + arm_dofs_per_side return { - self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value: qpos[ - :, :left_arm_end - ], - self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value: qpos[ - :, left_arm_end:left_eef_end - ], - self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value: qpos[ - :, left_eef_end:right_arm_end - ], - self.ControlParts.RIGHT_EEF.value + self.EndEffector.GRIPPER.value: qpos[ - :, right_arm_end: - ], + self.ControlParts.LEFT_ARM.value + + self.JointType.QPOS.value: qpos[:, :left_arm_end], + self.ControlParts.LEFT_EEF.value + + self.EndEffector.GRIPPER.value: qpos[:, left_arm_end:left_eef_end], + self.ControlParts.RIGHT_ARM.value + + self.JointType.QPOS.value: qpos[:, left_eef_end:right_arm_end], + self.ControlParts.RIGHT_EEF.value + + self.EndEffector.GRIPPER.value: qpos[:, right_arm_end:], } def _build_state_vector( @@ -307,7 +307,9 @@ def _build_state_vector( normalized = self.EefNormalizer.normalize_eef( qpos_data, part, robot=self._runtime_robot ) - state_entries[key] = self._fit_state_value(key, normalized, qpos.dtype) + state_entries[key] = self._fit_state_value( + key, normalized, qpos.dtype + ) else: state_entries = { self.ControlParts.LEFT_ARM.value @@ -321,7 +323,8 @@ def _build_state_vector( self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value: self._normalize_gripper( qpos_chunks[ - self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value + self.ControlParts.LEFT_EEF.value + + self.EndEffector.GRIPPER.value ], self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value, ), @@ -365,9 +368,11 @@ def _build_state_vector( ], ) eef_pose_dict = { - key: value.to(self.device, dtype=qpos.dtype) - if isinstance(value, torch.Tensor) - else torch.as_tensor(value, device=self.device, dtype=qpos.dtype) + key: ( + value.to(self.device, dtype=qpos.dtype) + if isinstance(value, torch.Tensor) + else torch.as_tensor(value, device=self.device, dtype=qpos.dtype) + ) for key, value in eef_pose_dict.items() } state_entries.update(eef_pose_dict) @@ -505,14 +510,15 @@ def _decode_first_action( elif key.startswith(self.ControlParts.RIGHT_ARM.value): value = ( qpos_chunks[ - self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value + self.ControlParts.RIGHT_ARM.value + + self.JointType.QPOS.value ] + value ) elif self.EndEffector.GRIPPER.value in key: - value = self.gripper_closed_value + ( - 1.0 - value - ) * (self.gripper_open_value - self.gripper_closed_value) + value = self.gripper_closed_value + (1.0 - value) * ( + self.gripper_open_value - self.gripper_closed_value + ) decoded_parts.append(value) if not decoded_parts: @@ -548,9 +554,7 @@ def _expand_env_action( value = value.unsqueeze(0) control_part = key.replace(self.EndEffector.GRIPPER.value, "") target_dim = len( - self._runtime_robot.get_joint_ids( - control_part, remove_mimic=False - ) + self._runtime_robot.get_joint_ids(control_part, remove_mimic=False) ) if target_dim > value.shape[-1]: value = value.repeat(1, target_dim) @@ -595,9 +599,11 @@ def forward( tensordict["value"] = self.value_head(critic_input) tensordict["policy_context"] = context tensordict["loc"] = mean_action - tensordict["scale"] = self.log_std.clamp( - self.log_std_min, self.log_std_max - ).exp().expand_as(mean_action) + tensordict["scale"] = ( + self.log_std.clamp(self.log_std_min, self.log_std_max) + .exp() + .expand_as(mean_action) + ) return tensordict @torch.no_grad() From 7ac335a931a27b3419c73c3db617f2fafed3afa3 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 9 Mar 2026 16:28:58 +0000 Subject: [PATCH 14/14] Step 32 instead of 1 --- .../agents/rl/collector/async_collector.py | 2 + .../agents/rl/collector/sync_collector.py | 2 + embodichain/agents/rl/models/vla_policy.py | 198 ++++++++++++++---- embodichain/agents/rl/utils/trainer.py | 6 +- 4 files changed, 164 insertions(+), 44 deletions(-) diff --git a/embodichain/agents/rl/collector/async_collector.py b/embodichain/agents/rl/collector/async_collector.py index 8c4b1785..465accb9 100644 --- a/embodichain/agents/rl/collector/async_collector.py +++ b/embodichain/agents/rl/collector/async_collector.py @@ -160,6 +160,8 @@ def _collect_loop(self): next_obs_td = dict_to_tensordict(next_obs_dict, self.device) done = terminated | truncated next_obs_for_td = next_obs_td["observation"] + if hasattr(self.policy, "reset_envs"): + self.policy.reset_envs(done, next_obs_for_td) batch_size = next_obs_td.batch_size[0] next_td = TensorDict( diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index fca615d8..ecce819a 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -73,6 +73,8 @@ def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: # Build "next" TensorDict done = terminated | truncated next_obs_for_td = next_obs_td["observation"] + if hasattr(self.policy, "reset_envs"): + self.policy.reset_envs(done, next_obs_for_td) # Ensure batch_size consistency - use next_obs_td's batch_size batch_size = next_obs_td.batch_size[0] diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py index 448ec93d..8f1fd80f 100644 --- a/embodichain/agents/rl/models/vla_policy.py +++ b/embodichain/agents/rl/models/vla_policy.py @@ -77,18 +77,14 @@ def __init__( JointType, Modality, ) - from dexechain.data.global_mapping import ( - GlobalMapping, - ) # pyright: ignore[reportMissingImports] + from dexechain.data.global_mapping import GlobalMapping # pyright: ignore[reportMissingImports] from dexechain.lab.gym.utils.gym_utils import ( # pyright: ignore[reportMissingImports] get_pk_serial_chain_from_robot_type, ) from dexechain.lab.gym.utils.misc import ( # pyright: ignore[reportMissingImports] _data_key_to_control_part, ) - from dexechain.utils.utility import ( - get_right_name, - ) # pyright: ignore[reportMissingImports] + from dexechain.utils.utility import get_right_name # pyright: ignore[reportMissingImports] self.ActionMode = ActionMode self.ControlParts = ControlParts @@ -138,6 +134,9 @@ def __init__( self._runtime_robot = None self._state_history: torch.Tensor | None = None self._image_history: torch.Tensor | None = None + self._cached_chunk: torch.Tensor | None = None + self._cached_chunk_context: TensorDict | None = None + self._cached_chunk_step: torch.Tensor | None = None def bind_env(self, env) -> None: self._runtime_env = env @@ -149,6 +148,49 @@ def bind_env(self, env) -> None: except Exception: self._runtime_robot = None + def _reset_chunk_cache(self, env_mask: torch.Tensor | None = None) -> None: + if env_mask is None: + self._cached_chunk = None + self._cached_chunk_context = None + self._cached_chunk_step = None + return + if self._cached_chunk_step is not None: + self._cached_chunk_step[env_mask] = self.inference_horizon + + @torch.no_grad() + def reset_envs( + self, done_mask: torch.Tensor, next_observation: TensorDict | None = None + ) -> None: + if done_mask.dim() > 1: + done_mask = done_mask.squeeze(-1) + done_mask = done_mask.to(device=self.device, dtype=torch.bool) + if not done_mask.any(): + return + + self._reset_chunk_cache(done_mask) + + if next_observation is None: + if self._state_history is not None: + self._state_history[done_mask] = 0 + if self._image_history is not None: + self._image_history[done_mask] = 0 + return + + current_state, _ = self._build_state_vector(next_observation) + current_images = self._extract_current_images(next_observation) + reset_state_history = current_state.unsqueeze(1).repeat(1, self.state_history_len, 1) + reset_image_history = current_images.unsqueeze(1).repeat(1, self.img_history_size, 1, 1, 1, 1) + + if self._state_history is None or self._state_history.shape[0] != current_state.shape[0]: + self._state_history = reset_state_history + else: + self._state_history[done_mask] = reset_state_history[done_mask] + + if self._image_history is None or self._image_history.shape[0] != current_images.shape[0]: + self._image_history = reset_image_history + else: + self._image_history[done_mask] = reset_image_history[done_mask] + def _resolve_action_key_order( self, action_key_order: Optional[list[str]] ) -> list[str]: @@ -261,14 +303,18 @@ def _split_qpos(self, qpos: torch.Tensor) -> dict[str, torch.Tensor]: right_arm_end = left_eef_end + arm_dofs_per_side return { - self.ControlParts.LEFT_ARM.value - + self.JointType.QPOS.value: qpos[:, :left_arm_end], - self.ControlParts.LEFT_EEF.value - + self.EndEffector.GRIPPER.value: qpos[:, left_arm_end:left_eef_end], - self.ControlParts.RIGHT_ARM.value - + self.JointType.QPOS.value: qpos[:, left_eef_end:right_arm_end], - self.ControlParts.RIGHT_EEF.value - + self.EndEffector.GRIPPER.value: qpos[:, right_arm_end:], + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value: qpos[ + :, :left_arm_end + ], + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value: qpos[ + :, left_arm_end:left_eef_end + ], + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value: qpos[ + :, left_eef_end:right_arm_end + ], + self.ControlParts.RIGHT_EEF.value + self.EndEffector.GRIPPER.value: qpos[ + :, right_arm_end: + ], } def _build_state_vector( @@ -307,9 +353,7 @@ def _build_state_vector( normalized = self.EefNormalizer.normalize_eef( qpos_data, part, robot=self._runtime_robot ) - state_entries[key] = self._fit_state_value( - key, normalized, qpos.dtype - ) + state_entries[key] = self._fit_state_value(key, normalized, qpos.dtype) else: state_entries = { self.ControlParts.LEFT_ARM.value @@ -323,8 +367,7 @@ def _build_state_vector( self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value: self._normalize_gripper( qpos_chunks[ - self.ControlParts.LEFT_EEF.value - + self.EndEffector.GRIPPER.value + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value ], self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value, ), @@ -368,11 +411,9 @@ def _build_state_vector( ], ) eef_pose_dict = { - key: ( - value.to(self.device, dtype=qpos.dtype) - if isinstance(value, torch.Tensor) - else torch.as_tensor(value, device=self.device, dtype=qpos.dtype) - ) + key: value.to(self.device, dtype=qpos.dtype) + if isinstance(value, torch.Tensor) + else torch.as_tensor(value, device=self.device, dtype=qpos.dtype) for key, value in eef_pose_dict.items() } state_entries.update(eef_pose_dict) @@ -460,6 +501,20 @@ def _build_policy_context( ) return batch, critic_input, context + def _slice_batch( + self, batch: dict[str, torch.Tensor | list[str]], mask: torch.Tensor + ) -> dict[str, torch.Tensor | list[str]]: + mask_list = mask.tolist() + sliced: dict[str, torch.Tensor | list[str]] = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + sliced[key] = value[mask] + elif isinstance(value, list): + sliced[key] = [item for item, keep in zip(value, mask_list) if keep] + else: + sliced[key] = value + return sliced + def _predict_chunk_actions( self, batch: dict[str, torch.Tensor | list[str]] ) -> torch.Tensor: @@ -485,17 +540,16 @@ def _predict_chunk_actions( ) return data[self.Modality.ACTIONS.value] - def _decode_first_action( - self, trajectory: torch.Tensor, observation: TensorDict + def _decode_action_step( + self, step_action: torch.Tensor, observation: TensorDict ) -> torch.Tensor: - first_step = trajectory[:, 0] current_qpos = observation["robot"][self.JointType.QPOS.value].to(self.device) qpos_chunks = self._split_qpos(current_qpos) decoded_parts: list[torch.Tensor] = [] for key in self.action_key_order: indices = self.indices_generator.get([key]) - value = first_step[:, indices] + value = step_action[:, indices] if ( self.ActionMode.RELATIVE.value in key and self.JointType.QPOS.value in key @@ -510,15 +564,14 @@ def _decode_first_action( elif key.startswith(self.ControlParts.RIGHT_ARM.value): value = ( qpos_chunks[ - self.ControlParts.RIGHT_ARM.value - + self.JointType.QPOS.value + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value ] + value ) elif self.EndEffector.GRIPPER.value in key: - value = self.gripper_closed_value + (1.0 - value) * ( - self.gripper_open_value - self.gripper_closed_value - ) + value = self.gripper_closed_value + ( + 1.0 - value + ) * (self.gripper_open_value - self.gripper_closed_value) decoded_parts.append(value) if not decoded_parts: @@ -527,6 +580,11 @@ def _decode_first_action( ) return torch.cat(decoded_parts, dim=-1).to(self.device) + def _decode_first_action( + self, trajectory: torch.Tensor, observation: TensorDict + ) -> torch.Tensor: + return self._decode_action_step(trajectory[:, 0], observation) + def _expand_env_action( self, action: torch.Tensor, observation: TensorDict ) -> torch.Tensor: @@ -554,7 +612,9 @@ def _expand_env_action( value = value.unsqueeze(0) control_part = key.replace(self.EndEffector.GRIPPER.value, "") target_dim = len( - self._runtime_robot.get_joint_ids(control_part, remove_mimic=False) + self._runtime_robot.get_joint_ids( + control_part, remove_mimic=False + ) ) if target_dim > value.shape[-1]: value = value.repeat(1, target_dim) @@ -590,20 +650,65 @@ def forward( batch, critic_input, context = self._build_policy_context( observation, update_history=True ) - trajectory = self._predict_chunk_actions(batch) - mean_action = self._decode_first_action(trajectory, observation) + batch_size = observation.batch_size[0] + + if ( + self._cached_chunk is None + or self._cached_chunk_context is None + or self._cached_chunk_step is None + or self._cached_chunk.shape[0] != batch_size + ): + self._cached_chunk = None + self._cached_chunk_context = None + self._cached_chunk_step = torch.full( + (batch_size,), + self.inference_horizon, + device=self.device, + dtype=torch.long, + ) + + refresh_mask = self._cached_chunk_step >= self.inference_horizon + if refresh_mask.any(): + refresh_batch = self._slice_batch(batch, refresh_mask) + refresh_trajectory = self._predict_chunk_actions(refresh_batch) + refresh_context = context[refresh_mask] + + if self._cached_chunk is None: + chunk_shape = (batch_size,) + tuple(refresh_trajectory.shape[1:]) + self._cached_chunk = torch.zeros( + chunk_shape, + device=refresh_trajectory.device, + dtype=refresh_trajectory.dtype, + ) + if self._cached_chunk_context is None: + self._cached_chunk_context = context.clone() + + self._cached_chunk[refresh_mask] = refresh_trajectory + self._cached_chunk_context["state_history"][refresh_mask] = refresh_context[ + "state_history" + ] + self._cached_chunk_context["image_history"][refresh_mask] = refresh_context[ + "image_history" + ] + self._cached_chunk_step[refresh_mask] = 0 + + step_indices = self._cached_chunk_step.clone() + raw_step_actions = self._cached_chunk[ + torch.arange(batch_size, device=self.device), step_indices + ] + mean_action = self._decode_action_step(raw_step_actions, observation) action, log_prob, _ = self._action_stats(mean_action, deterministic) tensordict["action"] = action tensordict["env_action"] = self._expand_env_action(action, observation) tensordict["sample_log_prob"] = log_prob tensordict["value"] = self.value_head(critic_input) - tensordict["policy_context"] = context + tensordict["policy_context"] = self._cached_chunk_context.clone() + tensordict["chunk_step_idx"] = step_indices.unsqueeze(-1) tensordict["loc"] = mean_action - tensordict["scale"] = ( - self.log_std.clamp(self.log_std_min, self.log_std_max) - .exp() - .expand_as(mean_action) - ) + tensordict["scale"] = self.log_std.clamp( + self.log_std_min, self.log_std_max + ).exp().expand_as(mean_action) + self._cached_chunk_step = self._cached_chunk_step + 1 return tensordict @torch.no_grad() @@ -622,7 +727,14 @@ def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: observation, update_history=False, cached_context=context ) trajectory = self._predict_chunk_actions(batch) - mean_action = self._decode_first_action(trajectory, observation) + if "chunk_step_idx" in tensordict.keys(): + step_idx = tensordict["chunk_step_idx"].squeeze(-1).long() + step_action = trajectory[ + torch.arange(trajectory.shape[0], device=trajectory.device), step_idx + ] + mean_action = self._decode_action_step(step_action, observation) + else: + mean_action = self._decode_first_action(trajectory, observation) _, log_prob, entropy = self._action_stats( mean_action, deterministic=False, provided_action=actions ) diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 491040c0..47e76e2f 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -324,6 +324,8 @@ def _eval_once(self, num_episodes: int = 5): self.policy.eval() if hasattr(self.policy, "bind_env"): self.policy.bind_env(self.eval_env) + if hasattr(self.policy, "_reset_chunk_cache"): + self.policy._reset_chunk_cache() episode_returns = [] episode_lengths = [] @@ -357,10 +359,12 @@ def _eval_once(self, num_episodes: int = 5): env_action ) next_obs = dict_to_tensordict(next_obs, self.device) + done = terminated | truncated + if hasattr(self.policy, "reset_envs"): + self.policy.reset_envs(done, next_obs["observation"]) obs = next_obs # Update statistics only for still-running environments - done = terminated | truncated still_running = ~done_mask cumulative_reward[still_running] += reward[still_running].float() step_count[still_running] += 1