diff --git a/configs/agents/rl/basic/cart_pole/train_config.json b/configs/agents/rl/basic/cart_pole/train_config.json index 06031b5e..02a302d1 100644 --- a/configs/agents/rl/basic/cart_pole/train_config.json +++ b/configs/agents/rl/basic/cart_pole/train_config.json @@ -9,7 +9,7 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "eval_freq": 200, "save_freq": 200, "use_wandb": false, diff --git a/configs/agents/rl/basic/cart_pole/train_config_grpo.json b/configs/agents/rl/basic/cart_pole/train_config_grpo.json index 6625afe1..4da5cab7 100644 --- a/configs/agents/rl/basic/cart_pole/train_config_grpo.json +++ b/configs/agents/rl/basic/cart_pole/train_config_grpo.json @@ -9,7 +9,7 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "eval_freq": 200, "save_freq": 200, "use_wandb": true, diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json index b58e2c63..d44aa0b3 100644 --- a/configs/agents/rl/push_cube/train_config.json +++ b/configs/agents/rl/push_cube/train_config.json @@ -9,7 +9,7 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "enable_eval": true, "num_eval_envs": 16, "num_eval_episodes": 3, diff --git a/docs/source/overview/rl/algorithm.md b/docs/source/overview/rl/algorithm.md index a22110e8..162e487f 100644 --- a/docs/source/overview/rl/algorithm.md +++ b/docs/source/overview/rl/algorithm.md @@ -5,20 +5,17 @@ This module contains the core implementations of reinforcement learning algorith ## Main Classes and Functions ### BaseAlgorithm -- Abstract base class for RL algorithms, defining common interfaces such as buffer initialization, data collection, and update. +- Abstract base class for RL algorithms, defining a single update interface over a collected rollout. - Key methods: - - `initialize_buffer(num_steps, num_envs, obs_dim, action_dim)`: Initialize the trajectory buffer. - - `collect_rollout(env, policy, obs, num_steps, on_step_callback)`: Collect interaction data. - - `update()`: Update the policy based on collected data. -- Designed to be algorithm-agnostic; Trainer only depends on this interface to support various RL algorithms. -- Supports multi-environment parallel collection, compatible with Gymnasium/IsaacGym environments. + - `update(rollout)`: Update the policy based on a shared rollout `TensorDict`. +- Designed to be algorithm-agnostic; `Trainer` handles collection while algorithms focus on loss computation and optimization. +- Consumes a shared `[N, T + 1]` rollout `TensorDict` and typically converts it to a transition-aligned view over the valid first `T` steps before optimization. ### PPO - Mainstream on-policy algorithm, supports Generalized Advantage Estimation (GAE), policy update, and hyperparameter configuration. - Key methods: - - `_compute_gae(rewards, values, dones)`: Generalized Advantage Estimation. - - `collect_rollout`: Collect trajectories and compute advantages/returns. - - `update`: Multi-epoch minibatch optimization, including entropy, value, and policy loss, with gradient clipping. + - `compute_gae(rollout, gamma, gae_lambda)`: Generalized Advantage Estimation over a shared rollout `TensorDict`, using `value[:, -1]` as the bootstrap value and ignoring the padded final transition slot. + - `update(rollout)`: Multi-epoch minibatch optimization, including entropy, value, and policy loss, with gradient clipping. - Supports custom callbacks, detailed logging, and GPU acceleration. - Typical training flow: collect rollout → compute advantage/return → multi-epoch minibatch optimization. - Supports advantage normalization, entropy regularization, value loss weighting, etc. @@ -31,8 +28,7 @@ This module contains the core implementations of reinforcement learning algorith - Key methods: - `_compute_step_returns_and_mask(rewards, dones)`: Step-wise discounted returns and valid-step mask. - `_compute_step_group_advantages(step_returns, seq_mask)`: Per-step group normalization with masked mean/std. - - `collect_rollout`: Collect trajectories and compute step-wise advantages. - - `update`: Multi-epoch minibatch optimization with optional KL penalty. + - `update(rollout)`: Multi-epoch minibatch optimization with optional KL penalty. - Supports both **Embodied AI** (dense reward, from-scratch training) and **VLA** (sparse reward, fine-tuning) modes via `kl_coef` configuration. ### Config Classes @@ -43,19 +39,11 @@ This module contains the core implementations of reinforcement learning algorith ## Code Example ```python class BaseAlgorithm: - def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim): - ... - def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None): - ... - def update(self): + def update(self, rollout): ... class PPO(BaseAlgorithm): - def _compute_gae(self, rewards, values, dones): - ... - def collect_rollout(self, ...): - ... - def update(self): + def update(self, rollout): ... ``` @@ -71,10 +59,9 @@ class PPO(BaseAlgorithm): - Typical usage: ```python algo = PPO(cfg, policy) -buffer = algo.initialize_buffer(...) -for _ in range(num_iterations): - algo.collect_rollout(...) - algo.update() +rollout = collector.collect(buffer_size, rollout=buffer.start_rollout()) +buffer.add(rollout) +algo.update(buffer.get(flatten=False)) ``` --- diff --git a/docs/source/overview/rl/buffer.md b/docs/source/overview/rl/buffer.md index 06a38640..955e11b6 100644 --- a/docs/source/overview/rl/buffer.md +++ b/docs/source/overview/rl/buffer.md @@ -5,61 +5,72 @@ This module implements the data buffer for RL training, responsible for storing ## Main Classes and Structure ### RolloutBuffer -- Used for on-policy algorithms (such as PPO, GRPO), efficiently stores observations, actions, rewards, dones, values, and logprobs for each step. -- Supports multi-environment parallelism (shape: [T, N, ...]), all data allocated on GPU. +- Used for on-policy algorithms (such as PPO, GRPO), storing a shared rollout `TensorDict` for collector and algorithm stages. +- Supports multi-environment parallelism with rollout batch shape `[N, T + 1]`, all data allocated on GPU. - Structure fields: - - `obs`: Observation tensor, float32, shape [T, N, obs_dim] - - `actions`: Action tensor, float32, shape [T, N, action_dim] - - `rewards`: Reward tensor, float32, shape [T, N] - - `dones`: Done flags, bool, shape [T, N] - - `values`: Value estimates, float32, shape [T, N] - - `logprobs`: Action log probabilities, float32, shape [T, N] - - `_extras`: Algorithm-specific fields (e.g., advantages, returns), dict[str, Tensor] + - `obs`: Flattened observation tensor, float32, shape `[N, T + 1, obs_dim]` + - `action`: Action tensor, float32, shape `[N, T + 1, action_dim]` + - `sample_log_prob`: Action log probabilities, float32, shape `[N, T + 1]` + - `value`: Value estimates, float32, shape `[N, T + 1]` + - `reward`: Reward tensor, float32, shape `[N, T + 1]` + - `done`: Done flags, bool, shape `[N, T + 1]` + - `terminated`: Termination flags, bool, shape `[N, T + 1]` + - `truncated`: Truncation flags, bool, shape `[N, T + 1]` + - Algorithm-added fields such as `advantage`, `return`, `seq_mask`, and `seq_return` + +The final time index is valid for `obs` and `value`, where it stores the last +observation and bootstrap value. For transition-only fields (`action`, `reward`, +`done`, etc.), the final slot is padding so all rollout fields can share the +same `[N, T + 1]` batch shape. ## Main Methods -- `add(obs, action, reward, done, value, logprob)`: Add one step of data. -- `set_extras(extras)`: Attach algorithm-related tensors (e.g., advantages, returns). -- `iterate_minibatches(batch_size)`: Randomly sample minibatches, returns dict (including all fields and extras). -- Supports efficient GPU shuffle and indexing for large-scale training. +- `start_rollout()`: Returns the shared preallocated rollout `TensorDict` for collector write-in. +- `add(rollout)`: Marks the shared rollout as ready for consumption. +- `get(flatten=True)`: Returns the stored rollout after converting it to a + transition view over the valid first `T` steps. +- `transition_view(rollout, flatten=False)`: Builds a transition-aligned view + that drops the padded final slot from transition-only fields. +- `iterate_minibatches(rollout, batch_size, device)`: Shared batching utility in `buffer/utils.py`. ## Usage Example ```python -buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, device) -for t in range(num_steps): - buffer.add(obs, action, reward, done, value, logprob) -buffer.set_extras({"advantages": adv, "returns": ret}) -for batch in buffer.iterate_minibatches(batch_size): - # batch["obs"], batch["actions"], batch["advantages"] ... +buffer = RolloutBuffer(num_envs, rollout_len, obs_dim, action_dim, device) +rollout = collector.collect(num_steps=rollout_len, rollout=buffer.start_rollout()) +buffer.add(rollout) + +rollout = buffer.get(flatten=False) +flat_rollout = transition_view(rollout, flatten=True) +for batch in iterate_minibatches(flat_rollout, batch_size, device): + # batch["obs"], batch["action"], batch["advantage"] ... pass ``` ## Design and Extension -- Supports multi-environment parallel collection, compatible with Gymnasium/IsaacGym environments. -- All data is allocated on GPU to avoid frequent CPU-GPU copying. -- The extras field can be flexibly extended to meet different algorithm needs (e.g., GAE, TD-lambda, distributional advantages). -- The iterator automatically shuffles to improve training stability. -- Compatible with various RL algorithms (PPO, GRPO, A2C, SAC, etc.), custom fields and sampling logic supported. +- Supports multi-environment parallel collection, compatible with Gymnasium-style vectorized environments. +- All tensors are preallocated on device to avoid frequent CPU-GPU copying. +- Algorithm-specific fields are attached directly onto the shared rollout `TensorDict` during optimization. +- The shared minibatch iterator automatically shuffles flattened transition entries for PPO/GRPO style updates. ## Code Example ```python class RolloutBuffer: - def __init__(self, num_steps, num_envs, obs_dim, action_dim, device): - # Initialize tensors + def __init__(self, num_envs, rollout_len, obs_dim, action_dim, device): + # Preallocate rollout TensorDict ... - def add(self, obs, action, reward, done, value, logprob): - # Add data + def start_rollout(self): + # Return shared rollout storage ... - def set_extras(self, extras): - # Attach algorithm-related tensors + def add(self, rollout): + # Mark rollout as full ... - def iterate_minibatches(self, batch_size): - # Random minibatch sampling + def get(self, flatten=True): + # Consume rollout ... ``` ## Practical Tips -- It is recommended to call set_extras after each rollout to ensure advantage/return tensors align with main data. -- When using iterate_minibatches, set batch_size appropriately for training stability. -- Extend the extras field as needed for custom sampling and statistics. +- The rollout buffer stores flattened RL observations; structured observations should be flattened or encoded before entering this buffer. +- `value[:, -1]` stores the bootstrap value of the final observation. The final slot of transition-only fields is padding and should be ignored during optimization. +- Use `transition_view()` plus `iterate_minibatches()` instead of duplicating rollout slicing logic in each algorithm. --- diff --git a/docs/source/overview/rl/models.md b/docs/source/overview/rl/models.md index c67c58ae..9c85cd9f 100644 --- a/docs/source/overview/rl/models.md +++ b/docs/source/overview/rl/models.md @@ -7,9 +7,10 @@ This module contains RL policy networks and related model implementations, suppo ### Policy - Abstract base class for RL policies; all policies must inherit from it. - Unified interface: - - `get_action(obs, deterministic=False)`: Sample or output actions. - - `get_value(obs)`: Estimate state value. - - `evaluate_actions(obs, actions)`: Evaluate action probabilities, entropy, and value. + - `get_action(tensordict, deterministic=False)`: Sample actions into a `TensorDict` without gradients. + - `forward(tensordict, deterministic=False)`: Low-level action/value write path used by policy implementations. + - `get_value(tensordict)`: Estimate state value into a `TensorDict`. + - `evaluate_actions(tensordict)`: Return optimization-time policy outputs from a `TensorDict`. - Supports GPU deployment and distributed training. ### ActorCritic @@ -19,8 +20,8 @@ This module contains RL policy networks and related model implementations, suppo - Actor-only policy without Critic. Used with GRPO (Group Relative Policy Optimization), which estimates advantages via group-level return comparison instead of a value function. - Supports Gaussian action distributions, learnable log_std, suitable for continuous action spaces. - Key methods: - - `get_action`: Actor network outputs mean, samples action, returns log_prob and critic value. - - `evaluate_actions`: Used for loss calculation in PPO/SAC algorithms. + - `forward`: Actor network outputs mean, samples action, and writes policy outputs into a `TensorDict`. + - `evaluate_actions`: Used for loss calculation in PPO/GRPO algorithms. - Custom actor/critic network architectures supported (e.g., MLP/CNN/Transformer). ### MLP @@ -36,8 +37,19 @@ This module contains RL policy networks and related model implementations, suppo ```python actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1) -policy = build_policy(policy_block, obs_space, action_space, device, actor=actor, critic=critic) -action, log_prob, value = policy.get_action(obs) +policy = build_policy( + policy_block, + env.flattened_observation_space, + env.action_space, + device, + actor=actor, + critic=critic, +) +step_td = TensorDict({"obs": obs}, batch_size=[obs.shape[0]], device=obs.device) +step_td = policy.get_action(step_td) +action = step_td["action"] +log_prob = step_td["sample_log_prob"] +value = step_td["value"] ``` ## Extension and Customization diff --git a/docs/source/overview/rl/trainer.md b/docs/source/overview/rl/trainer.md index 1a2b0fe3..5ef4ee99 100644 --- a/docs/source/overview/rl/trainer.md +++ b/docs/source/overview/rl/trainer.md @@ -20,7 +20,7 @@ This module implements the main RL training loop, logging management, and event- ## Main Methods - `train(total_timesteps)`: Main training loop, automatically collects data, updates policy, and logs. -- `_collect_rollout()`: Collect one rollout, supports custom callback statistics. +- `_collect_rollout()`: Collect one rollout through `SyncCollector`, supports custom callback statistics. - `_log_train(losses)`: Log training loss, reward, sampling speed, etc. - `_eval_once()`: Periodic evaluation, records evaluation metrics. - `save_checkpoint()`: Save model parameters and training state. @@ -35,7 +35,7 @@ This module implements the main RL training loop, logging management, and event- ## Usage Example ```python -trainer = Trainer(policy, env, algorithm, num_steps, batch_size, writer, ...) +trainer = Trainer(policy, env, algorithm, buffer_size, batch_size, writer, ...) trainer.train(total_steps) trainer.save_checkpoint() ``` @@ -44,6 +44,7 @@ trainer.save_checkpoint() - Custom event modules can be implemented for environment reset, data collection, evaluation, etc. - Supports multi-environment parallelism and distributed training. - Training process can be flexibly adjusted via config files. +- The current trainer uses a shared rollout `TensorDict` with uniform `[N, T + 1]` layout: collector writes policy-side fields, `EmbodiedEnv` writes environment-side step fields through `set_rollout_buffer()`, and algorithms consume the valid first `T` steps through `transition_view()`. ## Practical Tips - It is recommended to perform periodic evaluation and model saving to prevent loss of progress during training. diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index b0126dde..28054648 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -57,14 +57,14 @@ Configuration Sections Runtime Settings ^^^^^^^^^^^^^^^^ -The ``runtime`` section controls experiment setup: +The ``trainer`` section controls experiment setup: - **exp_name**: Experiment name (used for output directories) - **seed**: Random seed for reproducibility -- **cuda**: Whether to use GPU (default: true) +- **device**: Runtime device string, e.g. ``"cpu"`` or ``"cuda:0"`` - **headless**: Whether to run simulation in headless mode - **iterations**: Number of training iterations -- **rollout_steps**: Steps per rollout (e.g., 1024) +- **buffer_size**: Steps collected per rollout (e.g., 1024) - **eval_freq**: Frequency of evaluation (in steps) - **save_freq**: Frequency of checkpoint saving (in steps) - **use_wandb**: Whether to enable Weights & Biases logging (set in JSON config) @@ -109,7 +109,7 @@ Policy Configuration The ``policy`` section defines the neural network policy: - **name**: Policy name (e.g., "actor_critic", "vla") -- **cfg**: Policy-specific hyperparameters (empty for actor_critic) +- **action_dim**: Optional policy output action dimension. If omitted, it is inferred from ``env.action_space``. - **actor**: Actor network configuration (required for actor_critic) - **critic**: Critic network configuration (required for actor_critic) @@ -119,16 +119,19 @@ Example: "policy": { "name": "actor_critic", - "cfg": {}, "actor": { "type": "mlp", - "hidden_sizes": [256, 256], - "activation": "relu" + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } }, "critic": { "type": "mlp", - "hidden_sizes": [256, 256], - "activation": "relu" + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } } } @@ -231,9 +234,9 @@ Training Process The training process follows this sequence: -1. **Rollout Phase**: Algorithm collects trajectories by interacting with the environment (via ``collect_rollout``). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info. -2. **Advantage/Return Computation**: Algorithm computes advantages and returns (e.g. GAE for PPO, step-wise group normalization for GRPO; stored in buffer extras) -3. **Update Phase**: Algorithm updates the policy using collected data (e.g., PPO) +1. **Rollout Phase**: ``SyncCollector`` interacts with the environment and writes policy-side fields into a shared rollout ``TensorDict`` with uniform ``[N, T + 1]`` layout. ``EmbodiedEnv`` writes environment-side step fields such as ``reward``, ``done``, ``terminated``, and ``truncated`` into the same rollout via ``set_rollout_buffer()``. The final slot of transition-only fields is reserved as padding, while ``obs[:, -1]`` and ``value[:, -1]`` remain valid bootstrap data. +2. **Advantage/Return Computation**: Algorithm computes advantages and returns from the collected rollout (e.g. GAE for PPO, step-wise group normalization for GRPO) and converts it to a transition-aligned view over the valid first ``T`` steps before minibatch optimization. +3. **Update Phase**: Algorithm updates the policy with ``update(rollout)`` 4. **Logging**: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases 5. **Evaluation** (periodic): Trainer evaluates the current policy 6. **Checkpointing** (periodic): Trainer saves model checkpoints @@ -251,23 +254,23 @@ All policies must inherit from the ``Policy`` abstract base class: class Policy(nn.Module, ABC): device: torch.device + def get_action(self, tensordict, deterministic: bool = False): + """Samples action, sample_log_prob, and value into the TensorDict.""" + ... + @abstractmethod - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Returns (action, log_prob, value)""" + def forward(self, tensordict, deterministic: bool = False): + """Writes action, sample_log_prob, and value into the TensorDict.""" raise NotImplementedError @abstractmethod - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - """Returns value estimate""" + def get_value(self, tensordict): + """Writes value estimate into the TensorDict.""" raise NotImplementedError @abstractmethod - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Returns (log_prob, entropy, value)""" + def evaluate_actions(self, tensordict): + """Returns a new TensorDict with log_prob, entropy, and value.""" raise NotImplementedError Available Policies @@ -292,13 +295,13 @@ Adding a New Algorithm To add a new algorithm: 1. Create a new algorithm class in ``embodichain/agents/rl/algo/`` -2. Implement ``initialize_buffer()``, ``collect_rollout()``, and ``update()`` methods +2. Implement ``update(rollout)`` and consume the shared rollout ``TensorDict`` 3. Register in ``algo/__init__.py``: .. code-block:: python + from tensordict import TensorDict from embodichain.agents.rl.algo import BaseAlgorithm, register_algo - from embodichain.agents.rl.buffer import RolloutBuffer @register_algo("my_algo") class MyAlgorithm(BaseAlgorithm): @@ -306,24 +309,11 @@ To add a new algorithm: self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) - self.buffer = None - def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim): - """Initialize the algorithm's buffer.""" - self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device) - - def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None): - """Control data collection process (interact with env, fill buffer, compute advantages/returns).""" - # Collect trajectories - # Compute advantages/returns (e.g., GAE for on-policy algorithms) - # Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret}) - # Return empty dict (dense logging handled in trainer) - return {} - - def update(self): - """Update the policy using collected data.""" - # Access extras from buffer: self.buffer._extras.get("advantages") - # Use self.buffer to update policy + def update(self, rollout: TensorDict): + """Update the policy using a collected rollout.""" + # compute advantages / returns from rollout + # optimize policy parameters return {"loss": 0.0} Adding a New Policy @@ -340,17 +330,21 @@ To add a new policy: @register_policy("my_policy") class MyPolicy(Policy): - def __init__(self, obs_space, action_space, device, config): + def __init__(self, obs_dim, action_dim, device, config): super().__init__() self.device = device # Initialize your networks here - def get_action(self, obs, deterministic=False): + def get_action(self, tensordict, deterministic=False): ... - def get_value(self, obs): + def forward(self, tensordict, deterministic=False): ... - def evaluate_actions(self, obs, actions): + def get_value(self, tensordict): ... + def evaluate_actions(self, tensordict): + ... + +Current built-in MLP policies use flattened observations in the training path. If your policy requires structured or multi-modal inputs, keep the richer ``obs_space`` interface and define a matching rollout/collector schema. Adding a New Environment ------------------------ @@ -414,7 +408,7 @@ Best Practices - **Observation Format**: Environments should provide consistent observation shape/types (torch.float32) and a single ``done = terminated | truncated``. -- **Algorithm Interface**: Algorithms must implement ``initialize_buffer()``, ``collect_rollout()``, and ``update()`` methods. The algorithm completely controls data collection and buffer management. +- **Algorithm Interface**: Algorithms implement ``update(rollout)`` and consume a shared rollout ``TensorDict``. Collection is handled by ``SyncCollector`` plus environment-side rollout writes in ``EmbodiedEnv``. - **Reward Configuration**: Use the ``RewardManager`` in your environment config to define reward components. Organize reward components in ``info["rewards"]`` dictionary and metrics in ``info["metrics"]`` dictionary. The trainer performs dense per-step logging directly from environment info. diff --git a/embodichain/agents/rl/algo/__init__.py b/embodichain/agents/rl/algo/__init__.py index 4b69879a..0aefde03 100644 --- a/embodichain/agents/rl/algo/__init__.py +++ b/embodichain/agents/rl/algo/__init__.py @@ -20,6 +20,7 @@ import torch from .base import BaseAlgorithm +from .common import compute_gae from .ppo import PPOCfg, PPO from .grpo import GRPOCfg, GRPO @@ -51,6 +52,7 @@ def build_algo(name: str, cfg_kwargs: Dict[str, float], policy, device: torch.de "PPO", "GRPOCfg", "GRPO", + "compute_gae", "get_registered_algo_names", "build_algo", ] diff --git a/embodichain/agents/rl/algo/base.py b/embodichain/agents/rl/algo/base.py index fcb3fc00..b1516472 100644 --- a/embodichain/agents/rl/algo/base.py +++ b/embodichain/agents/rl/algo/base.py @@ -16,37 +16,19 @@ from __future__ import annotations -from typing import Dict, Any, Callable +from typing import Dict import torch +from tensordict import TensorDict class BaseAlgorithm: """Base class for RL algorithms. - Algorithms must implement buffer initialization, rollout collection, and - policy update. Trainer depends only on this interface to remain - algorithm-agnostic. + Algorithms only implement policy updates over collected rollouts. """ device: torch.device - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ) -> None: - """Initialize internal buffer(s) required by the algorithm.""" - raise NotImplementedError - - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect trajectories and return logging info (e.g., reward components).""" - raise NotImplementedError - - def update(self) -> Dict[str, float]: + def update(self, rollout: TensorDict) -> Dict[str, float]: """Update policy using collected data and return training losses.""" raise NotImplementedError diff --git a/embodichain/agents/rl/algo/common.py b/embodichain/agents/rl/algo/common.py new file mode 100644 index 00000000..9dcf2f78 --- /dev/null +++ b/embodichain/agents/rl/algo/common.py @@ -0,0 +1,68 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +__all__ = ["compute_gae"] + + +def compute_gae( + rollout: TensorDict, gamma: float, gae_lambda: float +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute GAE over a rollout stored as `[num_envs, time + 1]`. + + Args: + rollout: Rollout TensorDict where `value[:, -1]` stores the bootstrap + value for the final observation and transition-only fields reserve + their last slot as padding. + gamma: Discount factor. + gae_lambda: GAE lambda coefficient. + + Returns: + Tuple of `(advantages, returns)`, both shaped `[num_envs, time]`. + """ + rewards = rollout["reward"][:, :-1].float() + dones = rollout["done"][:, :-1].bool() + values = rollout["value"].float() + + if rewards.ndim != 2: + raise ValueError( + f"Expected reward tensor with shape [num_envs, time], got {rewards.shape}." + ) + + num_envs, time_dim = rewards.shape + if values.shape != (num_envs, time_dim + 1): + raise ValueError( + "Expected value tensor with shape [num_envs, time + 1], got " + f"{values.shape} for rewards shape {rewards.shape}." + ) + advantages = torch.zeros_like(rollout["reward"].float()) + last_advantage = torch.zeros(num_envs, device=rewards.device, dtype=rewards.dtype) + + for t in reversed(range(time_dim)): + not_done = (~dones[:, t]).float() + delta = rewards[:, t] + gamma * values[:, t + 1] * not_done - values[:, t] + last_advantage = delta + gamma * gae_lambda * not_done * last_advantage + advantages[:, t] = last_advantage + + returns = torch.zeros_like(advantages) + returns[:, :-1] = advantages[:, :-1] + values[:, :-1] + rollout["advantage"] = advantages + rollout["return"] = returns + return advantages[:, :-1], returns[:, :-1] diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 4654ed54..12f7c32f 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -17,13 +17,13 @@ from __future__ import annotations from copy import deepcopy -from typing import Any, Callable, Dict +from typing import Dict import torch from tensordict import TensorDict -from embodichain.agents.rl.buffer import RolloutBuffer -from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation +from embodichain.agents.rl.buffer import iterate_minibatches, transition_view +from embodichain.agents.rl.utils import AlgorithmCfg from embodichain.utils import configclass from .base import BaseAlgorithm @@ -38,15 +38,12 @@ class GRPOCfg(AlgorithmCfg): kl_coef: float = 0.02 group_size: int = 4 eps: float = 1e-8 - # Collect fresh groups every rollout instead of continuing from prior states. reset_every_rollout: bool = True - # If True, do not optimize steps after the first done in each environment - # during a rollout. This better matches "one completion per prompt". truncate_at_first_done: bool = True class GRPO(BaseAlgorithm): - """Group Relative Policy Optimization on top of RolloutBuffer.""" + """Group Relative Policy Optimization on top of TensorDict rollouts.""" def __init__(self, cfg: GRPOCfg, policy): if cfg.group_size < 2: @@ -57,9 +54,6 @@ def __init__(self, cfg: GRPOCfg, policy): self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) - self.buffer: RolloutBuffer | None = None - # Only create ref_policy when kl_coef > 0 (e.g. VLA fine-tuning). - # For from-scratch training (CartPole etc.), kl_coef=0 avoids the "tight band" problem. if self.cfg.kl_coef > 0.0: self.ref_policy = deepcopy(policy).to(self.device).eval() for param in self.ref_policy.parameters(): @@ -67,142 +61,78 @@ def __init__(self, cfg: GRPOCfg, policy): else: self.ref_policy = None - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ) -> None: - if num_envs % self.cfg.group_size != 0: - raise ValueError( - f"GRPO requires num_envs divisible by group_size, got " - f"num_envs={num_envs}, group_size={self.cfg.group_size}." - ) - self.buffer = RolloutBuffer( - num_steps, num_envs, obs_dim, action_dim, self.device - ) - def _compute_step_returns_and_mask( self, rewards: torch.Tensor, dones: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute step-wise discounted returns R_t = r_t + gamma * R_{t+1} and mask. - - Solves causal + discount bias: each step's return only depends on future rewards. - Returns: - step_returns: shape [T, N], discounted return from step t onward. - seq_mask: shape [T, N], 1 for valid steps, 0 after first done (if truncate). - """ - t_steps, n_envs = rewards.shape + """Compute discounted returns and valid-step mask over `[N, T]` rollout.""" + n_envs, t_steps = rewards.shape seq_mask = torch.ones( - (t_steps, n_envs), dtype=torch.float32, device=self.device + (n_envs, t_steps), dtype=torch.float32, device=self.device ) step_returns = torch.zeros( - (t_steps, n_envs), dtype=torch.float32, device=self.device + (n_envs, t_steps), dtype=torch.float32, device=self.device ) alive = torch.ones(n_envs, dtype=torch.float32, device=self.device) for t in range(t_steps): - seq_mask[t] = alive + seq_mask[:, t] = alive if self.cfg.truncate_at_first_done: - alive = alive * (~dones[t]).float() + alive = alive * (~dones[:, t]).float() running_return = torch.zeros(n_envs, dtype=torch.float32, device=self.device) for t in reversed(range(t_steps)): running_return = ( - rewards[t] + self.cfg.gamma * running_return * (~dones[t]).float() + rewards[:, t] + self.cfg.gamma * running_return * (~dones[:, t]).float() ) - step_returns[t] = running_return + step_returns[:, t] = running_return return step_returns, seq_mask def _compute_step_group_advantages( self, step_returns: torch.Tensor, seq_mask: torch.Tensor ) -> torch.Tensor: - """Per-step group normalization with masked mean/std for variable-length sequences. - - When group members have different survival lengths, only compare against - peers still alive at that step (avoids dead envs' zeros dragging down the mean). - """ - t_steps, n_envs = step_returns.shape + """Normalize per-step returns within each environment group.""" + n_envs, t_steps = step_returns.shape group_size = self.cfg.group_size - returns_grouped = step_returns.view(t_steps, n_envs // group_size, group_size) - mask_grouped = seq_mask.view(t_steps, n_envs // group_size, group_size) + returns_grouped = step_returns.view(n_envs // group_size, group_size, t_steps) + mask_grouped = seq_mask.view(n_envs // group_size, group_size, t_steps) - valid_count = mask_grouped.sum(dim=2, keepdim=True) + valid_count = mask_grouped.sum(dim=1, keepdim=True) valid_count_safe = torch.clamp(valid_count, min=1.0) group_mean = (returns_grouped * mask_grouped).sum( - dim=2, keepdim=True + dim=1, keepdim=True ) / valid_count_safe diff_sq = ((returns_grouped - group_mean) ** 2) * mask_grouped - group_var = diff_sq.sum(dim=2, keepdim=True) / valid_count_safe + group_var = diff_sq.sum(dim=1, keepdim=True) / valid_count_safe group_std = torch.sqrt(group_var) - adv = (returns_grouped - group_mean) / (group_std + self.cfg.eps) - adv = adv.view(t_steps, n_envs) * seq_mask - return adv - - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - if self.buffer is None: - raise RuntimeError( - "Buffer not initialized. Call initialize_buffer() first." - ) - - policy.train() - self.buffer.step = 0 - current_obs = obs + advantages = (returns_grouped - group_mean) / (group_std + self.cfg.eps) + return advantages.view(n_envs, t_steps) * seq_mask - if self.cfg.reset_every_rollout: - current_obs, _ = env.reset() - if isinstance(current_obs, TensorDict): - current_obs = flatten_dict_observation(current_obs) - - for _ in range(num_steps): - actions, log_prob, _ = policy.get_action(current_obs, deterministic=False) - am = getattr(env, "action_manager", None) - action_type = ( - am.action_type if am else getattr(env, "action_type", "delta_qpos") - ) - action_dict = {action_type: actions} - next_obs, reward, terminated, truncated, env_info = env.step(action_dict) - done = (terminated | truncated).bool() - reward = reward.float() - - if isinstance(next_obs, TensorDict): - next_obs = flatten_dict_observation(next_obs) - - # GRPO does not use value function targets; store zeros in value slot. - value_placeholder = torch.zeros_like(reward) - self.buffer.add( - current_obs, actions, reward, done, value_placeholder, log_prob + def update(self, rollout: TensorDict) -> Dict[str, float]: + rollout = rollout.clone() + num_envs = rollout.batch_size[0] + if num_envs % self.cfg.group_size != 0: + raise ValueError( + f"GRPO requires num_envs divisible by group_size, got " + f"num_envs={num_envs}, group_size={self.cfg.group_size}." ) - if on_step_callback is not None: - on_step_callback(current_obs, actions, reward, done, env_info, next_obs) - current_obs = next_obs - - step_returns, seq_mask = self._compute_step_returns_and_mask( - self.buffer.rewards, self.buffer.dones + rewards = rollout["reward"][:, :-1].float() + dones = rollout["done"][:, :-1].bool() + step_returns, seq_mask = self._compute_step_returns_and_mask(rewards, dones) + rollout["advantage"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) + rollout["advantage"][:, :-1] = self._compute_step_group_advantages( + step_returns, seq_mask ) - advantages = self._compute_step_group_advantages(step_returns, seq_mask) + rollout["seq_mask"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) + rollout["seq_mask"][:, :-1] = seq_mask + rollout["seq_return"] = torch.zeros_like(rollout["reward"], dtype=torch.float32) + rollout["seq_return"][:, :-1] = step_returns - self.buffer.set_extras( - { - "advantages": advantages, - "seq_mask": seq_mask, - "seq_return": step_returns, - } - ) - return {} - - def update(self) -> Dict[str, float]: - if self.buffer is None: - raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") + flat_rollout = transition_view(rollout, flatten=True) total_actor_loss = 0.0 total_entropy = 0.0 @@ -210,14 +140,16 @@ def update(self) -> Dict[str, float]: total_weight = 0.0 for _ in range(self.cfg.n_epochs): - for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): - obs = batch["obs"] - actions = batch["actions"] - old_logprobs = batch["logprobs"] - advantages = batch["advantages"].detach() - seq_mask = batch["seq_mask"].float() - - logprobs, entropy, _ = self.policy.evaluate_actions(obs, actions) + for batch in iterate_minibatches( + flat_rollout, self.cfg.batch_size, self.device + ): + old_logprobs = batch["sample_log_prob"].clone() + advantages = batch["advantage"].detach() + seq_mask_batch = batch["seq_mask"].float() + + eval_batch = self.policy.evaluate_actions(batch) + logprobs = eval_batch["sample_log_prob"] + entropy = eval_batch["entropy"] ratio = (logprobs - old_logprobs).exp() surr1 = ratio * advantages surr2 = ( @@ -226,20 +158,19 @@ def update(self) -> Dict[str, float]: ) * advantages ) - actor_num = -(torch.min(surr1, surr2) * seq_mask).sum() - denom = torch.clamp(seq_mask.sum(), min=1.0) + actor_num = -(torch.min(surr1, surr2) * seq_mask_batch).sum() + denom = torch.clamp(seq_mask_batch.sum(), min=1.0) actor_loss = actor_num / denom - entropy_loss = -(entropy * seq_mask).sum() / denom + entropy_loss = -(entropy * seq_mask_batch).sum() / denom if self.ref_policy is not None: with torch.no_grad(): - ref_logprobs, _, _ = self.ref_policy.evaluate_actions( - obs, actions - ) + ref_batch = self.ref_policy.evaluate_actions(batch) + ref_logprobs = ref_batch["sample_log_prob"] log_ref_over_pi = ref_logprobs - logprobs kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0 - kl = (kl_per * seq_mask).sum() / denom + kl = (kl_per * seq_mask_batch).sum() / denom else: kl = torch.tensor(0.0, device=self.device) @@ -258,7 +189,7 @@ def update(self) -> Dict[str, float]: weight = float(denom.item()) total_actor_loss += actor_loss.item() * weight - masked_entropy = (entropy * seq_mask).sum() / denom + masked_entropy = (entropy * seq_mask_batch).sum() / denom total_entropy += masked_entropy.item() * weight total_kl += kl.item() * weight total_weight += weight diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index b1256ce0..e33ee5b3 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -14,14 +14,15 @@ # limitations under the License. # ---------------------------------------------------------------------------- -import torch -from typing import Dict, Any, Tuple, Callable +from typing import Dict +import torch from tensordict import TensorDict -from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation -from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.agents.rl.buffer import iterate_minibatches, transition_view +from embodichain.agents.rl.utils import AlgorithmCfg from embodichain.utils import configclass +from .common import compute_gae from .base import BaseAlgorithm @@ -36,112 +37,24 @@ class PPOCfg(AlgorithmCfg): class PPO(BaseAlgorithm): - """PPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design).""" + """PPO algorithm consuming TensorDict rollouts.""" def __init__(self, cfg: PPOCfg, policy): self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) - self.buffer: RolloutBuffer | None = None # no per-rollout aggregation for dense logging - def _compute_gae( - self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Internal method to compute GAE. Only called by collect_rollout.""" - T, N = rewards.shape - advantages = torch.zeros_like(rewards, device=self.device) - last_adv = torch.zeros(N, device=self.device) - for t in reversed(range(T)): - next_value = values[t + 1] if t < T - 1 else torch.zeros_like(values[0]) - not_done = (~dones[t]).float() - delta = rewards[t] + self.cfg.gamma * next_value * not_done - values[t] - last_adv = ( - delta + self.cfg.gamma * self.cfg.gae_lambda * not_done * last_adv - ) - advantages[t] = last_adv - returns = advantages + values - return advantages, returns - - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ): - """Initialize the rollout buffer. Called by trainer before first rollout.""" - self.buffer = RolloutBuffer( - num_steps, num_envs, obs_dim, action_dim, self.device - ) - - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect a rollout. Algorithm controls the data collection process.""" - if self.buffer is None: - raise RuntimeError( - "Buffer not initialized. Call initialize_buffer() first." - ) - - policy.train() - self.buffer.step = 0 - current_obs = obs - - for t in range(num_steps): - # Get action from policy - actions, log_prob, value = policy.get_action( - current_obs, deterministic=False - ) - - # Wrap action as dict for env processing - am = getattr(env, "action_manager", None) - action_type = ( - am.action_type if am else getattr(env, "action_type", "delta_qpos") - ) - action_dict = {action_type: actions} - - # Step environment - result = env.step(action_dict) - next_obs, reward, terminated, truncated, env_info = result - done = terminated | truncated - # Light dtype normalization - reward = reward.float() - done = done.bool() - - # Flatten TensorDict observation from ObservationManager if needed - if isinstance(next_obs, TensorDict): - next_obs = flatten_dict_observation(next_obs) - - # Add to buffer - self.buffer.add(current_obs, actions, reward, done, value, log_prob) - - # Dense logging is handled in Trainer.on_step via info; no aggregation here - # Call callback for statistics and logging - if on_step_callback is not None: - on_step_callback(current_obs, actions, reward, done, env_info, next_obs) - - current_obs = next_obs - - # Compute advantages/returns and attach to buffer extras - adv, ret = self._compute_gae( - self.buffer.rewards, self.buffer.values, self.buffer.dones - ) - self.buffer.set_extras({"advantages": adv, "returns": ret}) - - # No aggregated logging results; Trainer performs dense per-step logging - return {} - - def update(self) -> dict: - """Update the policy using the collected rollout buffer.""" - if self.buffer is None: - raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") - - # Normalize advantages (optional, common default) - adv = self.buffer._extras.get("advantages") - adv = (adv - adv.mean()) / (adv.std() + 1e-8) + def update(self, rollout: TensorDict) -> Dict[str, float]: + """Update the policy using a collected rollout.""" + rollout = rollout.clone() + compute_gae(rollout, gamma=self.cfg.gamma, gae_lambda=self.cfg.gae_lambda) + flat_rollout = transition_view(rollout, flatten=True) + + advantages = flat_rollout["advantage"] + adv_mean = advantages.mean() + adv_std = advantages.std().clamp_min(1e-8) total_actor_loss = 0.0 total_value_loss = 0.0 @@ -149,23 +62,24 @@ def update(self) -> dict: total_steps = 0 for _ in range(self.cfg.n_epochs): - for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): - obs = batch["obs"] - actions = batch["actions"] - old_logprobs = batch["logprobs"] - returns = batch["returns"] - advantages = ( - (batch["advantages"] - adv.mean()) / (adv.std() + 1e-8) - ).detach() - - logprobs, entropy, values = self.policy.evaluate_actions(obs, actions) + for batch in iterate_minibatches( + flat_rollout, self.cfg.batch_size, self.device + ): + old_logprobs = batch["sample_log_prob"].clone() + returns = batch["return"].clone() + batch_advantages = ((batch["advantage"] - adv_mean) / adv_std).detach() + + eval_batch = self.policy.evaluate_actions(batch) + logprobs = eval_batch["sample_log_prob"] + entropy = eval_batch["entropy"] + values = eval_batch["value"] ratio = (logprobs - old_logprobs).exp() - surr1 = ratio * advantages + surr1 = ratio * batch_advantages surr2 = ( torch.clamp( ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef ) - * advantages + * batch_advantages ) actor_loss = -torch.min(surr1, surr2).mean() value_loss = torch.nn.functional.mse_loss(values, returns) @@ -184,7 +98,7 @@ def update(self) -> dict: ) self.optimizer.step() - bs = obs.shape[0] + bs = batch.batch_size[0] total_actor_loss += actor_loss.item() * bs total_value_loss += value_loss.item() * bs total_entropy += (-entropy_loss.item()) * bs diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py index 5080d251..a83c0776 100644 --- a/embodichain/agents/rl/buffer/__init__.py +++ b/embodichain/agents/rl/buffer/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .rollout_buffer import RolloutBuffer +from .standard_buffer import RolloutBuffer +from .utils import iterate_minibatches, transition_view -__all__ = ["RolloutBuffer"] +__all__ = ["RolloutBuffer", "iterate_minibatches", "transition_view"] diff --git a/embodichain/agents/rl/buffer/rollout_buffer.py b/embodichain/agents/rl/buffer/rollout_buffer.py deleted file mode 100644 index cbd66f4e..00000000 --- a/embodichain/agents/rl/buffer/rollout_buffer.py +++ /dev/null @@ -1,106 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Dict, Iterator - -import torch - - -class RolloutBuffer: - """On-device rollout buffer for on-policy algorithms. - - Stores (obs, actions, rewards, dones, values, logprobs) over time. - After finalize(), exposes advantages/returns and minibatch iteration. - """ - - def __init__( - self, - num_steps: int, - num_envs: int, - obs_dim: int, - action_dim: int, - device: torch.device, - ): - self.num_steps = num_steps - self.num_envs = num_envs - self.obs_dim = obs_dim - self.action_dim = action_dim - self.device = device - - T, N = num_steps, num_envs - self.obs = torch.zeros(T, N, obs_dim, dtype=torch.float32, device=device) - self.actions = torch.zeros(T, N, action_dim, dtype=torch.float32, device=device) - self.rewards = torch.zeros(T, N, dtype=torch.float32, device=device) - self.dones = torch.zeros(T, N, dtype=torch.bool, device=device) - self.values = torch.zeros(T, N, dtype=torch.float32, device=device) - self.logprobs = torch.zeros(T, N, dtype=torch.float32, device=device) - - self.step = 0 - # Container for algorithm-specific extra fields (e.g., advantages, returns) - self._extras: dict[str, torch.Tensor] = {} - - def add( - self, - obs: torch.Tensor, - action: torch.Tensor, - reward: torch.Tensor, - done: torch.Tensor, - value: torch.Tensor, - logprob: torch.Tensor, - ) -> None: - t = self.step - self.obs[t].copy_(obs) - self.actions[t].copy_(action) - self.rewards[t].copy_(reward) - self.dones[t].copy_(done) - self.values[t].copy_(value) - self.logprobs[t].copy_(logprob) - self.step += 1 - - def set_extras(self, extras: dict[str, torch.Tensor]) -> None: - """Attach algorithm-specific tensors (shape [T, N, ...]) for batching. - - Examples: - {"advantages": adv, "returns": ret} - """ - self._extras = extras or {} - - def iterate_minibatches(self, batch_size: int) -> Iterator[Dict[str, torch.Tensor]]: - T, N = self.num_steps, self.num_envs - total = T * N - indices = torch.randperm(total, device=self.device) - for start in range(0, total, batch_size): - idx = indices[start : start + batch_size] - t_idx = idx // N - n_idx = idx % N - batch = { - "obs": self.obs[t_idx, n_idx], - "actions": self.actions[t_idx, n_idx], - "rewards": self.rewards[t_idx, n_idx], - "dones": self.dones[t_idx, n_idx], - "values": self.values[t_idx, n_idx], - "logprobs": self.logprobs[t_idx, n_idx], - } - # Slice extras if present and shape aligned to [T, N, ...] - for name, tensor in self._extras.items(): - try: - batch[name] = tensor[t_idx, n_idx] - except Exception: - # Skip misaligned extras silently - continue - yield batch diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py new file mode 100644 index 00000000..2df69f86 --- /dev/null +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -0,0 +1,188 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +from .utils import transition_view + +__all__ = ["RolloutBuffer"] + + +class RolloutBuffer: + """Single-rollout buffer backed by a preallocated TensorDict. + + The shared rollout uses a uniform `[num_envs, time + 1]` layout. For + transition-only fields such as `action`, `reward`, and `done`, the final + time index is reused as padding so the collector, environment, and + algorithms can share a single TensorDict batch shape. + """ + + def __init__( + self, + num_envs: int, + rollout_len: int, + obs_dim: int, + action_dim: int, + device: torch.device, + ) -> None: + self.num_envs = num_envs + self.rollout_len = rollout_len + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + self._rollout = self._allocate_rollout() + self._is_full = False + + def start_rollout(self) -> TensorDict: + """Return the shared rollout TensorDict for collector write-in.""" + if self._is_full: + raise RuntimeError("RolloutBuffer already contains a rollout.") + self._clear_dynamic_fields() + return self._rollout + + def add(self, rollout: TensorDict) -> None: + """Mark the shared rollout as ready for consumption.""" + if rollout is not self._rollout: + raise ValueError( + "RolloutBuffer only accepts its shared rollout TensorDict." + ) + if tuple(rollout.batch_size) != (self.num_envs, self.rollout_len + 1): + raise ValueError( + "Rollout batch size does not match buffer allocation: " + f"expected ({self.num_envs}, {self.rollout_len + 1}), got {tuple(rollout.batch_size)}." + ) + self._validate_rollout_layout(rollout) + self._is_full = True + + def get(self, flatten: bool = True) -> TensorDict: + """Return the stored rollout and clear the buffer. + + When `flatten` is True, the rollout is first converted to a transition + view that drops the padded final slot from transition-only fields. + """ + if not self._is_full: + raise RuntimeError("RolloutBuffer is empty.") + + rollout = self._rollout + self._is_full = False + + if not flatten: + return rollout + + return transition_view(rollout, flatten=True) + + def is_full(self) -> bool: + """Return whether a rollout is waiting to be consumed.""" + return self._is_full + + def _allocate_rollout(self) -> TensorDict: + """Preallocate rollout storage with uniform `[num_envs, time + 1]` shape.""" + return TensorDict( + { + "obs": torch.empty( + self.num_envs, + self.rollout_len + 1, + self.obs_dim, + dtype=torch.float32, + device=self.device, + ), + "action": torch.empty( + self.num_envs, + self.rollout_len + 1, + self.action_dim, + dtype=torch.float32, + device=self.device, + ), + "sample_log_prob": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.float32, + device=self.device, + ), + "value": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.float32, + device=self.device, + ), + "reward": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.float32, + device=self.device, + ), + "done": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.bool, + device=self.device, + ), + "terminated": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.bool, + device=self.device, + ), + "truncated": torch.empty( + self.num_envs, + self.rollout_len + 1, + dtype=torch.bool, + device=self.device, + ), + }, + batch_size=[self.num_envs, self.rollout_len + 1], + device=self.device, + ) + + def _clear_dynamic_fields(self) -> None: + """Drop algorithm-added fields before reusing the shared rollout.""" + for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): + if key in self._rollout.keys(): + del self._rollout[key] + self._reset_padding_slot() + + def _reset_padding_slot(self) -> None: + """Reset the last transition-only slot reused as padding.""" + last_idx = self.rollout_len + self._rollout["action"][:, last_idx].zero_() + self._rollout["sample_log_prob"][:, last_idx].zero_() + self._rollout["reward"][:, last_idx].zero_() + self._rollout["done"][:, last_idx].fill_(False) + self._rollout["terminated"][:, last_idx].fill_(False) + self._rollout["truncated"][:, last_idx].fill_(False) + + def _validate_rollout_layout(self, rollout: TensorDict) -> None: + """Validate the expected tensor shapes for the shared rollout.""" + expected_shapes = { + "obs": (self.num_envs, self.rollout_len + 1, self.obs_dim), + "action": (self.num_envs, self.rollout_len + 1, self.action_dim), + "sample_log_prob": (self.num_envs, self.rollout_len + 1), + "value": (self.num_envs, self.rollout_len + 1), + "reward": (self.num_envs, self.rollout_len + 1), + "done": (self.num_envs, self.rollout_len + 1), + "terminated": (self.num_envs, self.rollout_len + 1), + "truncated": (self.num_envs, self.rollout_len + 1), + } + for key, expected_shape in expected_shapes.items(): + actual_shape = tuple(rollout[key].shape) + if actual_shape != expected_shape: + raise ValueError( + f"Rollout field '{key}' shape mismatch: expected {expected_shape}, " + f"got {actual_shape}." + ) diff --git a/embodichain/agents/rl/buffer/utils.py b/embodichain/agents/rl/buffer/utils.py new file mode 100644 index 00000000..7c0d265b --- /dev/null +++ b/embodichain/agents/rl/buffer/utils.py @@ -0,0 +1,77 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from collections.abc import Iterator + +import torch +from tensordict import TensorDict + +__all__ = ["iterate_minibatches", "transition_view"] + + +def transition_view(rollout: TensorDict, flatten: bool = False) -> TensorDict: + """Build a transition-aligned TensorDict from a rollout. + + The shared rollout uses a uniform `[num_envs, time + 1]` layout. For + transition-only fields such as `action`, `reward`, and `done`, the final + slot is reserved as padding so that all rollout fields share the same batch + shape. This helper drops that padded slot and exposes the valid transition + slices as a TensorDict with batch shape `[num_envs, time]`. + + Args: + rollout: Rollout TensorDict with root batch shape `[num_envs, time + 1]`. + flatten: If True, return a flattened `[num_envs * time]` view. + + Returns: + TensorDict containing transition-aligned fields. + """ + action = rollout["action"][:, :-1] + num_envs, time_dim = action.shape[:2] + td = TensorDict( + { + "obs": rollout["obs"][:, :-1], + "action": action, + "sample_log_prob": rollout["sample_log_prob"][:, :-1], + "value": rollout["value"][:, :-1], + "next_value": rollout["value"][:, 1:], + "reward": rollout["reward"][:, :-1], + "done": rollout["done"][:, :-1], + "terminated": rollout["terminated"][:, :-1], + "truncated": rollout["truncated"][:, :-1], + }, + batch_size=[num_envs, time_dim], + device=rollout.device, + ) + + for key in ("advantage", "return", "seq_mask", "seq_return", "entropy"): + if key in rollout.keys(): + td[key] = rollout[key][:, :-1] + + if flatten: + return td.reshape(num_envs * time_dim) + return td + + +def iterate_minibatches( + rollout: TensorDict, batch_size: int, device: torch.device +) -> Iterator[TensorDict]: + """Yield shuffled minibatches from a flattened rollout.""" + total = rollout.batch_size[0] + indices = torch.randperm(total, device=device) + for start in range(0, total, batch_size): + yield rollout[indices[start : start + batch_size]] diff --git a/embodichain/agents/rl/collector/__init__.py b/embodichain/agents/rl/collector/__init__.py new file mode 100644 index 00000000..683c0ad5 --- /dev/null +++ b/embodichain/agents/rl/collector/__init__.py @@ -0,0 +1,20 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .base import BaseCollector +from .sync_collector import SyncCollector + +__all__ = ["BaseCollector", "SyncCollector"] diff --git a/embodichain/agents/rl/collector/base.py b/embodichain/agents/rl/collector/base.py new file mode 100644 index 00000000..de3f2967 --- /dev/null +++ b/embodichain/agents/rl/collector/base.py @@ -0,0 +1,38 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable + +from tensordict import TensorDict + +__all__ = ["BaseCollector"] + + +class BaseCollector(ABC): + """Base class for rollout collectors.""" + + @abstractmethod + def collect( + self, + num_steps: int, + rollout: TensorDict | None = None, + on_step_callback: Callable[[TensorDict, dict], None] | None = None, + ) -> TensorDict: + """Collect a rollout and return it as a TensorDict.""" + raise NotImplementedError diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py new file mode 100644 index 00000000..16c5b584 --- /dev/null +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -0,0 +1,174 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Callable + +import torch +from tensordict import TensorDict + +from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation +from .base import BaseCollector + +__all__ = ["SyncCollector"] + + +class SyncCollector(BaseCollector): + """Synchronously collect rollouts from a vectorized environment.""" + + def __init__( + self, + env, + policy, + device: torch.device, + reset_every_rollout: bool = False, + ) -> None: + self.env = env + self.policy = policy + self.device = device + self.reset_every_rollout = reset_every_rollout + self._supports_shared_rollout = hasattr(self.env, "set_rollout_buffer") + self.obs_td = self._reset_env() + + @torch.no_grad() + def collect( + self, + num_steps: int, + rollout: TensorDict | None = None, + on_step_callback: Callable[[TensorDict, dict], None] | None = None, + ) -> TensorDict: + self.policy.train() + if self.reset_every_rollout: + self.obs_td = self._reset_env() + + if rollout is None: + raise ValueError( + "SyncCollector.collect() requires a preallocated rollout TensorDict." + ) + if tuple(rollout.batch_size) != (self.env.num_envs, num_steps + 1): + raise ValueError( + "Preallocated rollout batch size mismatch: " + f"expected ({self.env.num_envs}, {num_steps + 1}), got {tuple(rollout.batch_size)}." + ) + self._validate_rollout(rollout, num_steps) + if self._supports_shared_rollout: + self.env.set_rollout_buffer(rollout) + + initial_obs = flatten_dict_observation(self.obs_td) + rollout["obs"][:, 0] = initial_obs + for step_idx in range(num_steps): + step_td = TensorDict( + {"obs": rollout["obs"][:, step_idx]}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + step_td = self.policy.get_action(step_td) + + next_obs, reward, terminated, truncated, env_info = self.env.step( + self._to_action_dict(step_td["action"]) + ) + next_obs_td = dict_to_tensordict(next_obs, self.device) + self._write_step( + rollout=rollout, + step_idx=step_idx, + step_td=step_td, + ) + if not self._supports_shared_rollout: + self._write_env_step( + rollout=rollout, + step_idx=step_idx, + reward=reward, + terminated=terminated, + truncated=truncated, + ) + rollout["obs"][:, step_idx + 1] = flatten_dict_observation(next_obs_td) + + if on_step_callback is not None: + on_step_callback(rollout[:, step_idx], env_info) + + self.obs_td = next_obs_td + + self._attach_final_value(rollout) + return rollout + + def _attach_final_value(self, rollout: TensorDict) -> None: + """Populate the bootstrap value for the final observed state.""" + final_obs = rollout["obs"][:, -1] + last_next_td = TensorDict( + {"obs": final_obs}, + batch_size=[rollout.batch_size[0]], + device=self.device, + ) + self.policy.get_value(last_next_td) + rollout["value"][:, -1] = last_next_td["value"] + + def _reset_env(self) -> TensorDict: + obs, _ = self.env.reset() + return dict_to_tensordict(obs, self.device) + + def _to_action_dict(self, action: torch.Tensor) -> dict[str, torch.Tensor]: + am = getattr(self.env, "action_manager", None) + action_type = ( + am.action_type if am else getattr(self.env, "action_type", "delta_qpos") + ) + return {action_type: action} + + def _write_step( + self, + rollout: TensorDict, + step_idx: int, + step_td: TensorDict, + ) -> None: + """Write policy-side fields for one transition into the shared rollout TensorDict.""" + rollout["action"][:, step_idx] = step_td["action"] + rollout["sample_log_prob"][:, step_idx] = step_td["sample_log_prob"] + rollout["value"][:, step_idx] = step_td["value"] + + def _write_env_step( + self, + rollout: TensorDict, + step_idx: int, + reward: torch.Tensor, + terminated: torch.Tensor, + truncated: torch.Tensor, + ) -> None: + """Populate transition-side fields when the environment does not own the rollout.""" + done = terminated | truncated + rollout["reward"][:, step_idx] = reward.to(self.device) + rollout["done"][:, step_idx] = done.to(self.device) + rollout["terminated"][:, step_idx] = terminated.to(self.device) + rollout["truncated"][:, step_idx] = truncated.to(self.device) + + def _validate_rollout(self, rollout: TensorDict, num_steps: int) -> None: + """Validate rollout layout expected by the collector.""" + expected_shapes = { + "obs": (self.env.num_envs, num_steps + 1, self.policy.obs_dim), + "action": (self.env.num_envs, num_steps + 1, self.policy.action_dim), + "sample_log_prob": (self.env.num_envs, num_steps + 1), + "value": (self.env.num_envs, num_steps + 1), + "reward": (self.env.num_envs, num_steps + 1), + "done": (self.env.num_envs, num_steps + 1), + "terminated": (self.env.num_envs, num_steps + 1), + "truncated": (self.env.num_envs, num_steps + 1), + } + for key, expected_shape in expected_shapes.items(): + actual_shape = tuple(rollout[key].shape) + if actual_shape != expected_shape: + raise ValueError( + f"Preallocated rollout field '{key}' shape mismatch: " + f"expected {expected_shape}, got {actual_shape}." + ) diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 4b0c0a0b..51cf7653 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -16,9 +16,11 @@ from __future__ import annotations +import inspect from typing import Dict, Type -import torch + from gymnasium import spaces +import torch from .actor_critic import ActorCritic from .actor_only import ActorOnly @@ -43,15 +45,30 @@ def get_policy_class(name: str) -> Type[Policy] | None: return _POLICY_REGISTRY.get(name) +def _resolve_space_dim(space_or_dim: spaces.Space | int, name: str) -> int: + """Resolve a flattened feature dimension from an integer or simple Box space.""" + if isinstance(space_or_dim, int): + return space_or_dim + if isinstance(space_or_dim, spaces.Box) and len(space_or_dim.shape) > 0: + return int(space_or_dim.shape[-1]) + raise TypeError( + f"{name} must be an int or a flat Box space for MLP-based policies, got {type(space_or_dim)!r}." + ) + + def build_policy( policy_block: dict, - obs_space: spaces.Space, - action_space: spaces.Space, + obs_space: spaces.Space | int, + action_space: spaces.Space | int, device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, ) -> Policy: - """Build policy strictly from json-like block: { name: ..., cfg: {...} }""" + """Build a policy from config using spaces for extensibility. + + Built-in MLP policies still resolve flattened `obs_dim` / `action_dim`, while + custom policies may accept richer `obs_space` / `action_space` inputs. + """ name = policy_block["name"].lower() if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) @@ -59,18 +76,50 @@ def build_policy( f"Policy '{name}' is not registered. Available policies: {available}" ) policy_cls = _POLICY_REGISTRY[name] + if name == "actor_critic": if actor is None or critic is None: raise ValueError( "ActorCritic policy requires external 'actor' and 'critic' modules." ) - return policy_cls(obs_space, action_space, device, actor=actor, critic=critic) + obs_dim = _resolve_space_dim(obs_space, "obs_space") + action_dim = _resolve_space_dim(action_space, "action_space") + return policy_cls( + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + actor=actor, + critic=critic, + ) elif name == "actor_only": if actor is None: raise ValueError("ActorOnly policy requires external 'actor' module.") - return policy_cls(obs_space, action_space, device, actor=actor) - else: - return policy_cls(obs_space, action_space, device) + obs_dim = _resolve_space_dim(obs_space, "obs_space") + action_dim = _resolve_space_dim(action_space, "action_space") + return policy_cls( + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + actor=actor, + ) + + init_params = inspect.signature(policy_cls.__init__).parameters + build_kwargs: dict[str, object] = {"device": device} + if "obs_space" in init_params: + build_kwargs["obs_space"] = obs_space + elif "obs_dim" in init_params: + build_kwargs["obs_dim"] = _resolve_space_dim(obs_space, "obs_space") + + if "action_space" in init_params: + build_kwargs["action_space"] = action_space + elif "action_dim" in init_params: + build_kwargs["action_dim"] = _resolve_space_dim(action_space, "action_space") + + if "actor" in init_params and actor is not None: + build_kwargs["actor"] = actor + if "critic" in init_params and critic is not None: + build_kwargs["critic"] = critic + return policy_cls(**build_kwargs) def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 35f9a961..32caf0e3 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -16,11 +16,11 @@ from __future__ import annotations -from typing import Dict, Any, Tuple - import torch import torch.nn as nn from torch.distributions.normal import Normal +from tensordict import TensorDict + from .mlp import MLP from .policy import Policy @@ -36,23 +36,21 @@ class ActorCritic(Policy): This allows seamless swapping with other policy implementations (e.g., VLAPolicy) without modifying RL algorithm code. - Implements: - - get_action(obs, deterministic=False) -> (action, log_prob, value) - - get_value(obs) - - evaluate_actions(obs, actions) -> (log_prob, entropy, value) + Implements TensorDict-native interfaces while preserving `get_action()` + compatibility for evaluation and legacy call-sites. """ def __init__( self, - obs_space, - action_space, + obs_dim: int, + action_dim: int, device: torch.device, actor: nn.Module, critic: nn.Module, ): super().__init__() - self.obs_dim = obs_space.shape[-1] - self.action_dim = action_space.shape[-1] + self.obs_dim = obs_dim + self.action_dim = action_dim self.device = device # Require external injection of actor and critic @@ -66,31 +64,38 @@ def __init__( self.log_std_min = -5.0 self.log_std_max = 2.0 - @torch.no_grad() - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _distribution(self, obs: torch.Tensor) -> Normal: mean = self.actor(obs) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) - dist = Normal(mean, std) + return Normal(mean, std) + + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + dist = self._distribution(obs) + mean = dist.mean action = mean if deterministic else dist.sample() - log_prob = dist.log_prob(action).sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return action, log_prob, value + tensordict["action"] = action + tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) + tensordict["value"] = self.critic(obs).squeeze(-1) + return tensordict - @torch.no_grad() - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - return self.critic(obs).squeeze(-1) + def get_value(self, tensordict: TensorDict) -> TensorDict: + tensordict["value"] = self.critic(tensordict["obs"]).squeeze(-1) + return tensordict - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) - log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) - std = log_std.exp().expand(mean.shape[0], -1) - dist = Normal(mean, std) - log_prob = dist.log_prob(actions).sum(dim=-1) - entropy = dist.entropy().sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return log_prob, entropy, value + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["obs"] + action = tensordict["action"] + dist = self._distribution(obs) + return TensorDict( + { + "sample_log_prob": dist.log_prob(action).sum(dim=-1), + "entropy": dist.entropy().sum(dim=-1), + "value": self.critic(obs).squeeze(-1), + }, + batch_size=tensordict.batch_size, + device=tensordict.device, + ) diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py index c54fd515..3d6d1f78 100644 --- a/embodichain/agents/rl/models/actor_only.py +++ b/embodichain/agents/rl/models/actor_only.py @@ -16,11 +16,11 @@ from __future__ import annotations -from typing import Tuple - import torch import torch.nn as nn from torch.distributions.normal import Normal +from tensordict import TensorDict + from .policy import Policy @@ -33,14 +33,14 @@ class ActorOnly(Policy): def __init__( self, - obs_space, - action_space, + obs_dim: int, + action_dim: int, device: torch.device, actor: nn.Module, ): super().__init__() - self.obs_dim = obs_space.shape[-1] - self.action_dim = action_space.shape[-1] + self.obs_dim = obs_dim + self.action_dim = action_dim self.device = device self.actor = actor @@ -50,31 +50,43 @@ def __init__( self.log_std_min = -5.0 self.log_std_max = 2.0 - @torch.no_grad() - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _distribution(self, obs: torch.Tensor) -> Normal: mean = self.actor(obs) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) - dist = Normal(mean, std) + return Normal(mean, std) + + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs = tensordict["obs"] + dist = self._distribution(obs) + mean = dist.mean action = mean if deterministic else dist.sample() - log_prob = dist.log_prob(action).sum(dim=-1) - value = torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) - return action, log_prob, value + tensordict["action"] = action + tensordict["sample_log_prob"] = dist.log_prob(action).sum(dim=-1) + tensordict["value"] = torch.zeros( + obs.shape[0], device=self.device, dtype=obs.dtype + ) + return tensordict - @torch.no_grad() - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - return torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) + def get_value(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["obs"] + tensordict["value"] = torch.zeros( + obs.shape[0], device=self.device, dtype=obs.dtype + ) + return tensordict - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) - log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) - std = log_std.exp().expand(mean.shape[0], -1) - dist = Normal(mean, std) - log_prob = dist.log_prob(actions).sum(dim=-1) - entropy = dist.entropy().sum(dim=-1) - value = torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) - return log_prob, entropy, value + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["obs"] + action = tensordict["action"] + dist = self._distribution(obs) + return TensorDict( + { + "sample_log_prob": dist.log_prob(action).sum(dim=-1), + "entropy": dist.entropy().sum(dim=-1), + "value": torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype), + }, + batch_size=tensordict.batch_size, + device=tensordict.device, + ) diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index 21c13a96..dd71faa5 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -23,11 +23,11 @@ from __future__ import annotations -from typing import Tuple from abc import ABC, abstractmethod import torch.nn as nn import torch +from tensordict import TensorDict class Policy(nn.Module, ABC): @@ -45,50 +45,48 @@ class Policy(nn.Module, ABC): def __init__(self) -> None: super().__init__() - @abstractmethod def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sample an action from the policy. + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + """Sample actions into the provided TensorDict without gradients. Args: - obs: Observation tensor of shape (batch_size, obs_dim) + tensordict: Input TensorDict containing `obs`. deterministic: If True, return the mean action; otherwise sample Returns: - Tuple of (action, log_prob, value): - - action: Sampled action tensor of shape (batch_size, action_dim) - - log_prob: Log probability of the action, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + TensorDict with `action`, `sample_log_prob`, and `value` populated. """ + with torch.no_grad(): + return self.forward(tensordict, deterministic=deterministic) + + @abstractmethod + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + """Write sampled actions and value estimates into the TensorDict.""" raise NotImplementedError @abstractmethod - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - """Get value estimate for given observations. + def get_value(self, tensordict: TensorDict) -> TensorDict: + """Write value estimate for the given observations into the TensorDict. Args: - obs: Observation tensor of shape (batch_size, obs_dim) + tensordict: Input TensorDict containing `obs`. Returns: - Value estimate tensor of shape (batch_size,) + TensorDict with `value` populated. """ raise NotImplementedError @abstractmethod - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Evaluate actions and compute log probabilities, entropy, and values. + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + """Evaluate actions and return current policy outputs. Args: - obs: Observation tensor of shape (batch_size, obs_dim) - actions: Action tensor of shape (batch_size, action_dim) + tensordict: TensorDict containing `obs` and `action`. Returns: - Tuple of (log_prob, entropy, value): - - log_prob: Log probability of actions, shape (batch_size,) - - entropy: Entropy of the action distribution, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + A new TensorDict containing `sample_log_prob`, `entropy`, and `value`. """ raise NotImplementedError diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index ec7c968e..bf08c746 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -29,6 +29,7 @@ from embodichain.agents.rl.models import build_policy, get_registered_policy_names from embodichain.agents.rl.models import build_mlp_from_cfg from embodichain.agents.rl.algo import build_algo, get_registered_algo_names +from embodichain.agents.rl.utils import dict_to_tensordict, flatten_dict_observation from embodichain.agents.rl.utils.trainer import Trainer from embodichain.utils import logger from embodichain.lab.gym.envs.tasks.rl import build_env @@ -64,7 +65,9 @@ def train_from_config(config_path: str): seed = int(trainer_cfg.get("seed", 1)) device_str = trainer_cfg.get("device", "cpu") iterations = int(trainer_cfg.get("iterations", 250)) - rollout_steps = int(trainer_cfg.get("rollout_steps", 2048)) + buffer_size = int( + trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048)) + ) enable_eval = bool(trainer_cfg.get("enable_eval", False)) eval_freq = int(trainer_cfg.get("eval_freq", 10000)) save_freq = int(trainer_cfg.get("save_freq", 50000)) @@ -163,6 +166,10 @@ def train_from_config(config_path: str): ) env = build_env(gym_config_data["id"], base_env_cfg=gym_env_cfg) + sample_obs, _ = env.reset() + sample_obs_td = dict_to_tensordict(sample_obs, device) + obs_dim = flatten_dict_observation(sample_obs_td).shape[-1] + flat_obs_space = env.flattened_observation_space # Create evaluation environment only if enabled eval_env = None @@ -178,13 +185,15 @@ def train_from_config(config_path: str): # Build Policy via registry policy_name = policy_block["name"] + env_action_dim = env.action_space.shape[-1] + action_dim = policy_block.get("action_dim", env_action_dim) + action_dim = int(action_dim) + if action_dim != env_action_dim: + raise ValueError( + f"Configured policy.action_dim={action_dim} does not match env action dim {env_action_dim}." + ) # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only) if policy_name.lower() == "actor_critic": - # Get observation dimension from flattened observation space - # flattened_observation_space returns Box space for RL training - obs_dim = env.flattened_observation_space.shape[-1] - action_dim = env.action_space.shape[-1] - actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: @@ -197,16 +206,13 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.flattened_observation_space, + flat_obs_space, env.action_space, device, actor=actor, critic=critic, ) elif policy_name.lower() == "actor_only": - obs_dim = env.flattened_observation_space.shape[-1] - action_dim = env.action_space.shape[-1] - actor_cfg = policy_block.get("actor") if actor_cfg is None: raise ValueError( @@ -217,14 +223,14 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.flattened_observation_space, + flat_obs_space, env.action_space, device, actor=actor, ) else: policy = build_policy( - policy_block, env.flattened_observation_space, env.action_space, device + policy_block, env.observation_space, env.action_space, device ) # Build Algorithm via factory @@ -276,7 +282,7 @@ def train_from_config(config_path: str): policy=policy, env=env, algorithm=algo, - num_steps=rollout_steps, + buffer_size=buffer_size, batch_size=algo_cfg["batch_size"], writer=writer, eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled @@ -299,7 +305,7 @@ def train_from_config(config_path: str): f"Algorithm: {algo_name} (available: {get_registered_algo_names()})" ) - total_steps = int(iterations * rollout_steps * env.num_envs) + total_steps = int(iterations * buffer_size * env.num_envs) logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})") try: diff --git a/embodichain/agents/rl/utils/__init__.py b/embodichain/agents/rl/utils/__init__.py index 7bd835e8..30f59136 100644 --- a/embodichain/agents/rl/utils/__init__.py +++ b/embodichain/agents/rl/utils/__init__.py @@ -15,9 +15,10 @@ # ---------------------------------------------------------------------------- from .config import AlgorithmCfg -from .helper import flatten_dict_observation +from .helper import dict_to_tensordict, flatten_dict_observation __all__ = [ "AlgorithmCfg", + "dict_to_tensordict", "flatten_dict_observation", ] diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index 42259506..e443a649 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -14,40 +14,63 @@ # limitations under the License. # ---------------------------------------------------------------------------- +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + import torch from tensordict import TensorDict +__all__ = [ + "dict_to_tensordict", + "flatten_dict_observation", +] -def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: - """ - Flatten hierarchical TensorDict observations from ObservationManager. - Recursively traverse nested TensorDicts, collect all tensor values, - flatten each to (num_envs, -1), and concatenate in sorted key order. +def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: + """Flatten a hierarchical observation TensorDict into a 2D tensor. Args: - obs: Nested TensorDict structure, e.g. TensorDict(robot=TensorDict(qpos=..., qvel=...), ...) + obs: Observation TensorDict with batch dimension `[num_envs]`. Returns: - Concatenated flat tensor of shape (num_envs, total_dim) + Flattened observation tensor of shape `[num_envs, obs_dim]`. """ - obs_list = [] + obs_list: list[torch.Tensor] = [] - def _collect_tensors(d, prefix=""): - """Recursively collect tensors from nested TensorDicts in sorted order.""" - for key in sorted(d.keys()): - full_key = f"{prefix}/{key}" if prefix else key - value = d[key] + def _collect_tensors(data: TensorDict) -> None: + for key in sorted(data.keys()): + value = data[key] if isinstance(value, TensorDict): - _collect_tensors(value, full_key) + _collect_tensors(value) elif isinstance(value, torch.Tensor): - # Flatten tensor to (num_envs, -1) shape obs_list.append(value.flatten(start_dim=1)) _collect_tensors(obs) if not obs_list: - raise ValueError("No tensors found in observation TensorDict") + raise ValueError("No tensors found in observation TensorDict.") + + return torch.cat(obs_list, dim=-1) - result = torch.cat(obs_list, dim=-1) - return result + +def dict_to_tensordict( + obs_dict: TensorDict | Mapping[str, Any], device: torch.device | str +) -> TensorDict: + """Convert an environment observation mapping into a TensorDict. + + Args: + obs_dict: Environment observation returned by `reset()` or `step()`. + device: Target device for the resulting TensorDict. + + Returns: + Observation TensorDict moved onto the target device. + """ + if isinstance(obs_dict, TensorDict): + return obs_dict.to(device) + if not isinstance(obs_dict, Mapping): + raise TypeError( + f"Expected observation mapping or TensorDict, got {type(obs_dict)!r}." + ) + return TensorDict.from_dict(dict(obs_dict), device=device) diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 5b17a8e0..4f660232 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Dict, Any, Tuple, Callable +from typing import Dict import time import numpy as np import torch @@ -25,6 +25,8 @@ import wandb from tensordict import TensorDict +from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.agents.rl.collector import SyncCollector from embodichain.lab.gym.envs.managers.event_manager import EventManager from .helper import flatten_dict_observation @@ -37,7 +39,7 @@ def __init__( policy, env, algorithm, - num_steps: int, + buffer_size: int, batch_size: int, writer: SummaryWriter | None, eval_freq: int, @@ -54,7 +56,7 @@ def __init__( self.env = env self.eval_env = eval_env self.algorithm = algorithm - self.num_steps = num_steps + self.buffer_size = buffer_size self.batch_size = batch_size self.writer = writer self.eval_freq = eval_freq @@ -75,28 +77,31 @@ def __init__( self.start_time = time.time() self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) - - # initial obs (assume env returns torch tensors already on target device) - obs, _ = self.env.reset() - - # Initialize algorithm's buffer - # Flatten TensorDict observations from ObservationManager to tensor for RL algorithms - if isinstance(obs, TensorDict): - obs_tensor = flatten_dict_observation(obs) - obs_dim = obs_tensor.shape[-1] - num_envs = obs_tensor.shape[0] - # Store flattened observation for RL training - self.obs = obs_tensor - - action_space = getattr(self.env, "action_space", None) - action_dim = action_space.shape[-1] if action_space else None - if action_dim is None: - raise RuntimeError( - "Env must expose action_space with shape for buffer initialization." - ) - - # Algorithm manages its own buffer - self.algorithm.initialize_buffer(num_steps, num_envs, obs_dim, action_dim) + num_envs = getattr(self.env, "num_envs", None) + if num_envs is None: + raise RuntimeError("Env must expose num_envs for trainer statistics.") + obs_dim = getattr(self.policy, "obs_dim", None) + action_dim = getattr(self.policy, "action_dim", None) + if obs_dim is None or action_dim is None: + raise RuntimeError("Policy must expose obs_dim and action_dim.") + + self.buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=self.buffer_size, + obs_dim=obs_dim, + action_dim=action_dim, + device=self.device, + ) + self.collector = SyncCollector( + env=self.env, + policy=self.policy, + device=self.device, + reset_every_rollout=bool( + getattr( + getattr(self.algorithm, "cfg", None), "reset_every_rollout", False + ) + ), + ) # episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) @@ -137,7 +142,7 @@ def train(self, total_timesteps: int): print(f"Start training, total steps: {total_timesteps}") while self.global_step < total_timesteps: self._collect_rollout() - losses = self.algorithm.update() + losses = self.algorithm.update(self.buffer.get(flatten=False)) self._log_train(losses) if ( self.eval_freq > 0 @@ -150,12 +155,14 @@ def train(self, total_timesteps: int): @torch.no_grad() def _collect_rollout(self): - """Collect a rollout. Algorithm controls the data collection process.""" + """Collect a rollout with the synchronous collector.""" # Callback function for statistics and logging - def on_step(obs, actions, reward, done, info, next_obs): + def on_step(tensordict: TensorDict, info: dict): """Callback called at each step during rollout collection.""" - # Episode stats (stay on device; convert only when episode ends) + reward = tensordict["reward"] + done = tensordict["done"] + # Episode stats self.curr_ret += reward self.curr_len += 1 done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) @@ -167,10 +174,7 @@ def on_step(obs, actions, reward, done, info, next_obs): self.curr_ret[done_idx] = 0 self.curr_len[done_idx] = 0 - # Update global step and observation - # next_obs is already flattened in algorithm's collect_rollout - self.obs = next_obs - self.global_step += next_obs.shape[0] + self.global_step += tensordict.batch_size[0] if isinstance(info, dict): rewards_dict = info.get("rewards") @@ -183,14 +187,12 @@ def on_step(obs, actions, reward, done, info, next_obs): if log_dict and self.use_wandb: wandb.log(log_dict, step=self.global_step) - # Algorithm controls data collection - result = self.algorithm.collect_rollout( - env=self.env, - policy=self.policy, - obs=self.obs, - num_steps=self.num_steps, + rollout = self.collector.collect( + num_steps=self.buffer_size, + rollout=self.buffer.start_rollout(), on_step_callback=on_step, ) + self.buffer.add(rollout) def _log_train(self, losses: Dict[str, float]): if self.writer: @@ -258,7 +260,13 @@ def _eval_once(self, num_episodes: int = 5): # Run episode until all environments complete while not done_mask.all(): # Get deterministic actions from policy - actions, _, _ = self.policy.get_action(obs, deterministic=True) + action_td = TensorDict( + {"obs": obs}, + batch_size=[num_envs], + device=self.device, + ) + action_td = self.policy.get_action(action_td, deterministic=True) + actions = action_td["action"] am = getattr(self.eval_env, "action_manager", None) action_type = ( am.action_type diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index 9042351e..ab945b45 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -645,6 +645,8 @@ def step( rewards=rewards, dones=dones, info=info, + terminateds=terminateds, + truncateds=truncateds, **kwargs, ) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 870a00ef..5e40d6fd 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -264,6 +264,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): # For example, a shared rollout buffer initialized in model training process and passed to the environment for data collection. self.rollout_buffer: TensorDict | None = None self._max_rollout_steps = 0 + self._rollout_buffer_mode: str | None = None if self.cfg.init_rollout_buffer: self.rollout_buffer = init_rollout_buffer_from_gym_space( obs_space=self.observation_space, @@ -273,6 +274,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): device=self.device, ) self._max_rollout_steps = self.rollout_buffer.shape[1] + self._rollout_buffer_mode = "expert" self.current_rollout_step = 0 @@ -287,14 +289,31 @@ def set_rollout_buffer(self, rollout_buffer: TensorDict) -> None: such as a shared rollout buffer initialized in model training process and passed to the environment for data collection. Args: - rollout_buffer (TensorDict): The rollout buffer to be set. The shape of the buffer should be (num_envs, max_episode_steps, *data_shape) for each key. + rollout_buffer (TensorDict): The rollout buffer to be set. RL + rollouts use a uniform `[num_envs, time + 1]` layout so all + fields share the same batch shape; the last slot of + transition-only fields is reserved as padding. Expert buffers + keep the legacy `[num_envs, time]` batch layout. """ - if len(rollout_buffer.shape) != 2: - logger.log_error( - f"Invalid rollout buffer shape: {rollout_buffer.shape}. The expected shape is (num_envs, max_episode_steps) for each key." - ) self.rollout_buffer = rollout_buffer - self._max_rollout_steps = self.rollout_buffer.shape[1] + self._rollout_buffer_mode = self._infer_rollout_buffer_mode(rollout_buffer) + if self._rollout_buffer_mode == "rl": + batch_size = self.rollout_buffer.batch_size + if len(batch_size) != 2: + message = ( + f"Invalid RL rollout buffer batch size: {batch_size}. " + "Expected a 2D batch layout [num_envs, time + 1] for RL rollouts." + ) + logger.log_error(message) + raise ValueError(message) + self._max_rollout_steps = batch_size[1] - 1 + else: + if len(rollout_buffer.shape) != 2: + logger.log_error( + f"Invalid rollout buffer shape: {rollout_buffer.shape}. The expected shape is (num_envs, max_episode_steps) for each key." + ) + self._max_rollout_steps = self.rollout_buffer.shape[1] + self.current_rollout_step = 0 def _init_sim_state(self, **kwargs): """Initialize the simulation state at the beginning of scene creation.""" @@ -414,33 +433,21 @@ def _hook_after_sim_step( ): # TODO: We may make the data collection customizable for rollout buffer. if self.rollout_buffer is not None: - buffer_device = self.rollout_buffer.device if self.current_rollout_step < self._max_rollout_steps: - # Extract data into episode buffer. - self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( - obs.to(buffer_device), non_blocking=True - ) - if isinstance(action, TensorDict): - action_to_store = ( - action["qpos"] - if "qpos" in action - else (action["qvel"] if "qvel" in action else action["qf"]) + if self._rollout_buffer_mode == "rl": + self._write_rl_rollout_step( + obs=obs, + rewards=rewards, + dones=dones, + terminateds=kwargs.get("terminateds"), + truncateds=kwargs.get("truncateds"), ) - elif isinstance(action, torch.Tensor): - action_to_store = action else: - logger.log_warning( - f"Unexpected action type {type(action)} in _hook_after_sim_step; " - "skipping action storage in rollout buffer." + self._write_episode_rollout_step( + obs=obs, + action=action, + rewards=rewards, ) - action_to_store = None - if action_to_store is not None: - self.rollout_buffer["actions"][ - :, self.current_rollout_step, ... - ].copy_(action_to_store.to(buffer_device), non_blocking=True) - self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( - rewards.to(buffer_device), non_blocking=True - ) self.current_rollout_step += 1 else: logger.log_warning( @@ -514,7 +521,7 @@ def _initialize_episode( ) # Clear episode buffers and reset success status for environments being reset - if self.rollout_buffer is not None: + if self.rollout_buffer is not None and self._rollout_buffer_mode != "rl": self.current_rollout_step = 0 self.episode_success_status[env_ids_to_process] = False @@ -528,6 +535,86 @@ def _initialize_episode( if self.cfg.rewards: self.reward_manager.reset(env_ids=env_ids) + def _infer_rollout_buffer_mode(self, rollout_buffer: TensorDict) -> str: + """Infer whether the rollout buffer is expert recording or RL training data.""" + if { + "obs", + "action", + "reward", + "done", + "value", + "terminated", + "truncated", + }.issubset(set(rollout_buffer.keys())): + return "rl" + return "expert" + + def _write_episode_rollout_step( + self, + obs: EnvObs, + action: EnvAction, + rewards: torch.Tensor, + ) -> None: + """Write one step into the legacy episode recording rollout buffer.""" + buffer_device = self.rollout_buffer.device + self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( + obs.to(buffer_device), non_blocking=True + ) + if isinstance(action, TensorDict): + action_to_store = ( + action["qpos"] + if "qpos" in action + else (action["qvel"] if "qvel" in action else action["qf"]) + ) + elif isinstance(action, torch.Tensor): + action_to_store = action + else: + logger.log_warning( + f"Unexpected action type {type(action)} in _hook_after_sim_step; " + "skipping action storage in rollout buffer." + ) + action_to_store = None + if action_to_store is not None: + self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( + action_to_store.to(buffer_device), non_blocking=True + ) + self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( + rewards.to(buffer_device), non_blocking=True + ) + + def _write_rl_rollout_step( + self, + obs: EnvObs, + rewards: torch.Tensor, + dones: torch.Tensor, + terminateds: torch.Tensor | None, + truncateds: torch.Tensor | None, + ) -> None: + """Write environment-side fields into an externally managed RL rollout buffer.""" + buffer_device = self.rollout_buffer.device + self.rollout_buffer["reward"][:, self.current_rollout_step].copy_( + rewards.to(buffer_device), non_blocking=True + ) + self.rollout_buffer["done"][:, self.current_rollout_step].copy_( + dones.to(buffer_device), non_blocking=True + ) + terminateds = ( + terminateds + if terminateds is not None + else torch.zeros_like(dones, dtype=torch.bool) + ) + truncateds = ( + truncateds + if truncateds is not None + else torch.zeros_like(dones, dtype=torch.bool) + ) + self.rollout_buffer["terminated"][:, self.current_rollout_step].copy_( + terminateds.to(buffer_device), non_blocking=True + ) + self.rollout_buffer["truncated"][:, self.current_rollout_step].copy_( + truncateds.to(buffer_device), non_blocking=True + ) + def _step_action(self, action: EnvAction) -> EnvAction: """Set action control command into simulation. diff --git a/tests/agents/test_rl.py b/tests/agents/test_rl.py index e0a5beaf..f356502c 100644 --- a/tests/agents/test_rl.py +++ b/tests/agents/test_rl.py @@ -68,7 +68,7 @@ def setup_method(self): test_train_config = train_config.copy() test_train_config["trainer"]["gym_config"] = self.temp_gym_config_path test_train_config["trainer"]["iterations"] = 2 - test_train_config["trainer"]["rollout_steps"] = 32 + test_train_config["trainer"]["buffer_size"] = 32 test_train_config["trainer"]["eval_freq"] = 1000000 # Disable eval test_train_config["trainer"]["save_freq"] = 1000000 # Disable save test_train_config["trainer"]["headless"] = True diff --git a/tests/agents/test_shared_rollout.py b/tests/agents/test_shared_rollout.py new file mode 100644 index 00000000..89335229 --- /dev/null +++ b/tests/agents/test_shared_rollout.py @@ -0,0 +1,227 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from copy import deepcopy + +import torch +from tensordict import TensorDict + +from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.agents.rl.collector import SyncCollector +from embodichain.agents.rl.utils import flatten_dict_observation +from embodichain.lab.gym.envs.tasks.rl import build_env +from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES +from embodichain.lab.sim import SimulationManagerCfg, SimulationManager +from embodichain.utils.utility import load_json + + +class _FakePolicy: + def __init__(self, obs_dim: int, action_dim: int, device: torch.device) -> None: + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + + def train(self) -> None: + pass + + def forward(self, tensordict: TensorDict) -> TensorDict: + obs = tensordict["obs"] + tensordict["action"] = obs[:, : self.action_dim] * 0.25 + tensordict["sample_log_prob"] = obs.sum(dim=-1) * 0.1 + tensordict["value"] = obs.mean(dim=-1) + return tensordict + + def get_action( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + return self.forward(tensordict) + + def get_value(self, tensordict: TensorDict) -> TensorDict: + tensordict["value"] = tensordict["obs"].mean(dim=-1) + return tensordict + + +class _FakeEnv: + def __init__( + self, num_envs: int, obs_dim: int, action_dim: int, device: torch.device + ): + self.num_envs = num_envs + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + self.action_type = "delta_qpos" + self.rollout_buffer: TensorDict | None = None + self.current_rollout_step = 0 + self._obs = self._make_obs(step=0) + + def reset(self, **kwargs): + self.current_rollout_step = 0 + self._obs = self._make_obs(step=0) + return self._obs, {} + + def set_rollout_buffer(self, rollout_buffer: TensorDict) -> None: + self.rollout_buffer = rollout_buffer + self.current_rollout_step = 0 + + def step(self, action_dict): + action = action_dict[self.action_type] + step_idx = self.current_rollout_step + 1 + next_obs = self._make_obs(step=step_idx) + reward = action.sum(dim=-1) + terminated = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + truncated = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + + if self.rollout_buffer is not None: + self.rollout_buffer["reward"][:, self.current_rollout_step] = reward + self.rollout_buffer["done"][:, self.current_rollout_step] = ( + terminated | truncated + ) + self.rollout_buffer["terminated"][:, self.current_rollout_step] = terminated + self.rollout_buffer["truncated"][:, self.current_rollout_step] = truncated + self.current_rollout_step += 1 + + self._obs = next_obs + return next_obs, reward, terminated, truncated, {} + + def _make_obs(self, step: int) -> TensorDict: + base = torch.full( + (self.num_envs, self.obs_dim), + fill_value=float(step), + dtype=torch.float32, + device=self.device, + ) + return TensorDict( + { + "agent": TensorDict( + {"state": base}, + batch_size=[self.num_envs], + device=self.device, + ) + }, + batch_size=[self.num_envs], + device=self.device, + ) + + +def test_shared_rollout_collects_policy_and_env_fields(): + device = torch.device("cpu") + num_envs = 3 + rollout_len = 4 + obs_dim = 5 + action_dim = 2 + + env = _FakeEnv( + num_envs=num_envs, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + policy = _FakePolicy(obs_dim=obs_dim, action_dim=action_dim, device=device) + collector = SyncCollector(env=env, policy=policy, device=device) + buffer = RolloutBuffer( + num_envs=num_envs, + rollout_len=rollout_len, + obs_dim=obs_dim, + action_dim=action_dim, + device=device, + ) + + rollout = collector.collect( + num_steps=rollout_len, + rollout=buffer.start_rollout(), + ) + buffer.add(rollout) + stored = buffer.get(flatten=False) + + assert stored.batch_size == torch.Size([num_envs, rollout_len + 1]) + assert stored["obs"].shape == torch.Size([num_envs, rollout_len + 1, obs_dim]) + assert torch.allclose(stored["obs"][:, 0], torch.zeros(num_envs, obs_dim)) + assert torch.allclose( + stored["obs"][:, -1], + torch.full((num_envs, obs_dim), float(rollout_len), dtype=torch.float32), + ) + assert torch.allclose( + stored["value"][:, 1], torch.ones(num_envs, dtype=torch.float32) + ) + assert torch.allclose( + stored["action"][:, 0], + torch.zeros(num_envs, action_dim), + ) + assert torch.allclose( + stored["sample_log_prob"][:, 1], + torch.full((num_envs,), 0.5, dtype=torch.float32), + ) + assert torch.allclose( + stored["reward"][:, 2], + torch.full((num_envs,), 1.0, dtype=torch.float32), + ) + assert torch.allclose( + stored["value"][:, -1], + torch.full((num_envs,), 4.0, dtype=torch.float32), + ) + assert torch.allclose( + stored["action"][:, -1], + torch.zeros(num_envs, action_dim, dtype=torch.float32), + ) + + +def test_embodied_env_writes_next_fields_into_external_rollout(): + gym_config = load_json("configs/agents/rl/basic/cart_pole/gym_config.json") + env_cfg = config_to_cfg(gym_config, manager_modules=DEFAULT_MANAGER_MODULES) + env_cfg = deepcopy(env_cfg) + env_cfg.num_envs = 2 + env_cfg.sim_cfg = SimulationManagerCfg( + headless=True, + sim_device=torch.device("cpu"), + enable_rt=False, + gpu_id=0, + ) + + env = build_env(gym_config["id"], base_env_cfg=env_cfg) + try: + obs, _ = env.reset() + obs_dim = flatten_dict_observation(obs).shape[-1] + action_dim = env.action_space.shape[-1] + buffer = RolloutBuffer( + num_envs=env.num_envs, + rollout_len=4, + obs_dim=obs_dim, + action_dim=action_dim, + device=torch.device("cpu"), + ) + rollout = buffer.start_rollout() + env.set_rollout_buffer(rollout) + + action = torch.zeros( + env.num_envs, + action_dim, + dtype=torch.float32, + device=env.device, + ) + next_obs, reward, terminated, truncated, _ = env.step({"delta_qpos": action}) + done = (terminated | truncated).cpu() + + assert env.current_rollout_step == 1 + assert torch.allclose(rollout["reward"][:, 0].cpu(), reward.cpu()) + assert torch.equal(rollout["done"][:, 0].cpu(), done) + assert torch.equal(rollout["terminated"][:, 0].cpu(), terminated.cpu()) + assert torch.equal(rollout["truncated"][:, 0].cpu(), truncated.cpu()) + finally: + env.close() + if SimulationManager.is_instantiated(): + SimulationManager.get_instance().destroy()