Skip to content

async collect buffer for VLA RL#122

Open
yhnsu wants to merge 16 commits intomainfrom
yhn/rl_vla
Open

async collect buffer for VLA RL#122
yhnsu wants to merge 16 commits intomainfrom
yhn/rl_vla

Conversation

@yhnsu
Copy link
Collaborator

@yhnsu yhnsu commented Feb 6, 2026

RL Training Framework Guide

TensorDict-based RL framework supporting standard PPO and asynchronous VLA training.


Quick Start

Configuration

{
  "trainer": {
    "buffer_size": 2048,
    "model_type": "standard"  // or "vla"
  },
  "policy": {"name": "actor_critic"},
  "algorithm": {
    "name": "ppo",
    "cfg": {
      "learning_rate": 3e-4,
      "gamma": 0.99,
      "n_epochs": 10,
      "batch_size": 64
    }
  }
}

Run Training

python embodichain/agents/rl/train.py --config configs/agents/rl/my_config.json

Architecture

Trainer → Collector (sync/async) → Buffer (standard/vla) → Algorithm (PPO)

Components:

  • Collector: Gather data from environment (SyncCollector / AsyncCollector)
  • Buffer: Store transitions (RolloutBuffer / VLABuffer)
  • Algorithm: Update policy (PPO)
  • Trainer: Coordinate training loop

Training Modes

Standard Mode (Default)

For: Normal models (<100ms inference/step)

SyncCollector → Collect 2048 steps → Train → Clear buffer → Repeat

Config: {"trainer": {"model_type": "standard"}}

Pros: Simple, stable, low memory, no staleness

VLA Async Mode

For: Large models (>1 sec inference/step)

Background: AsyncCollector → Continuously collect → VLABuffer
Main:       Wait for buffer full → Train → Repeat

Config: {"trainer": {"model_type": "vla"}}

Pros: 2-3x speedup via parallel collection
Cons: Data staleness, higher memory


Collectors

SyncCollector

Collects complete rollout synchronously:

from embodichain.agents.rl.collector import SyncCollector

collector = SyncCollector(env, policy, device, callback)
rollout = collector.collect(num_steps=2048)  # [T, N, ...]

AsyncCollector

Runs in background thread:

from embodichain.agents.rl.collector import AsyncCollector

collector = AsyncCollector(env, policy, buffer, device, callback)
collector.start()   # Begin background collection
# ... buffer fills automatically ...
collector.stop()    # Stop collection

Buffers

RolloutBuffer (Standard)

Single-use buffer:

from embodichain.agents.rl.buffer import RolloutBuffer

buffer = RolloutBuffer(buffer_size=2048, device=device)
buffer.add(rollout)  # [T, N, ...]
data = buffer.get(flatten=True)  # [T*N, ...], auto-clears

VLABuffer (Async)

Circular FIFO buffer:

from embodichain.agents.rl.buffer import VLABuffer

buffer = VLABuffer(buffer_size=4096, device=device)
buffer.add(transition)  # Single step
data = buffer.get(flatten=True)  # [buffer_size, ...] when full

Circular behavior: [T0,T1,T2,T3] → add T4 → [T4,T1,T2,T3] (T0 overwritten)


VLA Integration

1. Implement Model

class MyVLAModel(nn.Module):
    def forward(self, obs: TensorDict) -> TensorDict:
        # Add 'action', 'sample_log_prob', 'value'
        ...
    def get_value(self, obs: TensorDict) -> TensorDict:
        # Add 'value'
        ...
    def evaluate_actions(self, obs: TensorDict) -> TensorDict:
        # Add 'sample_log_prob', 'entropy', 'value'
        ...

2. Implement Loading

Edit embodichain/agents/rl/models/vla_policy.py:

def load_vla_model(model_path, model_class, model_config, device):
    model = MyVLAModel(**model_config)
    model.load_state_dict(torch.load(model_path))
    return model.to(device)

3. Configure

{
  "trainer": {"model_type": "vla"},
  "policy": {
    "name": "vla",
    "vla_config": {
      "model_path": "checkpoints/vla.pt",
      "model_class": "MyVLAModel",
      "model_config": {}
    }
  }
}

Common APIs

Trainer

from embodichain.agents.rl.utils import Trainer

trainer = Trainer(
    policy, env, algorithm,
    buffer_size=2048,
    model_type="standard",  # or "vla"
    ...
)
trainer.train(total_timesteps=1000000)

Buffer Methods

buffer.add(data)            # Add data
data = buffer.get(flatten=True)  # Retrieve data
buffer.is_full()            # Check ready status
buffer.clear()              # Clear buffer
buffer.get_stats()          # Statistics

Algorithm

from embodichain.agents.rl.algo import PPO, PPOCfg

algorithm = PPO(PPOCfg(...), policy)
losses = algorithm.update(rollout)  # Returns loss dict

FAQ

Q: When use VLA mode?
A: Inference >100ms/step AND GPU training fast

Q: Buffer size?
A: Standard: 2048-4096 (rollout size). VLA: 2048-4096 (buffer capacity)

Q: Data staleness impact?
A: Minor. PPO robust to staleness. 2-3x speedup >> small penalty

Q: Debug data flow?
A: buffer.get_stats() or _print_tensordict_tree(rollout) in ppo.py


Workflows

Standard

collector = SyncCollector(env, policy, device, callback)
while step < total:
    rollout = collector.collect(num_steps=2048)
    buffer.add(rollout)
    data = buffer.get(flatten=True)
    losses = algorithm.update(data)

VLA

