Skip to content

Refactor RL rollout with collector/buffer separation and shared TensorDict#175

Open
yangchen73 wants to merge 19 commits intomainfrom
yc/vla_rl
Open

Refactor RL rollout with collector/buffer separation and shared TensorDict#175
yangchen73 wants to merge 19 commits intomainfrom
yc/vla_rl

Conversation

@yangchen73
Copy link
Collaborator

Description

This PR refactors the RL rollout stack around a shared TensorDict.

  • Separates rollout collection and buffer management
  • Updates PPO/GRPO to use the new collector and buffer interfaces
  • Uses a shared TensorDict between the environment and trainer for rollout data

Type of change

  • Enhancement (non-breaking change which improves an existing functionality)
  • Documentation update

Checklist

  • I have run the black . command to format the code base.
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • Dependencies have been updated, if applicable.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the RL rollout pipeline to use a shared, preallocated TensorDict that is written jointly by the rollout collector (policy-side fields) and the environment (transition-side next.* fields), while updating PPO/GRPO, trainer wiring, docs, and configs accordingly.

Changes:

  • Introduces SyncCollector and a new RolloutBuffer implementation that shares a preallocated rollout TensorDict.
  • Updates PPO/GRPO + trainer loop to consume TensorDict rollouts (including GAE via compute_gae).
  • Extends EmbodiedEnv/BaseEnv to support writing next.* fields into an externally provided rollout buffer; updates docs/configs/tests.

Reviewed changes

Copilot reviewed 28 out of 28 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/agents/test_shared_rollout.py Adds coverage for shared-rollout behavior across fake env and real EmbodiedEnv.
tests/agents/test_rl.py Updates trainer config key to buffer_size.
embodichain/lab/gym/envs/embodied_env.py Adds external-rollout mode inference + writes next.* fields into shared RL rollout.
embodichain/lab/gym/envs/base_env.py Passes terminateds/truncateds to post-step hook for rollout writing.
embodichain/agents/rl/utils/trainer.py Switches trainer to collector+buffer flow and passes rollout into algorithm.update().
embodichain/agents/rl/utils/helper.py Adds dict_to_tensordict + compute_gae; refines flattening helper.
embodichain/agents/rl/utils/init.py Exposes new RL utilities in package exports.
embodichain/agents/rl/train.py Plumbs buffer_size, derives obs_dim from reset obs, validates action_dim, updates policy build callsites.
embodichain/agents/rl/models/policy.py Moves policy API to TensorDict-based forward/get_value/evaluate_actions with get_action compatibility wrapper.
embodichain/agents/rl/models/actor_only.py Updates ActorOnly to TensorDict-based interfaces and constructor dims.
embodichain/agents/rl/models/actor_critic.py Updates ActorCritic to TensorDict-based interfaces and constructor dims.
embodichain/agents/rl/models/init.py Updates build_policy() signature to (obs_dim, action_dim, device, ...) and keeps registry.
embodichain/agents/rl/collector/sync_collector.py Implements synchronous rollout collection into a preallocated shared TensorDict.
embodichain/agents/rl/collector/base.py Defines collector interface for returning TensorDict rollouts.
embodichain/agents/rl/collector/init.py Exports collector types.
embodichain/agents/rl/buffer/standard_buffer.py Introduces shared TensorDict-backed rollout buffer.
embodichain/agents/rl/buffer/rollout_buffer.py Re-exports RolloutBuffer to new implementation.
embodichain/agents/rl/buffer/init.py Updates buffer exports to new implementation.
embodichain/agents/rl/algo/ppo.py Refactors PPO to compute/update from TensorDict rollouts (GAE via helper).
embodichain/agents/rl/algo/grpo.py Refactors GRPO to compute/update from TensorDict rollouts.
embodichain/agents/rl/algo/base.py Simplifies algorithm interface to update(rollout: TensorDict).
docs/source/tutorial/rl.rst Updates tutorial to new collector/buffer/shared-rollout flow and config schema.
docs/source/overview/rl/trainer.md Updates trainer overview for SyncCollector + shared rollout TensorDict.
docs/source/overview/rl/models.md Updates model API docs to TensorDict-native policy interface.
docs/source/overview/rl/algorithm.md Updates algorithm docs to update(rollout)-only interface and shared rollout usage.
configs/agents/rl/push_cube/train_config.json Renames rollout_steps to buffer_size and adds policy.action_dim.
configs/agents/rl/basic/cart_pole/train_config_grpo.json Renames rollout_steps to buffer_size and adds policy.action_dim.
configs/agents/rl/basic/cart_pole/train_config.json Renames rollout_steps to buffer_size and adds policy.action_dim.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 28 out of 28 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 28 out of 28 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +576 to +580
obs_to_store = (
flatten_dict_observation(obs) if isinstance(obs, TensorDict) else obs
)
self.rollout_buffer["next", "observation"][:, self.current_rollout_step].copy_(
obs_to_store.to(buffer_device), non_blocking=True
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In external RL rollout mode, obs here comes from BaseEnv.step() before its internal auto-reset of done envs (reset happens after _hook_after_sim_step). But SyncCollector advances its next-step observation from the next_obs returned by env.step(), which is after that reset. This makes rollout['next','observation'] inconsistent with the observation used for the next policy step whenever an env terminates/truncates mid-rollout. Consider aligning semantics (e.g., have the collector write next.observation from the returned next_obs, or pass both terminal and post-reset observations into the hook / store terminal obs under a separate key).

Copilot uses AI. Check for mistakes.
@@ -16,91 +16,6 @@

from __future__ import annotations

from typing import Dict, Iterator
from .standard_buffer import RolloutBuffer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove this file? It seems useless for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'v removed it.

policy_block: dict,
obs_space: spaces.Space,
action_space: spaces.Space,
obs_dim: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change to int? For VLA a we have dict-like obs space.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. I copied it from original version and forgot removing it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now use env.action_space and env.observation_space directly.


__all__ = [
"AlgorithmCfg",
"compute_gae",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it will only used for RL algorithm, we should put compute_gae into algo module (maybe a new file that collect all algorithms).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it out of rl.utils into rl.algo.common and updated PPO to import it from there.

if self.rollout_buffer is not None:
if (
self.rollout_buffer is not None
and self._rollout_buffer_mode != "external_rl"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may use expert and rl for this variable

Copy link
Collaborator Author

@yangchen73 yangchen73 Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, it is more clear.


terminateds = (
terminateds
if terminateds is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove truncateds and terminateds? Because dones has been computed and stored to the buffer and can be used directly in algorithm.

Copy link
Collaborator Author

@yangchen73 yangchen73 Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It make sense. But I think it is better to keep them since we may use them in the near future.

"approx_ref_kl": total_kl / max(1.0, total_weight),
}

def _iterate_minibatches(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_iterate_minibatches is repeated for PPO and GRPO. It should be the method of Buffer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to buffer/utils.py

dtype=torch.float32,
device=self.device,
),
"next": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create an extra buffer with shape (B, T, ...)? It would be a waste of memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the extra next.obs buffer, now the structure is:
obs: [B, T+1, obs_dim]
action: [B, T, action_dim]
sample_log_prob: [B, T]
value: [B, T]
reward: [B, T]
done: [B, T]
terminated: [B, T]
truncated: [B, T]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The length will be unequal. We could just append 1 to length (T+1 in total)

batch_size=[obs_tensor.shape[0]],
device=self.device,
)
self.policy.forward(step_td)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use forward here? forward will compute gradient and we should use get_action instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add @torch.no_grad() upon.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. get_action() is now return tensordict, and the collector uses get_action() instead of calling forward() directly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants