From b9cfd390fa2a16213fae14c47bc508dad555d56d Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 05:14:56 +0000 Subject: [PATCH 01/23] Feat: seperate buffer and collector --- .../rl/basic/cart_pole/train_config.json | 3 +- .../rl/basic/cart_pole/train_config_grpo.json | 3 +- configs/agents/rl/push_cube/train_config.json | 3 +- embodichain/agents/rl/algo/base.py | 26 +-- embodichain/agents/rl/algo/grpo.py | 186 ++++++------------ embodichain/agents/rl/algo/ppo.py | 151 ++++---------- embodichain/agents/rl/buffer/__init__.py | 2 +- .../agents/rl/buffer/rollout_buffer.py | 89 +-------- .../agents/rl/buffer/standard_buffer.py | 54 +++++ embodichain/agents/rl/collector/__init__.py | 20 ++ embodichain/agents/rl/collector/base.py | 37 ++++ .../agents/rl/collector/sync_collector.py | 118 +++++++++++ embodichain/agents/rl/models/__init__.py | 22 ++- embodichain/agents/rl/models/actor_critic.py | 64 +++--- embodichain/agents/rl/models/actor_only.py | 65 +++--- embodichain/agents/rl/models/policy.py | 42 ++-- embodichain/agents/rl/train.py | 41 ++-- embodichain/agents/rl/utils/__init__.py | 4 +- embodichain/agents/rl/utils/helper.py | 110 +++++++++-- embodichain/agents/rl/utils/trainer.py | 65 +++--- tests/agents/test_rl.py | 2 +- 21 files changed, 595 insertions(+), 512 deletions(-) create mode 100644 embodichain/agents/rl/buffer/standard_buffer.py create mode 100644 embodichain/agents/rl/collector/__init__.py create mode 100644 embodichain/agents/rl/collector/base.py create mode 100644 embodichain/agents/rl/collector/sync_collector.py diff --git a/configs/agents/rl/basic/cart_pole/train_config.json b/configs/agents/rl/basic/cart_pole/train_config.json index 06031b5e..f4e99372 100644 --- a/configs/agents/rl/basic/cart_pole/train_config.json +++ b/configs/agents/rl/basic/cart_pole/train_config.json @@ -9,7 +9,7 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "eval_freq": 200, "save_freq": 200, "use_wandb": false, @@ -35,6 +35,7 @@ }, "policy": { "name": "actor_critic", + "action_dim": 2, "actor": { "type": "mlp", "network_cfg": { diff --git a/configs/agents/rl/basic/cart_pole/train_config_grpo.json b/configs/agents/rl/basic/cart_pole/train_config_grpo.json index 6625afe1..1caf6b0d 100644 --- a/configs/agents/rl/basic/cart_pole/train_config_grpo.json +++ b/configs/agents/rl/basic/cart_pole/train_config_grpo.json @@ -9,7 +9,7 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "eval_freq": 200, "save_freq": 200, "use_wandb": true, @@ -36,6 +36,7 @@ }, "policy": { "name": "actor_only", + "action_dim": 2, "actor": { "type": "mlp", "network_cfg": { diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json index b58e2c63..c79776b6 100644 --- a/configs/agents/rl/push_cube/train_config.json +++ b/configs/agents/rl/push_cube/train_config.json @@ -9,7 +9,7 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "enable_eval": true, "num_eval_envs": 16, "num_eval_episodes": 3, @@ -38,6 +38,7 @@ }, "policy": { "name": "actor_critic", + "action_dim": 8, "actor": { "type": "mlp", "network_cfg": { diff --git a/embodichain/agents/rl/algo/base.py b/embodichain/agents/rl/algo/base.py index fcb3fc00..b1516472 100644 --- a/embodichain/agents/rl/algo/base.py +++ b/embodichain/agents/rl/algo/base.py @@ -16,37 +16,19 @@ from __future__ import annotations -from typing import Dict, Any, Callable +from typing import Dict import torch +from tensordict import TensorDict class BaseAlgorithm: """Base class for RL algorithms. - Algorithms must implement buffer initialization, rollout collection, and - policy update. Trainer depends only on this interface to remain - algorithm-agnostic. + Algorithms only implement policy updates over collected rollouts. """ device: torch.device - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ) -> None: - """Initialize internal buffer(s) required by the algorithm.""" - raise NotImplementedError - - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect trajectories and return logging info (e.g., reward components).""" - raise NotImplementedError - - def update(self) -> Dict[str, float]: + def update(self, rollout: TensorDict) -> Dict[str, float]: """Update policy using collected data and return training losses.""" raise NotImplementedError diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 4654ed54..4cef287c 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -16,14 +16,14 @@ from __future__ import annotations +import math from copy import deepcopy -from typing import Any, Callable, Dict +from typing import Dict import torch from tensordict import TensorDict -from embodichain.agents.rl.buffer import RolloutBuffer -from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation +from embodichain.agents.rl.utils import AlgorithmCfg from embodichain.utils import configclass from .base import BaseAlgorithm @@ -38,15 +38,12 @@ class GRPOCfg(AlgorithmCfg): kl_coef: float = 0.02 group_size: int = 4 eps: float = 1e-8 - # Collect fresh groups every rollout instead of continuing from prior states. reset_every_rollout: bool = True - # If True, do not optimize steps after the first done in each environment - # during a rollout. This better matches "one completion per prompt". truncate_at_first_done: bool = True class GRPO(BaseAlgorithm): - """Group Relative Policy Optimization on top of RolloutBuffer.""" + """Group Relative Policy Optimization on top of TensorDict rollouts.""" def __init__(self, cfg: GRPOCfg, policy): if cfg.group_size < 2: @@ -57,9 +54,6 @@ def __init__(self, cfg: GRPOCfg, policy): self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) - self.buffer: RolloutBuffer | None = None - # Only create ref_policy when kl_coef > 0 (e.g. VLA fine-tuning). - # For from-scratch training (CartPole etc.), kl_coef=0 avoids the "tight band" problem. if self.cfg.kl_coef > 0.0: self.ref_policy = deepcopy(policy).to(self.device).eval() for param in self.ref_policy.parameters(): @@ -67,142 +61,75 @@ def __init__(self, cfg: GRPOCfg, policy): else: self.ref_policy = None - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ) -> None: - if num_envs % self.cfg.group_size != 0: - raise ValueError( - f"GRPO requires num_envs divisible by group_size, got " - f"num_envs={num_envs}, group_size={self.cfg.group_size}." - ) - self.buffer = RolloutBuffer( - num_steps, num_envs, obs_dim, action_dim, self.device - ) - def _compute_step_returns_and_mask( self, rewards: torch.Tensor, dones: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute step-wise discounted returns R_t = r_t + gamma * R_{t+1} and mask. - - Solves causal + discount bias: each step's return only depends on future rewards. - Returns: - step_returns: shape [T, N], discounted return from step t onward. - seq_mask: shape [T, N], 1 for valid steps, 0 after first done (if truncate). - """ - t_steps, n_envs = rewards.shape + """Compute discounted returns and valid-step mask over `[N, T]` rollout.""" + n_envs, t_steps = rewards.shape seq_mask = torch.ones( - (t_steps, n_envs), dtype=torch.float32, device=self.device + (n_envs, t_steps), dtype=torch.float32, device=self.device ) step_returns = torch.zeros( - (t_steps, n_envs), dtype=torch.float32, device=self.device + (n_envs, t_steps), dtype=torch.float32, device=self.device ) alive = torch.ones(n_envs, dtype=torch.float32, device=self.device) for t in range(t_steps): - seq_mask[t] = alive + seq_mask[:, t] = alive if self.cfg.truncate_at_first_done: - alive = alive * (~dones[t]).float() + alive = alive * (~dones[:, t]).float() running_return = torch.zeros(n_envs, dtype=torch.float32, device=self.device) for t in reversed(range(t_steps)): running_return = ( - rewards[t] + self.cfg.gamma * running_return * (~dones[t]).float() + rewards[:, t] + self.cfg.gamma * running_return * (~dones[:, t]).float() ) - step_returns[t] = running_return + step_returns[:, t] = running_return return step_returns, seq_mask def _compute_step_group_advantages( self, step_returns: torch.Tensor, seq_mask: torch.Tensor ) -> torch.Tensor: - """Per-step group normalization with masked mean/std for variable-length sequences. - - When group members have different survival lengths, only compare against - peers still alive at that step (avoids dead envs' zeros dragging down the mean). - """ - t_steps, n_envs = step_returns.shape + """Normalize per-step returns within each environment group.""" + n_envs, t_steps = step_returns.shape group_size = self.cfg.group_size - returns_grouped = step_returns.view(t_steps, n_envs // group_size, group_size) - mask_grouped = seq_mask.view(t_steps, n_envs // group_size, group_size) + returns_grouped = step_returns.view(n_envs // group_size, group_size, t_steps) + mask_grouped = seq_mask.view(n_envs // group_size, group_size, t_steps) - valid_count = mask_grouped.sum(dim=2, keepdim=True) + valid_count = mask_grouped.sum(dim=1, keepdim=True) valid_count_safe = torch.clamp(valid_count, min=1.0) group_mean = (returns_grouped * mask_grouped).sum( - dim=2, keepdim=True + dim=1, keepdim=True ) / valid_count_safe diff_sq = ((returns_grouped - group_mean) ** 2) * mask_grouped - group_var = diff_sq.sum(dim=2, keepdim=True) / valid_count_safe + group_var = diff_sq.sum(dim=1, keepdim=True) / valid_count_safe group_std = torch.sqrt(group_var) - adv = (returns_grouped - group_mean) / (group_std + self.cfg.eps) - adv = adv.view(t_steps, n_envs) * seq_mask - return adv - - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - if self.buffer is None: - raise RuntimeError( - "Buffer not initialized. Call initialize_buffer() first." - ) - - policy.train() - self.buffer.step = 0 - current_obs = obs + advantages = (returns_grouped - group_mean) / (group_std + self.cfg.eps) + return advantages.view(n_envs, t_steps) * seq_mask - if self.cfg.reset_every_rollout: - current_obs, _ = env.reset() - if isinstance(current_obs, TensorDict): - current_obs = flatten_dict_observation(current_obs) - - for _ in range(num_steps): - actions, log_prob, _ = policy.get_action(current_obs, deterministic=False) - am = getattr(env, "action_manager", None) - action_type = ( - am.action_type if am else getattr(env, "action_type", "delta_qpos") - ) - action_dict = {action_type: actions} - next_obs, reward, terminated, truncated, env_info = env.step(action_dict) - done = (terminated | truncated).bool() - reward = reward.float() - - if isinstance(next_obs, TensorDict): - next_obs = flatten_dict_observation(next_obs) - - # GRPO does not use value function targets; store zeros in value slot. - value_placeholder = torch.zeros_like(reward) - self.buffer.add( - current_obs, actions, reward, done, value_placeholder, log_prob + def update(self, rollout: TensorDict) -> Dict[str, float]: + rollout = rollout.clone() + num_envs = rollout.batch_size[0] + if num_envs % self.cfg.group_size != 0: + raise ValueError( + f"GRPO requires num_envs divisible by group_size, got " + f"num_envs={num_envs}, group_size={self.cfg.group_size}." ) - if on_step_callback is not None: - on_step_callback(current_obs, actions, reward, done, env_info, next_obs) - current_obs = next_obs - - step_returns, seq_mask = self._compute_step_returns_and_mask( - self.buffer.rewards, self.buffer.dones + rewards = rollout["next", "reward"].float() + dones = rollout["next", "done"].bool() + step_returns, seq_mask = self._compute_step_returns_and_mask(rewards, dones) + rollout["advantage"] = self._compute_step_group_advantages( + step_returns, seq_mask ) - advantages = self._compute_step_group_advantages(step_returns, seq_mask) - - self.buffer.set_extras( - { - "advantages": advantages, - "seq_mask": seq_mask, - "seq_return": step_returns, - } - ) - return {} + rollout["seq_mask"] = seq_mask + rollout["seq_return"] = step_returns - def update(self) -> Dict[str, float]: - if self.buffer is None: - raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") + flat_rollout = rollout.reshape(math.prod(rollout.batch_size)) total_actor_loss = 0.0 total_entropy = 0.0 @@ -210,14 +137,14 @@ def update(self) -> Dict[str, float]: total_weight = 0.0 for _ in range(self.cfg.n_epochs): - for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): - obs = batch["obs"] - actions = batch["actions"] - old_logprobs = batch["logprobs"] - advantages = batch["advantages"].detach() - seq_mask = batch["seq_mask"].float() - - logprobs, entropy, _ = self.policy.evaluate_actions(obs, actions) + for batch in self._iterate_minibatches(flat_rollout, self.cfg.batch_size): + old_logprobs = batch["sample_log_prob"].clone() + advantages = batch["advantage"].detach() + seq_mask_batch = batch["seq_mask"].float() + + eval_batch = self.policy.evaluate_actions(batch.clone()) + logprobs = eval_batch["sample_log_prob"] + entropy = eval_batch["entropy"] ratio = (logprobs - old_logprobs).exp() surr1 = ratio * advantages surr2 = ( @@ -226,20 +153,19 @@ def update(self) -> Dict[str, float]: ) * advantages ) - actor_num = -(torch.min(surr1, surr2) * seq_mask).sum() - denom = torch.clamp(seq_mask.sum(), min=1.0) + actor_num = -(torch.min(surr1, surr2) * seq_mask_batch).sum() + denom = torch.clamp(seq_mask_batch.sum(), min=1.0) actor_loss = actor_num / denom - entropy_loss = -(entropy * seq_mask).sum() / denom + entropy_loss = -(entropy * seq_mask_batch).sum() / denom if self.ref_policy is not None: with torch.no_grad(): - ref_logprobs, _, _ = self.ref_policy.evaluate_actions( - obs, actions - ) + ref_batch = self.ref_policy.evaluate_actions(batch.clone()) + ref_logprobs = ref_batch["sample_log_prob"] log_ref_over_pi = ref_logprobs - logprobs kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0 - kl = (kl_per * seq_mask).sum() / denom + kl = (kl_per * seq_mask_batch).sum() / denom else: kl = torch.tensor(0.0, device=self.device) @@ -258,7 +184,7 @@ def update(self) -> Dict[str, float]: weight = float(denom.item()) total_actor_loss += actor_loss.item() * weight - masked_entropy = (entropy * seq_mask).sum() / denom + masked_entropy = (entropy * seq_mask_batch).sum() / denom total_entropy += masked_entropy.item() * weight total_kl += kl.item() * weight total_weight += weight @@ -268,3 +194,13 @@ def update(self) -> Dict[str, float]: "entropy": total_entropy / max(1.0, total_weight), "approx_ref_kl": total_kl / max(1.0, total_weight), } + + def _iterate_minibatches( + self, rollout: TensorDict, batch_size: int + ) -> list[TensorDict]: + total = rollout.batch_size[0] + indices = torch.randperm(total, device=self.device) + return [ + rollout[indices[start : start + batch_size]] + for start in range(0, total, batch_size) + ] diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index b1256ce0..b8d787b6 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -14,13 +14,13 @@ # limitations under the License. # ---------------------------------------------------------------------------- -import torch -from typing import Dict, Any, Tuple, Callable +import math +from typing import Dict +import torch from tensordict import TensorDict -from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation -from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.agents.rl.utils import AlgorithmCfg, compute_gae from embodichain.utils import configclass from .base import BaseAlgorithm @@ -36,112 +36,24 @@ class PPOCfg(AlgorithmCfg): class PPO(BaseAlgorithm): - """PPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design).""" + """PPO algorithm consuming TensorDict rollouts.""" def __init__(self, cfg: PPOCfg, policy): self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) - self.buffer: RolloutBuffer | None = None # no per-rollout aggregation for dense logging - def _compute_gae( - self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Internal method to compute GAE. Only called by collect_rollout.""" - T, N = rewards.shape - advantages = torch.zeros_like(rewards, device=self.device) - last_adv = torch.zeros(N, device=self.device) - for t in reversed(range(T)): - next_value = values[t + 1] if t < T - 1 else torch.zeros_like(values[0]) - not_done = (~dones[t]).float() - delta = rewards[t] + self.cfg.gamma * next_value * not_done - values[t] - last_adv = ( - delta + self.cfg.gamma * self.cfg.gae_lambda * not_done * last_adv - ) - advantages[t] = last_adv - returns = advantages + values - return advantages, returns - - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ): - """Initialize the rollout buffer. Called by trainer before first rollout.""" - self.buffer = RolloutBuffer( - num_steps, num_envs, obs_dim, action_dim, self.device - ) - - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect a rollout. Algorithm controls the data collection process.""" - if self.buffer is None: - raise RuntimeError( - "Buffer not initialized. Call initialize_buffer() first." - ) - - policy.train() - self.buffer.step = 0 - current_obs = obs - - for t in range(num_steps): - # Get action from policy - actions, log_prob, value = policy.get_action( - current_obs, deterministic=False - ) - - # Wrap action as dict for env processing - am = getattr(env, "action_manager", None) - action_type = ( - am.action_type if am else getattr(env, "action_type", "delta_qpos") - ) - action_dict = {action_type: actions} - - # Step environment - result = env.step(action_dict) - next_obs, reward, terminated, truncated, env_info = result - done = terminated | truncated - # Light dtype normalization - reward = reward.float() - done = done.bool() - - # Flatten TensorDict observation from ObservationManager if needed - if isinstance(next_obs, TensorDict): - next_obs = flatten_dict_observation(next_obs) - - # Add to buffer - self.buffer.add(current_obs, actions, reward, done, value, log_prob) - - # Dense logging is handled in Trainer.on_step via info; no aggregation here - # Call callback for statistics and logging - if on_step_callback is not None: - on_step_callback(current_obs, actions, reward, done, env_info, next_obs) - - current_obs = next_obs - - # Compute advantages/returns and attach to buffer extras - adv, ret = self._compute_gae( - self.buffer.rewards, self.buffer.values, self.buffer.dones - ) - self.buffer.set_extras({"advantages": adv, "returns": ret}) - - # No aggregated logging results; Trainer performs dense per-step logging - return {} - - def update(self) -> dict: - """Update the policy using the collected rollout buffer.""" - if self.buffer is None: - raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") - - # Normalize advantages (optional, common default) - adv = self.buffer._extras.get("advantages") - adv = (adv - adv.mean()) / (adv.std() + 1e-8) + def update(self, rollout: TensorDict) -> Dict[str, float]: + """Update the policy using a collected rollout.""" + rollout = rollout.clone() + compute_gae(rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda) + flat_rollout = rollout.reshape(math.prod(rollout.batch_size)) + + advantages = flat_rollout["advantage"] + adv_mean = advantages.mean() + adv_std = advantages.std().clamp_min(1e-8) total_actor_loss = 0.0 total_value_loss = 0.0 @@ -149,23 +61,22 @@ def update(self) -> dict: total_steps = 0 for _ in range(self.cfg.n_epochs): - for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): - obs = batch["obs"] - actions = batch["actions"] - old_logprobs = batch["logprobs"] - returns = batch["returns"] - advantages = ( - (batch["advantages"] - adv.mean()) / (adv.std() + 1e-8) - ).detach() - - logprobs, entropy, values = self.policy.evaluate_actions(obs, actions) + for batch in self._iterate_minibatches(flat_rollout, self.cfg.batch_size): + old_logprobs = batch["sample_log_prob"].clone() + returns = batch["return"].clone() + batch_advantages = ((batch["advantage"] - adv_mean) / adv_std).detach() + + eval_batch = self.policy.evaluate_actions(batch.clone()) + logprobs = eval_batch["sample_log_prob"] + entropy = eval_batch["entropy"] + values = eval_batch["value"] ratio = (logprobs - old_logprobs).exp() - surr1 = ratio * advantages + surr1 = ratio * batch_advantages surr2 = ( torch.clamp( ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef ) - * advantages + * batch_advantages ) actor_loss = -torch.min(surr1, surr2).mean() value_loss = torch.nn.functional.mse_loss(values, returns) @@ -184,7 +95,7 @@ def update(self) -> dict: ) self.optimizer.step() - bs = obs.shape[0] + bs = batch.batch_size[0] total_actor_loss += actor_loss.item() * bs total_value_loss += value_loss.item() * bs total_entropy += (-entropy_loss.item()) * bs @@ -195,3 +106,13 @@ def update(self) -> dict: "value_loss": total_value_loss / max(1, total_steps), "entropy": total_entropy / max(1, total_steps), } + + def _iterate_minibatches( + self, rollout: TensorDict, batch_size: int + ) -> list[TensorDict]: + total = rollout.batch_size[0] + indices = torch.randperm(total, device=self.device) + return [ + rollout[indices[start : start + batch_size]] + for start in range(0, total, batch_size) + ] diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py index 5080d251..d90b2a06 100644 --- a/embodichain/agents/rl/buffer/__init__.py +++ b/embodichain/agents/rl/buffer/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .rollout_buffer import RolloutBuffer +from .standard_buffer import RolloutBuffer __all__ = ["RolloutBuffer"] diff --git a/embodichain/agents/rl/buffer/rollout_buffer.py b/embodichain/agents/rl/buffer/rollout_buffer.py index cbd66f4e..e2854261 100644 --- a/embodichain/agents/rl/buffer/rollout_buffer.py +++ b/embodichain/agents/rl/buffer/rollout_buffer.py @@ -16,91 +16,6 @@ from __future__ import annotations -from typing import Dict, Iterator +from .standard_buffer import RolloutBuffer -import torch - - -class RolloutBuffer: - """On-device rollout buffer for on-policy algorithms. - - Stores (obs, actions, rewards, dones, values, logprobs) over time. - After finalize(), exposes advantages/returns and minibatch iteration. - """ - - def __init__( - self, - num_steps: int, - num_envs: int, - obs_dim: int, - action_dim: int, - device: torch.device, - ): - self.num_steps = num_steps - self.num_envs = num_envs - self.obs_dim = obs_dim - self.action_dim = action_dim - self.device = device - - T, N = num_steps, num_envs - self.obs = torch.zeros(T, N, obs_dim, dtype=torch.float32, device=device) - self.actions = torch.zeros(T, N, action_dim, dtype=torch.float32, device=device) - self.rewards = torch.zeros(T, N, dtype=torch.float32, device=device) - self.dones = torch.zeros(T, N, dtype=torch.bool, device=device) - self.values = torch.zeros(T, N, dtype=torch.float32, device=device) - self.logprobs = torch.zeros(T, N, dtype=torch.float32, device=device) - - self.step = 0 - # Container for algorithm-specific extra fields (e.g., advantages, returns) - self._extras: dict[str, torch.Tensor] = {} - - def add( - self, - obs: torch.Tensor, - action: torch.Tensor, - reward: torch.Tensor, - done: torch.Tensor, - value: torch.Tensor, - logprob: torch.Tensor, - ) -> None: - t = self.step - self.obs[t].copy_(obs) - self.actions[t].copy_(action) - self.rewards[t].copy_(reward) - self.dones[t].copy_(done) - self.values[t].copy_(value) - self.logprobs[t].copy_(logprob) - self.step += 1 - - def set_extras(self, extras: dict[str, torch.Tensor]) -> None: - """Attach algorithm-specific tensors (shape [T, N, ...]) for batching. - - Examples: - {"advantages": adv, "returns": ret} - """ - self._extras = extras or {} - - def iterate_minibatches(self, batch_size: int) -> Iterator[Dict[str, torch.Tensor]]: - T, N = self.num_steps, self.num_envs - total = T * N - indices = torch.randperm(total, device=self.device) - for start in range(0, total, batch_size): - idx = indices[start : start + batch_size] - t_idx = idx // N - n_idx = idx % N - batch = { - "obs": self.obs[t_idx, n_idx], - "actions": self.actions[t_idx, n_idx], - "rewards": self.rewards[t_idx, n_idx], - "dones": self.dones[t_idx, n_idx], - "values": self.values[t_idx, n_idx], - "logprobs": self.logprobs[t_idx, n_idx], - } - # Slice extras if present and shape aligned to [T, N, ...] - for name, tensor in self._extras.items(): - try: - batch[name] = tensor[t_idx, n_idx] - except Exception: - # Skip misaligned extras silently - continue - yield batch +__all__ = ["RolloutBuffer"] diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py new file mode 100644 index 00000000..fe4528af --- /dev/null +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -0,0 +1,54 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import math + +from tensordict import TensorDict + +__all__ = ["RolloutBuffer"] + + +class RolloutBuffer: + """Single-rollout buffer backed by a TensorDict.""" + + def __init__(self) -> None: + self._rollout: TensorDict | None = None + + def add(self, rollout: TensorDict) -> None: + """Store a single rollout with batch shape `[num_envs, time]`.""" + if self._rollout is not None: + raise RuntimeError("RolloutBuffer already contains a rollout.") + self._rollout = rollout.clone() + + def get(self, flatten: bool = True) -> TensorDict: + """Return the stored rollout and clear the buffer.""" + if self._rollout is None: + raise RuntimeError("RolloutBuffer is empty.") + + rollout = self._rollout + self._rollout = None + + if not flatten: + return rollout + + total_batch = math.prod(rollout.batch_size) + return rollout.reshape(total_batch) + + def is_full(self) -> bool: + """Return whether a rollout is waiting to be consumed.""" + return self._rollout is not None diff --git a/embodichain/agents/rl/collector/__init__.py b/embodichain/agents/rl/collector/__init__.py new file mode 100644 index 00000000..683c0ad5 --- /dev/null +++ b/embodichain/agents/rl/collector/__init__.py @@ -0,0 +1,20 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .base import BaseCollector +from .sync_collector import SyncCollector + +__all__ = ["BaseCollector", "SyncCollector"] diff --git a/embodichain/agents/rl/collector/base.py b/embodichain/agents/rl/collector/base.py new file mode 100644 index 00000000..7c047aec --- /dev/null +++ b/embodichain/agents/rl/collector/base.py @@ -0,0 +1,37 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable + +from tensordict import TensorDict + +__all__ = ["BaseCollector"] + + +class BaseCollector(ABC): + """Base class for rollout collectors.""" + + @abstractmethod + def collect( + self, + num_steps: int, + on_step_callback: Callable[[TensorDict, dict], None] | None = None, + ) -> TensorDict: + """Collect a rollout and return it as a TensorDict.""" + raise NotImplementedError diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py new file mode 100644 index 00000000..5985897d --- /dev/null +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -0,0 +1,118 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Callable + +import torch +from tensordict import TensorDict + +from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation +from .base import BaseCollector + +__all__ = ["SyncCollector"] + + +class SyncCollector(BaseCollector): + """Synchronously collect rollouts from a vectorized environment.""" + + def __init__( + self, + env, + policy, + device: torch.device, + reset_every_rollout: bool = False, + ) -> None: + self.env = env + self.policy = policy + self.device = device + self.reset_every_rollout = reset_every_rollout + self.obs_td = self._reset_env() + + def collect( + self, + num_steps: int, + on_step_callback: Callable[[TensorDict, dict], None] | None = None, + ) -> TensorDict: + self.policy.train() + if self.reset_every_rollout: + self.obs_td = self._reset_env() + + rollout_steps: list[TensorDict] = [] + + for _ in range(num_steps): + obs_tensor = flatten_dict_observation(self.obs_td) + step_td = TensorDict( + {"observation": obs_tensor}, + batch_size=[obs_tensor.shape[0]], + device=self.device, + ) + self.policy.forward(step_td) + + next_obs, reward, terminated, truncated, env_info = self.env.step( + self._to_action_dict(step_td["action"]) + ) + next_obs_td = dict_to_tensordict(next_obs, self.device) + next_obs_tensor = flatten_dict_observation(next_obs_td) + done = (terminated | truncated).bool() + + step_td["next"] = TensorDict( + { + "observation": next_obs_tensor, + "reward": reward.float(), + "done": done, + "terminated": terminated.bool(), + "truncated": truncated.bool(), + }, + batch_size=step_td.batch_size, + device=self.device, + ) + rollout_steps.append(step_td.clone()) + + if on_step_callback is not None: + on_step_callback(step_td, env_info) + + self.obs_td = next_obs_td + + rollout = torch.stack(rollout_steps, dim=1) + self._attach_next_values(rollout) + return rollout + + def _attach_next_values(self, rollout: TensorDict) -> None: + """Populate `next.value` for GAE bootstrap.""" + next_values = torch.zeros_like(rollout["value"]) + next_values[:, :-1] = rollout["value"][:, 1:] + + last_next_td = TensorDict( + {"observation": rollout["next", "observation"][:, -1]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + self.policy.get_value(last_next_td) + next_values[:, -1] = last_next_td["value"] + rollout["next", "value"] = next_values + + def _reset_env(self) -> TensorDict: + obs, _ = self.env.reset() + return dict_to_tensordict(obs, self.device) + + def _to_action_dict(self, action: torch.Tensor) -> dict[str, torch.Tensor]: + am = getattr(self.env, "action_manager", None) + action_type = ( + am.action_type if am else getattr(self.env, "action_type", "delta_qpos") + ) + return {action_type: action} diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 4b0c0a0b..ccbe7d92 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -18,7 +18,6 @@ from typing import Dict, Type import torch -from gymnasium import spaces from .actor_critic import ActorCritic from .actor_only import ActorOnly @@ -45,8 +44,8 @@ def get_policy_class(name: str) -> Type[Policy] | None: def build_policy( policy_block: dict, - obs_space: spaces.Space, - action_space: spaces.Space, + obs_dim: int, + action_dim: int, device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, @@ -64,13 +63,24 @@ def build_policy( raise ValueError( "ActorCritic policy requires external 'actor' and 'critic' modules." ) - return policy_cls(obs_space, action_space, device, actor=actor, critic=critic) + return policy_cls( + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + actor=actor, + critic=critic, + ) elif name == "actor_only": if actor is None: raise ValueError("ActorOnly policy requires external 'actor' module.") - return policy_cls(obs_space, action_space, device, actor=actor) + return policy_cls( + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + actor=actor, + ) else: - return policy_cls(obs_space, action_space, device) + return policy_cls(obs_dim=obs_dim, action_dim=action_dim, device=device) def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 35f9a961..f2002e04 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -16,11 +16,11 @@ from __future__ import annotations -from typing import Dict, Any, Tuple - import torch import torch.nn as nn from torch.distributions.normal import Normal +from tensordict import TensorDict + from .mlp import MLP from .policy import Policy @@ -36,23 +36,21 @@ class ActorCritic(Policy): This allows seamless swapping with other policy implementations (e.g., VLAPolicy) without modifying RL algorithm code. - Implements: - - get_action(obs, deterministic=False) -> (action, log_prob, value) - - get_value(obs) - - evaluate_actions(obs, actions) -> (log_prob, entropy, value) + Implements TensorDict-native interfaces while preserving `get_action()` + compatibility for evaluation and legacy call-sites. """ def __init__( self, - obs_space, - action_space, + obs_dim: int, + action_dim: int, device: torch.device, actor: nn.Module, critic: nn.Module, ): super().__init__() - self.obs_dim = obs_space.shape[-1] - self.action_dim = action_space.shape[-1] + self.obs_dim = obs_dim + self.action_dim = action_dim self.device = device # Require external injection of actor and critic @@ -66,31 +64,33 @@ def __init__( self.log_std_min = -5.0 self.log_std_max = 2.0 - @torch.no_grad() - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _distribution(self, obs: torch.Tensor) -> Normal: mean = self.actor(obs) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) - dist = Normal(mean, std) + return Normal(mean, std) + + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["observation"] + dist = self._distribution(obs) + mean = dist.mean action = mean if deterministic else dist.sample() - log_prob = dist.log_prob(action).sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return action, log_prob, value + tensordict["action"] = action + tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) + tensordict["value"] = self.critic(obs).squeeze(-1) + return tensordict - @torch.no_grad() - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - return self.critic(obs).squeeze(-1) + def get_value(self, tensordict: TensorDict) -> TensorDict: + tensordict["value"] = self.critic(tensordict["observation"]).squeeze(-1) + return tensordict - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) - log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) - std = log_std.exp().expand(mean.shape[0], -1) - dist = Normal(mean, std) - log_prob = dist.log_prob(actions).sum(dim=-1) - entropy = dist.entropy().sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return log_prob, entropy, value + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["observation"] + action = tensordict["action"] + dist = self._distribution(obs) + tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) + tensordict["entropy"] = dist.entropy().sum(dim=-1) + tensordict["value"] = self.critic(obs).squeeze(-1) + return tensordict diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py index c54fd515..e80109c4 100644 --- a/embodichain/agents/rl/models/actor_only.py +++ b/embodichain/agents/rl/models/actor_only.py @@ -16,11 +16,11 @@ from __future__ import annotations -from typing import Tuple - import torch import torch.nn as nn from torch.distributions.normal import Normal +from tensordict import TensorDict + from .policy import Policy @@ -33,14 +33,14 @@ class ActorOnly(Policy): def __init__( self, - obs_space, - action_space, + obs_dim: int, + action_dim: int, device: torch.device, actor: nn.Module, ): super().__init__() - self.obs_dim = obs_space.shape[-1] - self.action_dim = action_space.shape[-1] + self.obs_dim = obs_dim + self.action_dim = action_dim self.device = device self.actor = actor @@ -50,31 +50,40 @@ def __init__( self.log_std_min = -5.0 self.log_std_max = 2.0 - @torch.no_grad() - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _distribution(self, obs: torch.Tensor) -> Normal: mean = self.actor(obs) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) - dist = Normal(mean, std) + return Normal(mean, std) + + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["observation"] + dist = self._distribution(obs) + mean = dist.mean action = mean if deterministic else dist.sample() - log_prob = dist.log_prob(action).sum(dim=-1) - value = torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) - return action, log_prob, value + tensordict["action"] = action + tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) + tensordict["value"] = torch.zeros( + obs.shape[0], device=self.device, dtype=obs.dtype + ) + return tensordict - @torch.no_grad() - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - return torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) + def get_value(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["observation"] + tensordict["value"] = torch.zeros( + obs.shape[0], device=self.device, dtype=obs.dtype + ) + return tensordict - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) - log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) - std = log_std.exp().expand(mean.shape[0], -1) - dist = Normal(mean, std) - log_prob = dist.log_prob(actions).sum(dim=-1) - entropy = dist.entropy().sum(dim=-1) - value = torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) - return log_prob, entropy, value + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["observation"] + action = tensordict["action"] + dist = self._distribution(obs) + tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) + tensordict["entropy"] = dist.entropy().sum(dim=-1) + tensordict["value"] = torch.zeros( + obs.shape[0], device=self.device, dtype=obs.dtype + ) + return tensordict diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index 21c13a96..0e8faac3 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -23,11 +23,11 @@ from __future__ import annotations -from typing import Tuple from abc import ABC, abstractmethod import torch.nn as nn import torch +from tensordict import TensorDict class Policy(nn.Module, ABC): @@ -45,11 +45,10 @@ class Policy(nn.Module, ABC): def __init__(self) -> None: super().__init__() - @abstractmethod def get_action( self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sample an action from the policy. + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compatibility layer for tensor-only callers. Args: obs: Observation tensor of shape (batch_size, obs_dim) @@ -61,34 +60,41 @@ def get_action( - log_prob: Log probability of the action, shape (batch_size,) - value: Value estimate, shape (batch_size,) """ + td = TensorDict( + {"observation": obs}, + batch_size=[obs.shape[0]], + device=obs.device, + ) + td = self.forward(td, deterministic=deterministic) + return td["action"], td["sample_log_prob"], td["value"] + + @abstractmethod + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + """Write sampled actions and value estimates into the TensorDict.""" raise NotImplementedError @abstractmethod - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - """Get value estimate for given observations. + def get_value(self, tensordict: TensorDict) -> TensorDict: + """Write value estimate for the given observations into the TensorDict. Args: - obs: Observation tensor of shape (batch_size, obs_dim) + tensordict: Input TensorDict containing `observation`. Returns: - Value estimate tensor of shape (batch_size,) + TensorDict with `value` populated. """ raise NotImplementedError @abstractmethod - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Evaluate actions and compute log probabilities, entropy, and values. + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + """Evaluate actions and write log prob, entropy, and values. Args: - obs: Observation tensor of shape (batch_size, obs_dim) - actions: Action tensor of shape (batch_size, action_dim) + tensordict: TensorDict containing `observation` and `action`. Returns: - Tuple of (log_prob, entropy, value): - - log_prob: Log probability of actions, shape (batch_size,) - - entropy: Entropy of the action distribution, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + TensorDict with `sample_log_prob`, `entropy`, and `value` populated. """ raise NotImplementedError diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index ec7c968e..62418ad0 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -29,6 +29,7 @@ from embodichain.agents.rl.models import build_policy, get_registered_policy_names from embodichain.agents.rl.models import build_mlp_from_cfg from embodichain.agents.rl.algo import build_algo, get_registered_algo_names +from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation from embodichain.agents.rl.utils.trainer import Trainer from embodichain.utils import logger from embodichain.lab.gym.envs.tasks.rl import build_env @@ -64,7 +65,9 @@ def train_from_config(config_path: str): seed = int(trainer_cfg.get("seed", 1)) device_str = trainer_cfg.get("device", "cpu") iterations = int(trainer_cfg.get("iterations", 250)) - rollout_steps = int(trainer_cfg.get("rollout_steps", 2048)) + buffer_size = int( + trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048)) + ) enable_eval = bool(trainer_cfg.get("enable_eval", False)) eval_freq = int(trainer_cfg.get("eval_freq", 10000)) save_freq = int(trainer_cfg.get("save_freq", 50000)) @@ -163,6 +166,9 @@ def train_from_config(config_path: str): ) env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg) + sample_obs, _ = env.reset() + sample_obs_td = dict_to_tensordict(sample_obs, device) + obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] # Create evaluation environment only if enabled eval_env = None @@ -178,13 +184,17 @@ def train_from_config(config_path: str): # Build Policy via registry policy_name = policy_block["name"] + action_dim = policy_block.get("action_dim") + if action_dim is None: + raise ValueError("Policy config must define 'action_dim'.") + action_dim = int(action_dim) + env_action_dim = env.action_space.shape[-1] + if action_dim != env_action_dim: + raise ValueError( + f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}." + ) # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only) if policy_name.lower() == "actor_critic": - # Get observation dimension from flattened observation space - # flattened_observation_space returns Box space for RL training - obs_dim = env.flattened_observation_space.shape[-1] - action_dim = env.action_space.shape[-1] - actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: @@ -197,16 +207,13 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.flattened_observation_space, - env.action_space, + obs_dim, + action_dim, device, actor=actor, critic=critic, ) elif policy_name.lower() == "actor_only": - obs_dim = env.flattened_observation_space.shape[-1] - action_dim = env.action_space.shape[-1] - actor_cfg = policy_block.get("actor") if actor_cfg is None: raise ValueError( @@ -217,15 +224,13 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.flattened_observation_space, - env.action_space, + obs_dim, + action_dim, device, actor=actor, ) else: - policy = build_policy( - policy_block, env.flattened_observation_space, env.action_space, device - ) + policy = build_policy(policy_block, obs_dim, action_dim, device) # Build Algorithm via factory algo_name = algo_block["name"].lower() @@ -276,7 +281,7 @@ def train_from_config(config_path: str): policy=policy, env=env, algorithm=algo, - num_steps=rollout_steps, + buffer_size=buffer_size, batch_size=algo_cfg["batch_size"], writer=writer, eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled @@ -299,7 +304,7 @@ def train_from_config(config_path: str): f"Algorithm: {algo_name} (available: {get_registered_algo_names()})" ) - total_steps = int(iterations * rollout_steps * env.num_envs) + total_steps = int(iterations * buffer_size * env.num_envs) logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})") try: diff --git a/embodichain/agents/rl/utils/__init__.py b/embodichain/agents/rl/utils/__init__.py index 7bd835e8..bd6dbc4f 100644 --- a/embodichain/agents/rl/utils/__init__.py +++ b/embodichain/agents/rl/utils/__init__.py @@ -15,9 +15,11 @@ # ---------------------------------------------------------------------------- from .config import AlgorithmCfg -from .helper import flatten_dict_observation +from .helper import compute_gae, dict_to_tensordict, flatten_dict_observation __all__ = [ "AlgorithmCfg", + "compute_gae", + "dict_to_tensordict", "flatten_dict_observation", ] diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index 42259506..aa38d6d8 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -14,40 +14,114 @@ # limitations under the License. # ---------------------------------------------------------------------------- +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + import torch from tensordict import TensorDict +__all__ = [ + "compute_gae", + "dict_to_tensordict", + "flatten_dict_observation", +] -def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: - """ - Flatten hierarchical TensorDict observations from ObservationManager. - Recursively traverse nested TensorDicts, collect all tensor values, - flatten each to (num_envs, -1), and concatenate in sorted key order. +def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: + """Flatten a hierarchical observation TensorDict into a 2D tensor. Args: - obs: Nested TensorDict structure, e.g. TensorDict(robot=TensorDict(qpos=..., qvel=...), ...) + obs: Observation TensorDict with batch dimension `[num_envs]`. Returns: - Concatenated flat tensor of shape (num_envs, total_dim) + Flattened observation tensor of shape `[num_envs, obs_dim]`. """ - obs_list = [] + obs_list: list[torch.Tensor] = [] - def _collect_tensors(d, prefix=""): - """Recursively collect tensors from nested TensorDicts in sorted order.""" - for key in sorted(d.keys()): - full_key = f"{prefix}/{key}" if prefix else key - value = d[key] + def _collect_tensors(data: TensorDict) -> None: + for key in sorted(data.keys()): + value = data[key] if isinstance(value, TensorDict): - _collect_tensors(value, full_key) + _collect_tensors(value) elif isinstance(value, torch.Tensor): - # Flatten tensor to (num_envs, -1) shape obs_list.append(value.flatten(start_dim=1)) _collect_tensors(obs) if not obs_list: - raise ValueError("No tensors found in observation TensorDict") + raise ValueError("No tensors found in observation TensorDict.") + + return torch.cat(obs_list, dim=-1) + + +def dict_to_tensordict( + obs_dict: TensorDict | Mapping[str, Any], device: torch.device | str +) -> TensorDict: + """Convert an environment observation mapping into a TensorDict. + + Args: + obs_dict: Environment observation returned by `reset()` or `step()`. + device: Target device for the resulting TensorDict. + + Returns: + Observation TensorDict moved onto the target device. + """ + if isinstance(obs_dict, TensorDict): + return obs_dict.to(device) + if not isinstance(obs_dict, Mapping): + raise TypeError( + f"Expected observation mapping or TensorDict, got {type(obs_dict)!r}." + ) + return TensorDict.from_dict(dict(obs_dict), device=device) + + +def compute_gae( + rollout: TensorDict, gamma: float, gae_lambda: float +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute GAE over a rollout with batch shape `[num_envs, time]`. + + Args: + rollout: Rollout TensorDict containing `value` and `next` transition data. + gamma: Discount factor. + gae_lambda: GAE lambda coefficient. + + Returns: + Tuple of `(advantages, returns)`, both shaped `[num_envs, time]`. + """ + rewards = rollout["next", "reward"].float() + dones = rollout["next", "done"].bool() + values = rollout["value"].float() + + if rewards.ndim != 2: + raise ValueError( + f"Expected reward tensor with shape [num_envs, time], got {rewards.shape}." + ) + + next_values = _get_next_values(rollout, values) + num_envs, time_dim = rewards.shape + advantages = torch.zeros_like(rewards) + last_advantage = torch.zeros(num_envs, device=rewards.device, dtype=rewards.dtype) + + for t in reversed(range(time_dim)): + not_done = (~dones[:, t]).float() + delta = rewards[:, t] + gamma * next_values[:, t] * not_done - values[:, t] + last_advantage = delta + gamma * gae_lambda * not_done * last_advantage + advantages[:, t] = last_advantage + + returns = advantages + values + rollout["advantage"] = advantages + rollout["return"] = returns + return advantages, returns + + +def _get_next_values(rollout: TensorDict, values: torch.Tensor) -> torch.Tensor: + """Resolve next-step values for GAE bootstrap.""" + next_value = rollout.get(("next", "value"), None) + if next_value is not None: + return next_value.float() - result = torch.cat(obs_list, dim=-1) - return result + next_values = torch.zeros_like(values) + next_values[:, :-1] = values[:, 1:] + return next_values diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 5b17a8e0..54987967 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Dict, Any, Tuple, Callable +from typing import Dict import time import numpy as np import torch @@ -25,6 +25,8 @@ import wandb from tensordict import TensorDict +from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.agents.rl.collector import SyncCollector from embodichain.lab.gym.envs.managers.event_manager import EventManager from .helper import flatten_dict_observation @@ -37,7 +39,7 @@ def __init__( policy, env, algorithm, - num_steps: int, + buffer_size: int, batch_size: int, writer: SummaryWriter | None, eval_freq: int, @@ -54,7 +56,7 @@ def __init__( self.env = env self.eval_env = eval_env self.algorithm = algorithm - self.num_steps = num_steps + self.buffer_size = buffer_size self.batch_size = batch_size self.writer = writer self.eval_freq = eval_freq @@ -76,27 +78,20 @@ def __init__( self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) - # initial obs (assume env returns torch tensors already on target device) - obs, _ = self.env.reset() - - # Initialize algorithm's buffer - # Flatten TensorDict observations from ObservationManager to tensor for RL algorithms - if isinstance(obs, TensorDict): - obs_tensor = flatten_dict_observation(obs) - obs_dim = obs_tensor.shape[-1] - num_envs = obs_tensor.shape[0] - # Store flattened observation for RL training - self.obs = obs_tensor - - action_space = getattr(self.env, "action_space", None) - action_dim = action_space.shape[-1] if action_space else None - if action_dim is None: - raise RuntimeError( - "Env must expose action_space with shape for buffer initialization." - ) - - # Algorithm manages its own buffer - self.algorithm.initialize_buffer(num_steps, num_envs, obs_dim, action_dim) + self.buffer = RolloutBuffer() + self.collector = SyncCollector( + env=self.env, + policy=self.policy, + device=self.device, + reset_every_rollout=bool( + getattr( + getattr(self.algorithm, "cfg", None), "reset_every_rollout", False + ) + ), + ) + num_envs = getattr(self.env, "num_envs", None) + if num_envs is None: + raise RuntimeError("Env must expose num_envs for trainer statistics.") # episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) @@ -137,7 +132,7 @@ def train(self, total_timesteps: int): print(f"Start training, total steps: {total_timesteps}") while self.global_step < total_timesteps: self._collect_rollout() - losses = self.algorithm.update() + losses = self.algorithm.update(self.buffer.get(flatten=False)) self._log_train(losses) if ( self.eval_freq > 0 @@ -150,11 +145,13 @@ def train(self, total_timesteps: int): @torch.no_grad() def _collect_rollout(self): - """Collect a rollout. Algorithm controls the data collection process.""" + """Collect a rollout with the synchronous collector.""" # Callback function for statistics and logging - def on_step(obs, actions, reward, done, info, next_obs): + def on_step(tensordict: TensorDict, info: dict): """Callback called at each step during rollout collection.""" + reward = tensordict["next", "reward"] + done = tensordict["next", "done"] # Episode stats (stay on device; convert only when episode ends) self.curr_ret += reward self.curr_len += 1 @@ -167,10 +164,7 @@ def on_step(obs, actions, reward, done, info, next_obs): self.curr_ret[done_idx] = 0 self.curr_len[done_idx] = 0 - # Update global step and observation - # next_obs is already flattened in algorithm's collect_rollout - self.obs = next_obs - self.global_step += next_obs.shape[0] + self.global_step += tensordict.batch_size[0] if isinstance(info, dict): rewards_dict = info.get("rewards") @@ -183,14 +177,11 @@ def on_step(obs, actions, reward, done, info, next_obs): if log_dict and self.use_wandb: wandb.log(log_dict, step=self.global_step) - # Algorithm controls data collection - result = self.algorithm.collect_rollout( - env=self.env, - policy=self.policy, - obs=self.obs, - num_steps=self.num_steps, + rollout = self.collector.collect( + num_steps=self.buffer_size, on_step_callback=on_step, ) + self.buffer.add(rollout) def _log_train(self, losses: Dict[str, float]): if self.writer: diff --git a/tests/agents/test_rl.py b/tests/agents/test_rl.py index e0a5beaf..f356502c 100644 --- a/tests/agents/test_rl.py +++ b/tests/agents/test_rl.py @@ -68,7 +68,7 @@ def setup_method(self): test_train_config = train_config.copy() test_train_config["trainer"]["gym_config"] = self.temp_gym_config_path test_train_config["trainer"]["iterations"] = 2 - test_train_config["trainer"]["rollout_steps"] = 32 + test_train_config["trainer"]["buffer_size"] = 32 test_train_config["trainer"]["eval_freq"] = 1000000 # Disable eval test_train_config["trainer"]["save_freq"] = 1000000 # Disable save test_train_config["trainer"]["headless"] = True From 3ef4e45cab925c46f063c688a24d4885e6055c44 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 10:27:19 +0000 Subject: [PATCH 02/23] Refactor: rl rollout around shared tensordict --- docs/source/overview/rl/algorithm.md | 37 +-- docs/source/overview/rl/models.md | 15 +- docs/source/overview/rl/trainer.md | 5 +- docs/source/tutorial/rl.rst | 79 +++--- .../agents/rl/buffer/standard_buffer.py | 128 +++++++++- embodichain/agents/rl/collector/base.py | 1 + .../agents/rl/collector/sync_collector.py | 52 ++-- embodichain/agents/rl/utils/trainer.py | 23 +- embodichain/lab/gym/envs/base_env.py | 2 + embodichain/lab/gym/envs/embodied_env.py | 123 ++++++++-- tests/agents/test_shared_rollout.py | 225 ++++++++++++++++++ 11 files changed, 553 insertions(+), 137 deletions(-) create mode 100644 tests/agents/test_shared_rollout.py diff --git a/docs/source/overview/rl/algorithm.md b/docs/source/overview/rl/algorithm.md index a22110e8..e7970709 100644 --- a/docs/source/overview/rl/algorithm.md +++ b/docs/source/overview/rl/algorithm.md @@ -5,20 +5,17 @@ This module contains the core implementations of reinforcement learning algorith ## Main Classes and Functions ### BaseAlgorithm -- Abstract base class for RL algorithms, defining common interfaces such as buffer initialization, data collection, and update. +- Abstract base class for RL algorithms, defining a single update interface over a collected rollout. - Key methods: - - `initialize_buffer(num_steps, num_envs, obs_dim, action_dim)`: Initialize the trajectory buffer. - - `collect_rollout(env, policy, obs, num_steps, on_step_callback)`: Collect interaction data. - - `update()`: Update the policy based on collected data. -- Designed to be algorithm-agnostic; Trainer only depends on this interface to support various RL algorithms. -- Supports multi-environment parallel collection, compatible with Gymnasium/IsaacGym environments. + - `update(rollout)`: Update the policy based on a shared rollout `TensorDict`. +- Designed to be algorithm-agnostic; `Trainer` handles collection while algorithms focus on loss computation and optimization. +- Supports multi-environment parallel collection through a shared `[N, T]` rollout `TensorDict`. ### PPO - Mainstream on-policy algorithm, supports Generalized Advantage Estimation (GAE), policy update, and hyperparameter configuration. - Key methods: - - `_compute_gae(rewards, values, dones)`: Generalized Advantage Estimation. - - `collect_rollout`: Collect trajectories and compute advantages/returns. - - `update`: Multi-epoch minibatch optimization, including entropy, value, and policy loss, with gradient clipping. + - `compute_gae(rollout, gamma, gae_lambda)`: Generalized Advantage Estimation over a shared rollout `TensorDict`. + - `update(rollout)`: Multi-epoch minibatch optimization, including entropy, value, and policy loss, with gradient clipping. - Supports custom callbacks, detailed logging, and GPU acceleration. - Typical training flow: collect rollout → compute advantage/return → multi-epoch minibatch optimization. - Supports advantage normalization, entropy regularization, value loss weighting, etc. @@ -31,8 +28,7 @@ This module contains the core implementations of reinforcement learning algorith - Key methods: - `_compute_step_returns_and_mask(rewards, dones)`: Step-wise discounted returns and valid-step mask. - `_compute_step_group_advantages(step_returns, seq_mask)`: Per-step group normalization with masked mean/std. - - `collect_rollout`: Collect trajectories and compute step-wise advantages. - - `update`: Multi-epoch minibatch optimization with optional KL penalty. + - `update(rollout)`: Multi-epoch minibatch optimization with optional KL penalty. - Supports both **Embodied AI** (dense reward, from-scratch training) and **VLA** (sparse reward, fine-tuning) modes via `kl_coef` configuration. ### Config Classes @@ -43,19 +39,11 @@ This module contains the core implementations of reinforcement learning algorith ## Code Example ```python class BaseAlgorithm: - def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim): - ... - def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None): - ... - def update(self): + def update(self, rollout): ... class PPO(BaseAlgorithm): - def _compute_gae(self, rewards, values, dones): - ... - def collect_rollout(self, ...): - ... - def update(self): + def update(self, rollout): ... ``` @@ -71,10 +59,9 @@ class PPO(BaseAlgorithm): - Typical usage: ```python algo = PPO(cfg, policy) -buffer = algo.initialize_buffer(...) -for _ in range(num_iterations): - algo.collect_rollout(...) - algo.update() +rollout = collector.collect(buffer_size, rollout=buffer.start_rollout()) +buffer.add(rollout) +algo.update(buffer.get(flatten=False)) ``` --- diff --git a/docs/source/overview/rl/models.md b/docs/source/overview/rl/models.md index c67c58ae..56de7926 100644 --- a/docs/source/overview/rl/models.md +++ b/docs/source/overview/rl/models.md @@ -7,9 +7,10 @@ This module contains RL policy networks and related model implementations, suppo ### Policy - Abstract base class for RL policies; all policies must inherit from it. - Unified interface: - - `get_action(obs, deterministic=False)`: Sample or output actions. - - `get_value(obs)`: Estimate state value. - - `evaluate_actions(obs, actions)`: Evaluate action probabilities, entropy, and value. + - `forward(tensordict, deterministic=False)`: Write action, log prob, and value into a `TensorDict`. + - `get_value(tensordict)`: Estimate state value into a `TensorDict`. + - `evaluate_actions(tensordict)`: Evaluate action probabilities, entropy, and value from a `TensorDict`. +- `get_action(obs, deterministic=False)` is retained as a compatibility layer for evaluation and legacy callers. - Supports GPU deployment and distributed training. ### ActorCritic @@ -19,8 +20,8 @@ This module contains RL policy networks and related model implementations, suppo - Actor-only policy without Critic. Used with GRPO (Group Relative Policy Optimization), which estimates advantages via group-level return comparison instead of a value function. - Supports Gaussian action distributions, learnable log_std, suitable for continuous action spaces. - Key methods: - - `get_action`: Actor network outputs mean, samples action, returns log_prob and critic value. - - `evaluate_actions`: Used for loss calculation in PPO/SAC algorithms. + - `forward`: Actor network outputs mean, samples action, and writes policy outputs into a `TensorDict`. + - `evaluate_actions`: Used for loss calculation in PPO/GRPO algorithms. - Custom actor/critic network architectures supported (e.g., MLP/CNN/Transformer). ### MLP @@ -29,14 +30,14 @@ This module contains RL policy networks and related model implementations, suppo - Supports orthogonal initialization and output reshaping. ### Factory Functions -- `build_policy(policy_block, obs_space, action_space, device, ...)`: Automatically build policy from config. +- `build_policy(policy_block, obs_dim, action_dim, device, ...)`: Automatically build policy from config. - `build_mlp_from_cfg(module_cfg, in_dim, out_dim)`: Automatically build MLP from config. ## Usage Example ```python actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1) -policy = build_policy(policy_block, obs_space, action_space, device, actor=actor, critic=critic) +policy = build_policy(policy_block, obs_dim, action_dim, device, actor=actor, critic=critic) action, log_prob, value = policy.get_action(obs) ``` diff --git a/docs/source/overview/rl/trainer.md b/docs/source/overview/rl/trainer.md index 1a2b0fe3..ffe09cbd 100644 --- a/docs/source/overview/rl/trainer.md +++ b/docs/source/overview/rl/trainer.md @@ -20,7 +20,7 @@ This module implements the main RL training loop, logging management, and event- ## Main Methods - `train(total_timesteps)`: Main training loop, automatically collects data, updates policy, and logs. -- `_collect_rollout()`: Collect one rollout, supports custom callback statistics. +- `_collect_rollout()`: Collect one rollout through `SyncCollector`, supports custom callback statistics. - `_log_train(losses)`: Log training loss, reward, sampling speed, etc. - `_eval_once()`: Periodic evaluation, records evaluation metrics. - `save_checkpoint()`: Save model parameters and training state. @@ -35,7 +35,7 @@ This module implements the main RL training loop, logging management, and event- ## Usage Example ```python -trainer = Trainer(policy, env, algorithm, num_steps, batch_size, writer, ...) +trainer = Trainer(policy, env, algorithm, buffer_size, batch_size, writer, ...) trainer.train(total_steps) trainer.save_checkpoint() ``` @@ -44,6 +44,7 @@ trainer.save_checkpoint() - Custom event modules can be implemented for environment reset, data collection, evaluation, etc. - Supports multi-environment parallelism and distributed training. - Training process can be flexibly adjusted via config files. +- The current trainer uses a shared rollout `TensorDict`: collector writes policy-side fields and `EmbodiedEnv` writes environment-side `next.*` fields through `set_rollout_buffer()`. ## Practical Tips - It is recommended to perform periodic evaluation and model saving to prevent loss of progress during training. diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index b0126dde..bcefb80a 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -57,14 +57,14 @@ Configuration Sections Runtime Settings ^^^^^^^^^^^^^^^^ -The ``runtime`` section controls experiment setup: +The ``trainer`` section controls experiment setup: - **exp_name**: Experiment name (used for output directories) - **seed**: Random seed for reproducibility -- **cuda**: Whether to use GPU (default: true) +- **device**: Runtime device string, e.g. ``"cpu"`` or ``"cuda:0"`` - **headless**: Whether to run simulation in headless mode - **iterations**: Number of training iterations -- **rollout_steps**: Steps per rollout (e.g., 1024) +- **buffer_size**: Steps collected per rollout (e.g., 1024) - **eval_freq**: Frequency of evaluation (in steps) - **save_freq**: Frequency of checkpoint saving (in steps) - **use_wandb**: Whether to enable Weights & Biases logging (set in JSON config) @@ -109,7 +109,7 @@ Policy Configuration The ``policy`` section defines the neural network policy: - **name**: Policy name (e.g., "actor_critic", "vla") -- **cfg**: Policy-specific hyperparameters (empty for actor_critic) +- **action_dim**: Policy output action dimension; must match ``env.action_space`` - **actor**: Actor network configuration (required for actor_critic) - **critic**: Critic network configuration (required for actor_critic) @@ -119,16 +119,20 @@ Example: "policy": { "name": "actor_critic", - "cfg": {}, + "action_dim": 8, "actor": { "type": "mlp", - "hidden_sizes": [256, 256], - "activation": "relu" + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } }, "critic": { "type": "mlp", - "hidden_sizes": [256, 256], - "activation": "relu" + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } } } @@ -231,9 +235,9 @@ Training Process The training process follows this sequence: -1. **Rollout Phase**: Algorithm collects trajectories by interacting with the environment (via ``collect_rollout``). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info. -2. **Advantage/Return Computation**: Algorithm computes advantages and returns (e.g. GAE for PPO, step-wise group normalization for GRPO; stored in buffer extras) -3. **Update Phase**: Algorithm updates the policy using collected data (e.g., PPO) +1. **Rollout Phase**: ``SyncCollector`` interacts with the environment and writes policy-side fields into a shared rollout ``TensorDict``. ``EmbodiedEnv`` writes environment-side ``next.*`` fields into the same rollout via ``set_rollout_buffer()``. +2. **Advantage/Return Computation**: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) +3. **Update Phase**: Algorithm updates the policy with ``update(rollout)`` 4. **Logging**: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases 5. **Evaluation** (periodic): Trainer evaluates the current policy 6. **Checkpointing** (periodic): Trainer saves model checkpoints @@ -252,22 +256,18 @@ All policies must inherit from the ``Policy`` abstract base class: device: torch.device @abstractmethod - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Returns (action, log_prob, value)""" + def forward(self, tensordict, deterministic: bool = False): + """Writes action, sample_log_prob, and value into the TensorDict.""" raise NotImplementedError @abstractmethod - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - """Returns value estimate""" + def get_value(self, tensordict): + """Writes value estimate into the TensorDict.""" raise NotImplementedError @abstractmethod - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Returns (log_prob, entropy, value)""" + def evaluate_actions(self, tensordict): + """Writes log_prob, entropy, and value into the TensorDict.""" raise NotImplementedError Available Policies @@ -292,13 +292,13 @@ Adding a New Algorithm To add a new algorithm: 1. Create a new algorithm class in ``embodichain/agents/rl/algo/`` -2. Implement ``initialize_buffer()``, ``collect_rollout()``, and ``update()`` methods +2. Implement ``update(rollout)`` and consume the shared rollout ``TensorDict`` 3. Register in ``algo/__init__.py``: .. code-block:: python + from tensordict import TensorDict from embodichain.agents.rl.algo import BaseAlgorithm, register_algo - from embodichain.agents.rl.buffer import RolloutBuffer @register_algo("my_algo") class MyAlgorithm(BaseAlgorithm): @@ -306,24 +306,11 @@ To add a new algorithm: self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) - self.buffer = None - - def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim): - """Initialize the algorithm's buffer.""" - self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device) - - def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None): - """Control data collection process (interact with env, fill buffer, compute advantages/returns).""" - # Collect trajectories - # Compute advantages/returns (e.g., GAE for on-policy algorithms) - # Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret}) - # Return empty dict (dense logging handled in trainer) - return {} - def update(self): - """Update the policy using collected data.""" - # Access extras from buffer: self.buffer._extras.get("advantages") - # Use self.buffer to update policy + def update(self, rollout: TensorDict): + """Update the policy using a collected rollout.""" + # compute advantages / returns from rollout + # optimize policy parameters return {"loss": 0.0} Adding a New Policy @@ -340,16 +327,16 @@ To add a new policy: @register_policy("my_policy") class MyPolicy(Policy): - def __init__(self, obs_space, action_space, device, config): + def __init__(self, obs_dim, action_dim, device, config): super().__init__() self.device = device # Initialize your networks here - def get_action(self, obs, deterministic=False): + def forward(self, tensordict, deterministic=False): ... - def get_value(self, obs): + def get_value(self, tensordict): ... - def evaluate_actions(self, obs, actions): + def evaluate_actions(self, tensordict): ... Adding a New Environment @@ -414,7 +401,7 @@ Best Practices - **Observation Format**: Environments should provide consistent observation shape/types (torch.float32) and a single ``done = terminated | truncated``. -- **Algorithm Interface**: Algorithms must implement ``initialize_buffer()``, ``collect_rollout()``, and ``update()`` methods. The algorithm completely controls data collection and buffer management. +- **Algorithm Interface**: Algorithms implement ``update(rollout)`` and consume a shared rollout ``TensorDict``. Collection is handled by ``SyncCollector`` plus environment-side rollout writes in ``EmbodiedEnv``. - **Reward Configuration**: Use the ``RewardManager`` in your environment config to define reward components. Organize reward components in ``info["rewards"]`` dictionary and metrics in ``info["metrics"]`` dictionary. The trainer performs dense per-step logging directly from environment info. diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index fe4528af..9196b893 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -18,30 +18,58 @@ import math +import torch from tensordict import TensorDict __all__ = ["RolloutBuffer"] class RolloutBuffer: - """Single-rollout buffer backed by a TensorDict.""" + """Single-rollout buffer backed by a preallocated TensorDict.""" - def __init__(self) -> None: - self._rollout: TensorDict | None = None + def __init__( + self, + num_envs: int, + rollout_len: int, + obs_dim: int, + action_dim: int, + device: torch.device, + ) -> None: + self.num_envs = num_envs + self.rollout_len = rollout_len + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + self._rollout = self._allocate_rollout() + self._is_full = False - def add(self, rollout: TensorDict) -> None: - """Store a single rollout with batch shape `[num_envs, time]`.""" - if self._rollout is not None: + def start_rollout(self) -> TensorDict: + """Return the shared rollout TensorDict for collector write-in.""" + if self._is_full: raise RuntimeError("RolloutBuffer already contains a rollout.") - self._rollout = rollout.clone() + self._clear_dynamic_fields() + return self._rollout + + def add(self, rollout: TensorDict) -> None: + """Mark the shared rollout as ready for consumption.""" + if rollout is not self._rollout: + raise ValueError( + "RolloutBuffer only accepts its shared rollout TensorDict." + ) + if tuple(rollout.batch_size) != (self.num_envs, self.rollout_len): + raise ValueError( + "Rollout batch size does not match buffer allocation: " + f"expected ({self.num_envs}, {self.rollout_len}), got {tuple(rollout.batch_size)}." + ) + self._is_full = True def get(self, flatten: bool = True) -> TensorDict: """Return the stored rollout and clear the buffer.""" - if self._rollout is None: + if not self._is_full: raise RuntimeError("RolloutBuffer is empty.") rollout = self._rollout - self._rollout = None + self._is_full = False if not flatten: return rollout @@ -51,4 +79,84 @@ def get(self, flatten: bool = True) -> TensorDict: def is_full(self) -> bool: """Return whether a rollout is waiting to be consumed.""" - return self._rollout is not None + return self._is_full + + def _allocate_rollout(self) -> TensorDict: + """Preallocate rollout storage with batch shape `[num_envs, time]`.""" + return TensorDict( + { + "observation": torch.empty( + self.num_envs, + self.rollout_len, + self.obs_dim, + dtype=torch.float32, + device=self.device, + ), + "action": torch.empty( + self.num_envs, + self.rollout_len, + self.action_dim, + dtype=torch.float32, + device=self.device, + ), + "sample_log_prob": torch.empty( + self.num_envs, + self.rollout_len, + dtype=torch.float32, + device=self.device, + ), + "value": torch.empty( + self.num_envs, + self.rollout_len, + dtype=torch.float32, + device=self.device, + ), + "next": { + "observation": torch.empty( + self.num_envs, + self.rollout_len, + self.obs_dim, + dtype=torch.float32, + device=self.device, + ), + "reward": torch.empty( + self.num_envs, + self.rollout_len, + dtype=torch.float32, + device=self.device, + ), + "done": torch.empty( + self.num_envs, + self.rollout_len, + dtype=torch.bool, + device=self.device, + ), + "terminated": torch.empty( + self.num_envs, + self.rollout_len, + dtype=torch.bool, + device=self.device, + ), + "truncated": torch.empty( + self.num_envs, + self.rollout_len, + dtype=torch.bool, + device=self.device, + ), + "value": torch.empty( + self.num_envs, + self.rollout_len, + dtype=torch.float32, + device=self.device, + ), + }, + }, + batch_size=[self.num_envs, self.rollout_len], + device=self.device, + ) + + def _clear_dynamic_fields(self) -> None: + """Drop algorithm-added fields before reusing the shared rollout.""" + for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): + if key in self._rollout.keys(): + del self._rollout[key] diff --git a/embodichain/agents/rl/collector/base.py b/embodichain/agents/rl/collector/base.py index 7c047aec..de3f2967 100644 --- a/embodichain/agents/rl/collector/base.py +++ b/embodichain/agents/rl/collector/base.py @@ -31,6 +31,7 @@ class BaseCollector(ABC): def collect( self, num_steps: int, + rollout: TensorDict | None = None, on_step_callback: Callable[[TensorDict, dict], None] | None = None, ) -> TensorDict: """Collect a rollout and return it as a TensorDict.""" diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 5985897d..3cf41931 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -46,15 +46,26 @@ def __init__( def collect( self, num_steps: int, + rollout: TensorDict | None = None, on_step_callback: Callable[[TensorDict, dict], None] | None = None, ) -> TensorDict: self.policy.train() if self.reset_every_rollout: self.obs_td = self._reset_env() - rollout_steps: list[TensorDict] = [] + if rollout is None: + raise ValueError( + "SyncCollector.collect() requires a preallocated rollout TensorDict." + ) + if tuple(rollout.batch_size) != (self.env.num_envs, num_steps): + raise ValueError( + "Preallocated rollout batch size mismatch: " + f"expected ({self.env.num_envs}, {num_steps}), got {tuple(rollout.batch_size)}." + ) + if hasattr(self.env, "set_rollout_buffer"): + self.env.set_rollout_buffer(rollout) - for _ in range(num_steps): + for step_idx in range(num_steps): obs_tensor = flatten_dict_observation(self.obs_td) step_td = TensorDict( {"observation": obs_tensor}, @@ -62,33 +73,26 @@ def collect( device=self.device, ) self.policy.forward(step_td) + rollout["observation"][:, step_idx] = obs_tensor + rollout["action"][:, step_idx] = step_td["action"] + rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] + rollout["value"][:, step_idx] = step_td["value"] next_obs, reward, terminated, truncated, env_info = self.env.step( self._to_action_dict(step_td["action"]) ) next_obs_td = dict_to_tensordict(next_obs, self.device) - next_obs_tensor = flatten_dict_observation(next_obs_td) - done = (terminated | truncated).bool() - - step_td["next"] = TensorDict( - { - "observation": next_obs_tensor, - "reward": reward.float(), - "done": done, - "terminated": terminated.bool(), - "truncated": truncated.bool(), - }, - batch_size=step_td.batch_size, - device=self.device, + self._write_step( + rollout=rollout, + step_idx=step_idx, + step_td=step_td, ) - rollout_steps.append(step_td.clone()) if on_step_callback is not None: - on_step_callback(step_td, env_info) + on_step_callback(rollout[:, step_idx], env_info) self.obs_td = next_obs_td - rollout = torch.stack(rollout_steps, dim=1) self._attach_next_values(rollout) return rollout @@ -116,3 +120,15 @@ def _to_action_dict(self, action: torch.Tensor) -> dict[str, torch.Tensor]: am.action_type if am else getattr(self.env, "action_type", "delta_qpos") ) return {action_type: action} + + def _write_step( + self, + rollout: TensorDict, + step_idx: int, + step_td: TensorDict, + ) -> None: + """Write policy-side fields for one transition into the shared rollout TensorDict.""" + rollout["observation"][:, step_idx] = step_td["observation"] + rollout["action"][:, step_idx] = step_td["action"] + rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] + rollout["value"][:, step_idx] = step_td["value"] diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 54987967..98739af5 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -77,8 +77,21 @@ def __init__( self.start_time = time.time() self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) - - self.buffer = RolloutBuffer() + num_envs = getattr(self.env, "num_envs", None) + if num_envs is None: + raise RuntimeError("Env must expose num_envs for trainer statistics.") + obs_dim = getattr(self.policy, "obs_dim", None) + action_dim = getattr(self.policy, "action_dim", None) + if obs_dim is None or action_dim is None: + raise RuntimeError("Policy must expose obs_dim and action_dim.") + + self.buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=self.buffer_size, + obs_dim=obs_dim, + action_dim=action_dim, + device=self.device, + ) self.collector = SyncCollector( env=self.env, policy=self.policy, @@ -89,9 +102,6 @@ def __init__( ) ), ) - num_envs = getattr(self.env, "num_envs", None) - if num_envs is None: - raise RuntimeError("Env must expose num_envs for trainer statistics.") # episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) @@ -152,7 +162,7 @@ def on_step(tensordict: TensorDict, info: dict): """Callback called at each step during rollout collection.""" reward = tensordict["next", "reward"] done = tensordict["next", "done"] - # Episode stats (stay on device; convert only when episode ends) + # Episode stats self.curr_ret += reward self.curr_len += 1 done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) @@ -179,6 +189,7 @@ def on_step(tensordict: TensorDict, info: dict): rollout = self.collector.collect( num_steps=self.buffer_size, + rollout=self.buffer.start_rollout(), on_step_callback=on_step, ) self.buffer.add(rollout) diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index 9042351e..ab945b45 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -645,6 +645,8 @@ def step( rewards=rewards, dones=dones, info=info, + terminateds=terminateds, + truncateds=truncateds, **kwargs, ) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 870a00ef..69b3160a 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -264,6 +264,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): # For example, a shared rollout buffer initialized in model training process and passed to the environment for data collection. self.rollout_buffer: TensorDict | None = None self._max_rollout_steps = 0 + self._rollout_buffer_mode: str | None = None if self.cfg.init_rollout_buffer: self.rollout_buffer = init_rollout_buffer_from_gym_space( obs_space=self.observation_space, @@ -273,6 +274,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): device=self.device, ) self._max_rollout_steps = self.rollout_buffer.shape[1] + self._rollout_buffer_mode = "episode" self.current_rollout_step = 0 @@ -295,6 +297,8 @@ def set_rollout_buffer(self, rollout_buffer: TensorDict) -> None: ) self.rollout_buffer = rollout_buffer self._max_rollout_steps = self.rollout_buffer.shape[1] + self.current_rollout_step = 0 + self._rollout_buffer_mode = self._infer_rollout_buffer_mode(rollout_buffer) def _init_sim_state(self, **kwargs): """Initialize the simulation state at the beginning of scene creation.""" @@ -416,31 +420,20 @@ def _hook_after_sim_step( if self.rollout_buffer is not None: buffer_device = self.rollout_buffer.device if self.current_rollout_step < self._max_rollout_steps: - # Extract data into episode buffer. - self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( - obs.to(buffer_device), non_blocking=True - ) - if isinstance(action, TensorDict): - action_to_store = ( - action["qpos"] - if "qpos" in action - else (action["qvel"] if "qvel" in action else action["qf"]) + if self._rollout_buffer_mode == "external_rl": + self._write_external_rl_rollout_step( + obs=obs, + rewards=rewards, + dones=dones, + terminateds=kwargs.get("terminateds"), + truncateds=kwargs.get("truncateds"), ) - elif isinstance(action, torch.Tensor): - action_to_store = action else: - logger.log_warning( - f"Unexpected action type {type(action)} in _hook_after_sim_step; " - "skipping action storage in rollout buffer." + self._write_episode_rollout_step( + obs=obs, + action=action, + rewards=rewards, ) - action_to_store = None - if action_to_store is not None: - self.rollout_buffer["actions"][ - :, self.current_rollout_step, ... - ].copy_(action_to_store.to(buffer_device), non_blocking=True) - self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( - rewards.to(buffer_device), non_blocking=True - ) self.current_rollout_step += 1 else: logger.log_warning( @@ -514,7 +507,10 @@ def _initialize_episode( ) # Clear episode buffers and reset success status for environments being reset - if self.rollout_buffer is not None: + if ( + self.rollout_buffer is not None + and self._rollout_buffer_mode != "external_rl" + ): self.current_rollout_step = 0 self.episode_success_status[env_ids_to_process] = False @@ -528,6 +524,87 @@ def _initialize_episode( if self.cfg.rewards: self.reward_manager.reset(env_ids=env_ids) + def _infer_rollout_buffer_mode(self, rollout_buffer: TensorDict) -> str: + """Infer whether the rollout buffer is env-owned episode data or external RL data.""" + if "next" in rollout_buffer.keys() and "observation" in rollout_buffer.keys(): + return "external_rl" + return "episode" + + def _write_episode_rollout_step( + self, + obs: EnvObs, + action: EnvAction, + rewards: torch.Tensor, + ) -> None: + """Write one step into the legacy episode recording rollout buffer.""" + buffer_device = self.rollout_buffer.device + self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( + obs.to(buffer_device), non_blocking=True + ) + if isinstance(action, TensorDict): + action_to_store = ( + action["qpos"] + if "qpos" in action + else (action["qvel"] if "qvel" in action else action["qf"]) + ) + elif isinstance(action, torch.Tensor): + action_to_store = action + else: + logger.log_warning( + f"Unexpected action type {type(action)} in _hook_after_sim_step; " + "skipping action storage in rollout buffer." + ) + action_to_store = None + if action_to_store is not None: + self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( + action_to_store.to(buffer_device), non_blocking=True + ) + self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( + rewards.to(buffer_device), non_blocking=True + ) + + def _write_external_rl_rollout_step( + self, + obs: EnvObs, + rewards: torch.Tensor, + dones: torch.Tensor, + terminateds: torch.Tensor | None, + truncateds: torch.Tensor | None, + ) -> None: + """Write environment-side fields into an externally managed RL rollout buffer.""" + from embodichain.agents.rl.utils import flatten_dict_observation + + buffer_device = self.rollout_buffer.device + obs_to_store = ( + flatten_dict_observation(obs) if isinstance(obs, TensorDict) else obs + ) + self.rollout_buffer["next", "observation"][:, self.current_rollout_step].copy_( + obs_to_store.to(buffer_device), non_blocking=True + ) + self.rollout_buffer["next", "reward"][:, self.current_rollout_step].copy_( + rewards.to(buffer_device), non_blocking=True + ) + self.rollout_buffer["next", "done"][:, self.current_rollout_step].copy_( + dones.to(buffer_device), non_blocking=True + ) + + terminateds = ( + terminateds + if terminateds is not None + else torch.zeros_like(dones, dtype=torch.bool) + ) + truncateds = ( + truncateds + if truncateds is not None + else torch.zeros_like(dones, dtype=torch.bool) + ) + self.rollout_buffer["next", "terminated"][:, self.current_rollout_step].copy_( + terminateds.to(buffer_device), non_blocking=True + ) + self.rollout_buffer["next", "truncated"][:, self.current_rollout_step].copy_( + truncateds.to(buffer_device), non_blocking=True + ) + def _step_action(self, action: EnvAction) -> EnvAction: """Set action control command into simulation. diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py new file mode 100644 index 00000000..3ebf62cf --- /dev/null +++ b/tests/agents/test_shared_rollout.py @@ -0,0 +1,225 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from copy import deepcopy + +import torch +from tensordict import TensorDict + +from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.agents.rl.collector import SyncCollector +from embodichain.agents.rl.utils import flatten_dict_observation +from embodichain.lab.gym.envs.tasks.rl import build_env +from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES +from embodichain.lab.sim import SimulationManagerCfg, SimulationManager +from embodichain.utils.utility import load_json + + +class _FakePolicy: + def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + + def train(self) -> None: + pass + + def forward(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["observation"] + tensordict["action"] = obs[:, : self.action_dim] * 0.25 + tensordict["sample_log_prob"] = obs.sum(dim=-1) * 0.1 + tensordict["value"] = obs.mean(dim=-1) + return tensordict + + def get_value(self, tensordict: TensorDict) -> TensorDict: + tensordict["value"] = tensordict["observation"].mean(dim=-1) + return tensordict + + +class _FakeEnv: + def __init__(self, num_envs: int, obs_dim: int, action_dim: int, device: torch.device): + self.num_envs = num_envs + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + self.action_type = "delta_qpos" + self.rollout_buffer: TensorDict | None = None + self.current_rollout_step = 0 + self._obs = self._make_obs(step=0) + + def reset(self, **kwargs): + self.current_rollout_step = 0 + self._obs = self._make_obs(step=0) + return self._obs, {} + + def set_rollout_buffer(self, rollout_buffer: TensorDict) -> None: + self.rollout_buffer = rollout_buffer + self.current_rollout_step = 0 + + def step(self, action_dict): + action = action_dict[self.action_type] + step_idx = self.current_rollout_step + 1 + next_obs = self._make_obs(step=step_idx) + reward = action.sum(dim=-1) + terminated = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + truncated = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + + if self.rollout_buffer is not None: + self.rollout_buffer["next", "observation"][:, self.current_rollout_step] = ( + flatten_dict_observation(next_obs) + ) + self.rollout_buffer["next", "reward"][:, self.current_rollout_step] = reward + self.rollout_buffer["next", "done"][:, self.current_rollout_step] = ( + terminated | truncated + ) + self.rollout_buffer["next", "terminated"][:, self.current_rollout_step] = ( + terminated + ) + self.rollout_buffer["next", "truncated"][:, self.current_rollout_step] = ( + truncated + ) + self.current_rollout_step += 1 + + self._obs = next_obs + return next_obs, reward, terminated, truncated, {} + + def _make_obs(self, step: int) -> TensorDict: + base = torch.full( + (self.num_envs, self.obs_dim), + fill_value=float(step), + dtype=torch.float32, + device=self.device, + ) + return TensorDict( + { + "agent": TensorDict( + {"state": base}, + batch_size=[self.num_envs], + device=self.device, + ) + }, + batch_size=[self.num_envs], + device=self.device, + ) + + +def test_shared_rollout_collects_policy_and_env_fields(): + device = torch.device("cpu") + num_envs = 3 + rollout_len = 4 + obs_dim = 5 + action_dim = 2 + + env = _FakeEnv( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicy(obs_dim=obs_dim, action_dim=action_dim, device=device) + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + buffer.add(rollout) + stored = buffer.get(flatten=False) + + assert stored.batch_size == torch.Size([num_envs, rollout_len]) + assert torch.allclose(stored["observation"][:, 0], torch.zeros(num_envs, obs_dim)) + assert torch.allclose( + stored["next", "observation"][:, 0], + torch.ones(num_envs, obs_dim), + ) + assert torch.allclose( + stored["action"][:, 0], + torch.zeros(num_envs, action_dim), + ) + assert torch.allclose( + stored["sample_log_prob"][:, 1], + torch.full((num_envs,), 0.5, dtype=torch.float32), + ) + assert torch.allclose( + stored["next", "reward"][:, 2], + torch.full((num_envs,), 1.0, dtype=torch.float32), + ) + assert torch.allclose( + stored["next", "value"][:, -1], + torch.full((num_envs,), 4.0, dtype=torch.float32), + ) + + +def test_embodied_env_writes_next_fields_into_external_rollout(): + gym_config = load_json("configs/agents/rl/basic/cart_pole/gym_config.json") + env_cfg = config_to_cfg(gym_config, manager_modules=DEFAULT_MANAGER_MODULES) + env_cfg = deepcopy(env_cfg) + env_cfg.num_envs = 2 + env_cfg.sim_cfg = SimulationManagerCfg( + headless=True, + sim_device=torch.device("cpu"), + enable_rt=False, + gpu_id=0, + ) + + env = build_env(gym_config["id"], base_env_cfg=env_cfg) + try: + obs, _ = env.reset() + obs_dim = flatten_dict_observation(obs).shape[-1] + action_dim = env.action_space.shape[-1] + buffer = RolloutBuffer( + num_envs=env.num_envs, + rollout_len=4, + obs_dim=obs_dim, + action_dim=action_dim, + device=torch.device("cpu"), + ) + rollout = buffer.start_rollout() + env.set_rollout_buffer(rollout) + + action = torch.zeros( + env.num_envs, + action_dim, + dtype=torch.float32, + device=env.device, + ) + next_obs, reward, terminated, truncated, _ = env.step({"delta_qpos": action}) + next_obs_flat = flatten_dict_observation(next_obs).cpu() + done = (terminated | truncated).cpu() + + assert env.current_rollout_step == 1 + assert torch.allclose(rollout["next", "observation"][:, 0].cpu(), next_obs_flat) + assert torch.allclose(rollout["next", "reward"][:, 0].cpu(), reward.cpu()) + assert torch.equal(rollout["next", "done"][:, 0].cpu(), done) + assert torch.equal( + rollout["next", "terminated"][:, 0].cpu(), terminated.cpu() + ) + assert torch.equal( + rollout["next", "truncated"][:, 0].cpu(), truncated.cpu() + ) + finally: + env.close() + if SimulationManager.is_instantiated(): + SimulationManager.get_instance().destroy() From e22be664ea921601984136852a4d4cc61e9ee0f4 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 10:37:05 +0000 Subject: [PATCH 03/23] Reformate files --- tests/agents/test_shared_rollout.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 3ebf62cf..4f1d506a 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -52,7 +52,9 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: class _FakeEnv: - def __init__(self, num_envs: int, obs_dim: int, action_dim: int, device: torch.device): + def __init__( + self, num_envs: int, obs_dim: int, action_dim: int, device: torch.device + ): self.num_envs = num_envs self.obs_dim = obs_dim self.action_dim = action_dim @@ -87,12 +89,12 @@ def step(self, action_dict): self.rollout_buffer["next", "done"][:, self.current_rollout_step] = ( terminated | truncated ) - self.rollout_buffer["next", "terminated"][:, self.current_rollout_step] = ( - terminated - ) - self.rollout_buffer["next", "truncated"][:, self.current_rollout_step] = ( - truncated - ) + self.rollout_buffer["next", "terminated"][ + :, self.current_rollout_step + ] = terminated + self.rollout_buffer["next", "truncated"][ + :, self.current_rollout_step + ] = truncated self.current_rollout_step += 1 self._obs = next_obs @@ -213,12 +215,8 @@ def test_embodied_env_writes_next_fields_into_external_rollout(): assert torch.allclose(rollout["next", "observation"][:, 0].cpu(), next_obs_flat) assert torch.allclose(rollout["next", "reward"][:, 0].cpu(), reward.cpu()) assert torch.equal(rollout["next", "done"][:, 0].cpu(), done) - assert torch.equal( - rollout["next", "terminated"][:, 0].cpu(), terminated.cpu() - ) - assert torch.equal( - rollout["next", "truncated"][:, 0].cpu(), truncated.cpu() - ) + assert torch.equal(rollout["next", "terminated"][:, 0].cpu(), terminated.cpu()) + assert torch.equal(rollout["next", "truncated"][:, 0].cpu(), truncated.cpu()) finally: env.close() if SimulationManager.is_instantiated(): From 61d5910da6b14a783de600bd529e51d876282632 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 11:07:45 +0000 Subject: [PATCH 04/23] Update --- embodichain/agents/rl/algo/grpo.py | 10 +++--- embodichain/agents/rl/algo/ppo.py | 10 +++--- .../agents/rl/collector/sync_collector.py | 35 ++++++++++++++++--- embodichain/agents/rl/models/policy.py | 13 +++---- embodichain/lab/gym/envs/embodied_env.py | 1 - 5 files changed, 45 insertions(+), 24 deletions(-) diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 4cef287c..79d95f56 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -18,7 +18,7 @@ import math from copy import deepcopy -from typing import Dict +from typing import Dict, Iterator import torch from tensordict import TensorDict @@ -197,10 +197,8 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: def _iterate_minibatches( self, rollout: TensorDict, batch_size: int - ) -> list[TensorDict]: + ) -> Iterator[TensorDict]: total = rollout.batch_size[0] indices = torch.randperm(total, device=self.device) - return [ - rollout[indices[start : start + batch_size]] - for start in range(0, total, batch_size) - ] + for start in range(0, total, batch_size): + yield rollout[indices[start : start + batch_size]] diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index b8d787b6..199ee048 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -15,7 +15,7 @@ # ---------------------------------------------------------------------------- import math -from typing import Dict +from typing import Dict, Iterator import torch from tensordict import TensorDict @@ -109,10 +109,8 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: def _iterate_minibatches( self, rollout: TensorDict, batch_size: int - ) -> list[TensorDict]: + ) -> Iterator[TensorDict]: total = rollout.batch_size[0] indices = torch.randperm(total, device=self.device) - return [ - rollout[indices[start : start + batch_size]] - for start in range(0, total, batch_size) - ] + for start in range(0, total, batch_size): + yield rollout[indices[start : start + batch_size]] diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 3cf41931..1eb85466 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -41,6 +41,7 @@ def __init__( self.policy = policy self.device = device self.reset_every_rollout = reset_every_rollout + self._supports_shared_rollout = hasattr(self.env, "set_rollout_buffer") self.obs_td = self._reset_env() def collect( @@ -62,7 +63,7 @@ def collect( "Preallocated rollout batch size mismatch: " f"expected ({self.env.num_envs}, {num_steps}), got {tuple(rollout.batch_size)}." ) - if hasattr(self.env, "set_rollout_buffer"): + if self._supports_shared_rollout: self.env.set_rollout_buffer(rollout) for step_idx in range(num_steps): @@ -73,10 +74,6 @@ def collect( device=self.device, ) self.policy.forward(step_td) - rollout["observation"][:, step_idx] = obs_tensor - rollout["action"][:, step_idx] = step_td["action"] - rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] - rollout["value"][:, step_idx] = step_td["value"] next_obs, reward, terminated, truncated, env_info = self.env.step( self._to_action_dict(step_td["action"]) @@ -87,6 +84,15 @@ def collect( step_idx=step_idx, step_td=step_td, ) + if not self._supports_shared_rollout: + self._write_env_step( + rollout=rollout, + step_idx=step_idx, + next_obs_td=next_obs_td, + reward=reward, + terminated=terminated, + truncated=truncated, + ) if on_step_callback is not None: on_step_callback(rollout[:, step_idx], env_info) @@ -132,3 +138,22 @@ def _write_step( rollout["action"][:, step_idx] = step_td["action"] rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] rollout["value"][:, step_idx] = step_td["value"] + + def _write_env_step( + self, + rollout: TensorDict, + step_idx: int, + next_obs_td: TensorDict, + reward: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + ) -> None: + """Populate transition-side fields when the environment does not own the rollout.""" + done = terminated | truncated + rollout["next", "observation"][:, step_idx] = flatten_dict_observation( + next_obs_td + ) + rollout["next", "reward"][:, step_idx] = reward.to(self.device) + rollout["next", "done"][:, step_idx] = done.to(self.device) + rollout["next", "terminated"][:, step_idx] = terminated.to(self.device) + rollout["next", "truncated"][:, step_idx] = truncated.to(self.device) diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index 0e8faac3..6938268b 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -60,12 +60,13 @@ def get_action( - log_prob: Log probability of the action, shape (batch_size,) - value: Value estimate, shape (batch_size,) """ - td = TensorDict( - {"observation": obs}, - batch_size=[obs.shape[0]], - device=obs.device, - ) - td = self.forward(td, deterministic=deterministic) + with torch.no_grad(): + td = TensorDict( + {"observation": obs}, + batch_size=[obs.shape[0]], + device=obs.device, + ) + td = self.forward(td, deterministic=deterministic) return td["action"], td["sample_log_prob"], td["value"] @abstractmethod diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 69b3160a..ccd3bacd 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -418,7 +418,6 @@ def _hook_after_sim_step( ): # TODO: We may make the data collection customizable for rollout buffer. if self.rollout_buffer is not None: - buffer_device = self.rollout_buffer.device if self.current_rollout_step < self._max_rollout_steps: if self._rollout_buffer_mode == "external_rl": self._write_external_rl_rollout_step( From 5d9cb80f15fb8c732cc7a9415cc54b622990edb5 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 11:21:10 +0000 Subject: [PATCH 05/23] Update --- embodichain/agents/rl/collector/sync_collector.py | 1 + embodichain/lab/gym/envs/embodied_env.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 1eb85466..9554a2d6 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -44,6 +44,7 @@ def __init__( self._supports_shared_rollout = hasattr(self.env, "set_rollout_buffer") self.obs_td = self._reset_env() + @torch.no_grad() def collect( self, num_steps: int, diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index ccd3bacd..ad97e512 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -52,6 +52,7 @@ from embodichain.lab.gym.utils.gym_utils import ( init_rollout_buffer_from_gym_space, ) +from embodichain.agents.rl.utils import flatten_dict_observation from embodichain.utils import configclass, logger @@ -571,8 +572,6 @@ def _write_external_rl_rollout_step( truncateds: torch.Tensor | None, ) -> None: """Write environment-side fields into an externally managed RL rollout buffer.""" - from embodichain.agents.rl.utils import flatten_dict_observation - buffer_device = self.rollout_buffer.device obs_to_store = ( flatten_dict_observation(obs) if isinstance(obs, TensorDict) else obs From 86d42d2156bc42cefdb5e7bc7ce3fb0029fb7c69 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 14:11:58 +0000 Subject: [PATCH 06/23] Move compute_gae from rl.utils into rl.algo.common --- embodichain/agents/rl/algo/common.py | 72 +++++++++++++++++++++++++++ embodichain/agents/rl/algo/ppo.py | 3 +- embodichain/agents/rl/utils/helper.py | 51 ------------------- 3 files changed, 74 insertions(+), 52 deletions(-) create mode 100644 embodichain/agents/rl/algo/common.py diff --git a/embodichain/agents/rl/algo/common.py b/embodichain/agents/rl/algo/common.py new file mode 100644 index 00000000..eb3c5d17 --- /dev/null +++ b/embodichain/agents/rl/algo/common.py @@ -0,0 +1,72 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +__all__ = ["compute_gae"] + + +def compute_gae( + rollout: TensorDict, gamma: float, gae_lambda: float +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute GAE over a rollout with batch shape `[num_envs, time]`. + + Args: + rollout: Rollout TensorDict containing `value` and `next` transition data. + gamma: Discount factor. + gae_lambda: GAE lambda coefficient. + + Returns: + Tuple of `(advantages, returns)`, both shaped `[num_envs, time]`. + """ + rewards = rollout["next", "reward"].float() + dones = rollout["next", "done"].bool() + values = rollout["value"].float() + + if rewards.ndim != 2: + raise ValueError( + f"Expected reward tensor with shape [num_envs, time], got {rewards.shape}." + ) + + next_values = _get_next_values(rollout, values) + num_envs, time_dim = rewards.shape + advantages = torch.zeros_like(rewards) + last_advantage = torch.zeros(num_envs, device=rewards.device, dtype=rewards.dtype) + + for t in reversed(range(time_dim)): + not_done = (~dones[:, t]).float() + delta = rewards[:, t] + gamma * next_values[:, t] * not_done - values[:, t] + last_advantage = delta + gamma * gae_lambda * not_done * last_advantage + advantages[:, t] = last_advantage + + returns = advantages + values + rollout["advantage"] = advantages + rollout["return"] = returns + return advantages, returns + + +def _get_next_values(rollout: TensorDict, values: torch.Tensor) -> torch.Tensor: + """Resolve next-step values for GAE bootstrap.""" + next_value = rollout.get(("next", "value"), None) + if next_value is not None: + return next_value.float() + + next_values = torch.zeros_like(values) + next_values[:, :-1] = values[:, 1:] + return next_values diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 199ee048..f360f74b 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -20,8 +20,9 @@ import torch from tensordict import TensorDict -from embodichain.agents.rl.utils import AlgorithmCfg, compute_gae +from embodichain.agents.rl.utils import AlgorithmCfg from embodichain.utils import configclass +from .common import compute_gae from .base import BaseAlgorithm diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index aa38d6d8..e443a649 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -23,7 +23,6 @@ from tensordict import TensorDict __all__ = [ - "compute_gae", "dict_to_tensordict", "flatten_dict_observation", ] @@ -75,53 +74,3 @@ def dict_to_tensordict( f"Expected observation mapping or TensorDict, got {type(obs_dict)!r}." ) return TensorDict.from_dict(dict(obs_dict), device=device) - - -def compute_gae( - rollout: TensorDict, gamma: float, gae_lambda: float -) -> tuple[torch.Tensor, torch.Tensor]: - """Compute GAE over a rollout with batch shape `[num_envs, time]`. - - Args: - rollout: Rollout TensorDict containing `value` and `next` transition data. - gamma: Discount factor. - gae_lambda: GAE lambda coefficient. - - Returns: - Tuple of `(advantages, returns)`, both shaped `[num_envs, time]`. - """ - rewards = rollout["next", "reward"].float() - dones = rollout["next", "done"].bool() - values = rollout["value"].float() - - if rewards.ndim != 2: - raise ValueError( - f"Expected reward tensor with shape [num_envs, time], got {rewards.shape}." - ) - - next_values = _get_next_values(rollout, values) - num_envs, time_dim = rewards.shape - advantages = torch.zeros_like(rewards) - last_advantage = torch.zeros(num_envs, device=rewards.device, dtype=rewards.dtype) - - for t in reversed(range(time_dim)): - not_done = (~dones[:, t]).float() - delta = rewards[:, t] + gamma * next_values[:, t] * not_done - values[:, t] - last_advantage = delta + gamma * gae_lambda * not_done * last_advantage - advantages[:, t] = last_advantage - - returns = advantages + values - rollout["advantage"] = advantages - rollout["return"] = returns - return advantages, returns - - -def _get_next_values(rollout: TensorDict, values: torch.Tensor) -> torch.Tensor: - """Resolve next-step values for GAE bootstrap.""" - next_value = rollout.get(("next", "value"), None) - if next_value is not None: - return next_value.float() - - next_values = torch.zeros_like(values) - next_values[:, :-1] = values[:, 1:] - return next_values From 32d0b0971f5c478fafe022ba2afa0bbe887d7435 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 14:15:41 +0000 Subject: [PATCH 07/23] Remove rollout_buffer.py --- .../agents/rl/buffer/rollout_buffer.py | 21 ------------------- 1 file changed, 21 deletions(-) delete mode 100644 embodichain/agents/rl/buffer/rollout_buffer.py diff --git a/embodichain/agents/rl/buffer/rollout_buffer.py b/embodichain/agents/rl/buffer/rollout_buffer.py deleted file mode 100644 index e2854261..00000000 --- a/embodichain/agents/rl/buffer/rollout_buffer.py +++ /dev/null @@ -1,21 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from .standard_buffer import RolloutBuffer - -__all__ = ["RolloutBuffer"] From 76c7569b7ed6224ed5dfa67f42be5e92a1da0a2a Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 14:17:21 +0000 Subject: [PATCH 08/23] Get dimension info from env directly --- .../rl/basic/cart_pole/train_config.json | 1 - .../rl/basic/cart_pole/train_config_grpo.json | 1 - configs/agents/rl/push_cube/train_config.json | 1 - embodichain/agents/rl/train.py | 18 +++++++++--------- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/configs/agents/rl/basic/cart_pole/train_config.json b/configs/agents/rl/basic/cart_pole/train_config.json index f4e99372..02a302d1 100644 --- a/configs/agents/rl/basic/cart_pole/train_config.json +++ b/configs/agents/rl/basic/cart_pole/train_config.json @@ -35,7 +35,6 @@ }, "policy": { "name": "actor_critic", - "action_dim": 2, "actor": { "type": "mlp", "network_cfg": { diff --git a/configs/agents/rl/basic/cart_pole/train_config_grpo.json b/configs/agents/rl/basic/cart_pole/train_config_grpo.json index 1caf6b0d..4da5cab7 100644 --- a/configs/agents/rl/basic/cart_pole/train_config_grpo.json +++ b/configs/agents/rl/basic/cart_pole/train_config_grpo.json @@ -36,7 +36,6 @@ }, "policy": { "name": "actor_only", - "action_dim": 2, "actor": { "type": "mlp", "network_cfg": { diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json index c79776b6..d44aa0b3 100644 --- a/configs/agents/rl/push_cube/train_config.json +++ b/configs/agents/rl/push_cube/train_config.json @@ -38,7 +38,6 @@ }, "policy": { "name": "actor_critic", - "action_dim": 8, "actor": { "type": "mlp", "network_cfg": { diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index 62418ad0..a11192f1 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -184,11 +184,9 @@ def train_from_config(config_path: str): # Build Policy via registry policy_name = policy_block["name"] - action_dim = policy_block.get("action_dim") - if action_dim is None: - raise ValueError("Policy config must define 'action_dim'.") - action_dim = int(action_dim) env_action_dim = env.action_space.shape[-1] + action_dim = policy_block.get("action_dim", env_action_dim) + action_dim = int(action_dim) if action_dim != env_action_dim: raise ValueError( f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}." @@ -207,8 +205,8 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - obs_dim, - action_dim, + env.observation_space, + env.action_space, device, actor=actor, critic=critic, @@ -224,13 +222,15 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - obs_dim, - action_dim, + env.observation_space, + env.action_space, device, actor=actor, ) else: - policy = build_policy(policy_block, obs_dim, action_dim, device) + policy = build_policy( + policy_block, env.observation_space, env.action_space, device + ) # Build Algorithm via factory algo_name = algo_block["name"].lower() From 73239b614f4d0e3717fe9854160bf0c9cba27643 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 14:56:33 +0000 Subject: [PATCH 09/23] Move compute_gae from rl.utils into rl.algo.common --- embodichain/agents/rl/algo/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/embodichain/agents/rl/algo/__init__.py b/embodichain/agents/rl/algo/__init__.py index 4b69879a..0aefde03 100644 --- a/embodichain/agents/rl/algo/__init__.py +++ b/embodichain/agents/rl/algo/__init__.py @@ -20,6 +20,7 @@ import torch from .base import BaseAlgorithm +from .common import compute_gae from .ppo import PPOCfg, PPO from .grpo import GRPOCfg, GRPO @@ -51,6 +52,7 @@ def build_algo(name: str, cfg_kwargs: Dict[str, float], policy, device: torch.de "PPO", "GRPOCfg", "GRPO", + "compute_gae", "get_registered_algo_names", "build_algo", ] From f3b666dd95f11d92328a89ff37cba2173557a337 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 14:58:57 +0000 Subject: [PATCH 10/23] Get info from env --- embodichain/agents/rl/models/__init__.py | 49 +++++++++++++++++++++--- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index ccbe7d92..51cf7653 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -16,7 +16,10 @@ from __future__ import annotations +import inspect from typing import Dict, Type + +from gymnasium import spaces import torch from .actor_critic import ActorCritic @@ -42,15 +45,30 @@ def get_policy_class(name: str) -> Type[Policy] | None: return _POLICY_REGISTRY.get(name) +def _resolve_space_dim(space_or_dim: spaces.Space | int, name: str) -> int: + """Resolve a flattened feature dimension from an integer or simple Box space.""" + if isinstance(space_or_dim, int): + return space_or_dim + if isinstance(space_or_dim, spaces.Box) and len(space_or_dim.shape) > 0: + return int(space_or_dim.shape[-1]) + raise TypeError( + f"{name} must be an int or a flat Box space for MLP-based policies, got {type(space_or_dim)!r}." + ) + + def build_policy( policy_block: dict, - obs_dim: int, - action_dim: int, + obs_space: spaces.Space | int, + action_space: spaces.Space | int, device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, ) -> Policy: - """Build policy strictly from json-like block: { name: ..., cfg: {...} }""" + """Build a policy from config using spaces for extensibility. + + Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while + custom policies may accept richer `obs_space` / `action_space` inputs. + """ name = policy_block["name"].lower() if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) @@ -58,11 +76,14 @@ def build_policy( f"Policy '{name}' is not registered. Available policies: {available}" ) policy_cls = _POLICY_REGISTRY[name] + if name == "actor_critic": if actor is None or critic is None: raise ValueError( "ActorCritic policy requires external 'actor' and 'critic' modules." ) + obs_dim = _resolve_space_dim(obs_space, "obs_space") + action_dim = _resolve_space_dim(action_space, "action_space") return policy_cls( obs_dim=obs_dim, action_dim=action_dim, @@ -73,14 +94,32 @@ def build_policy( elif name == "actor_only": if actor is None: raise ValueError("ActorOnly policy requires external 'actor' module.") + obs_dim = _resolve_space_dim(obs_space, "obs_space") + action_dim = _resolve_space_dim(action_space, "action_space") return policy_cls( obs_dim=obs_dim, action_dim=action_dim, device=device, actor=actor, ) - else: - return policy_cls(obs_dim=obs_dim, action_dim=action_dim, device=device) + + init_params = inspect.signature(policy_cls.__init__).parameters + build_kwargs: dict[str, object] = {"device": device} + if "obs_space" in init_params: + build_kwargs["obs_space"] = obs_space + elif "obs_dim" in init_params: + build_kwargs["obs_dim"] = _resolve_space_dim(obs_space, "obs_space") + + if "action_space" in init_params: + build_kwargs["action_space"] = action_space + elif "action_dim" in init_params: + build_kwargs["action_dim"] = _resolve_space_dim(action_space, "action_space") + + if "actor" in init_params and actor is not None: + build_kwargs["actor"] = actor + if "critic" in init_params and critic is not None: + build_kwargs["critic"] = critic + return policy_cls(**build_kwargs) def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: From ab41911155667bcabd3eb2026b7d6ffbcb0c8e7b Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 15:04:23 +0000 Subject: [PATCH 11/23] Move compute_gae from rl.utils into rl.algo.common --- embodichain/agents/rl/utils/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/embodichain/agents/rl/utils/__init__.py b/embodichain/agents/rl/utils/__init__.py index bd6dbc4f..30f59136 100644 --- a/embodichain/agents/rl/utils/__init__.py +++ b/embodichain/agents/rl/utils/__init__.py @@ -15,11 +15,10 @@ # ---------------------------------------------------------------------------- from .config import AlgorithmCfg -from .helper import compute_gae, dict_to_tensordict, flatten_dict_observation +from .helper import dict_to_tensordict, flatten_dict_observation __all__ = [ "AlgorithmCfg", - "compute_gae", "dict_to_tensordict", "flatten_dict_observation", ] From 5e815c5391117be10326bdc28f4138e271f5c6dc Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 15:08:03 +0000 Subject: [PATCH 12/23] Use expert and rl for variable --- embodichain/lab/gym/envs/embodied_env.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index ad97e512..8e1e0e64 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -275,7 +275,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): device=self.device, ) self._max_rollout_steps = self.rollout_buffer.shape[1] - self._rollout_buffer_mode = "episode" + self._rollout_buffer_mode = "expert" self.current_rollout_step = 0 @@ -420,8 +420,8 @@ def _hook_after_sim_step( # TODO: We may make the data collection customizable for rollout buffer. if self.rollout_buffer is not None: if self.current_rollout_step < self._max_rollout_steps: - if self._rollout_buffer_mode == "external_rl": - self._write_external_rl_rollout_step( + if self._rollout_buffer_mode == "rl": + self._write_rl_rollout_step( obs=obs, rewards=rewards, dones=dones, @@ -507,10 +507,7 @@ def _initialize_episode( ) # Clear episode buffers and reset success status for environments being reset - if ( - self.rollout_buffer is not None - and self._rollout_buffer_mode != "external_rl" - ): + if self.rollout_buffer is not None and self._rollout_buffer_mode != "rl": self.current_rollout_step = 0 self.episode_success_status[env_ids_to_process] = False @@ -525,10 +522,10 @@ def _initialize_episode( self.reward_manager.reset(env_ids=env_ids) def _infer_rollout_buffer_mode(self, rollout_buffer: TensorDict) -> str: - """Infer whether the rollout buffer is env-owned episode data or external RL data.""" + """Infer whether the rollout buffer is expert recording or RL training data.""" if "next" in rollout_buffer.keys() and "observation" in rollout_buffer.keys(): - return "external_rl" - return "episode" + return "rl" + return "expert" def _write_episode_rollout_step( self, @@ -563,7 +560,7 @@ def _write_episode_rollout_step( rewards.to(buffer_device), non_blocking=True ) - def _write_external_rl_rollout_step( + def _write_rl_rollout_step( self, obs: EnvObs, rewards: torch.Tensor, @@ -585,7 +582,6 @@ def _write_external_rl_rollout_step( self.rollout_buffer["next", "done"][:, self.current_rollout_step].copy_( dones.to(buffer_device), non_blocking=True ) - terminateds = ( terminateds if terminateds is not None From 4a661b8c76499162a228a083229ef4a3c745682d Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 15:17:27 +0000 Subject: [PATCH 13/23] Change observation to obs --- embodichain/agents/rl/buffer/standard_buffer.py | 4 ++-- embodichain/agents/rl/collector/sync_collector.py | 10 ++++------ embodichain/agents/rl/models/actor_critic.py | 6 +++--- embodichain/agents/rl/models/actor_only.py | 6 +++--- embodichain/agents/rl/models/policy.py | 6 +++--- embodichain/lab/gym/envs/embodied_env.py | 4 ++-- tests/agents/test_shared_rollout.py | 12 ++++++------ 7 files changed, 23 insertions(+), 25 deletions(-) diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index 9196b893..ec706c14 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -85,7 +85,7 @@ def _allocate_rollout(self) -> TensorDict: """Preallocate rollout storage with batch shape `[num_envs, time]`.""" return TensorDict( { - "observation": torch.empty( + "obs": torch.empty( self.num_envs, self.rollout_len, self.obs_dim, @@ -112,7 +112,7 @@ def _allocate_rollout(self) -> TensorDict: device=self.device, ), "next": { - "observation": torch.empty( + "obs": torch.empty( self.num_envs, self.rollout_len, self.obs_dim, diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 9554a2d6..0adcfb8d 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -70,7 +70,7 @@ def collect( for step_idx in range(num_steps): obs_tensor = flatten_dict_observation(self.obs_td) step_td = TensorDict( - {"observation": obs_tensor}, + {"obs": obs_tensor}, batch_size=[obs_tensor.shape[0]], device=self.device, ) @@ -109,7 +109,7 @@ def _attach_next_values(self, rollout: TensorDict) -> None: next_values[:, :-1] = rollout["value"][:, 1:] last_next_td = TensorDict( - {"observation": rollout["next", "observation"][:, -1]}, + {"obs": rollout["next", "obs"][:, -1]}, batch_size=[rollout.batch_size[0]], device=self.device, ) @@ -135,7 +135,7 @@ def _write_step( step_td: TensorDict, ) -> None: """Write policy-side fields for one transition into the shared rollout TensorDict.""" - rollout["observation"][:, step_idx] = step_td["observation"] + rollout["obs"][:, step_idx] = step_td["obs"] rollout["action"][:, step_idx] = step_td["action"] rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] rollout["value"][:, step_idx] = step_td["value"] @@ -151,9 +151,7 @@ def _write_env_step( ) -> None: """Populate transition-side fields when the environment does not own the rollout.""" done = terminated | truncated - rollout["next", "observation"][:, step_idx] = flatten_dict_observation( - next_obs_td - ) + rollout["next", "obs"][:, step_idx] = flatten_dict_observation(next_obs_td) rollout["next", "reward"][:, step_idx] = reward.to(self.device) rollout["next", "done"][:, step_idx] = done.to(self.device) rollout["next", "terminated"][:, step_idx] = terminated.to(self.device) diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index f2002e04..7ee73927 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -73,7 +73,7 @@ def _distribution(self, obs: torch.Tensor) -> Normal: def forward( self, tensordict: TensorDict, deterministic: bool = False ) -> TensorDict: - obs = tensordict["observation"] + obs = tensordict["obs"] dist = self._distribution(obs) mean = dist.mean action = mean if deterministic else dist.sample() @@ -83,11 +83,11 @@ def forward( return tensordict def get_value(self, tensordict: TensorDict) -> TensorDict: - tensordict["value"] = self.critic(tensordict["observation"]).squeeze(-1) + tensordict["value"] = self.critic(tensordict["obs"]).squeeze(-1) return tensordict def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: - obs = tensordict["observation"] + obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py index e80109c4..8a905477 100644 --- a/embodichain/agents/rl/models/actor_only.py +++ b/embodichain/agents/rl/models/actor_only.py @@ -59,7 +59,7 @@ def _distribution(self, obs: torch.Tensor) -> Normal: def forward( self, tensordict: TensorDict, deterministic: bool = False ) -> TensorDict: - obs = tensordict["observation"] + obs = tensordict["obs"] dist = self._distribution(obs) mean = dist.mean action = mean if deterministic else dist.sample() @@ -71,14 +71,14 @@ def forward( return tensordict def get_value(self, tensordict: TensorDict) -> TensorDict: - obs = tensordict["observation"] + obs = tensordict["obs"] tensordict["value"] = torch.zeros( obs.shape[0], device=self.device, dtype=obs.dtype ) return tensordict def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: - obs = tensordict["observation"] + obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index 6938268b..58dfda9a 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -62,7 +62,7 @@ def get_action( """ with torch.no_grad(): td = TensorDict( - {"observation": obs}, + {"obs": obs}, batch_size=[obs.shape[0]], device=obs.device, ) @@ -81,7 +81,7 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: """Write value estimate for the given observations into the TensorDict. Args: - tensordict: Input TensorDict containing `observation`. + tensordict: Input TensorDict containing `obs`. Returns: TensorDict with `value` populated. @@ -93,7 +93,7 @@ def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: """Evaluate actions and write log prob, entropy, and values. Args: - tensordict: TensorDict containing `observation` and `action`. + tensordict: TensorDict containing `obs` and `action`. Returns: TensorDict with `sample_log_prob`, `entropy`, and `value` populated. diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 8e1e0e64..613d4bd9 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -523,7 +523,7 @@ def _initialize_episode( def _infer_rollout_buffer_mode(self, rollout_buffer: TensorDict) -> str: """Infer whether the rollout buffer is expert recording or RL training data.""" - if "next" in rollout_buffer.keys() and "observation" in rollout_buffer.keys(): + if "next" in rollout_buffer.keys() and "obs" in rollout_buffer.keys(): return "rl" return "expert" @@ -573,7 +573,7 @@ def _write_rl_rollout_step( obs_to_store = ( flatten_dict_observation(obs) if isinstance(obs, TensorDict) else obs ) - self.rollout_buffer["next", "observation"][:, self.current_rollout_step].copy_( + self.rollout_buffer["next", "obs"][:, self.current_rollout_step].copy_( obs_to_store.to(buffer_device), non_blocking=True ) self.rollout_buffer["next", "reward"][:, self.current_rollout_step].copy_( diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 4f1d506a..67e8d64e 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -40,14 +40,14 @@ def train(self) -> None: pass def forward(self, tensordict: TensorDict) -> TensorDict: - obs = tensordict["observation"] + obs = tensordict["obs"] tensordict["action"] = obs[:, : self.action_dim] * 0.25 tensordict["sample_log_prob"] = obs.sum(dim=-1) * 0.1 tensordict["value"] = obs.mean(dim=-1) return tensordict def get_value(self, tensordict: TensorDict) -> TensorDict: - tensordict["value"] = tensordict["observation"].mean(dim=-1) + tensordict["value"] = tensordict["obs"].mean(dim=-1) return tensordict @@ -82,7 +82,7 @@ def step(self, action_dict): truncated = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) if self.rollout_buffer is not None: - self.rollout_buffer["next", "observation"][:, self.current_rollout_step] = ( + self.rollout_buffer["next", "obs"][:, self.current_rollout_step] = ( flatten_dict_observation(next_obs) ) self.rollout_buffer["next", "reward"][:, self.current_rollout_step] = reward @@ -151,9 +151,9 @@ def test_shared_rollout_collects_policy_and_env_fields(): stored = buffer.get(flatten=False) assert stored.batch_size == torch.Size([num_envs, rollout_len]) - assert torch.allclose(stored["observation"][:, 0], torch.zeros(num_envs, obs_dim)) + assert torch.allclose(stored["obs"][:, 0], torch.zeros(num_envs, obs_dim)) assert torch.allclose( - stored["next", "observation"][:, 0], + stored["next", "obs"][:, 0], torch.ones(num_envs, obs_dim), ) assert torch.allclose( @@ -212,7 +212,7 @@ def test_embodied_env_writes_next_fields_into_external_rollout(): done = (terminated | truncated).cpu() assert env.current_rollout_step == 1 - assert torch.allclose(rollout["next", "observation"][:, 0].cpu(), next_obs_flat) + assert torch.allclose(rollout["next", "obs"][:, 0].cpu(), next_obs_flat) assert torch.allclose(rollout["next", "reward"][:, 0].cpu(), reward.cpu()) assert torch.equal(rollout["next", "done"][:, 0].cpu(), done) assert torch.equal(rollout["next", "terminated"][:, 0].cpu(), terminated.cpu()) From 55bf83b8753acf6a0954baf984c49449878df5ef Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 15:44:32 +0000 Subject: [PATCH 14/23] Fix: flatten Dict observation space --- embodichain/agents/rl/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index a11192f1..bf08c746 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -169,6 +169,7 @@ def train_from_config(config_path: str): sample_obs, _ = env.reset() sample_obs_td = dict_to_tensordict(sample_obs, device) obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] + flat_obs_space = env.flattened_observation_space # Create evaluation environment only if enabled eval_env = None @@ -205,7 +206,7 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.observation_space, + flat_obs_space, env.action_space, device, actor=actor, @@ -222,7 +223,7 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.observation_space, + flat_obs_space, env.action_space, device, actor=actor, From b72c2e6f510595eab6292bbae66c8d9a15a9eef4 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 16:12:52 +0000 Subject: [PATCH 15/23] Update evaluate_action to non inplace --- embodichain/agents/rl/algo/grpo.py | 4 ++-- embodichain/agents/rl/algo/ppo.py | 2 +- embodichain/agents/rl/models/actor_critic.py | 13 +++++++++---- embodichain/agents/rl/models/actor_only.py | 13 ++++++++----- embodichain/agents/rl/models/policy.py | 4 ++-- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 79d95f56..9ce5d17b 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -142,7 +142,7 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: advantages = batch["advantage"].detach() seq_mask_batch = batch["seq_mask"].float() - eval_batch = self.policy.evaluate_actions(batch.clone()) + eval_batch = self.policy.evaluate_actions(batch) logprobs = eval_batch["sample_log_prob"] entropy = eval_batch["entropy"] ratio = (logprobs - old_logprobs).exp() @@ -161,7 +161,7 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: if self.ref_policy is not None: with torch.no_grad(): - ref_batch = self.ref_policy.evaluate_actions(batch.clone()) + ref_batch = self.ref_policy.evaluate_actions(batch) ref_logprobs = ref_batch["sample_log_prob"] log_ref_over_pi = ref_logprobs - logprobs kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0 diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index f360f74b..ae34a284 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -67,7 +67,7 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: returns = batch["return"].clone() batch_advantages = ((batch["advantage"] - adv_mean) / adv_std).detach() - eval_batch = self.policy.evaluate_actions(batch.clone()) + eval_batch = self.policy.evaluate_actions(batch) logprobs = eval_batch["sample_log_prob"] entropy = eval_batch["entropy"] values = eval_batch["value"] diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 7ee73927..32caf0e3 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -90,7 +90,12 @@ def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) - tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) - tensordict["entropy"] = dist.entropy().sum(dim=-1) - tensordict["value"] = self.critic(obs).squeeze(-1) - return tensordict + return TensorDict( + { + "sample_log_prob": dist.log_prob(action).sum(dim=-1), + "entropy": dist.entropy().sum(dim=-1), + "value": self.critic(obs).squeeze(-1), + }, + batch_size=tensordict.batch_size, + device=tensordict.device, + ) diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py index 8a905477..3d6d1f78 100644 --- a/embodichain/agents/rl/models/actor_only.py +++ b/embodichain/agents/rl/models/actor_only.py @@ -81,9 +81,12 @@ def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: obs = tensordict["obs"] action = tensordict["action"] dist = self._distribution(obs) - tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) - tensordict["entropy"] = dist.entropy().sum(dim=-1) - tensordict["value"] = torch.zeros( - obs.shape[0], device=self.device, dtype=obs.dtype + return TensorDict( + { + "sample_log_prob": dist.log_prob(action).sum(dim=-1), + "entropy": dist.entropy().sum(dim=-1), + "value": torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype), + }, + batch_size=tensordict.batch_size, + device=tensordict.device, ) - return tensordict diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index 58dfda9a..61741021 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -90,12 +90,12 @@ def get_value(self, tensordict: TensorDict) -> TensorDict: @abstractmethod def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: - """Evaluate actions and write log prob, entropy, and values. + """Evaluate actions and return current policy outputs. Args: tensordict: TensorDict containing `obs` and `action`. Returns: - TensorDict with `sample_log_prob`, `entropy`, and `value` populated. + A new TensorDict containing `sample_log_prob`, `entropy`, and `value`. """ raise NotImplementedError From 3bf84a66fd5278bca3bd5c56dd1680668bb52a58 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 16:17:57 +0000 Subject: [PATCH 16/23] Move iterate_minibatches to buffer/utils.py --- embodichain/agents/rl/algo/grpo.py | 15 ++++------- embodichain/agents/rl/algo/ppo.py | 15 ++++------- embodichain/agents/rl/buffer/__init__.py | 3 ++- embodichain/agents/rl/buffer/utils.py | 34 ++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 21 deletions(-) create mode 100644 embodichain/agents/rl/buffer/utils.py diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 9ce5d17b..d8d1a3ef 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -18,11 +18,12 @@ import math from copy import deepcopy -from typing import Dict, Iterator +from typing import Dict import torch from tensordict import TensorDict +from embodichain.agents.rl.buffer import iterate_minibatches from embodichain.agents.rl.utils import AlgorithmCfg from embodichain.utils import configclass from .base import BaseAlgorithm @@ -137,7 +138,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: total_weight = 0.0 for _ in range(self.cfg.n_epochs): - for batch in self._iterate_minibatches(flat_rollout, self.cfg.batch_size): + for batch in iterate_minibatches( + flat_rollout, self.cfg.batch_size, self.device + ): old_logprobs = batch["sample_log_prob"].clone() advantages = batch["advantage"].detach() seq_mask_batch = batch["seq_mask"].float() @@ -194,11 +197,3 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: "entropy": total_entropy / max(1.0, total_weight), "approx_ref_kl": total_kl / max(1.0, total_weight), } - - def _iterate_minibatches( - self, rollout: TensorDict, batch_size: int - ) -> Iterator[TensorDict]: - total = rollout.batch_size[0] - indices = torch.randperm(total, device=self.device) - for start in range(0, total, batch_size): - yield rollout[indices[start : start + batch_size]] diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index ae34a284..6036d3b5 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -15,11 +15,12 @@ # ---------------------------------------------------------------------------- import math -from typing import Dict, Iterator +from typing import Dict import torch from tensordict import TensorDict +from embodichain.agents.rl.buffer import iterate_minibatches from embodichain.agents.rl.utils import AlgorithmCfg from embodichain.utils import configclass from .common import compute_gae @@ -62,7 +63,9 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: total_steps = 0 for _ in range(self.cfg.n_epochs): - for batch in self._iterate_minibatches(flat_rollout, self.cfg.batch_size): + for batch in iterate_minibatches( + flat_rollout, self.cfg.batch_size, self.device + ): old_logprobs = batch["sample_log_prob"].clone() returns = batch["return"].clone() batch_advantages = ((batch["advantage"] - adv_mean) / adv_std).detach() @@ -107,11 +110,3 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: "value_loss": total_value_loss / max(1, total_steps), "entropy": total_entropy / max(1, total_steps), } - - def _iterate_minibatches( - self, rollout: TensorDict, batch_size: int - ) -> Iterator[TensorDict]: - total = rollout.batch_size[0] - indices = torch.randperm(total, device=self.device) - for start in range(0, total, batch_size): - yield rollout[indices[start : start + batch_size]] diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py index d90b2a06..db0dd4dd 100644 --- a/embodichain/agents/rl/buffer/__init__.py +++ b/embodichain/agents/rl/buffer/__init__.py @@ -15,5 +15,6 @@ # ---------------------------------------------------------------------------- from .standard_buffer import RolloutBuffer +from .utils import iterate_minibatches -__all__ = ["RolloutBuffer"] +__all__ = ["RolloutBuffer", "iterate_minibatches"] diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py new file mode 100644 index 00000000..00b62a0c --- /dev/null +++ b/embodichain/agents/rl/buffer/utils.py @@ -0,0 +1,34 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections.abc import Iterator + +import torch +from tensordict import TensorDict + +__all__ = ["iterate_minibatches"] + + +def iterate_minibatches( + rollout: TensorDict, batch_size: int, device: torch.device +) -> Iterator[TensorDict]: + """Yield shuffled minibatches from a flattened rollout.""" + total = rollout.batch_size[0] + indices = torch.randperm(total, device=device) + for start in range(0, total, batch_size): + yield rollout[indices[start : start + batch_size]] From 8473e2aecfea97873ce0f3e6ff91c9476e44d73b Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 16:26:39 +0000 Subject: [PATCH 17/23] Return tensordict in get_action and uses get_action instead of calling forward() directly --- .../agents/rl/collector/sync_collector.py | 2 +- embodichain/agents/rl/models/policy.py | 21 ++++++------------- embodichain/agents/rl/utils/trainer.py | 8 ++++++- tests/agents/test_shared_rollout.py | 5 +++++ 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 0adcfb8d..704c8f6a 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -74,7 +74,7 @@ def collect( batch_size=[obs_tensor.shape[0]], device=self.device, ) - self.policy.forward(step_td) + step_td = self.policy.get_action(step_td) next_obs, reward, terminated, truncated, env_info = self.env.step( self._to_action_dict(step_td["action"]) diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index 61741021..dd71faa5 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -46,28 +46,19 @@ def __init__(self) -> None: super().__init__() def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Compatibility layer for tensor-only callers. + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + """Sample actions into the provided TensorDict without gradients. Args: - obs: Observation tensor of shape (batch_size, obs_dim) + tensordict: Input TensorDict containing `obs`. deterministic: If True, return the mean action; otherwise sample Returns: - Tuple of (action, log_prob, value): - - action: Sampled action tensor of shape (batch_size, action_dim) - - log_prob: Log probability of the action, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + TensorDict with `action`, `sample_log_prob`, and `value` populated. """ with torch.no_grad(): - td = TensorDict( - {"obs": obs}, - batch_size=[obs.shape[0]], - device=obs.device, - ) - td = self.forward(td, deterministic=deterministic) - return td["action"], td["sample_log_prob"], td["value"] + return self.forward(tensordict, deterministic=deterministic) @abstractmethod def forward( diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 98739af5..6dfb14b6 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -260,7 +260,13 @@ def _eval_once(self, num_episodes: int = 5): # Run episode until all environments complete while not done_mask.all(): # Get deterministic actions from policy - actions, _, _ = self.policy.get_action(obs, deterministic=True) + action_td = TensorDict( + {"obs": obs}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + actions = action_td["action"] am = getattr(self.eval_env, "action_manager", None) action_type = ( am.action_type diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 67e8d64e..31269325 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -46,6 +46,11 @@ def forward(self, tensordict: TensorDict) -> TensorDict: tensordict["value"] = obs.mean(dim=-1) return tensordict + def get_action( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + return self.forward(tensordict) + def get_value(self, tensordict: TensorDict) -> TensorDict: tensordict["value"] = tensordict["obs"].mean(dim=-1) return tensordict From 901bee641d12b521edb507649699b20aefc2068c Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 16:35:10 +0000 Subject: [PATCH 18/23] Remove extra next.obs buffer --- embodichain/agents/rl/buffer/standard_buffer.py | 7 ------- embodichain/agents/rl/collector/sync_collector.py | 6 ++---- embodichain/lab/gym/envs/embodied_env.py | 6 ------ tests/agents/test_shared_rollout.py | 8 +------- 4 files changed, 3 insertions(+), 24 deletions(-) diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index ec706c14..8a56d5f3 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -112,13 +112,6 @@ def _allocate_rollout(self) -> TensorDict: device=self.device, ), "next": { - "obs": torch.empty( - self.num_envs, - self.rollout_len, - self.obs_dim, - dtype=torch.float32, - device=self.device, - ), "reward": torch.empty( self.num_envs, self.rollout_len, diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 704c8f6a..22448d05 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -89,7 +89,6 @@ def collect( self._write_env_step( rollout=rollout, step_idx=step_idx, - next_obs_td=next_obs_td, reward=reward, terminated=terminated, truncated=truncated, @@ -108,8 +107,9 @@ def _attach_next_values(self, rollout: TensorDict) -> None: next_values = torch.zeros_like(rollout["value"]) next_values[:, :-1] = rollout["value"][:, 1:] + final_obs = flatten_dict_observation(self.obs_td) last_next_td = TensorDict( - {"obs": rollout["next", "obs"][:, -1]}, + {"obs": final_obs}, batch_size=[rollout.batch_size[0]], device=self.device, ) @@ -144,14 +144,12 @@ def _write_env_step( self, rollout: TensorDict, step_idx: int, - next_obs_td: TensorDict, reward: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, ) -> None: """Populate transition-side fields when the environment does not own the rollout.""" done = terminated | truncated - rollout["next", "obs"][:, step_idx] = flatten_dict_observation(next_obs_td) rollout["next", "reward"][:, step_idx] = reward.to(self.device) rollout["next", "done"][:, step_idx] = done.to(self.device) rollout["next", "terminated"][:, step_idx] = terminated.to(self.device) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 613d4bd9..995234d7 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -570,12 +570,6 @@ def _write_rl_rollout_step( ) -> None: """Write environment-side fields into an externally managed RL rollout buffer.""" buffer_device = self.rollout_buffer.device - obs_to_store = ( - flatten_dict_observation(obs) if isinstance(obs, TensorDict) else obs - ) - self.rollout_buffer["next", "obs"][:, self.current_rollout_step].copy_( - obs_to_store.to(buffer_device), non_blocking=True - ) self.rollout_buffer["next", "reward"][:, self.current_rollout_step].copy_( rewards.to(buffer_device), non_blocking=True ) diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 31269325..847bc55a 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -87,9 +87,6 @@ def step(self, action_dict): truncated = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) if self.rollout_buffer is not None: - self.rollout_buffer["next", "obs"][:, self.current_rollout_step] = ( - flatten_dict_observation(next_obs) - ) self.rollout_buffer["next", "reward"][:, self.current_rollout_step] = reward self.rollout_buffer["next", "done"][:, self.current_rollout_step] = ( terminated | truncated @@ -158,8 +155,7 @@ def test_shared_rollout_collects_policy_and_env_fields(): assert stored.batch_size == torch.Size([num_envs, rollout_len]) assert torch.allclose(stored["obs"][:, 0], torch.zeros(num_envs, obs_dim)) assert torch.allclose( - stored["next", "obs"][:, 0], - torch.ones(num_envs, obs_dim), + stored["value"][:, 1], torch.ones(num_envs, dtype=torch.float32) ) assert torch.allclose( stored["action"][:, 0], @@ -213,11 +209,9 @@ def test_embodied_env_writes_next_fields_into_external_rollout(): device=env.device, ) next_obs, reward, terminated, truncated, _ = env.step({"delta_qpos": action}) - next_obs_flat = flatten_dict_observation(next_obs).cpu() done = (terminated | truncated).cpu() assert env.current_rollout_step == 1 - assert torch.allclose(rollout["next", "obs"][:, 0].cpu(), next_obs_flat) assert torch.allclose(rollout["next", "reward"][:, 0].cpu(), reward.cpu()) assert torch.equal(rollout["next", "done"][:, 0].cpu(), done) assert torch.equal(rollout["next", "terminated"][:, 0].cpu(), terminated.cpu()) From 49412a32902c956c06f5fc60ca7256ba108aca1c Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Thu, 12 Mar 2026 16:57:14 +0000 Subject: [PATCH 19/23] Update docs --- docs/source/overview/rl/buffer.md | 73 ++++++++++++++++--------------- docs/source/overview/rl/models.md | 23 +++++++--- docs/source/tutorial/rl.rst | 15 +++++-- 3 files changed, 66 insertions(+), 45 deletions(-) diff --git a/docs/source/overview/rl/buffer.md b/docs/source/overview/rl/buffer.md index 06a38640..1f23a75b 100644 --- a/docs/source/overview/rl/buffer.md +++ b/docs/source/overview/rl/buffer.md @@ -5,61 +5,64 @@ This module implements the data buffer for RL training, responsible for storing ## Main Classes and Structure ### RolloutBuffer -- Used for on-policy algorithms (such as PPO, GRPO), efficiently stores observations, actions, rewards, dones, values, and logprobs for each step. -- Supports multi-environment parallelism (shape: [T, N, ...]), all data allocated on GPU. +- Used for on-policy algorithms (such as PPO, GRPO), storing a shared rollout `TensorDict` for collector and algorithm stages. +- Supports multi-environment parallelism with rollout batch shape `[N, T]`, all data allocated on GPU. - Structure fields: - - `obs`: Observation tensor, float32, shape [T, N, obs_dim] - - `actions`: Action tensor, float32, shape [T, N, action_dim] - - `rewards`: Reward tensor, float32, shape [T, N] - - `dones`: Done flags, bool, shape [T, N] - - `values`: Value estimates, float32, shape [T, N] - - `logprobs`: Action log probabilities, float32, shape [T, N] - - `_extras`: Algorithm-specific fields (e.g., advantages, returns), dict[str, Tensor] + - `obs`: Flattened observation tensor, float32, shape `[N, T, obs_dim]` + - `action`: Action tensor, float32, shape `[N, T, action_dim]` + - `sample_log_prob`: Action log probabilities, float32, shape `[N, T]` + - `value`: Value estimates, float32, shape `[N, T]` + - `next.reward`: Reward tensor, float32, shape `[N, T]` + - `next.done`: Done flags, bool, shape `[N, T]` + - `next.terminated`: Termination flags, bool, shape `[N, T]` + - `next.truncated`: Truncation flags, bool, shape `[N, T]` + - `next.value`: Bootstrap next-state values, float32, shape `[N, T]` + - Algorithm-added fields such as `advantage`, `return`, `seq_mask`, and `seq_return` ## Main Methods -- `add(obs, action, reward, done, value, logprob)`: Add one step of data. -- `set_extras(extras)`: Attach algorithm-related tensors (e.g., advantages, returns). -- `iterate_minibatches(batch_size)`: Randomly sample minibatches, returns dict (including all fields and extras). -- Supports efficient GPU shuffle and indexing for large-scale training. +- `start_rollout()`: Returns the shared preallocated rollout `TensorDict` for collector write-in. +- `add(rollout)`: Marks the shared rollout as ready for consumption. +- `get(flatten=True)`: Returns the stored rollout, optionally flattened over `[N, T]`. +- `iterate_minibatches(rollout, batch_size, device)`: Shared batching utility in `buffer/utils.py`. ## Usage Example ```python -buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, device) -for t in range(num_steps): - buffer.add(obs, action, reward, done, value, logprob) -buffer.set_extras({"advantages": adv, "returns": ret}) -for batch in buffer.iterate_minibatches(batch_size): - # batch["obs"], batch["actions"], batch["advantages"] ... +buffer = RolloutBuffer(num_envs, rollout_len, obs_dim, action_dim, device) +rollout = collector.collect(num_steps=rollout_len, rollout=buffer.start_rollout()) +buffer.add(rollout) + +rollout = buffer.get(flatten=False) +for batch in iterate_minibatches(rollout.reshape(-1), batch_size, device): + # batch["obs"], batch["action"], batch["advantage"] ... pass ``` ## Design and Extension -- Supports multi-environment parallel collection, compatible with Gymnasium/IsaacGym environments. -- All data is allocated on GPU to avoid frequent CPU-GPU copying. -- The extras field can be flexibly extended to meet different algorithm needs (e.g., GAE, TD-lambda, distributional advantages). -- The iterator automatically shuffles to improve training stability. -- Compatible with various RL algorithms (PPO, GRPO, A2C, SAC, etc.), custom fields and sampling logic supported. +- Supports multi-environment parallel collection, compatible with Gymnasium-style vectorized environments. +- All tensors are preallocated on device to avoid frequent CPU-GPU copying. +- Algorithm-specific fields are attached directly onto the shared rollout `TensorDict` during optimization. +- The shared minibatch iterator automatically shuffles flattened rollout entries for PPO/GRPO style updates. ## Code Example ```python class RolloutBuffer: - def __init__(self, num_steps, num_envs, obs_dim, action_dim, device): - # Initialize tensors + def __init__(self, num_envs, rollout_len, obs_dim, action_dim, device): + # Preallocate rollout TensorDict ... - def add(self, obs, action, reward, done, value, logprob): - # Add data + def start_rollout(self): + # Return shared rollout storage ... - def set_extras(self, extras): - # Attach algorithm-related tensors + def add(self, rollout): + # Mark rollout as full ... - def iterate_minibatches(self, batch_size): - # Random minibatch sampling + def get(self, flatten=True): + # Consume rollout ... ``` ## Practical Tips -- It is recommended to call set_extras after each rollout to ensure advantage/return tensors align with main data. -- When using iterate_minibatches, set batch_size appropriately for training stability. -- Extend the extras field as needed for custom sampling and statistics. +- The rollout buffer stores flattened RL observations; structured observations should be flattened or encoded before entering this buffer. +- `next.value` is kept for bootstrap convenience, while `next.obs` is intentionally not stored to reduce duplicated memory. +- Use `buffer/utils.py` for shared minibatch iteration instead of duplicating batching logic in each algorithm. --- diff --git a/docs/source/overview/rl/models.md b/docs/source/overview/rl/models.md index 56de7926..9c85cd9f 100644 --- a/docs/source/overview/rl/models.md +++ b/docs/source/overview/rl/models.md @@ -7,10 +7,10 @@ This module contains RL policy networks and related model implementations, suppo ### Policy - Abstract base class for RL policies; all policies must inherit from it. - Unified interface: - - `forward(tensordict, deterministic=False)`: Write action, log prob, and value into a `TensorDict`. + - `get_action(tensordict, deterministic=False)`: Sample actions into a `TensorDict` without gradients. + - `forward(tensordict, deterministic=False)`: Low-level action/value write path used by policy implementations. - `get_value(tensordict)`: Estimate state value into a `TensorDict`. - - `evaluate_actions(tensordict)`: Evaluate action probabilities, entropy, and value from a `TensorDict`. -- `get_action(obs, deterministic=False)` is retained as a compatibility layer for evaluation and legacy callers. + - `evaluate_actions(tensordict)`: Return optimization-time policy outputs from a `TensorDict`. - Supports GPU deployment and distributed training. ### ActorCritic @@ -30,15 +30,26 @@ This module contains RL policy networks and related model implementations, suppo - Supports orthogonal initialization and output reshaping. ### Factory Functions -- `build_policy(policy_block, obs_dim, action_dim, device, ...)`: Automatically build policy from config. +- `build_policy(policy_block, obs_space, action_space, device, ...)`: Automatically build policy from config. - `build_mlp_from_cfg(module_cfg, in_dim, out_dim)`: Automatically build MLP from config. ## Usage Example ```python actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1) -policy = build_policy(policy_block, obs_dim, action_dim, device, actor=actor, critic=critic) -action, log_prob, value = policy.get_action(obs) +policy = build_policy( + policy_block, + env.flattened_observation_space, + env.action_space, + device, + actor=actor, + critic=critic, +) +step_td = TensorDict({"obs": obs}, batch_size=[obs.shape[0]], device=obs.device) +step_td = policy.get_action(step_td) +action = step_td["action"] +log_prob = step_td["sample_log_prob"] +value = step_td["value"] ``` ## Extension and Customization diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index bcefb80a..81063fb8 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -109,7 +109,7 @@ Policy Configuration The ``policy`` section defines the neural network policy: - **name**: Policy name (e.g., "actor_critic", "vla") -- **action_dim**: Policy output action dimension; must match ``env.action_space`` +- **action_dim**: Optional policy output action dimension. If omitted, it is inferred from ``env.action_space``. - **actor**: Actor network configuration (required for actor_critic) - **critic**: Critic network configuration (required for actor_critic) @@ -119,7 +119,6 @@ Example: "policy": { "name": "actor_critic", - "action_dim": 8, "actor": { "type": "mlp", "network_cfg": { @@ -235,7 +234,7 @@ Training Process The training process follows this sequence: -1. **Rollout Phase**: ``SyncCollector`` interacts with the environment and writes policy-side fields into a shared rollout ``TensorDict``. ``EmbodiedEnv`` writes environment-side ``next.*`` fields into the same rollout via ``set_rollout_buffer()``. +1. **Rollout Phase**: ``SyncCollector`` interacts with the environment and writes policy-side fields into a shared rollout ``TensorDict``. ``EmbodiedEnv`` writes environment-side step fields such as ``next.reward``, ``next.done``, ``next.terminated``, and ``next.truncated`` into the same rollout via ``set_rollout_buffer()``. 2. **Advantage/Return Computation**: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) 3. **Update Phase**: Algorithm updates the policy with ``update(rollout)`` 4. **Logging**: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases @@ -255,6 +254,10 @@ All policies must inherit from the ``Policy`` abstract base class: class Policy(nn.Module, ABC): device: torch.device + def get_action(self, tensordict, deterministic: bool = False): + """Samples action, sample_log_prob, and value into the TensorDict.""" + ... + @abstractmethod def forward(self, tensordict, deterministic: bool = False): """Writes action, sample_log_prob, and value into the TensorDict.""" @@ -267,7 +270,7 @@ All policies must inherit from the ``Policy`` abstract base class: @abstractmethod def evaluate_actions(self, tensordict): - """Writes log_prob, entropy, and value into the TensorDict.""" + """Returns a new TensorDict with log_prob, entropy, and value.""" raise NotImplementedError Available Policies @@ -332,6 +335,8 @@ To add a new policy: self.device = device # Initialize your networks here + def get_action(self, tensordict, deterministic=False): + ... def forward(self, tensordict, deterministic=False): ... def get_value(self, tensordict): @@ -339,6 +344,8 @@ To add a new policy: def evaluate_actions(self, tensordict): ... +Current built-in MLP policies use flattened observations in the training path. If your policy requires structured or multi-modal inputs, keep the richer ``obs_space`` interface and define a matching rollout/collector schema. + Adding a New Environment ------------------------ From 5afe0d777327957672a8f00855b40c8605d31ed8 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 13 Mar 2026 05:51:27 +0000 Subject: [PATCH 20/23] Restructure buffer --- embodichain/agents/rl/algo/common.py | 36 +++-- embodichain/agents/rl/algo/grpo.py | 17 ++- embodichain/agents/rl/algo/ppo.py | 4 +- embodichain/agents/rl/buffer/__init__.py | 4 +- .../agents/rl/buffer/standard_buffer.py | 125 +++++++++++------- embodichain/agents/rl/buffer/utils.py | 45 ++++++- .../agents/rl/collector/sync_collector.py | 56 +++++--- embodichain/agents/rl/utils/trainer.py | 4 +- embodichain/lab/gym/envs/embodied_env.py | 33 +++-- tests/agents/test_shared_rollout.py | 31 +++-- 10 files changed, 233 insertions(+), 122 deletions(-) diff --git a/embodichain/agents/rl/algo/common.py b/embodichain/agents/rl/algo/common.py index eb3c5d17..9dcf2f78 100644 --- a/embodichain/agents/rl/algo/common.py +++ b/embodichain/agents/rl/algo/common.py @@ -25,18 +25,20 @@ def compute_gae( rollout: TensorDict, gamma: float, gae_lambda: float ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute GAE over a rollout with batch shape `[num_envs, time]`. + """Compute GAE over a rollout stored as `[num_envs, time + 1]`. Args: - rollout: Rollout TensorDict containing `value` and `next` transition data. + rollout: Rollout TensorDict where `value[:, -1]` stores the bootstrap + value for the final observation and transition-only fields reserve + their last slot as padding. gamma: Discount factor. gae_lambda: GAE lambda coefficient. Returns: Tuple of `(advantages, returns)`, both shaped `[num_envs, time]`. """ - rewards = rollout["next", "reward"].float() - dones = rollout["next", "done"].bool() + rewards = rollout["reward"][:, :-1].float() + dones = rollout["done"][:, :-1].bool() values = rollout["value"].float() if rewards.ndim != 2: @@ -44,29 +46,23 @@ def compute_gae( f"Expected reward tensor with shape [num_envs, time], got {rewards.shape}." ) - next_values = _get_next_values(rollout, values) num_envs, time_dim = rewards.shape - advantages = torch.zeros_like(rewards) + if values.shape != (num_envs, time_dim + 1): + raise ValueError( + "Expected value tensor with shape [num_envs, time + 1], got " + f"{values.shape} for rewards shape {rewards.shape}." + ) + advantages = torch.zeros_like(rollout["reward"].float()) last_advantage = torch.zeros(num_envs, device=rewards.device, dtype=rewards.dtype) for t in reversed(range(time_dim)): not_done = (~dones[:, t]).float() - delta = rewards[:, t] + gamma * next_values[:, t] * not_done - values[:, t] + delta = rewards[:, t] + gamma * values[:, t + 1] * not_done - values[:, t] last_advantage = delta + gamma * gae_lambda * not_done * last_advantage advantages[:, t] = last_advantage - returns = advantages + values + returns = torch.zeros_like(advantages) + returns[:, :-1] = advantages[:, :-1] + values[:, :-1] rollout["advantage"] = advantages rollout["return"] = returns - return advantages, returns - - -def _get_next_values(rollout: TensorDict, values: torch.Tensor) -> torch.Tensor: - """Resolve next-step values for GAE bootstrap.""" - next_value = rollout.get(("next", "value"), None) - if next_value is not None: - return next_value.float() - - next_values = torch.zeros_like(values) - next_values[:, :-1] = values[:, 1:] - return next_values + return advantages[:, :-1], returns[:, :-1] diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index d8d1a3ef..2b2d73d3 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -23,7 +23,7 @@ import torch from tensordict import TensorDict -from embodichain.agents.rl.buffer import iterate_minibatches +from embodichain.agents.rl.buffer import iterate_minibatches, transition_view from embodichain.agents.rl.utils import AlgorithmCfg from embodichain.utils import configclass from .base import BaseAlgorithm @@ -121,16 +121,19 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: f"num_envs={num_envs}, group_size={self.cfg.group_size}." ) - rewards = rollout["next", "reward"].float() - dones = rollout["next", "done"].bool() + rewards = rollout["reward"][:, :-1].float() + dones = rollout["done"][:, :-1].bool() step_returns, seq_mask = self._compute_step_returns_and_mask(rewards, dones) - rollout["advantage"] = self._compute_step_group_advantages( + rollout["advantage"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) + rollout["advantage"][:, :-1] = self._compute_step_group_advantages( step_returns, seq_mask ) - rollout["seq_mask"] = seq_mask - rollout["seq_return"] = step_returns + rollout["seq_mask"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) + rollout["seq_mask"][:, :-1] = seq_mask + rollout["seq_return"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) + rollout["seq_return"][:, :-1] = step_returns - flat_rollout = rollout.reshape(math.prod(rollout.batch_size)) + flat_rollout = transition_view(rollout, flatten=True) total_actor_loss = 0.0 total_entropy = 0.0 diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 6036d3b5..4ba88bb4 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -20,7 +20,7 @@ import torch from tensordict import TensorDict -from embodichain.agents.rl.buffer import iterate_minibatches +from embodichain.agents.rl.buffer import iterate_minibatches, transition_view from embodichain.agents.rl.utils import AlgorithmCfg from embodichain.utils import configclass from .common import compute_gae @@ -51,7 +51,7 @@ def update(self, rollout: TensorDict) -> Dict[str, float]: """Update the policy using a collected rollout.""" rollout = rollout.clone() compute_gae(rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda) - flat_rollout = rollout.reshape(math.prod(rollout.batch_size)) + flat_rollout = transition_view(rollout, flatten=True) advantages = flat_rollout["advantage"] adv_mean = advantages.mean() diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py index db0dd4dd..a83c0776 100644 --- a/embodichain/agents/rl/buffer/__init__.py +++ b/embodichain/agents/rl/buffer/__init__.py @@ -15,6 +15,6 @@ # ---------------------------------------------------------------------------- from .standard_buffer import RolloutBuffer -from .utils import iterate_minibatches +from .utils import iterate_minibatches, transition_view -__all__ = ["RolloutBuffer", "iterate_minibatches"] +__all__ = ["RolloutBuffer", "iterate_minibatches", "transition_view"] diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py index 8a56d5f3..2df69f86 100644 --- a/embodichain/agents/rl/buffer/standard_buffer.py +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -16,16 +16,22 @@ from __future__ import annotations -import math - import torch from tensordict import TensorDict +from .utils import transition_view + __all__ = ["RolloutBuffer"] class RolloutBuffer: - """Single-rollout buffer backed by a preallocated TensorDict.""" + """Single-rollout buffer backed by a preallocated TensorDict. + + The shared rollout uses a uniform `[num_envs, time + 1]` layout. For + transition-only fields such as `action`, `reward`, and `done`, the final + time index is reused as padding so the collector, environment, and + algorithms can share a single TensorDict batch shape. + """ def __init__( self, @@ -56,15 +62,20 @@ def add(self, rollout: TensorDict) -> None: raise ValueError( "RolloutBuffer only accepts its shared rollout TensorDict." ) - if tuple(rollout.batch_size) != (self.num_envs, self.rollout_len): + if tuple(rollout.batch_size) != (self.num_envs, self.rollout_len + 1): raise ValueError( "Rollout batch size does not match buffer allocation: " - f"expected ({self.num_envs}, {self.rollout_len}), got {tuple(rollout.batch_size)}." + f"expected ({self.num_envs}, {self.rollout_len + 1}), got {tuple(rollout.batch_size)}." ) + self._validate_rollout_layout(rollout) self._is_full = True def get(self, flatten: bool = True) -> TensorDict: - """Return the stored rollout and clear the buffer.""" + """Return the stored rollout and clear the buffer. + + When `flatten` is True, the rollout is first converted to a transition + view that drops the padded final slot from transition-only fields. + """ if not self._is_full: raise RuntimeError("RolloutBuffer is empty.") @@ -74,77 +85,68 @@ def get(self, flatten: bool = True) -> TensorDict: if not flatten: return rollout - total_batch = math.prod(rollout.batch_size) - return rollout.reshape(total_batch) + return transition_view(rollout, flatten=True) def is_full(self) -> bool: """Return whether a rollout is waiting to be consumed.""" return self._is_full def _allocate_rollout(self) -> TensorDict: - """Preallocate rollout storage with batch shape `[num_envs, time]`.""" + """Preallocate rollout storage with uniform `[num_envs, time + 1]` shape.""" return TensorDict( { "obs": torch.empty( self.num_envs, - self.rollout_len, + self.rollout_len + 1, self.obs_dim, dtype=torch.float32, device=self.device, ), "action": torch.empty( self.num_envs, - self.rollout_len, + self.rollout_len + 1, self.action_dim, dtype=torch.float32, device=self.device, ), "sample_log_prob": torch.empty( self.num_envs, - self.rollout_len, + self.rollout_len + 1, dtype=torch.float32, device=self.device, ), "value": torch.empty( self.num_envs, - self.rollout_len, + self.rollout_len + 1, dtype=torch.float32, device=self.device, ), - "next": { - "reward": torch.empty( - self.num_envs, - self.rollout_len, - dtype=torch.float32, - device=self.device, - ), - "done": torch.empty( - self.num_envs, - self.rollout_len, - dtype=torch.bool, - device=self.device, - ), - "terminated": torch.empty( - self.num_envs, - self.rollout_len, - dtype=torch.bool, - device=self.device, - ), - "truncated": torch.empty( - self.num_envs, - self.rollout_len, - dtype=torch.bool, - device=self.device, - ), - "value": torch.empty( - self.num_envs, - self.rollout_len, - dtype=torch.float32, - device=self.device, - ), - }, + "reward": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.float32, + device=self.device, + ), + "done": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.bool, + device=self.device, + ), + "terminated": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.bool, + device=self.device, + ), + "truncated": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.bool, + device=self.device, + ), }, - batch_size=[self.num_envs, self.rollout_len], + batch_size=[self.num_envs, self.rollout_len + 1], device=self.device, ) @@ -153,3 +155,34 @@ def _clear_dynamic_fields(self) -> None: for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): if key in self._rollout.keys(): del self._rollout[key] + self._reset_padding_slot() + + def _reset_padding_slot(self) -> None: + """Reset the last transition-only slot reused as padding.""" + last_idx = self.rollout_len + self._rollout["action"][:, last_idx].zero_() + self._rollout["sample_log_prob"][:, last_idx].zero_() + self._rollout["reward"][:, last_idx].zero_() + self._rollout["done"][:, last_idx].fill_(False) + self._rollout["terminated"][:, last_idx].fill_(False) + self._rollout["truncated"][:, last_idx].fill_(False) + + def _validate_rollout_layout(self, rollout: TensorDict) -> None: + """Validate the expected tensor shapes for the shared rollout.""" + expected_shapes = { + "obs": (self.num_envs, self.rollout_len + 1, self.obs_dim), + "action": (self.num_envs, self.rollout_len + 1, self.action_dim), + "sample_log_prob": (self.num_envs, self.rollout_len + 1), + "value": (self.num_envs, self.rollout_len + 1), + "reward": (self.num_envs, self.rollout_len + 1), + "done": (self.num_envs, self.rollout_len + 1), + "terminated": (self.num_envs, self.rollout_len + 1), + "truncated": (self.num_envs, self.rollout_len + 1), + } + for key, expected_shape in expected_shapes.items(): + actual_shape = tuple(rollout[key].shape) + if actual_shape != expected_shape: + raise ValueError( + f"Rollout field '{key}' shape mismatch: expected {expected_shape}, " + f"got {actual_shape}." + ) diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py index 00b62a0c..7c0d265b 100644 --- a/embodichain/agents/rl/buffer/utils.py +++ b/embodichain/agents/rl/buffer/utils.py @@ -21,7 +21,50 @@ import torch from tensordict import TensorDict -__all__ = ["iterate_minibatches"] +__all__ = ["iterate_minibatches", "transition_view"] + + +def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: + """Build a transition-aligned TensorDict from a rollout. + + The shared rollout uses a uniform `[num_envs, time + 1]` layout. For + transition-only fields such as `action`, `reward`, and `done`, the final + slot is reserved as padding so that all rollout fields share the same batch + shape. This helper drops that padded slot and exposes the valid transition + slices as a TensorDict with batch shape `[num_envs, time]`. + + Args: + rollout: Rollout TensorDict with root batch shape `[num_envs, time + 1]`. + flatten: If True, return a flattened `[num_envs * time]` view. + + Returns: + TensorDict containing transition-aligned fields. + """ + action = rollout["action"][:, :-1] + num_envs, time_dim = action.shape[:2] + td = TensorDict( + { + "obs": rollout["obs"][:, :-1], + "action": action, + "sample_log_prob": rollout["sample_log_prob"][:, :-1], + "value": rollout["value"][:, :-1], + "next_value": rollout["value"][:, 1:], + "reward": rollout["reward"][:, :-1], + "done": rollout["done"][:, :-1], + "terminated": rollout["terminated"][:, :-1], + "truncated": rollout["truncated"][:, :-1], + }, + batch_size=[num_envs, time_dim], + device=rollout.device, + ) + + for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): + if key in rollout.keys(): + td[key] = rollout[key][:, :-1] + + if flatten: + return td.reshape(num_envs * time_dim) + return td def iterate_minibatches( diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py index 22448d05..16c5b584 100644 --- a/embodichain/agents/rl/collector/sync_collector.py +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -59,19 +59,21 @@ def collect( raise ValueError( "SyncCollector.collect() requires a preallocated rollout TensorDict." ) - if tuple(rollout.batch_size) != (self.env.num_envs, num_steps): + if tuple(rollout.batch_size) != (self.env.num_envs, num_steps + 1): raise ValueError( "Preallocated rollout batch size mismatch: " - f"expected ({self.env.num_envs}, {num_steps}), got {tuple(rollout.batch_size)}." + f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}." ) + self._validate_rollout(rollout, num_steps) if self._supports_shared_rollout: self.env.set_rollout_buffer(rollout) + initial_obs = flatten_dict_observation(self.obs_td) + rollout["obs"][:, 0] = initial_obs for step_idx in range(num_steps): - obs_tensor = flatten_dict_observation(self.obs_td) step_td = TensorDict( - {"obs": obs_tensor}, - batch_size=[obs_tensor.shape[0]], + {"obs": rollout["obs"][:, step_idx]}, + batch_size=[rollout.batch_size[0]], device=self.device, ) step_td = self.policy.get_action(step_td) @@ -93,29 +95,26 @@ def collect( terminated=terminated, truncated=truncated, ) + rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) if on_step_callback is not None: on_step_callback(rollout[:, step_idx], env_info) self.obs_td = next_obs_td - self._attach_next_values(rollout) + self._attach_final_value(rollout) return rollout - def _attach_next_values(self, rollout: TensorDict) -> None: - """Populate `next.value` for GAE bootstrap.""" - next_values = torch.zeros_like(rollout["value"]) - next_values[:, :-1] = rollout["value"][:, 1:] - - final_obs = flatten_dict_observation(self.obs_td) + def _attach_final_value(self, rollout: TensorDict) -> None: + """Populate the bootstrap value for the final observed state.""" + final_obs = rollout["obs"][:, -1] last_next_td = TensorDict( {"obs": final_obs}, batch_size=[rollout.batch_size[0]], device=self.device, ) self.policy.get_value(last_next_td) - next_values[:, -1] = last_next_td["value"] - rollout["next", "value"] = next_values + rollout["value"][:, -1] = last_next_td["value"] def _reset_env(self) -> TensorDict: obs, _ = self.env.reset() @@ -135,7 +134,6 @@ def _write_step( step_td: TensorDict, ) -> None: """Write policy-side fields for one transition into the shared rollout TensorDict.""" - rollout["obs"][:, step_idx] = step_td["obs"] rollout["action"][:, step_idx] = step_td["action"] rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] rollout["value"][:, step_idx] = step_td["value"] @@ -150,7 +148,27 @@ def _write_env_step( ) -> None: """Populate transition-side fields when the environment does not own the rollout.""" done = terminated | truncated - rollout["next", "reward"][:, step_idx] = reward.to(self.device) - rollout["next", "done"][:, step_idx] = done.to(self.device) - rollout["next", "terminated"][:, step_idx] = terminated.to(self.device) - rollout["next", "truncated"][:, step_idx] = truncated.to(self.device) + rollout["reward"][:, step_idx] = reward.to(self.device) + rollout["done"][:, step_idx] = done.to(self.device) + rollout["terminated"][:, step_idx] = terminated.to(self.device) + rollout["truncated"][:, step_idx] = truncated.to(self.device) + + def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: + """Validate rollout layout expected by the collector.""" + expected_shapes = { + "obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim), + "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), + "sample_log_prob": (self.env.num_envs, num_steps + 1), + "value": (self.env.num_envs, num_steps + 1), + "reward": (self.env.num_envs, num_steps + 1), + "done": (self.env.num_envs, num_steps + 1), + "terminated": (self.env.num_envs, num_steps + 1), + "truncated": (self.env.num_envs, num_steps + 1), + } + for key, expected_shape in expected_shapes.items(): + actual_shape = tuple(rollout[key].shape) + if actual_shape != expected_shape: + raise ValueError( + f"Preallocated rollout field '{key}' shape mismatch: " + f"expected {expected_shape}, got {actual_shape}." + ) diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 6dfb14b6..4f660232 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -160,8 +160,8 @@ def _collect_rollout(self): # Callback function for statistics and logging def on_step(tensordict: TensorDict, info: dict): """Callback called at each step during rollout collection.""" - reward = tensordict["next", "reward"] - done = tensordict["next", "done"] + reward = tensordict["reward"] + done = tensordict["done"] # Episode stats self.curr_ret += reward self.curr_len += 1 diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 995234d7..ebe26e7f 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -290,16 +290,23 @@ def set_rollout_buffer(self, rollout_buffer: TensorDict) -> None: such as a shared rollout buffer initialized in model training process and passed to the environment for data collection. Args: - rollout_buffer (TensorDict): The rollout buffer to be set. The shape of the buffer should be (num_envs, max_episode_steps, *data_shape) for each key. + rollout_buffer (TensorDict): The rollout buffer to be set. RL + rollouts use a uniform `[num_envs, time + 1]` layout so all + fields share the same batch shape; the last slot of + transition-only fields is reserved as padding. Expert buffers + keep the legacy `[num_envs, time]` batch layout. """ - if len(rollout_buffer.shape) != 2: - logger.log_error( - f"Invalid rollout buffer shape: {rollout_buffer.shape}. The expected shape is (num_envs, max_episode_steps) for each key." - ) self.rollout_buffer = rollout_buffer - self._max_rollout_steps = self.rollout_buffer.shape[1] - self.current_rollout_step = 0 self._rollout_buffer_mode = self._infer_rollout_buffer_mode(rollout_buffer) + if self._rollout_buffer_mode == "rl": + self._max_rollout_steps = self.rollout_buffer.batch_size[1] - 1 + else: + if len(rollout_buffer.shape) != 2: + logger.log_error( + f"Invalid rollout buffer shape: {rollout_buffer.shape}. The expected shape is (num_envs, max_episode_steps) for each key." + ) + self._max_rollout_steps = self.rollout_buffer.shape[1] + self.current_rollout_step = 0 def _init_sim_state(self, **kwargs): """Initialize the simulation state at the beginning of scene creation.""" @@ -523,7 +530,9 @@ def _initialize_episode( def _infer_rollout_buffer_mode(self, rollout_buffer: TensorDict) -> str: """Infer whether the rollout buffer is expert recording or RL training data.""" - if "next" in rollout_buffer.keys() and "obs" in rollout_buffer.keys(): + if {"obs", "action", "reward", "done", "value"}.issubset( + set(rollout_buffer.keys()) + ): return "rl" return "expert" @@ -570,10 +579,10 @@ def _write_rl_rollout_step( ) -> None: """Write environment-side fields into an externally managed RL rollout buffer.""" buffer_device = self.rollout_buffer.device - self.rollout_buffer["next", "reward"][:, self.current_rollout_step].copy_( + self.rollout_buffer["reward"][:, self.current_rollout_step].copy_( rewards.to(buffer_device), non_blocking=True ) - self.rollout_buffer["next", "done"][:, self.current_rollout_step].copy_( + self.rollout_buffer["done"][:, self.current_rollout_step].copy_( dones.to(buffer_device), non_blocking=True ) terminateds = ( @@ -586,10 +595,10 @@ def _write_rl_rollout_step( if truncateds is not None else torch.zeros_like(dones, dtype=torch.bool) ) - self.rollout_buffer["next", "terminated"][:, self.current_rollout_step].copy_( + self.rollout_buffer["terminated"][:, self.current_rollout_step].copy_( terminateds.to(buffer_device), non_blocking=True ) - self.rollout_buffer["next", "truncated"][:, self.current_rollout_step].copy_( + self.rollout_buffer["truncated"][:, self.current_rollout_step].copy_( truncateds.to(buffer_device), non_blocking=True ) diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 847bc55a..29b7fb77 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -87,14 +87,14 @@ def step(self, action_dict): truncated = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) if self.rollout_buffer is not None: - self.rollout_buffer["next", "reward"][:, self.current_rollout_step] = reward - self.rollout_buffer["next", "done"][:, self.current_rollout_step] = ( + self.rollout_buffer["reward"][:, self.current_rollout_step] = reward + self.rollout_buffer["done"][:, self.current_rollout_step] = ( terminated | truncated ) - self.rollout_buffer["next", "terminated"][ + self.rollout_buffer["terminated"][ :, self.current_rollout_step ] = terminated - self.rollout_buffer["next", "truncated"][ + self.rollout_buffer["truncated"][ :, self.current_rollout_step ] = truncated self.current_rollout_step += 1 @@ -152,8 +152,13 @@ def test_shared_rollout_collects_policy_and_env_fields(): buffer.add(rollout) stored = buffer.get(flatten=False) - assert stored.batch_size == torch.Size([num_envs, rollout_len]) + assert stored.batch_size == torch.Size([num_envs, rollout_len + 1]) + assert stored["obs"].shape == torch.Size([num_envs, rollout_len + 1, obs_dim]) assert torch.allclose(stored["obs"][:, 0], torch.zeros(num_envs, obs_dim)) + assert torch.allclose( + stored["obs"][:, -1], + torch.full((num_envs, obs_dim), float(rollout_len), dtype=torch.float32), + ) assert torch.allclose( stored["value"][:, 1], torch.ones(num_envs, dtype=torch.float32) ) @@ -166,13 +171,17 @@ def test_shared_rollout_collects_policy_and_env_fields(): torch.full((num_envs,), 0.5, dtype=torch.float32), ) assert torch.allclose( - stored["next", "reward"][:, 2], + stored["reward"][:, 2], torch.full((num_envs,), 1.0, dtype=torch.float32), ) assert torch.allclose( - stored["next", "value"][:, -1], + stored["value"][:, -1], torch.full((num_envs,), 4.0, dtype=torch.float32), ) + assert torch.allclose( + stored["action"][:, -1], + torch.zeros(num_envs, action_dim, dtype=torch.float32), + ) def test_embodied_env_writes_next_fields_into_external_rollout(): @@ -212,10 +221,10 @@ def test_embodied_env_writes_next_fields_into_external_rollout(): done = (terminated | truncated).cpu() assert env.current_rollout_step == 1 - assert torch.allclose(rollout["next", "reward"][:, 0].cpu(), reward.cpu()) - assert torch.equal(rollout["next", "done"][:, 0].cpu(), done) - assert torch.equal(rollout["next", "terminated"][:, 0].cpu(), terminated.cpu()) - assert torch.equal(rollout["next", "truncated"][:, 0].cpu(), truncated.cpu()) + assert torch.allclose(rollout["reward"][:, 0].cpu(), reward.cpu()) + assert torch.equal(rollout["done"][:, 0].cpu(), done) + assert torch.equal(rollout["terminated"][:, 0].cpu(), terminated.cpu()) + assert torch.equal(rollout["truncated"][:, 0].cpu(), truncated.cpu()) finally: env.close() if SimulationManager.is_instantiated(): From d99bbe5fd9282ec314426dd249b8daf55c79fd36 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 13 Mar 2026 05:51:39 +0000 Subject: [PATCH 21/23] Update docs --- docs/source/overview/rl/algorithm.md | 4 +-- docs/source/overview/rl/buffer.md | 38 +++++++++++++++++----------- docs/source/overview/rl/trainer.md | 2 +- docs/source/tutorial/rl.rst | 4 +-- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/docs/source/overview/rl/algorithm.md b/docs/source/overview/rl/algorithm.md index e7970709..162e487f 100644 --- a/docs/source/overview/rl/algorithm.md +++ b/docs/source/overview/rl/algorithm.md @@ -9,12 +9,12 @@ This module contains the core implementations of reinforcement learning algorith - Key methods: - `update(rollout)`: Update the policy based on a shared rollout `TensorDict`. - Designed to be algorithm-agnostic; `Trainer` handles collection while algorithms focus on loss computation and optimization. -- Supports multi-environment parallel collection through a shared `[N, T]` rollout `TensorDict`. +- Consumes a shared `[N, T + 1]` rollout `TensorDict` and typically converts it to a transition-aligned view over the valid first `T` steps before optimization. ### PPO - Mainstream on-policy algorithm, supports Generalized Advantage Estimation (GAE), policy update, and hyperparameter configuration. - Key methods: - - `compute_gae(rollout, gamma, gae_lambda)`: Generalized Advantage Estimation over a shared rollout `TensorDict`. + - `compute_gae(rollout, gamma, gae_lambda)`: Generalized Advantage Estimation over a shared rollout `TensorDict`, using `value[:, -1]` as the bootstrap value and ignoring the padded final transition slot. - `update(rollout)`: Multi-epoch minibatch optimization, including entropy, value, and policy loss, with gradient clipping. - Supports custom callbacks, detailed logging, and GPU acceleration. - Typical training flow: collect rollout → compute advantage/return → multi-epoch minibatch optimization. diff --git a/docs/source/overview/rl/buffer.md b/docs/source/overview/rl/buffer.md index 1f23a75b..955e11b6 100644 --- a/docs/source/overview/rl/buffer.md +++ b/docs/source/overview/rl/buffer.md @@ -6,23 +6,30 @@ This module implements the data buffer for RL training, responsible for storing ### RolloutBuffer - Used for on-policy algorithms (such as PPO, GRPO), storing a shared rollout `TensorDict` for collector and algorithm stages. -- Supports multi-environment parallelism with rollout batch shape `[N, T]`, all data allocated on GPU. +- Supports multi-environment parallelism with rollout batch shape `[N, T + 1]`, all data allocated on GPU. - Structure fields: - - `obs`: Flattened observation tensor, float32, shape `[N, T, obs_dim]` - - `action`: Action tensor, float32, shape `[N, T, action_dim]` - - `sample_log_prob`: Action log probabilities, float32, shape `[N, T]` - - `value`: Value estimates, float32, shape `[N, T]` - - `next.reward`: Reward tensor, float32, shape `[N, T]` - - `next.done`: Done flags, bool, shape `[N, T]` - - `next.terminated`: Termination flags, bool, shape `[N, T]` - - `next.truncated`: Truncation flags, bool, shape `[N, T]` - - `next.value`: Bootstrap next-state values, float32, shape `[N, T]` + - `obs`: Flattened observation tensor, float32, shape `[N, T + 1, obs_dim]` + - `action`: Action tensor, float32, shape `[N, T + 1, action_dim]` + - `sample_log_prob`: Action log probabilities, float32, shape `[N, T + 1]` + - `value`: Value estimates, float32, shape `[N, T + 1]` + - `reward`: Reward tensor, float32, shape `[N, T + 1]` + - `done`: Done flags, bool, shape `[N, T + 1]` + - `terminated`: Termination flags, bool, shape `[N, T + 1]` + - `truncated`: Truncation flags, bool, shape `[N, T + 1]` - Algorithm-added fields such as `advantage`, `return`, `seq_mask`, and `seq_return` +The final time index is valid for `obs` and `value`, where it stores the last +observation and bootstrap value. For transition-only fields (`action`, `reward`, +`done`, etc.), the final slot is padding so all rollout fields can share the +same `[N, T + 1]` batch shape. + ## Main Methods - `start_rollout()`: Returns the shared preallocated rollout `TensorDict` for collector write-in. - `add(rollout)`: Marks the shared rollout as ready for consumption. -- `get(flatten=True)`: Returns the stored rollout, optionally flattened over `[N, T]`. +- `get(flatten=True)`: Returns the stored rollout after converting it to a + transition view over the valid first `T` steps. +- `transition_view(rollout, flatten=False)`: Builds a transition-aligned view + that drops the padded final slot from transition-only fields. - `iterate_minibatches(rollout, batch_size, device)`: Shared batching utility in `buffer/utils.py`. ## Usage Example @@ -32,7 +39,8 @@ rollout = collector.collect(num_steps=rollout_len, rollout=buffer.start_rollout( buffer.add(rollout) rollout = buffer.get(flatten=False) -for batch in iterate_minibatches(rollout.reshape(-1), batch_size, device): +flat_rollout = transition_view(rollout, flatten=True) +for batch in iterate_minibatches(flat_rollout, batch_size, device): # batch["obs"], batch["action"], batch["advantage"] ... pass ``` @@ -41,7 +49,7 @@ for batch in iterate_minibatches(rollout.reshape(-1), batch_size, device): - Supports multi-environment parallel collection, compatible with Gymnasium-style vectorized environments. - All tensors are preallocated on device to avoid frequent CPU-GPU copying. - Algorithm-specific fields are attached directly onto the shared rollout `TensorDict` during optimization. -- The shared minibatch iterator automatically shuffles flattened rollout entries for PPO/GRPO style updates. +- The shared minibatch iterator automatically shuffles flattened transition entries for PPO/GRPO style updates. ## Code Example ```python @@ -62,7 +70,7 @@ class RolloutBuffer: ## Practical Tips - The rollout buffer stores flattened RL observations; structured observations should be flattened or encoded before entering this buffer. -- `next.value` is kept for bootstrap convenience, while `next.obs` is intentionally not stored to reduce duplicated memory. -- Use `buffer/utils.py` for shared minibatch iteration instead of duplicating batching logic in each algorithm. +- `value[:, -1]` stores the bootstrap value of the final observation. The final slot of transition-only fields is padding and should be ignored during optimization. +- Use `transition_view()` plus `iterate_minibatches()` instead of duplicating rollout slicing logic in each algorithm. --- diff --git a/docs/source/overview/rl/trainer.md b/docs/source/overview/rl/trainer.md index ffe09cbd..5ef4ee99 100644 --- a/docs/source/overview/rl/trainer.md +++ b/docs/source/overview/rl/trainer.md @@ -44,7 +44,7 @@ trainer.save_checkpoint() - Custom event modules can be implemented for environment reset, data collection, evaluation, etc. - Supports multi-environment parallelism and distributed training. - Training process can be flexibly adjusted via config files. -- The current trainer uses a shared rollout `TensorDict`: collector writes policy-side fields and `EmbodiedEnv` writes environment-side `next.*` fields through `set_rollout_buffer()`. +- The current trainer uses a shared rollout `TensorDict` with uniform `[N, T + 1]` layout: collector writes policy-side fields, `EmbodiedEnv` writes environment-side step fields through `set_rollout_buffer()`, and algorithms consume the valid first `T` steps through `transition_view()`. ## Practical Tips - It is recommended to perform periodic evaluation and model saving to prevent loss of progress during training. diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index 81063fb8..28054648 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -234,8 +234,8 @@ Training Process The training process follows this sequence: -1. **Rollout Phase**: ``SyncCollector`` interacts with the environment and writes policy-side fields into a shared rollout ``TensorDict``. ``EmbodiedEnv`` writes environment-side step fields such as ``next.reward``, ``next.done``, ``next.terminated``, and ``next.truncated`` into the same rollout via ``set_rollout_buffer()``. -2. **Advantage/Return Computation**: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) +1. **Rollout Phase**: ``SyncCollector`` interacts with the environment and writes policy-side fields into a shared rollout ``TensorDict`` with uniform ``[N, T + 1]`` layout. ``EmbodiedEnv`` writes environment-side step fields such as ``reward``, ``done``, ``terminated``, and ``truncated`` into the same rollout via ``set_rollout_buffer()``. The final slot of transition-only fields is reserved as padding, while ``obs[:, -1]`` and ``value[:, -1]`` remain valid bootstrap data. +2. **Advantage/Return Computation**: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) and converts it to a transition-aligned view over the valid first ``T`` steps before minibatch optimization. 3. **Update Phase**: Algorithm updates the policy with ``update(rollout)`` 4. **Logging**: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases 5. **Evaluation** (periodic): Trainer evaluates the current policy From 34bf0b40c73729b5c75ac63c9dffdf9fb8ee09b3 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 13 Mar 2026 05:53:24 +0000 Subject: [PATCH 22/23] Reformat files --- tests/agents/test_shared_rollout.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py index 29b7fb77..89335229 100644 --- a/tests/agents/test_shared_rollout.py +++ b/tests/agents/test_shared_rollout.py @@ -91,12 +91,8 @@ def step(self, action_dict): self.rollout_buffer["done"][:, self.current_rollout_step] = ( terminated | truncated ) - self.rollout_buffer["terminated"][ - :, self.current_rollout_step - ] = terminated - self.rollout_buffer["truncated"][ - :, self.current_rollout_step - ] = truncated + self.rollout_buffer["terminated"][:, self.current_rollout_step] = terminated + self.rollout_buffer["truncated"][:, self.current_rollout_step] = truncated self.current_rollout_step += 1 self._obs = next_obs From 9b44992d6cdce05c2150ab654c89e69254b1650e Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 13 Mar 2026 06:05:32 +0000 Subject: [PATCH 23/23] Update --- embodichain/agents/rl/algo/grpo.py | 1 - embodichain/agents/rl/algo/ppo.py | 1 - embodichain/lab/gym/envs/embodied_env.py | 23 ++++++++++++++++++----- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 2b2d73d3..12f7c32f 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -16,7 +16,6 @@ from __future__ import annotations -import math from copy import deepcopy from typing import Dict diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 4ba88bb4..e33ee5b3 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -14,7 +14,6 @@ # limitations under the License. # ---------------------------------------------------------------------------- -import math from typing import Dict import torch diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index ebe26e7f..5e40d6fd 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -52,7 +52,6 @@ from embodichain.lab.gym.utils.gym_utils import ( init_rollout_buffer_from_gym_space, ) -from embodichain.agents.rl.utils import flatten_dict_observation from embodichain.utils import configclass, logger @@ -299,7 +298,15 @@ def set_rollout_buffer(self, rollout_buffer: TensorDict) -> None: self.rollout_buffer = rollout_buffer self._rollout_buffer_mode = self._infer_rollout_buffer_mode(rollout_buffer) if self._rollout_buffer_mode == "rl": - self._max_rollout_steps = self.rollout_buffer.batch_size[1] - 1 + batch_size = self.rollout_buffer.batch_size + if len(batch_size) != 2: + message = ( + f"Invalid RL rollout buffer batch size: {batch_size}. " + "Expected a 2D batch layout [num_envs, time + 1] for RL rollouts." + ) + logger.log_error(message) + raise ValueError(message) + self._max_rollout_steps = batch_size[1] - 1 else: if len(rollout_buffer.shape) != 2: logger.log_error( @@ -530,9 +537,15 @@ def _initialize_episode( def _infer_rollout_buffer_mode(self, rollout_buffer: TensorDict) -> str: """Infer whether the rollout buffer is expert recording or RL training data.""" - if {"obs", "action", "reward", "done", "value"}.issubset( - set(rollout_buffer.keys()) - ): + if { + "obs", + "action", + "reward", + "done", + "value", + "terminated", + "truncated", + }.issubset(set(rollout_buffer.keys())): return "rl" return "expert"