collector = AsyncCollector(env, policy, buffer, device, callback)
collector.start()
while step < total:
    while not buffer.is_full():
        time.sleep(0.1)
    data = buffer.get(flatten=True)
    losses = algorithm.update(data)
collector.stop()

File Structure

embodichain/agents/rl/
├── train.py              # Entry point
├── algo/ppo.py          # PPO algorithm
├── buffer/
│   ├── standard_buffer.py  # RolloutBuffer
│   └── vla_buffer.py       # VLABuffer
├── collector/
│   ├── base.py             # BaseCollector
│   ├── sync_collector.py   # SyncCollector
│   └── async_collector.py  # AsyncCollector
├── models/
│   ├── actor_critic.py     # Standard policy
│   └── vla_policy.py       # VLA wrapper
└── utils/trainer.py     # Training coordinator

References

Copilot AI review requested due to automatic review settings February 6, 2026 04:22
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request introduces a comprehensive refactoring of the RL training framework to use TensorDict-based data flow, replacing the previous tensor-based approach. The PR adds support for two training modes: standard synchronous PPO and asynchronous VLA training designed for scenarios with slow model inference.

Changes:

  • Migrated entire RL pipeline to TensorDict-based architecture for structured, extensible data flow
  • Introduced dual buffer system: RolloutBuffer (standard) and VLABuffer (async with FIFO)
  • Added AsyncCollector for background data collection in VLA mode with thread-based parallelism
  • Refactored Policy interface to use TensorDict inputs/outputs with in-place modifications
  • Updated PPO algorithm to work with TensorDict rollouts and removed dependency on gym spaces
  • Modified configuration to use buffer_size instead of rollout_steps and added action_dim requirement

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
embodichain/agents/rl/utils/trainer.py Refactored to support dual training modes (sync/async) with TensorDict
embodichain/agents/rl/utils/helper.py Added dict_to_tensordict, compute_gae, and logging utilities
embodichain/agents/rl/utils/async_collector.py New async data collector for VLA mode with background thread
embodichain/agents/rl/buffer/rollout_buffer.py Renamed/refactored to VLABuffer with circular indexing
embodichain/agents/rl/buffer/standard_buffer.py New RolloutBuffer for standard PPO mode
embodichain/agents/rl/buffer/init.py Updated exports for dual buffer system
embodichain/agents/rl/algo/ppo.py Refactored to use TensorDict data flow throughout
embodichain/agents/rl/algo/base.py Updated base algorithm interface for TensorDict
embodichain/agents/rl/models/policy.py Changed interface to TensorDict-based methods
embodichain/agents/rl/models/actor_critic.py Implemented TensorDict-based policy with in-place modifications
embodichain/agents/rl/models/init.py Removed gymnasium dependency, added action_dim parameter
embodichain/agents/rl/train.py Added action_dim requirement, removed gym space dependency
tests/agents/test_rl.py Updated test to use buffer_size parameter
configs/agents/rl/push_cube/train_config.json Updated config with buffer_size, action_dim, and eval_freq
configs/agents/rl/basic/cart_pole/train_config.json Updated config with buffer_size
docs/source/tutorial/rl.rst Updated documentation to reference buffer_size
pyproject.toml Added tensordict>=0.5.0 dependency
Comments suppressed due to low confidence (1)

embodichain/agents/rl/train.py:289

  • The buffer_type parameter is not read from the trainer config and not passed to the Trainer constructor (line 273-289). This means the VLA async mode introduced in this PR cannot be used, as it will always default to "standard" mode. Add buffer_type = trainer_cfg.get("buffer_type", "standard") before the Trainer initialization and pass it as buffer_type=buffer_type to the Trainer constructor.
    trainer = Trainer(
        policy=policy,
        env=env,
        algorithm=algo,
        buffer_size=buffer_size,
        batch_size=algo_cfg["batch_size"],
        writer=writer,
        eval_freq=eval_freq if enable_eval else 0,  # Disable eval if not enabled
        save_freq=save_freq,
        checkpoint_dir=checkpoint_dir,
        exp_name=exp_name,
        use_wandb=use_wandb,
        eval_env=eval_env,  # None if enable_eval=False
        event_cfg=train_event_cfg,
        eval_event_cfg=eval_event_cfg if enable_eval else {},
        num_eval_episodes=num_eval_episodes,
    )

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


# Update global step
num_envs = tensordict.batch_size[0]
self.global_step += num_envs
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The self.global_step variable is updated from the async collector thread (line 182 via callback) and potentially read from the main thread (lines 214, 244, 255). This creates a race condition. Consider using a thread-safe counter (e.g., threading.Lock protection or multiprocessing.Value) or tracking steps only in one thread.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings February 6, 2026 07:51
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +113 to +158
def get(self, flatten: bool = True) -> TensorDict:
"""Get valid data from buffer.

Args:
flatten: If True, return flattened [size, ...]. Currently only supports True.

Returns:
TensorDict with batch_size=[size, ...] containing valid data
"""
if not self._initialized or self.size == 0:
raise ValueError("Buffer is empty")

if not flatten:
raise NotImplementedError("Only flatten=True is supported for VLABuffer")

# Return first 'size' elements (valid data)
# Note: Data is in insertion order up to write_pos, then wraps
if self.size < self.buffer_size:
# Buffer not yet full, data is [0:size]
return self.buffer[: self.size]
else:
# Buffer full, need to rearrange to maintain temporal order
# Oldest data is at write_pos, newest at write_pos-1
indices = (
torch.arange(
self.write_pos,
self.write_pos + self.buffer_size,
device=self.device,
)
% self.buffer_size
)
return self.buffer[indices]

