Conversation
There was a problem hiding this comment.
Pull request overview
This pull request introduces a comprehensive refactoring of the RL training framework to use TensorDict-based data flow, replacing the previous tensor-based approach. The PR adds support for two training modes: standard synchronous PPO and asynchronous VLA training designed for scenarios with slow model inference.
Changes:
- Migrated entire RL pipeline to TensorDict-based architecture for structured, extensible data flow
- Introduced dual buffer system: RolloutBuffer (standard) and VLABuffer (async with FIFO)
- Added AsyncCollector for background data collection in VLA mode with thread-based parallelism
- Refactored Policy interface to use TensorDict inputs/outputs with in-place modifications
- Updated PPO algorithm to work with TensorDict rollouts and removed dependency on gym spaces
- Modified configuration to use
buffer_sizeinstead ofrollout_stepsand addedaction_dimrequirement
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| embodichain/agents/rl/utils/trainer.py | Refactored to support dual training modes (sync/async) with TensorDict |
| embodichain/agents/rl/utils/helper.py | Added dict_to_tensordict, compute_gae, and logging utilities |
| embodichain/agents/rl/utils/async_collector.py | New async data collector for VLA mode with background thread |
| embodichain/agents/rl/buffer/rollout_buffer.py | Renamed/refactored to VLABuffer with circular indexing |
| embodichain/agents/rl/buffer/standard_buffer.py | New RolloutBuffer for standard PPO mode |
| embodichain/agents/rl/buffer/init.py | Updated exports for dual buffer system |
| embodichain/agents/rl/algo/ppo.py | Refactored to use TensorDict data flow throughout |
| embodichain/agents/rl/algo/base.py | Updated base algorithm interface for TensorDict |
| embodichain/agents/rl/models/policy.py | Changed interface to TensorDict-based methods |
| embodichain/agents/rl/models/actor_critic.py | Implemented TensorDict-based policy with in-place modifications |
| embodichain/agents/rl/models/init.py | Removed gymnasium dependency, added action_dim parameter |
| embodichain/agents/rl/train.py | Added action_dim requirement, removed gym space dependency |
| tests/agents/test_rl.py | Updated test to use buffer_size parameter |
| configs/agents/rl/push_cube/train_config.json | Updated config with buffer_size, action_dim, and eval_freq |
| configs/agents/rl/basic/cart_pole/train_config.json | Updated config with buffer_size |
| docs/source/tutorial/rl.rst | Updated documentation to reference buffer_size |
| pyproject.toml | Added tensordict>=0.5.0 dependency |
Comments suppressed due to low confidence (1)
embodichain/agents/rl/train.py:289
- The
buffer_typeparameter is not read from the trainer config and not passed to the Trainer constructor (line 273-289). This means the VLA async mode introduced in this PR cannot be used, as it will always default to "standard" mode. Addbuffer_type = trainer_cfg.get("buffer_type", "standard")before the Trainer initialization and pass it asbuffer_type=buffer_typeto the Trainer constructor.
trainer = Trainer(
policy=policy,
env=env,
algorithm=algo,
buffer_size=buffer_size,
batch_size=algo_cfg["batch_size"],
writer=writer,
eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled
save_freq=save_freq,
checkpoint_dir=checkpoint_dir,
exp_name=exp_name,
use_wandb=use_wandb,
eval_env=eval_env, # None if enable_eval=False
event_cfg=train_event_cfg,
eval_event_cfg=eval_event_cfg if enable_eval else {},
num_eval_episodes=num_eval_episodes,
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # Update global step | ||
| num_envs = tensordict.batch_size[0] | ||
| self.global_step += num_envs |
There was a problem hiding this comment.
The self.global_step variable is updated from the async collector thread (line 182 via callback) and potentially read from the main thread (lines 214, 244, 255). This creates a race condition. Consider using a thread-safe counter (e.g., threading.Lock protection or multiprocessing.Value) or tracking steps only in one thread.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get(self, flatten: bool = True) -> TensorDict: | ||
| """Get valid data from buffer. | ||
|
|
||
| Args: | ||
| flatten: If True, return flattened [size, ...]. Currently only supports True. | ||
|
|
||
| Returns: | ||
| TensorDict with batch_size=[size, ...] containing valid data | ||
| """ | ||
| if not self._initialized or self.size == 0: | ||
| raise ValueError("Buffer is empty") | ||
|
|
||
| if not flatten: | ||
| raise NotImplementedError("Only flatten=True is supported for VLABuffer") | ||
|
|
||
| # Return first 'size' elements (valid data) | ||
| # Note: Data is in insertion order up to write_pos, then wraps | ||
| if self.size < self.buffer_size: | ||
| # Buffer not yet full, data is [0:size] | ||
| return self.buffer[: self.size] | ||
| else: | ||
| # Buffer full, need to rearrange to maintain temporal order | ||
| # Oldest data is at write_pos, newest at write_pos-1 | ||
| indices = ( | ||
| torch.arange( | ||
| self.write_pos, | ||
| self.write_pos + self.buffer_size, | ||
| device=self.device, | ||
| ) | ||
| % self.buffer_size | ||
| ) | ||
| return self.buffer[indices] | ||
|
|
||
| def clear(self) -> None: | ||
| """Clear buffer (reset pointers, keep pre-allocated memory).""" | ||
| self.write_pos = 0 | ||
| self.size = 0 | ||
| # Keep buffer allocated for reuse | ||
|
|
||
| def __len__(self) -> int: | ||
| """Return current number of valid transitions.""" | ||
| return self.size | ||
|
|
||
| def is_full(self) -> bool: | ||
| """Check if buffer is at full buffer_size.""" | ||
| return self.size >= self.buffer_size |
There was a problem hiding this comment.
The VLABuffer.get() and is_full() methods are called from the main thread while AsyncCollector writes to the buffer from a background thread, but these methods lack thread safety. The read of self.size and self.write_pos could return inconsistent values if a write is in progress. Additionally, buffer.get() performs complex operations (checking size, slicing buffer) that should be atomic with respect to concurrent writes. Consider adding thread synchronization or document that external locking is required.
There was a problem hiding this comment.
I'v added lock for get/is_full for thread safety
| if deterministic: | ||
| action = mean | ||
| else: | ||
| dist = Normal(mean, std) |
There was a problem hiding this comment.
The distribution is created twice when deterministic=False. Line 130 creates dist = Normal(mean, std), then lines 136-137 create it again. This is wasteful. Consider refactoring to create the distribution once and use either dist.mean or dist.sample() based on the deterministic flag.
| dist = Normal(mean, std) |
| next_value_td = TensorDict( | ||
| {"observation": next_obs_for_td}, | ||
| batch_size=next_td.batch_size, | ||
| device=self.device, | ||
| ) | ||
| self.policy.get_value(next_value_td) | ||
| next_td["value"] = next_value_td["value"] |
There was a problem hiding this comment.
The policy is accessed from both the background collector thread (lines 145-146, 200) and potentially from the main training thread during algorithm.update(). PyTorch tensors and models are not thread-safe by default. Concurrent access to the policy parameters during forward passes and gradient updates can lead to race conditions and corrupted gradients. Consider using locks to synchronize policy access, or ensure the policy is not being updated while the collector is running (e.g., by stopping collection during training).
| next_value_td = TensorDict( | |
| {"observation": next_obs_for_td}, | |
| batch_size=next_td.batch_size, | |
| device=self.device, | |
| ) | |
| self.policy.get_value(next_value_td) | |
| next_td["value"] = next_value_td["value"] | |
| # Protect policy access with lock to avoid races with training thread | |
| with self._lock: | |
| next_value_td = TensorDict( | |
| {"observation": next_obs_for_td}, | |
| batch_size=next_td.batch_size, | |
| device=self.device, | |
| ) | |
| self.policy.get_value(next_value_td) | |
| next_td["value"] = next_value_td["value"] |
|
|
||
| losses = self.algorithm.update(data) | ||
| self._log_train(losses) | ||
|
|
There was a problem hiding this comment.
After buffer.get() is called on line 238, the VLABuffer is not cleared (unlike RolloutBuffer which auto-clears). Since the buffer is full (size == buffer_size), the is_full() check on line 232 will immediately return True in the next iteration, causing the training loop to repeatedly train on the same data without waiting for new transitions. The buffer should be cleared after get(), or the is_full() logic should be modified to track whether data has been consumed.
| # Clear async buffer after consumption to avoid retraining on stale data | |
| if hasattr(self.buffer, "clear"): | |
| self.buffer.clear() |
| # Prepare next iteration - use the converted TensorDict | ||
| current_td = next_obs_td |
There was a problem hiding this comment.
The collector does not handle episode resets when done=True. After an episode terminates (done flag is set), the environment should be reset to get a fresh initial observation for the next episode. Currently, the collector continues using next_obs even after termination, which could contain stale data. Most RL environments auto-reset on episode end, but this should be made explicit or documented as a requirement.
| # Prepare next iteration - use the converted TensorDict | |
| current_td = next_obs_td | |
| # Prepare next iteration: | |
| # - if episode is done, reset env to get a fresh initial observation | |
| # - otherwise, continue from next_obs_td | |
| if done.any(): | |
| reset_result = self.env.reset() | |
| # Support both Gym/Gymnasium-style (obs, info) and plain-obs resets | |
| if isinstance(reset_result, tuple): | |
| reset_obs = reset_result[0] | |
| else: | |
| reset_obs = reset_result | |
| current_td = dict_to_tensordict(reset_obs, self.device) | |
| else: | |
| current_td = next_obs_td |
| # Store complete transition | ||
| rollout_list.append(current_td.clone()) |
There was a problem hiding this comment.
Calling .clone() on every transition creates a full copy of the TensorDict including all nested tensors, which can be memory-intensive for large rollouts. Since current_td is reassigned to next_obs_td on line 122 (which is a fresh TensorDict), the clone may be unnecessary. Consider whether a shallow copy or reference would suffice, or document why deep cloning is required here.
| # Store complete transition | |
| rollout_list.append(current_td.clone()) | |
| # Store complete transition (no clone needed: current_td is not mutated afterwards) | |
| rollout_list.append(current_td) |
There was a problem hiding this comment.
I don't think so. Every loop does current_td["next"] = next_td and then current_td = next_obs_td. If we don't use clone(), every appended element is the same TensorDict reference. Then the next loop overwrites its contents. As a result, every entry in rollout_list points to the same modified data.
| # Update global step based on collected data (main thread only) | ||
| batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 | ||
| self.global_step += batch_size |
There was a problem hiding this comment.
The global_step update in async mode only counts batch_size from the returned data (line 242), not the actual number of environment steps taken. Since VLABuffer is continuously being written to by AsyncCollector (which tracks steps in _step_count), the global_step will not accurately reflect the total number of environment interactions. Consider synchronizing global_step with the collector's _step_count, or documenting this discrepancy.
| # Update global step based on collected data (main thread only) | |
| batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 | |
| self.global_step += batch_size | |
| # Update global step. | |
| # Prefer the collector's step count (actual env interactions) if available, | |
| # otherwise fall back to counting processed batch size. | |
| batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0 | |
| steps_from_collector = getattr(collector, "_step_count", None) | |
| if isinstance(steps_from_collector, int) and steps_from_collector > self.global_step: | |
| self.global_step = steps_from_collector | |
| else: | |
| self.global_step += batch_size |
| @@ -166,30 +154,113 @@ def on_step(obs, actions, reward, done, info, next_obs): | |||
| self.curr_ret[done_idx] = 0 | |||
| self.curr_len[done_idx] = 0 | |||
|
|
|||
| # Update global step and observation | |||
| # next_obs is already flattened in algorithm's collect_rollout | |||
| self.obs = next_obs | |||
| self.global_step += next_obs.shape[0] | |||
|
|
|||
| if isinstance(info, dict): | |||
| rewards_dict = info.get("rewards") | |||
| metrics_dict = info.get("metrics") | |||
| # Log environment metrics | |||
| if isinstance(env_info, dict): | |||
| rewards_dict = env_info.get("rewards") | |||
| metrics_dict = env_info.get("metrics") | |||
| self._log_scalar_dict("rewards", rewards_dict) | |||
| self._log_scalar_dict("metrics", metrics_dict) | |||
| log_dict = {} | |||
| log_dict.update(self._pack_log_dict("rewards", rewards_dict)) | |||
| log_dict.update(self._pack_log_dict("metrics", metrics_dict)) | |||
| log_dict.update(pack_log_dict("rewards", rewards_dict)) | |||
| log_dict.update(pack_log_dict("metrics", metrics_dict)) | |||
| if log_dict and self.use_wandb: | |||
| wandb.log(log_dict, step=self.global_step) | |||
There was a problem hiding this comment.
The on_step callback modifies shared state (self.curr_ret, self.curr_len, self.ret_window, self.len_window, self.global_step) without thread synchronization. In async mode, this callback runs in the AsyncCollector background thread while the main thread could be accessing these same variables (e.g., in _log_train). This can cause race conditions and data corruption. Use threading.Lock to protect access to these shared variables, or ensure they're only accessed from one thread.
| def collect(self, **kwargs) -> TensorDict: | ||
| """Collect data from environment. | ||
|
|
There was a problem hiding this comment.
Overridden method signature does not match call, where it is passed too many arguments. Overriding method method SyncCollector.collect matches the call.
Overridden method signature does not match call, where it is passed an argument named 'num_steps'. Overriding method method SyncCollector.collect matches the call.
| def collect(self, **kwargs) -> TensorDict: | |
| """Collect data from environment. | |
| def collect(self, num_steps: int, **kwargs) -> TensorDict: | |
| """Collect data from environment. | |
| Args: | |
| num_steps: Number of steps to collect. |
| def collect(self, num_steps: int) -> TensorDict: | ||
| """Collect a synchronous rollout. | ||
|
|
||
| Args: | ||
| num_steps: Number of steps to collect | ||
|
|
||
| Returns: | ||
| TensorDict with batch_size=[T, N] containing full rollout | ||
| """ |
There was a problem hiding this comment.
This method requires 2 positional arguments, whereas overridden BaseCollector.collect requires 1.
| def collect(self, num_steps: int) -> TensorDict: | |
| """Collect a synchronous rollout. | |
| Args: | |
| num_steps: Number of steps to collect | |
| Returns: | |
| TensorDict with batch_size=[T, N] containing full rollout | |
| """ | |
| def collect(self, num_steps: int | None = None) -> TensorDict: | |
| """Collect a synchronous rollout. | |
| Args: | |
| num_steps: Number of steps to collect. | |
| Returns: | |
| TensorDict with batch_size=[T, N] containing full rollout | |
| """ | |
| if num_steps is None: | |
| raise TypeError("num_steps must be provided for SyncCollector.collect()") |
yangchen73
left a comment
There was a problem hiding this comment.
- need to use clone
- add locks
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.global_step += data.batch_size[0] if data.batch_size else 0 | ||
|
|
||
| losses = self.algorithm.update(data) | ||
| self._log_train(losses) |
There was a problem hiding this comment.
In async mode, the collector thread calls policy.forward/get_value while the main thread simultaneously runs algorithm.update() (backprop + optimizer.step) on the same policy. PyTorch modules/optimizers are not thread-safe, and concurrent CUDA kernels/parameter updates can lead to nondeterminism or crashes. Consider pausing collection during update, protecting policy access with a lock, or using a separate inference copy of the policy for the collector (periodically synced).
| ### VLABuffer (Async) | ||
|
|
||
| Circular FIFO buffer: | ||
|
|
||
| ```python |
There was a problem hiding this comment.
The guide describes VLABuffer as a circular FIFO buffer with step-level add(transition) and shows buffer.get(flatten=True), but the implementation only accepts full rollouts via add_rollout() and PPO’s GAE requires the unflattened [N, T] layout. Please update this section to match the actual VLABuffer API and required shapes, otherwise users will hit AttributeErrors or compute incorrect advantages.
| log_dict.update(pack_log_dict("rewards", rewards_dict)) | ||
| log_dict.update(pack_log_dict("metrics", metrics_dict)) | ||
| if log_dict and self.use_wandb: | ||
| wandb.log(log_dict, step=self.global_step) |
There was a problem hiding this comment.
on_step_callback logs to W&B using self.global_step, but global_step is only incremented once per rollout in _train_sync/_train_async. Because on_step_callback is invoked for every env step, these logs will repeatedly use the same step value (overwriting or producing a flat x-axis). Consider incrementing global_step inside the callback (e.g., by num_envs per env.step) or passing an explicit per-step counter into the callback for logging.
| wandb.log(log_dict, step=self.global_step) | |
| # Use a dedicated per-environment-step counter for W&B logging. | |
| # Lazily initialize it so we don't depend on __init__ details. | |
| env_log_step = getattr(self, "_env_log_step", 0) | |
| # Increment by the number of parallel environments (reward batch size). | |
| if isinstance(reward, torch.Tensor): | |
| env_log_step += reward.shape[0] | |
| else: | |
| env_log_step += 1 | |
| self._env_log_step = env_log_step | |
| wandb.log(log_dict, step=env_log_step) |
| @abstractmethod | ||
| def get_action( | ||
| self, obs: torch.Tensor, deterministic: bool = False | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """Sample an action from the policy. | ||
| def forward(self, tensordict: TensorDict) -> TensorDict: | ||
| """Forward pass that adds action to the input tensordict (in-place). |
There was a problem hiding this comment.
Policy.forward() in the abstract base class does not accept a deterministic argument, but Trainer._eval_once calls self.policy.forward(..., deterministic=True) and concrete implementations (ActorCritic/VLAPolicy) already support it. This makes the interface inconsistent and can cause TypeErrors for any other Policy implementation that follows the base signature. Consider updating the abstract method signature to include deterministic: bool = False (and documenting the expected behavior).
| while step < total: | ||
| rollout = collector.collect(num_steps=2048) | ||
| buffer.add(rollout) | ||
| data = buffer.get(flatten=True) | ||
| losses = algorithm.update(data) |
There was a problem hiding this comment.
This workflow calls buffer.get(flatten=True) and passes the flattened result to algorithm.update(). With the current PPO implementation, flattened input is treated as a [size, 1] rollout, making GAE effectively run with T=1 and producing incorrect advantages/targets. Update the example to pass the unflattened [N, T] rollout into update (flatten only inside PPO for minibatching).
| "num_envs": 64, | ||
| "iterations": 1000, | ||
| "rollout_steps": 1024, | ||
| "buffer_size": 1024, | ||
| "eval_freq": 2, | ||
| "save_freq": 200, |
There was a problem hiding this comment.
This config was updated to use trainer.buffer_size, but train.py now requires policy.action_dim to be present. As-is, running this config will raise "Missing 'action_dim' in policy config". Consider adding an explicit action_dim to the policy block (and optionally trainer.model_type) so the example remains runnable.
| # Initialize observation and get num_envs (needed for VLA buffer) | ||
| obs, _ = env.reset() | ||
| self.obs_tensordict = dict_to_tensordict(obs, device) | ||
| num_envs = self.obs_tensordict.batch_size[0] |
There was a problem hiding this comment.
Trainer.init calls env.reset() to infer num_envs / seed obs, but both SyncCollector/AsyncCollector (via BaseCollector) also call env.reset(). This extra reset can be expensive and can change the initial state/episode accounting. Since the env exposes num_envs (used elsewhere in train.py), consider using env.num_envs (and/or env.device) instead of resetting here, and avoid storing an unused obs_tensordict.
| # Initialize observation and get num_envs (needed for VLA buffer) | |
| obs, _ = env.reset() | |
| self.obs_tensordict = dict_to_tensordict(obs, device) | |
| num_envs = self.obs_tensordict.batch_size[0] | |
| # Initialize num_envs without forcing a reset when possible | |
| if hasattr(env, "num_envs"): | |
| num_envs = env.num_envs | |
| # No need to create an initial obs_tensordict here; collectors will reset the env. | |
| self.obs_tensordict = None | |
| else: | |
| # Fallback for environments that do not expose num_envs | |
| obs, _ = env.reset() | |
| self.obs_tensordict = dict_to_tensordict(obs, device) | |
| num_envs = self.obs_tensordict.batch_size[0] |
| """Helper utilities for RL training. | ||
|
|
||
| This module provides utility functions for RL algorithms. | ||
| """ | ||
|
|
There was a problem hiding this comment.
flatten_dict_observation() was removed from this module, but embodichain/lab/gym/envs/base_env.py still imports it to build flattened_observation_space. This will raise ImportError at runtime. Either restore a compatible flatten_dict_observation helper here (for backward compatibility) or update base_env.py in this PR to use the new TensorDict-based utilities.
| __all__ = ["RolloutBuffer"] | ||
| Provides two buffer implementations: | ||
| - RolloutBuffer: Standard PPO buffer (single rollout, use and discard) | ||
| - VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference) |
There was a problem hiding this comment.
The buffer module docstring says VLABuffer is “FIFO multi-rollout accumulation”, but VLABuffer currently stores exactly one rollout (present/None). This mismatch is likely to confuse users; update the docstring to reflect the current behavior, or adjust VLABuffer to match the documented semantics.
| - VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference) | |
| - VLABuffer: VLA buffer (single-rollout accumulation optimized for slow inference) |
| # Ensure 2D format [T, N] for GAE computation | ||
| if len(rollout.batch_size) == 1: | ||
| rollout = rollout.unsqueeze(1) # [size] -> [size, 1] | ||
|
|
There was a problem hiding this comment.
PPO.update() claims to support receiving a flattened rollout (batch_size=[size]) by doing rollout.unsqueeze(1), but that turns it into [size, 1] and makes GAE run with T=1 (incorrect unless the original rollout length was 1). Either require callers to pass an unflattened [N, T] / [T, N] rollout, or reshape using known (N, T) metadata (e.g., carry rollout_length/num_envs) before computing GAE.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 26 out of 26 changed files in this pull request and generated 15 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def dict_to_tensordict(obs_dict: dict, device: torch.device) -> TensorDict: | ||
| """Convert a nested observation dict into a TensorDict. | ||
|
|
||
| Args: | ||
| obs_dict: Nested observation dictionary returned by the environment. | ||
| device: Device to place tensors on. | ||
|
|
||
| Returns: | ||
| TensorDict with an outer ``"observation"`` key. | ||
| """ | ||
| Flatten hierarchical TensorDict observations from ObservationManager. | ||
|
|
||
| Recursively traverse nested TensorDicts, collect all tensor values, | ||
| flatten each to (num_envs, -1), and concatenate in sorted key order. | ||
| def _recursive_convert(data: dict) -> dict: | ||
| result = {} | ||
| for key, value in data.items(): | ||
| if isinstance(value, dict): | ||
| result[key] = _recursive_convert(value) | ||
| elif isinstance(value, torch.Tensor): | ||
| result[key] = value.to(device) | ||
| else: | ||
| result[key] = torch.tensor(value, device=device) | ||
| return result |
There was a problem hiding this comment.
dict_to_tensordict() assumes the environment returns a nested Python dict, but BaseEnv.get_obs()/reset() returns a TensorDict. Passing a TensorDict here will iterate items and then hit the else: torch.tensor(value) branch for nested TensorDict leaves, raising an error. Consider accepting obs as TensorDict | dict and, when it's already a TensorDict, just move it to the target device and wrap it under the outer "observation" key (and also handle nested TensorDict values during recursion).
| action = current_td["action"] | ||
| action_type = getattr(self.env, "action_type", "delta_qpos") | ||
| action_dict = {action_type: action} | ||
|
|
There was a problem hiding this comment.
AsyncCollector builds action_dict using getattr(self.env, "action_type", "delta_qpos"), but EmbodiedEnv typically routes actions through env.action_manager and expects the active term name as the dict key. If the key doesn't match, ActionManager.process_action() will raise. Prefer am = getattr(self.env, "action_manager", None) and use am.action_type when present.
| def load_vla_model( | ||
| model_path: str, | ||
| model_class: Optional[str] = None, | ||
| model_config: Optional[dict] = None, | ||
| device: torch.device = torch.device("cpu"), | ||
| ) -> nn.Module: | ||
| """Load VLA model from checkpoint. | ||
|
|
||
| This function should be implemented by the VLA team to load their | ||
| pretrained VLA model (vision encoder, language model, action head, etc.). | ||
|
|
||
| The returned module should have methods: | ||
| - forward(obs) -> (action, log_prob, value) | ||
| - get_value(obs) -> value | ||
| - evaluate_actions(obs, actions) -> (log_prob, entropy, value) | ||
|
|
||
| Args: | ||
| model_path: Path to checkpoint file | ||
| model_class: Fully qualified class name for VLA model | ||
| model_config: Configuration dict for model initialization | ||
| device: Device to load model on | ||
|
|
||
| Returns: | ||
| Initialized VLA model module | ||
|
|
||
| Example implementation by VLA team: | ||
| ```python | ||
| def load_vla_model(model_path, model_class, model_config, device): | ||
| import importlib | ||
|
|
||
| # Import VLA model class | ||
| module_name, class_name = model_class.rsplit(".", 1) | ||
| module = importlib.import_module(module_name) | ||
| ModelClass = getattr(module, class_name) | ||
|
|
||
| # Initialize model | ||
| model = ModelClass(**model_config) | ||
|
|
||
| # Load checkpoint | ||
| checkpoint = torch.load(model_path, map_location=device) | ||
| model.load_state_dict(checkpoint["model_state_dict"]) | ||
|
|
||
| model.to(device) | ||
| model.eval() | ||
|
|
||
| return model | ||
| ``` | ||
| """ | ||
| raise NotImplementedError( | ||
| "load_vla_model() must be implemented. " | ||
| f"Model path: {model_path}, class: {model_class}, config: {model_config}" | ||
| ) |
There was a problem hiding this comment.
load_vla_model() unconditionally raises NotImplementedError, but the repo now includes a VLA example config and documentation implying VLA training works out of the box. This will cause any run using policy.name="vla" to fail immediately. Either provide a default generic implementation (e.g., importlib-load model_class + torch.load state_dict as shown in the docstring), or clearly mark VLA support as requiring downstream customization and avoid registering/building the policy unless a loader is provided.
| ## 数据流动(TensorDict) | ||
|
|
||
| ``` | ||
| Environment ──▶ Collector ──▶ Algorithm ──▶ Policy | ||
| │ │ │ │ | ||
| │ TensorDict TensorDict Parameters | ||
| │ [T, N] [batch] Update | ||
| │ │ │ │ | ||
| └───────────────┴──────────────┴────────────┘ | ||
|
|
||
| TensorDict 结构: | ||
| { | ||
| "observation": Tensor or nested TensorDict, | ||
| "action": Tensor[T, N, action_dim], | ||
| "reward": Tensor[T, N, 1], | ||
| "done": Tensor[T, N, 1], | ||
| "value": Tensor[T, N, 1], | ||
| "sample_log_prob": Tensor[T, N, 1], | ||
| "advantage": Tensor[T, N, 1], # GAE计算后添加 | ||
| "return": Tensor[T, N, 1], # GAE计算后添加 | ||
| } |
There was a problem hiding this comment.
This architecture doc states rollouts are shaped [T, N] and uses a key named "return", but the collectors in this PR stack as [N, T] (batch-first) and GAE writes "value_target". Please update the documented TensorDict layout/keys to match the actual implementation to avoid misleading readers.
| class BaseAlgorithm: | ||
| """Base class for RL algorithms. | ||
| """Base class for RL algorithms following TorchRL conventions. | ||
|
|
||
| Algorithms must implement buffer initialization, rollout collection, and | ||
| policy update. Trainer depends only on this interface to remain | ||
| algorithm-agnostic. | ||
| Algorithms implement policy updates using TensorDict. | ||
| Data collection is handled separately by Collector classes (SyncCollector/AsyncCollector). | ||
| """ | ||
|
|
||
| device: torch.device | ||
|
|
||
| def initialize_buffer( | ||
| self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int | ||
| ) -> None: | ||
| """Initialize internal buffer(s) required by the algorithm.""" | ||
| raise NotImplementedError | ||
| def update(self, rollout: TensorDict) -> Dict[str, float]: | ||
| """Update policy using collected rollout data. | ||
|
|
||
| def collect_rollout( | ||
| self, | ||
| env, | ||
| policy, | ||
| obs: torch.Tensor, | ||
| num_steps: int, | ||
| on_step_callback: Callable | None = None, | ||
| ) -> Dict[str, Any]: | ||
| """Collect trajectories and return logging info (e.g., reward components).""" | ||
| raise NotImplementedError | ||
| Args: | ||
| rollout: TensorDict containing collected rollout data from Collector | ||
| Expected batch_size format: [T, N] for on-policy algorithms | ||
| where T is trajectory length and N is number of environments | ||
|
|
||
| def update(self) -> Dict[str, float]: | ||
| """Update policy using collected data and return training losses.""" | ||
| Returns: | ||
| Dictionary of training metrics (losses, learning stats, etc.) | ||
| """ | ||
| raise NotImplementedError |
There was a problem hiding this comment.
BaseAlgorithm now documents an update(self, rollout: TensorDict) interface, but the repository still registers GRPO (and GRPO implements the older initialize_buffer/collect_rollout/update() interface). Selecting algorithm.name="grpo" via build_algo will break at runtime once Trainer always calls algorithm.update(data) with an argument. Either update GRPO to the new interface or adjust the registry/trainer to support both interfaces.
| action = current_td["action"] | ||
| action_type = getattr(self.env, "action_type", "delta_qpos") | ||
| action_dict = {action_type: action} |
There was a problem hiding this comment.
Action key selection here ignores the environment's ActionManager. EmbodiedEnv exposes env.action_manager.action_type as the active term name; falling back to a hard-coded "delta_qpos" (or missing env.action_type) can produce an action dict with the wrong key and cause ActionManager.process_action() to raise (no matching key). Use am = getattr(self.env, "action_manager", None) and prefer am.action_type when available.
| @abstractmethod | ||
| def get_action( | ||
| self, obs: torch.Tensor, deterministic: bool = False | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """Sample an action from the policy. | ||
| def forward(self, tensordict: TensorDict) -> TensorDict: | ||
| """Forward pass that adds action to the input tensordict (in-place). | ||
|
|
||
| This is the main inference method following TorchRL conventions. | ||
|
|
||
| Args: | ||
| obs: Observation tensor of shape (batch_size, obs_dim) | ||
| deterministic: If True, return the mean action; otherwise sample | ||
| tensordict: Input TensorDict containing at minimum: | ||
| - "observation": Observation tensor or nested TensorDict | ||
|
|
||
| Returns: | ||
| Tuple of (action, log_prob, value): | ||
| - action: Sampled action tensor of shape (batch_size, action_dim) | ||
| - log_prob: Log probability of the action, shape (batch_size,) | ||
| - value: Value estimate, shape (batch_size,) | ||
| The same TensorDict (modified in-place) with added fields: | ||
| - "action": Sampled action tensor | ||
| - "sample_log_prob": Log probability of the sampled action | ||
| - "value": Value estimate (optional, for actor-critic) | ||
| - "loc": Distribution mean (optional, for continuous actions) | ||
| - "scale": Distribution std (optional, for continuous actions) | ||
| """ | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Policy.forward() is declared without a deterministic parameter, but callers in this PR pass deterministic=True during evaluation (Trainer._eval_once) and several concrete Policy implementations also define forward(..., deterministic: bool = False). As-is, a Policy implementation that follows the abstract signature will raise TypeError when used for evaluation. Update the abstract method signature (and docstring) to include deterministic: bool = False so the interface matches actual usage.
| """ | ||
| Buffer module for RL training. | ||
|
|
||
| __all__ = ["RolloutBuffer"] | ||
| Provides two buffer implementations: | ||
| - RolloutBuffer: Standard PPO buffer (single rollout, use and discard) | ||
| - VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference) | ||
| """ |
There was a problem hiding this comment.
The module docstring claims VLABuffer is a "FIFO multi-rollout accumulation" buffer, but the current VLABuffer implementation stores only a single rollout (self._rollout) and overwrites/clears it on get(). Either update the docstring to reflect the single-rollout behavior, or implement the FIFO accumulation described here.
| ### VLABuffer (Async) | ||
|
|
||
| Circular FIFO buffer: | ||
|
|
||
| ```python | ||
| from embodichain.agents.rl.buffer import VLABuffer | ||
|
|
||
| buffer = VLABuffer(buffer_size=4096, device=device) | ||
| buffer.add(transition) # Single step | ||
| data = buffer.get(flatten=True) # [buffer_size, ...] when full | ||
| ``` |
There was a problem hiding this comment.
This guide describes VLABuffer as a circular FIFO buffer that accepts single-step buffer.add(transition) calls, but the implementation in this PR requires full rollouts via add_rollout() and only stores a single rollout at a time. The example code here should be updated to match the actual API/behavior (or the buffer implementation should be adjusted to match the guide).
| ### Standard | ||
|
|
||
| ```python | ||
| collector = SyncCollector(env, policy, device, callback) | ||
| while step < total: | ||
| rollout = collector.collect(num_steps=2048) | ||
| buffer.add(rollout) | ||
| data = buffer.get(flatten=True) | ||
| losses = algorithm.update(data) | ||
| ``` | ||
|
|
||
| ### VLA | ||
|
|
||
| ```python | ||
| collector = AsyncCollector(env, policy, buffer, device, callback) | ||
| collector.start() | ||
| while step < total: | ||
| while not buffer.is_full(): | ||
| time.sleep(0.1) | ||
| data = buffer.get(flatten=True) | ||
| losses = algorithm.update(data) | ||
| collector.stop() |
There was a problem hiding this comment.
The workflow examples call buffer.get(flatten=True) and then pass the flattened TensorDict to algorithm.update(data). With the current PPO.update implementation, a 1D TensorDict is reshaped/unsqueezed in a way that loses the original [N, T] trajectory structure, so GAE will be incorrect. Update the guide to use flatten=False when retrieving rollouts for on-policy algorithms (or update PPO.update to reject/handle flattened inputs explicitly).
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 27 out of 27 changed files in this pull request and generated 14 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if ( | ||
| self.eval_freq > 0 | ||
| and self.eval_env is not None | ||
| and self.global_step % self.eval_freq == 0 | ||
| ): | ||
| self._eval_once(num_episodes=self.num_eval_episodes) |
There was a problem hiding this comment.
eval_freq is checked via self.global_step % self.eval_freq == 0, but global_step is incremented in large chunks (num_envs * num_steps per rollout). This makes evaluation effectively never run unless eval_freq happens to exactly divide those chunk sizes. Consider tracking next_eval_step (or last_eval_step) and triggering when global_step >= next_eval_step instead of relying on modulo equality.
| "eval_freq": 200, | ||
| "save_freq": 200, |
There was a problem hiding this comment.
With the current Trainer logic, eval_freq/save_freq are interpreted in environment steps (global_step increments by buffer_size * num_envs each rollout). Values like 200 will almost never satisfy global_step % eval_freq == 0 (e.g., 64 envs * 1024 steps = 65536 per rollout). Either adjust the config to use step-based values that align with global_step, or update Trainer to a threshold-based schedule.
| class Policy(nn.Module, ABC): | ||
| """Abstract base class that all RL policies must implement. | ||
|
|
||
| A Policy: | ||
| - Encapsulates neural networks that are trained by RL algorithms | ||
| - Handles internal computations (e.g., network output → distribution) | ||
| - Provides a uniform interface for algorithms (PPO, SAC, etc.) | ||
| - Uses TensorDict for all inputs and outputs (no tensor fallback) | ||
| """ |
There was a problem hiding this comment.
The docstring claims "no tensor fallback", but the implementations in this PR (e.g., ActorCritic/ActorOnly) accept both nested TensorDicts and plain tensors under tensordict["observation"]. Either enforce the TensorDict-only contract or adjust the base class docs to reflect the supported inputs.
| # Evaluation (pause collector during eval) | ||
| if ( | ||
| self.eval_freq > 0 | ||
| and self.eval_env is not None | ||
| and self.global_step % self.eval_freq == 0 | ||
| ): | ||
| collector.stop() | ||
| self._eval_once(num_episodes=self.num_eval_episodes) | ||
| collector.start() | ||
|
|
||
| # Checkpoint | ||
| if self.global_step % self.save_freq == 0: | ||
| self.save_checkpoint() |
There was a problem hiding this comment.
Same issue as the sync loop: global_step % eval_freq == 0 (and similarly for save_freq) is fragile when global_step increases by rollout-sized jumps. This can prevent evaluation/checkpointing from ever triggering in async mode. Prefer a threshold-based schedule (>= next_eval_step / >= next_save_step) or an explicit iteration counter.
| "algorithm": { | ||
| "name": "grpo", | ||
| "cfg": { | ||
| "learning_rate": 1e-5, | ||
| "n_epochs": 4, | ||
| "batch_size": 256, | ||
| "gamma": 0.99, | ||
| "clip_coef": 0.2, | ||
| "ent_coef": 0.001, | ||
| "kl_coef": 0.02, | ||
| "group_size": 4, | ||
| "eps": 1e-8, | ||
| "max_grad_norm": 1.0, | ||
| "truncate_at_first_done": true |
There was a problem hiding this comment.
This config selects "algorithm": {"name": "grpo"}, but GRPO in embodichain/agents/rl/algo/grpo.py still uses the old collect_rollout() / update() (no-args) interface and the old policy.get_action() API. With the new Trainer/Policy TensorDict flow, this config will fail at runtime unless GRPO is migrated to update(rollout: TensorDict) (and collection moved to Collectors) or the algorithm name is changed to one that supports the new interface.
| @abstractmethod | ||
| def get_action( | ||
| self, obs: torch.Tensor, deterministic: bool = False | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """Sample an action from the policy. | ||
| def forward(self, tensordict: TensorDict) -> TensorDict: | ||
| """Forward pass that adds action to the input tensordict (in-place). | ||
|
|
||
| This is the main inference method following TorchRL conventions. | ||
|
|
||
| Args: | ||
| obs: Observation tensor of shape (batch_size, obs_dim) | ||
| deterministic: If True, return the mean action; otherwise sample | ||
| tensordict: Input TensorDict containing at minimum: | ||
| - "observation": Observation tensor or nested TensorDict | ||
|
|
There was a problem hiding this comment.
The abstract Policy.forward() signature doesn't include the deterministic flag, but both Trainer._eval_once() and all concrete policies in this PR call/implement forward(..., deterministic=...). To avoid runtime/type-checking mismatches for future policies, consider adding deterministic: bool = False to the base class method signature (and updating the docstring accordingly).
| for module_name in import_modules: | ||
| importlib.import_module(module_name) |
There was a problem hiding this comment.
import_modules allows importing arbitrary module paths from a JSON config, which executes arbitrary Python code at runtime. If configs can come from untrusted sources (or are shared externally), this is a security risk; consider restricting to an allowlist/prefix (e.g., only dexechain.*/embodichain.*), or document clearly that configs must be trusted.
| ### VLABuffer (Async) | ||
|
|
||
| Circular FIFO buffer: | ||
|
|
||
| ```python | ||
| from embodichain.agents.rl.buffer import VLABuffer | ||
|
|
||
| buffer = VLABuffer(buffer_size=4096, device=device) | ||
| buffer.add(transition) # Single step | ||
| data = buffer.get(flatten=True) # [buffer_size, ...] when full | ||
| ``` | ||
|
|
||
| **Circular behavior**: `[T0,T1,T2,T3]` → add T4 → `[T4,T1,T2,T3]` (T0 overwritten) |
There was a problem hiding this comment.
The VLABuffer section describes a circular FIFO buffer with buffer.add(transition) and overwrite semantics, but the implemented VLABuffer only accepts full rollouts via add_rollout(...) and stores a single rollout (no FIFO, no per-transition add). Please update the guide to match the actual API/behavior, or adjust the implementation to match the documented FIFO semantics.
| Args: | ||
| rollout: TensorDict with batch_size=[T, N] from collect_rollout() | ||
| OR [size] from VLA buffer | ||
|
|
||
| # Dense logging is handled in Trainer.on_step via info; no aggregation here | ||
| # Call callback for statistics and logging | ||
| if on_step_callback is not None: | ||
| on_step_callback(current_obs, actions, reward, done, env_info, next_obs) | ||
| Returns: | ||
| Dictionary of training metrics | ||
| """ | ||
| # Ensure 2D format [T, N] for GAE computation | ||
| if len(rollout.batch_size) == 1: | ||
| rollout = rollout.unsqueeze(1) # [size] -> [size, 1] | ||
|
|
||
| current_obs = next_obs | ||
| # GAE layout: use config (default False = [N, T] batch-first) | ||
| time_first = self.cfg.rollout_time_first |
There was a problem hiding this comment.
The update() docstring/comments still reference rollouts shaped [T, N], but the collectors/buffers in this PR produce batch-first [N, T] by default (and rollout_time_first=False assumes that). Please update the docstring and the "Ensure 2D format" comment to reflect the actual expected layout, and clarify what 1D rollouts (flattened [N*T]) mean here.
| if self.model_type == "vla": | ||
| collector = AsyncCollector( | ||
| env=self.env, | ||
| policy=self.policy, | ||
| buffer=self.buffer, | ||
| device=self.device, | ||
| on_step_callback=self._create_step_callback(), | ||
| ) | ||
| self._train_async(collector, total_timesteps) | ||
| else: |
There was a problem hiding this comment.
There are existing tests for the standard training pipeline (tests/agents/test_rl.py), but the new async/VLA path (AsyncCollector + VLABuffer + model_type="vla") is untested. Adding at least a minimal smoke test that starts the async collector, waits for buffer.is_full(), runs a single PPO update, and shuts down cleanly would help catch threading/buffer regressions.
RL Training Framework Guide
TensorDict-based RL framework supporting standard PPO and asynchronous VLA training.
Quick Start
Configuration
{ "trainer": { "buffer_size": 2048, "model_type": "standard" // or "vla" }, "policy": {"name": "actor_critic"}, "algorithm": { "name": "ppo", "cfg": { "learning_rate": 3e-4, "gamma": 0.99, "n_epochs": 10, "batch_size": 64 } } }Run Training
Architecture
Components:
Training Modes
Standard Mode (Default)
For: Normal models (<100ms inference/step)
Config:
{"trainer": {"model_type": "standard"}}Pros: Simple, stable, low memory, no staleness
VLA Async Mode
For: Large models (>1 sec inference/step)
Config:
{"trainer": {"model_type": "vla"}}Pros: 2-3x speedup via parallel collection
Cons: Data staleness, higher memory
Collectors
SyncCollector
Collects complete rollout synchronously:
AsyncCollector
Runs in background thread:
Buffers
RolloutBuffer (Standard)
Single-use buffer:
VLABuffer (Async)
Circular FIFO buffer:
Circular behavior:
[T0,T1,T2,T3]→ add T4 →[T4,T1,T2,T3](T0 overwritten)VLA Integration
1. Implement Model
2. Implement Loading
Edit
embodichain/agents/rl/models/vla_policy.py:3. Configure
{ "trainer": {"model_type": "vla"}, "policy": { "name": "vla", "vla_config": { "model_path": "checkpoints/vla.pt", "model_class": "MyVLAModel", "model_config": {} } } }Common APIs
Trainer
Buffer Methods
Algorithm
FAQ
Q: When use VLA mode?
A: Inference >100ms/step AND GPU training fast
Q: Buffer size?
A: Standard: 2048-4096 (rollout size). VLA: 2048-4096 (buffer capacity)
Q: Data staleness impact?
A: Minor. PPO robust to staleness. 2-3x speedup >> small penalty
Q: Debug data flow?
A:
buffer.get_stats()or_print_tensordict_tree(rollout)in ppo.pyWorkflows
Standard
VLA
File Structure
References
configs/agents/rl/