Refactor RL rollout with collector/buffer separation and shared TensorDict#175
Refactor RL rollout with collector/buffer separation and shared TensorDict#175yangchen73 wants to merge 19 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the RL rollout pipeline to use a shared, preallocated TensorDict that is written jointly by the rollout collector (policy-side fields) and the environment (transition-side next.* fields), while updating PPO/GRPO, trainer wiring, docs, and configs accordingly.
Changes:
- Introduces
SyncCollectorand a newRolloutBufferimplementation that shares a preallocated rolloutTensorDict. - Updates PPO/GRPO + trainer loop to consume
TensorDictrollouts (including GAE viacompute_gae). - Extends
EmbodiedEnv/BaseEnvto support writingnext.*fields into an externally provided rollout buffer; updates docs/configs/tests.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/agents/test_shared_rollout.py | Adds coverage for shared-rollout behavior across fake env and real EmbodiedEnv. |
| tests/agents/test_rl.py | Updates trainer config key to buffer_size. |
| embodichain/lab/gym/envs/embodied_env.py | Adds external-rollout mode inference + writes next.* fields into shared RL rollout. |
| embodichain/lab/gym/envs/base_env.py | Passes terminateds/truncateds to post-step hook for rollout writing. |
| embodichain/agents/rl/utils/trainer.py | Switches trainer to collector+buffer flow and passes rollout into algorithm.update(). |
| embodichain/agents/rl/utils/helper.py | Adds dict_to_tensordict + compute_gae; refines flattening helper. |
| embodichain/agents/rl/utils/init.py | Exposes new RL utilities in package exports. |
| embodichain/agents/rl/train.py | Plumbs buffer_size, derives obs_dim from reset obs, validates action_dim, updates policy build callsites. |
| embodichain/agents/rl/models/policy.py | Moves policy API to TensorDict-based forward/get_value/evaluate_actions with get_action compatibility wrapper. |
| embodichain/agents/rl/models/actor_only.py | Updates ActorOnly to TensorDict-based interfaces and constructor dims. |
| embodichain/agents/rl/models/actor_critic.py | Updates ActorCritic to TensorDict-based interfaces and constructor dims. |
| embodichain/agents/rl/models/init.py | Updates build_policy() signature to (obs_dim, action_dim, device, ...) and keeps registry. |
| embodichain/agents/rl/collector/sync_collector.py | Implements synchronous rollout collection into a preallocated shared TensorDict. |
| embodichain/agents/rl/collector/base.py | Defines collector interface for returning TensorDict rollouts. |
| embodichain/agents/rl/collector/init.py | Exports collector types. |
| embodichain/agents/rl/buffer/standard_buffer.py | Introduces shared TensorDict-backed rollout buffer. |
| embodichain/agents/rl/buffer/rollout_buffer.py | Re-exports RolloutBuffer to new implementation. |
| embodichain/agents/rl/buffer/init.py | Updates buffer exports to new implementation. |
| embodichain/agents/rl/algo/ppo.py | Refactors PPO to compute/update from TensorDict rollouts (GAE via helper). |
| embodichain/agents/rl/algo/grpo.py | Refactors GRPO to compute/update from TensorDict rollouts. |
| embodichain/agents/rl/algo/base.py | Simplifies algorithm interface to update(rollout: TensorDict). |
| docs/source/tutorial/rl.rst | Updates tutorial to new collector/buffer/shared-rollout flow and config schema. |
| docs/source/overview/rl/trainer.md | Updates trainer overview for SyncCollector + shared rollout TensorDict. |
| docs/source/overview/rl/models.md | Updates model API docs to TensorDict-native policy interface. |
| docs/source/overview/rl/algorithm.md | Updates algorithm docs to update(rollout)-only interface and shared rollout usage. |
| configs/agents/rl/push_cube/train_config.json | Renames rollout_steps to buffer_size and adds policy.action_dim. |
| configs/agents/rl/basic/cart_pole/train_config_grpo.json | Renames rollout_steps to buffer_size and adds policy.action_dim. |
| configs/agents/rl/basic/cart_pole/train_config.json | Renames rollout_steps to buffer_size and adds policy.action_dim. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 28 out of 28 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 28 out of 28 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| obs_to_store = ( | ||
| flatten_dict_observation(obs) if isinstance(obs, TensorDict) else obs | ||
| ) | ||
| self.rollout_buffer["next", "observation"][:, self.current_rollout_step].copy_( | ||
| obs_to_store.to(buffer_device), non_blocking=True |
There was a problem hiding this comment.
In external RL rollout mode, obs here comes from BaseEnv.step() before its internal auto-reset of done envs (reset happens after _hook_after_sim_step). But SyncCollector advances its next-step observation from the next_obs returned by env.step(), which is after that reset. This makes rollout['next','observation'] inconsistent with the observation used for the next policy step whenever an env terminates/truncates mid-rollout. Consider aligning semantics (e.g., have the collector write next.observation from the returned next_obs, or pass both terminal and post-reset observations into the hook / store terminal obs under a separate key).
| @@ -16,91 +16,6 @@ | |||
|
|
|||
| from __future__ import annotations | |||
|
|
|||
| from typing import Dict, Iterator | |||
| from .standard_buffer import RolloutBuffer | |||
|
|
|||
There was a problem hiding this comment.
Could we remove this file? It seems useless for now.
There was a problem hiding this comment.
Yes, I'v removed it.
| policy_block: dict, | ||
| obs_space: spaces.Space, | ||
| action_space: spaces.Space, | ||
| obs_dim: int, |
There was a problem hiding this comment.
Why change to int? For VLA a we have dict-like obs space.
There was a problem hiding this comment.
Oh. I copied it from original version and forgot removing it.
There was a problem hiding this comment.
Now use env.action_space and env.observation_space directly.
|
|
||
| __all__ = [ | ||
| "AlgorithmCfg", | ||
| "compute_gae", |
There was a problem hiding this comment.
Since it will only used for RL algorithm, we should put compute_gae into algo module (maybe a new file that collect all algorithms).
There was a problem hiding this comment.
I moved it out of rl.utils into rl.algo.common and updated PPO to import it from there.
| if self.rollout_buffer is not None: | ||
| if ( | ||
| self.rollout_buffer is not None | ||
| and self._rollout_buffer_mode != "external_rl" |
There was a problem hiding this comment.
We may use expert and rl for this variable
There was a problem hiding this comment.
Agree, it is more clear.
|
|
||
| terminateds = ( | ||
| terminateds | ||
| if terminateds is not None |
There was a problem hiding this comment.
Can we remove truncateds and terminateds? Because dones has been computed and stored to the buffer and can be used directly in algorithm.
There was a problem hiding this comment.
It make sense. But I think it is better to keep them since we may use them in the near future.
embodichain/agents/rl/algo/grpo.py
Outdated
| "approx_ref_kl": total_kl / max(1.0, total_weight), | ||
| } | ||
|
|
||
| def _iterate_minibatches( |
There was a problem hiding this comment.
_iterate_minibatches is repeated for PPO and GRPO. It should be the method of Buffer
There was a problem hiding this comment.
I moved it to buffer/utils.py
| dtype=torch.float32, | ||
| device=self.device, | ||
| ), | ||
| "next": { |
There was a problem hiding this comment.
Why create an extra buffer with shape (B, T, ...)? It would be a waste of memory.
There was a problem hiding this comment.
I removed the extra next.obs buffer, now the structure is:
obs: [B, T+1, obs_dim]
action: [B, T, action_dim]
sample_log_prob: [B, T]
value: [B, T]
reward: [B, T]
done: [B, T]
terminated: [B, T]
truncated: [B, T]
There was a problem hiding this comment.
The length will be unequal. We could just append 1 to length (T+1 in total)
| batch_size=[obs_tensor.shape[0]], | ||
| device=self.device, | ||
| ) | ||
| self.policy.forward(step_td) |
There was a problem hiding this comment.
Why use forward here? forward will compute gradient and we should use get_action instead
There was a problem hiding this comment.
I add @torch.no_grad() upon.
There was a problem hiding this comment.
Updated. get_action() is now return tensordict, and the collector uses get_action() instead of calling forward() directly.
Description
This PR refactors the RL rollout stack around a shared TensorDict.
Type of change
Checklist
black .command to format the code base.