def clear(self) -> None:
"""Clear buffer (reset pointers, keep pre-allocated memory)."""
self.write_pos = 0
self.size = 0
# Keep buffer allocated for reuse

def __len__(self) -> int:
"""Return current number of valid transitions."""
return self.size

def is_full(self) -> bool:
"""Check if buffer is at full buffer_size."""
return self.size >= self.buffer_size
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VLABuffer.get() and is_full() methods are called from the main thread while AsyncCollector writes to the buffer from a background thread, but these methods lack thread safety. The read of self.size and self.write_pos could return inconsistent values if a write is in progress. Additionally, buffer.get() performs complex operations (checking size, slicing buffer) that should be atomic with respect to concurrent writes. Consider adding thread synchronization or document that external locking is required.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'v added lock for get/is_full for thread safety

if deterministic:
action = mean
else:
dist = Normal(mean, std)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distribution is created twice when deterministic=False. Line 130 creates dist = Normal(mean, std), then lines 136-137 create it again. This is wasteful. Consider refactoring to create the distribution once and use either dist.mean or dist.sample() based on the deterministic flag.

Suggested change
dist = Normal(mean, std)

Copilot uses AI. Check for mistakes.
Comment on lines +195 to +201
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The policy is accessed from both the background collector thread (lines 145-146, 200) and potentially from the main training thread during algorithm.update(). PyTorch tensors and models are not thread-safe by default. Concurrent access to the policy parameters during forward passes and gradient updates can lead to race conditions and corrupted gradients. Consider using locks to synchronize policy access, or ensure the policy is not being updated while the collector is running (e.g., by stopping collection during training).

Suggested change
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]
# Protect policy access with lock to avoid races with training thread
with self._lock:
next_value_td = TensorDict(
{"observation": next_obs_for_td},
batch_size=next_td.batch_size,
device=self.device,
)
self.policy.get_value(next_value_td)
next_td["value"] = next_value_td["value"]

Copilot uses AI. Check for mistakes.

losses = self.algorithm.update(data)
self._log_train(losses)

Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After buffer.get() is called on line 238, the VLABuffer is not cleared (unlike RolloutBuffer which auto-clears). Since the buffer is full (size == buffer_size), the is_full() check on line 232 will immediately return True in the next iteration, causing the training loop to repeatedly train on the same data without waiting for new transitions. The buffer should be cleared after get(), or the is_full() logic should be modified to track whether data has been consumed.

Suggested change
# Clear async buffer after consumption to avoid retraining on stale data
if hasattr(self.buffer, "clear"):
self.buffer.clear()

Copilot uses AI. Check for mistakes.
Comment on lines +121 to +122
# Prepare next iteration - use the converted TensorDict
current_td = next_obs_td
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The collector does not handle episode resets when done=True. After an episode terminates (done flag is set), the environment should be reset to get a fresh initial observation for the next episode. Currently, the collector continues using next_obs even after termination, which could contain stale data. Most RL environments auto-reset on episode end, but this should be made explicit or documented as a requirement.

Suggested change
# Prepare next iteration - use the converted TensorDict
current_td = next_obs_td
# Prepare next iteration:
# - if episode is done, reset env to get a fresh initial observation
# - otherwise, continue from next_obs_td
if done.any():
reset_result = self.env.reset()
# Support both Gym/Gymnasium-style (obs, info) and plain-obs resets
if isinstance(reset_result, tuple):
reset_obs = reset_result[0]
else:
reset_obs = reset_result
current_td = dict_to_tensordict(reset_obs, self.device)
else:
current_td = next_obs_td

Copilot uses AI. Check for mistakes.
Comment on lines +114 to +115
# Store complete transition
rollout_list.append(current_td.clone())
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .clone() on every transition creates a full copy of the TensorDict including all nested tensors, which can be memory-intensive for large rollouts. Since current_td is reassigned to next_obs_td on line 122 (which is a fresh TensorDict), the clone may be unnecessary. Consider whether a shallow copy or reference would suffice, or document why deep cloning is required here.

Suggested change
# Store complete transition
rollout_list.append(current_td.clone())
# Store complete transition (no clone needed: current_td is not mutated afterwards)
rollout_list.append(current_td)

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@yangchen73 yangchen73 Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Every loop does current_td["next"] = next_td and then current_td = next_obs_td. If we don't use clone(), every appended element is the same TensorDict reference. Then the next loop overwrites its contents. As a result, every entry in rollout_list points to the same modified data.

Comment on lines +240 to +242
# Update global step based on collected data (main thread only)
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
self.global_step += batch_size
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global_step update in async mode only counts batch_size from the returned data (line 242), not the actual number of environment steps taken. Since VLABuffer is continuously being written to by AsyncCollector (which tracks steps in _step_count), the global_step will not accurately reflect the total number of environment interactions. Consider synchronizing global_step with the collector's _step_count, or documenting this discrepancy.

Suggested change
# Update global step based on collected data (main thread only)
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
self.global_step += batch_size
# Update global step.
# Prefer the collector's step count (actual env interactions) if available,
# otherwise fall back to counting processed batch size.
batch_size = data.batch_size[0] if len(data.batch_size) > 0 else 0
steps_from_collector = getattr(collector, "_step_count", None)
if isinstance(steps_from_collector, int) and steps_from_collector > self.global_step:
self.global_step = steps_from_collector
else:
self.global_step += batch_size

Copilot uses AI. Check for mistakes.
Comment on lines 146 to 167
@@ -166,30 +154,113 @@ def on_step(obs, actions, reward, done, info, next_obs):
self.curr_ret[done_idx] = 0
self.curr_len[done_idx] = 0

# Update global step and observation
# next_obs is already flattened in algorithm's collect_rollout
self.obs = next_obs
self.global_step += next_obs.shape[0]

if isinstance(info, dict):
rewards_dict = info.get("rewards")
metrics_dict = info.get("metrics")
# Log environment metrics
if isinstance(env_info, dict):
rewards_dict = env_info.get("rewards")
metrics_dict = env_info.get("metrics")
self._log_scalar_dict("rewards", rewards_dict)
self._log_scalar_dict("metrics", metrics_dict)
log_dict = {}
log_dict.update(self._pack_log_dict("rewards", rewards_dict))
log_dict.update(self._pack_log_dict("metrics", metrics_dict))
log_dict.update(pack_log_dict("rewards", rewards_dict))
log_dict.update(pack_log_dict("metrics", metrics_dict))
if log_dict and self.use_wandb:
wandb.log(log_dict, step=self.global_step)
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The on_step callback modifies shared state (self.curr_ret, self.curr_len, self.ret_window, self.len_window, self.global_step) without thread synchronization. In async mode, this callback runs in the AsyncCollector background thread while the main thread could be accessing these same variables (e.g., in _log_train). This can cause race conditions and data corruption. Use threading.Lock to protect access to these shared variables, or ensure they're only accessed from one thread.

Copilot uses AI. Check for mistakes.
Comment on lines +58 to +60
def collect(self, **kwargs) -> TensorDict:
"""Collect data from environment.

Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overridden method signature does not match call, where it is passed too many arguments. Overriding method method SyncCollector.collect matches the call.
Overridden method signature does not match call, where it is passed an argument named 'num_steps'. Overriding method method SyncCollector.collect matches the call.

Suggested change
def collect(self, **kwargs) -> TensorDict:
"""Collect data from environment.
def collect(self, num_steps: int, **kwargs) -> TensorDict:
"""Collect data from environment.
Args:
num_steps: Number of steps to collect.

Copilot uses AI. Check for mistakes.
Comment on lines +38 to +46
def collect(self, num_steps: int) -> TensorDict:
"""Collect a synchronous rollout.

Args:
num_steps: Number of steps to collect

Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
Copy link

Copilot AI Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method requires 2 positional arguments, whereas overridden BaseCollector.collect requires 1.

Suggested change
def collect(self, num_steps: int) -> TensorDict:
"""Collect a synchronous rollout.
Args:
num_steps: Number of steps to collect
Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
def collect(self, num_steps: int | None = None) -> TensorDict:
"""Collect a synchronous rollout.
Args:
num_steps: Number of steps to collect.
Returns:
TensorDict with batch_size=[T, N] containing full rollout
"""
if num_steps is None:
raise TypeError("num_steps must be provided for SyncCollector.collect()")

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@yangchen73 yangchen73 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • need to use clone
  • add locks

Copilot AI review requested due to automatic review settings March 2, 2026 03:01
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +250 to +253
self.global_step += data.batch_size[0] if data.batch_size else 0

losses = self.algorithm.update(data)
self._log_train(losses)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In async mode, the collector thread calls policy.forward/get_value while the main thread simultaneously runs algorithm.update() (backprop + optimizer.step) on the same policy. PyTorch modules/optimizers are not thread-safe, and concurrent CUDA kernels/parameter updates can lead to nondeterminism or crashes. Consider pausing collection during update, protecting policy access with a lock, or using a separate inference copy of the policy for the collector (periodically synced).

Copilot uses AI. Check for mistakes.
Comment on lines +124 to +128
### VLABuffer (Async)

Circular FIFO buffer:

```python
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guide describes VLABuffer as a circular FIFO buffer with step-level add(transition) and shows buffer.get(flatten=True), but the implementation only accepts full rollouts via add_rollout() and PPO’s GAE requires the unflattened [N, T] layout. Please update this section to match the actual VLABuffer API and required shapes, otherwise users will hit AttributeErrors or compute incorrect advantages.

Copilot uses AI. Check for mistakes.
log_dict.update(pack_log_dict("rewards", rewards_dict))
log_dict.update(pack_log_dict("metrics", metrics_dict))
if log_dict and self.use_wandb:
wandb.log(log_dict, step=self.global_step)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on_step_callback logs to W&B using self.global_step, but global_step is only incremented once per rollout in _train_sync/_train_async. Because on_step_callback is invoked for every env step, these logs will repeatedly use the same step value (overwriting or producing a flat x-axis). Consider incrementing global_step inside the callback (e.g., by num_envs per env.step) or passing an explicit per-step counter into the callback for logging.

Suggested change
wandb.log(log_dict, step=self.global_step)
# Use a dedicated per-environment-step counter for W&B logging.
# Lazily initialize it so we don't depend on __init__ details.
env_log_step = getattr(self, "_env_log_step", 0)
# Increment by the number of parallel environments (reward batch size).
if isinstance(reward, torch.Tensor):
env_log_step += reward.shape[0]
else:
env_log_step += 1
self._env_log_step = env_log_step
wandb.log(log_dict, step=env_log_step)

Copilot uses AI. Check for mistakes.
Comment on lines 51 to +53
@abstractmethod
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample an action from the policy.
def forward(self, tensordict: TensorDict) -> TensorDict:
"""Forward pass that adds action to the input tensordict (in-place).
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Policy.forward() in the abstract base class does not accept a deterministic argument, but Trainer._eval_once calls self.policy.forward(..., deterministic=True) and concrete implementations (ActorCritic/VLAPolicy) already support it. This makes the interface inconsistent and can cause TypeErrors for any other Policy implementation that follows the base signature. Consider updating the abstract method signature to include deterministic: bool = False (and documenting the expected behavior).

Copilot uses AI. Check for mistakes.
Comment on lines +245 to +249
while step < total:
rollout = collector.collect(num_steps=2048)
buffer.add(rollout)
data = buffer.get(flatten=True)
losses = algorithm.update(data)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This workflow calls buffer.get(flatten=True) and passes the flattened result to algorithm.update(). With the current PPO implementation, flattened input is treated as a [size, 1] rollout, making GAE effectively run with T=1 and producing incorrect advantages/targets. Update the example to pass the unflattened [N, T] rollout into update (flatten only inside PPO for minibatching).

Copilot uses AI. Check for mistakes.
Comment on lines 10 to 14
"num_envs": 64,
"iterations": 1000,
"rollout_steps": 1024,
"buffer_size": 1024,
"eval_freq": 2,
"save_freq": 200,
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config was updated to use trainer.buffer_size, but train.py now requires policy.action_dim to be present. As-is, running this config will raise "Missing 'action_dim' in policy config". Consider adding an explicit action_dim to the policy block (and optionally trainer.model_type) so the example remains runnable.

Copilot uses AI. Check for mistakes.
Comment on lines +79 to +82
# Initialize observation and get num_envs (needed for VLA buffer)
obs, _ = env.reset()
self.obs_tensordict = dict_to_tensordict(obs, device)
num_envs = self.obs_tensordict.batch_size[0]
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trainer.init calls env.reset() to infer num_envs / seed obs, but both SyncCollector/AsyncCollector (via BaseCollector) also call env.reset(). This extra reset can be expensive and can change the initial state/episode accounting. Since the env exposes num_envs (used elsewhere in train.py), consider using env.num_envs (and/or env.device) instead of resetting here, and avoid storing an unused obs_tensordict.

Suggested change
# Initialize observation and get num_envs (needed for VLA buffer)
obs, _ = env.reset()
self.obs_tensordict = dict_to_tensordict(obs, device)
num_envs = self.obs_tensordict.batch_size[0]
# Initialize num_envs without forcing a reset when possible
if hasattr(env, "num_envs"):
num_envs = env.num_envs
# No need to create an initial obs_tensordict here; collectors will reset the env.
self.obs_tensordict = None
else:
# Fallback for environments that do not expose num_envs
obs, _ = env.reset()
self.obs_tensordict = dict_to_tensordict(obs, device)
num_envs = self.obs_tensordict.batch_size[0]

Copilot uses AI. Check for mistakes.
Comment on lines 17 to 21
"""Helper utilities for RL training.

This module provides utility functions for RL algorithms.
"""

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flatten_dict_observation() was removed from this module, but embodichain/lab/gym/envs/base_env.py still imports it to build flattened_observation_space. This will raise ImportError at runtime. Either restore a compatible flatten_dict_observation helper here (for backward compatibility) or update base_env.py in this PR to use the new TensorDict-based utilities.

Copilot uses AI. Check for mistakes.
__all__ = ["RolloutBuffer"]
Provides two buffer implementations:
- RolloutBuffer: Standard PPO buffer (single rollout, use and discard)
- VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference)
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer module docstring says VLABuffer is “FIFO multi-rollout accumulation”, but VLABuffer currently stores exactly one rollout (present/None). This mismatch is likely to confuse users; update the docstring to reflect the current behavior, or adjust VLABuffer to match the documented semantics.

Suggested change
- VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference)
- VLABuffer: VLA buffer (single-rollout accumulation optimized for slow inference)

Copilot uses AI. Check for mistakes.
Comment on lines +96 to 99
# Ensure 2D format [T, N] for GAE computation
if len(rollout.batch_size) == 1:
rollout = rollout.unsqueeze(1) # [size] -> [size, 1]

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PPO.update() claims to support receiving a flattened rollout (batch_size=[size]) by doing rollout.unsqueeze(1), but that turns it into [size, 1] and makes GAE run with T=1 (incorrect unless the original rollout length was 1). Either require callers to pass an unflattened [N, T] / [T, N] rollout, or reshape using known (N, T) metadata (e.g., carry rollout_length/num_envs) before computing GAE.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 9, 2026 14:02
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 15 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +26 to +46
def dict_to_tensordict(obs_dict: dict, device: torch.device) -> TensorDict:
"""Convert a nested observation dict into a TensorDict.

Args:
obs_dict: Nested observation dictionary returned by the environment.
device: Device to place tensors on.

Returns:
TensorDict with an outer ``"observation"`` key.
"""
Flatten hierarchical TensorDict observations from ObservationManager.

Recursively traverse nested TensorDicts, collect all tensor values,
flatten each to (num_envs, -1), and concatenate in sorted key order.
def _recursive_convert(data: dict) -> dict:
result = {}
for key, value in data.items():
if isinstance(value, dict):
result[key] = _recursive_convert(value)
elif isinstance(value, torch.Tensor):
result[key] = value.to(device)
else:
result[key] = torch.tensor(value, device=device)
return result
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict_to_tensordict() assumes the environment returns a nested Python dict, but BaseEnv.get_obs()/reset() returns a TensorDict. Passing a TensorDict here will iterate items and then hit the else: torch.tensor(value) branch for nested TensorDict leaves, raising an error. Consider accepting obs as TensorDict | dict and, when it's already a TensorDict, just move it to the target device and wrap it under the outer "observation" key (and also handle nested TensorDict values during recursion).

Copilot uses AI. Check for mistakes.
Comment on lines +149 to +152
action = current_td["action"]
action_type = getattr(self.env, "action_type", "delta_qpos")
action_dict = {action_type: action}

Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsyncCollector builds action_dict using getattr(self.env, "action_type", "delta_qpos"), but EmbodiedEnv typically routes actions through env.action_manager and expects the active term name as the dict key. If the key doesn't match, ActionManager.process_action() will raise. Prefer am = getattr(self.env, "action_manager", None) and use am.action_type when present.

Copilot uses AI. Check for mistakes.
Comment on lines +142 to +193
def load_vla_model(
model_path: str,
model_class: Optional[str] = None,
model_config: Optional[dict] = None,
device: torch.device = torch.device("cpu"),
) -> nn.Module:
"""Load VLA model from checkpoint.

This function should be implemented by the VLA team to load their
pretrained VLA model (vision encoder, language model, action head, etc.).

The returned module should have methods:
- forward(obs) -> (action, log_prob, value)
- get_value(obs) -> value
- evaluate_actions(obs, actions) -> (log_prob, entropy, value)

Args:
model_path: Path to checkpoint file
model_class: Fully qualified class name for VLA model
model_config: Configuration dict for model initialization
device: Device to load model on

Returns:
Initialized VLA model module

Example implementation by VLA team:
```python
def load_vla_model(model_path, model_class, model_config, device):
import importlib

# Import VLA model class
module_name, class_name = model_class.rsplit(".", 1)
module = importlib.import_module(module_name)
ModelClass = getattr(module, class_name)

# Initialize model
model = ModelClass(**model_config)

# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])

model.to(device)
model.eval()

return model
```
"""
raise NotImplementedError(
"load_vla_model() must be implemented. "
f"Model path: {model_path}, class: {model_class}, config: {model_config}"
)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_vla_model() unconditionally raises NotImplementedError, but the repo now includes a VLA example config and documentation implying VLA training works out of the box. This will cause any run using policy.name="vla" to fail immediately. Either provide a default generic implementation (e.g., importlib-load model_class + torch.load state_dict as shown in the docstring), or clearly mark VLA support as requiring downstream customization and avoid registering/building the policy unless a loader is provided.

Copilot uses AI. Check for mistakes.
Comment on lines +127 to +147
## 数据流动(TensorDict)

```
Environment ──▶ Collector ──▶ Algorithm ──▶ Policy
│ │ │ │
│ TensorDict TensorDict Parameters
│ [T, N] [batch] Update
│ │ │ │
└───────────────┴──────────────┴────────────┘

TensorDict 结构:
{
"observation": Tensor or nested TensorDict,
"action": Tensor[T, N, action_dim],
"reward": Tensor[T, N, 1],
"done": Tensor[T, N, 1],
"value": Tensor[T, N, 1],
"sample_log_prob": Tensor[T, N, 1],
"advantage": Tensor[T, N, 1], # GAE计算后添加
"return": Tensor[T, N, 1], # GAE计算后添加
}
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This architecture doc states rollouts are shaped [T, N] and uses a key named "return", but the collectors in this PR stack as [N, T] (batch-first) and GAE writes "value_target". Please update the documented TensorDict layout/keys to match the actual implementation to avoid misleading readers.

Copilot uses AI. Check for mistakes.
Comment on lines 24 to 44
class BaseAlgorithm:
"""Base class for RL algorithms.
"""Base class for RL algorithms following TorchRL conventions.

Algorithms must implement buffer initialization, rollout collection, and
policy update. Trainer depends only on this interface to remain
algorithm-agnostic.
Algorithms implement policy updates using TensorDict.
Data collection is handled separately by Collector classes (SyncCollector/AsyncCollector).
"""

device: torch.device

def initialize_buffer(
self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int
) -> None:
"""Initialize internal buffer(s) required by the algorithm."""
raise NotImplementedError
def update(self, rollout: TensorDict) -> Dict[str, float]:
"""Update policy using collected rollout data.

def collect_rollout(
self,
env,
policy,
obs: torch.Tensor,
num_steps: int,
on_step_callback: Callable | None = None,
) -> Dict[str, Any]:
"""Collect trajectories and return logging info (e.g., reward components)."""
raise NotImplementedError
Args:
rollout: TensorDict containing collected rollout data from Collector
Expected batch_size format: [T, N] for on-policy algorithms
where T is trajectory length and N is number of environments

def update(self) -> Dict[str, float]:
"""Update policy using collected data and return training losses."""
Returns:
Dictionary of training metrics (losses, learning stats, etc.)
"""
raise NotImplementedError
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseAlgorithm now documents an update(self, rollout: TensorDict) interface, but the repository still registers GRPO (and GRPO implements the older initialize_buffer/collect_rollout/update() interface). Selecting algorithm.name="grpo" via build_algo will break at runtime once Trainer always calls algorithm.update(data) with an argument. Either update GRPO to the new interface or adjust the registry/trainer to support both interfaces.

Copilot uses AI. Check for mistakes.
Comment on lines +58 to +60
action = current_td["action"]
action_type = getattr(self.env, "action_type", "delta_qpos")
action_dict = {action_type: action}
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Action key selection here ignores the environment's ActionManager. EmbodiedEnv exposes env.action_manager.action_type as the active term name; falling back to a hard-coded "delta_qpos" (or missing env.action_type) can produce an action dict with the wrong key and cause ActionManager.process_action() to raise (no matching key). Use am = getattr(self.env, "action_manager", None) and prefer am.action_type when available.

Copilot uses AI. Check for mistakes.
Comment on lines 51 to 69
@abstractmethod
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample an action from the policy.
def forward(self, tensordict: TensorDict) -> TensorDict:
"""Forward pass that adds action to the input tensordict (in-place).

This is the main inference method following TorchRL conventions.

Args:
obs: Observation tensor of shape (batch_size, obs_dim)
deterministic: If True, return the mean action; otherwise sample
tensordict: Input TensorDict containing at minimum:
- "observation": Observation tensor or nested TensorDict

Returns:
Tuple of (action, log_prob, value):
- action: Sampled action tensor of shape (batch_size, action_dim)
- log_prob: Log probability of the action, shape (batch_size,)
- value: Value estimate, shape (batch_size,)
The same TensorDict (modified in-place) with added fields:
- "action": Sampled action tensor
- "sample_log_prob": Log probability of the sampled action
- "value": Value estimate (optional, for actor-critic)
- "loc": Distribution mean (optional, for continuous actions)
- "scale": Distribution std (optional, for continuous actions)
"""
raise NotImplementedError
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Policy.forward() is declared without a deterministic parameter, but callers in this PR pass deterministic=True during evaluation (Trainer._eval_once) and several concrete Policy implementations also define forward(..., deterministic: bool = False). As-is, a Policy implementation that follows the abstract signature will raise TypeError when used for evaluation. Update the abstract method signature (and docstring) to include deterministic: bool = False so the interface matches actual usage.

Copilot uses AI. Check for mistakes.
Comment on lines +17 to +23
"""
Buffer module for RL training.

__all__ = ["RolloutBuffer"]
Provides two buffer implementations:
- RolloutBuffer: Standard PPO buffer (single rollout, use and discard)
- VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference)
"""
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring claims VLABuffer is a "FIFO multi-rollout accumulation" buffer, but the current VLABuffer implementation stores only a single rollout (self._rollout) and overwrites/clears it on get(). Either update the docstring to reflect the single-rollout behavior, or implement the FIFO accumulation described here.

Copilot uses AI. Check for mistakes.
Comment on lines +124 to +134
### VLABuffer (Async)

Circular FIFO buffer:

```python
from embodichain.agents.rl.buffer import VLABuffer

buffer = VLABuffer(buffer_size=4096, device=device)
buffer.add(transition) # Single step
data = buffer.get(flatten=True) # [buffer_size, ...] when full
```
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guide describes VLABuffer as a circular FIFO buffer that accepts single-step buffer.add(transition) calls, but the implementation in this PR requires full rollouts via add_rollout() and only stores a single rollout at a time. The example code here should be updated to match the actual API/behavior (or the buffer implementation should be adjusted to match the guide).

Copilot uses AI. Check for mistakes.
Comment on lines +241 to +262
### Standard

```python
collector = SyncCollector(env, policy, device, callback)
while step < total:
rollout = collector.collect(num_steps=2048)
buffer.add(rollout)
data = buffer.get(flatten=True)
losses = algorithm.update(data)
```

### VLA

```python
collector = AsyncCollector(env, policy, buffer, device, callback)
collector.start()
while step < total:
while not buffer.is_full():
time.sleep(0.1)
data = buffer.get(flatten=True)
losses = algorithm.update(data)
collector.stop()
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The workflow examples call buffer.get(flatten=True) and then pass the flattened TensorDict to algorithm.update(data). With the current PPO.update implementation, a 1D TensorDict is reshaped/unsqueezed in a way that loses the original [N, T] trajectory structure, so GAE will be incorrect. Update the guide to use flatten=False when retrieving rollouts for on-policy algorithms (or update PPO.update to reject/handle flattened inputs explicitly).

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings March 9, 2026 16:08
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 27 out of 27 changed files in this pull request and generated 14 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 220 to 225
if (
self.eval_freq > 0
and self.eval_env is not None
and self.global_step % self.eval_freq == 0
):
self._eval_once(num_episodes=self.num_eval_episodes)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval_freq is checked via self.global_step % self.eval_freq == 0, but global_step is incremented in large chunks (num_envs * num_steps per rollout). This makes evaluation effectively never run unless eval_freq happens to exactly divide those chunk sizes. Consider tracking next_eval_step (or last_eval_step) and triggering when global_step >= next_eval_step instead of relying on modulo equality.

Copilot uses AI. Check for mistakes.
Comment on lines +16 to 17
"eval_freq": 200,
"save_freq": 200,
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current Trainer logic, eval_freq/save_freq are interpreted in environment steps (global_step increments by buffer_size * num_envs each rollout). Values like 200 will almost never satisfy global_step % eval_freq == 0 (e.g., 64 envs * 1024 steps = 65536 per rollout). Either adjust the config to use step-based values that align with global_step, or update Trainer to a threshold-based schedule.

Copilot uses AI. Check for mistakes.
Comment on lines 35 to 43
class Policy(nn.Module, ABC):
"""Abstract base class that all RL policies must implement.

A Policy:
- Encapsulates neural networks that are trained by RL algorithms
- Handles internal computations (e.g., network output → distribution)
- Provides a uniform interface for algorithms (PPO, SAC, etc.)
- Uses TensorDict for all inputs and outputs (no tensor fallback)
"""
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring claims "no tensor fallback", but the implementations in this PR (e.g., ActorCritic/ActorOnly) accept both nested TensorDicts and plain tensors under tensordict["observation"]. Either enforce the TensorDict-only contract or adjust the base class docs to reflect the supported inputs.

Copilot uses AI. Check for mistakes.
Comment on lines +257 to +269
# Evaluation (pause collector during eval)
if (
self.eval_freq > 0
and self.eval_env is not None
and self.global_step % self.eval_freq == 0
):
collector.stop()
self._eval_once(num_episodes=self.num_eval_episodes)
collector.start()

# Checkpoint
if self.global_step % self.save_freq == 0:
self.save_checkpoint()
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as the sync loop: global_step % eval_freq == 0 (and similarly for save_freq) is fragile when global_step increases by rollout-sized jumps. This can prevent evaluation/checkpointing from ever triggering in async mode. Prefer a threshold-based schedule (>= next_eval_step / >= next_save_step) or an explicit iteration counter.

Copilot uses AI. Check for mistakes.
Comment on lines +45 to +58
"algorithm": {
"name": "grpo",
"cfg": {
"learning_rate": 1e-5,
"n_epochs": 4,
"batch_size": 256,
"gamma": 0.99,
"clip_coef": 0.2,
"ent_coef": 0.001,
"kl_coef": 0.02,
"group_size": 4,
"eps": 1e-8,
"max_grad_norm": 1.0,
"truncate_at_first_done": true
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config selects "algorithm": {"name": "grpo"}, but GRPO in embodichain/agents/rl/algo/grpo.py still uses the old collect_rollout() / update() (no-args) interface and the old policy.get_action() API. With the new Trainer/Policy TensorDict flow, this config will fail at runtime unless GRPO is migrated to update(rollout: TensorDict) (and collection moved to Collectors) or the algorithm name is changed to one that supports the new interface.

Copilot uses AI. Check for mistakes.
Comment on lines 51 to 60
@abstractmethod
def get_action(
self, obs: torch.Tensor, deterministic: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample an action from the policy.
def forward(self, tensordict: TensorDict) -> TensorDict:
"""Forward pass that adds action to the input tensordict (in-place).

This is the main inference method following TorchRL conventions.

Args:
obs: Observation tensor of shape (batch_size, obs_dim)
deterministic: If True, return the mean action; otherwise sample
tensordict: Input TensorDict containing at minimum:
- "observation": Observation tensor or nested TensorDict

Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The abstract Policy.forward() signature doesn't include the deterministic flag, but both Trainer._eval_once() and all concrete policies in this PR call/implement forward(..., deterministic=...). To avoid runtime/type-checking mismatches for future policies, consider adding deterministic: bool = False to the base class method signature (and updating the docstring accordingly).

Copilot uses AI. Check for mistakes.
Comment on lines +145 to +146
for module_name in import_modules:
importlib.import_module(module_name)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import_modules allows importing arbitrary module paths from a JSON config, which executes arbitrary Python code at runtime. If configs can come from untrusted sources (or are shared externally), this is a security risk; consider restricting to an allowlist/prefix (e.g., only dexechain.*/embodichain.*), or document clearly that configs must be trusted.

Copilot uses AI. Check for mistakes.
Comment on lines +124 to +136
### VLABuffer (Async)

Circular FIFO buffer:

```python
from embodichain.agents.rl.buffer import VLABuffer

buffer = VLABuffer(buffer_size=4096, device=device)
buffer.add(transition) # Single step
data = buffer.get(flatten=True) # [buffer_size, ...] when full
```

**Circular behavior**: `[T0,T1,T2,T3]` → add T4 → `[T4,T1,T2,T3]` (T0 overwritten)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VLABuffer section describes a circular FIFO buffer with buffer.add(transition) and overwrite semantics, but the implemented VLABuffer only accepts full rollouts via add_rollout(...) and stores a single rollout (no FIFO, no per-transition add). Please update the guide to match the actual API/behavior, or adjust the implementation to match the documented FIFO semantics.

Copilot uses AI. Check for mistakes.
Comment on lines +92 to +104
Args:
rollout: TensorDict with batch_size=[T, N] from collect_rollout()
OR [size] from VLA buffer

# Dense logging is handled in Trainer.on_step via info; no aggregation here
# Call callback for statistics and logging
if on_step_callback is not None:
on_step_callback(current_obs, actions, reward, done, env_info, next_obs)
Returns:
Dictionary of training metrics
"""
# Ensure 2D format [T, N] for GAE computation
if len(rollout.batch_size) == 1:
rollout = rollout.unsqueeze(1) # [size] -> [size, 1]

current_obs = next_obs
# GAE layout: use config (default False = [N, T] batch-first)
time_first = self.cfg.rollout_time_first
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update() docstring/comments still reference rollouts shaped [T, N], but the collectors/buffers in this PR produce batch-first [N, T] by default (and rollout_time_first=False assumes that). Please update the docstring and the "Ensure 2D format" comment to reflect the actual expected layout, and clarify what 1D rollouts (flattened [N*T]) mean here.

Copilot uses AI. Check for mistakes.
Comment on lines +182 to +191
if self.model_type == "vla":
collector = AsyncCollector(
env=self.env,
policy=self.policy,
buffer=self.buffer,
device=self.device,
on_step_callback=self._create_step_callback(),
)
self._train_async(collector, total_timesteps)
else:
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are existing tests for the standard training pipeline (tests/agents/test_rl.py), but the new async/VLA path (AsyncCollector + VLABuffer + model_type="vla") is untested. Adding at least a minimal smoke test that starts the async collector, waits for buffer.is_full(), runs a single PPO update, and shuts down cleanly would help catch threading/buffer regressions.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants