diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json index ae5026b2..b7639c06 100644 --- a/configs/agents/rl/push_cube/train_config.json +++ b/configs/agents/rl/push_cube/train_config.json @@ -9,14 +9,15 @@ "gpu_id": 0, "num_envs": 64, "iterations": 1000, - "rollout_steps": 1024, + "buffer_size": 1024, "enable_eval": true, "num_eval_envs": 16, "num_eval_episodes": 3, - "eval_freq": 2, + "eval_freq": 200, "save_freq": 200, "use_wandb": false, "wandb_project_name": "embodychain-push_cube", + "model_type": "standard", "events": { "eval": { "record_camera": { @@ -38,6 +39,7 @@ }, "policy": { "name": "actor_critic", + "action_dim": 8, "actor": { "type": "mlp", "network_cfg": { diff --git a/configs/agents/rl/stack_bowls/train_config_grpo.json b/configs/agents/rl/stack_bowls/train_config_grpo.json new file mode 100644 index 00000000..78e9121f --- /dev/null +++ b/configs/agents/rl/stack_bowls/train_config_grpo.json @@ -0,0 +1,61 @@ +{ + "trainer": { + "exp_name": "stack_bowls_vla_grpo", + "gym_config": "/root/workspace/research/embodichain/configs/gym/stack_bowls/gym_config.json", + "seed": 42, + "device": "cuda:0", + "headless": true, + "enable_rt": false, + "gpu_id": 0, + "num_envs": 8, + "iterations": 1000, + "buffer_size": 64, + "enable_eval": false, + "eval_freq": 0, + "save_freq": 100, + "use_wandb": false, + "wandb_project_name": "embodychain-stack_bowls", + "import_modules": [ + "dexechain.lab.gym.envs.tasks.tableware.stack_bowls_v1.stack_bowls" + ], + "model_type": "standard" + }, + "policy": { + "name": "vla", + "action_dim": 14, + "vla_config": { + "model_path": "/root/workspace/output/stack_bowls/checkpoint-19000", + "instruction": "Stack the bowls.", + "inference_horizon": 32, + "action_std_init": 0.01, + "robot_type": "CobotMagic", + "gripper_open_value": 0.05, + "gripper_closed_value": 0.0, + "action_key_order": [ + "left_armdelta_qpos", + "left_eefgripper", + "right_armdelta_qpos", + "right_eefgripper" + ], + "model_config": { + "torch_dtype": "float32" + } + } + }, + "algorithm": { + "name": "grpo", + "cfg": { + "learning_rate": 1e-5, + "n_epochs": 4, + "batch_size": 256, + "gamma": 0.99, + "clip_coef": 0.2, + "ent_coef": 0.001, + "kl_coef": 0.02, + "group_size": 4, + "eps": 1e-8, + "max_grad_norm": 1.0, + "truncate_at_first_done": true + } + } +} diff --git a/configs/agents/rl/vla_example/train_config.json b/configs/agents/rl/vla_example/train_config.json new file mode 100644 index 00000000..87583f38 --- /dev/null +++ b/configs/agents/rl/vla_example/train_config.json @@ -0,0 +1,70 @@ +{ + "trainer": { + "exp_name": "vla_fine_tuning_ppo", + "gym_config": "configs/agents/rl/push_cube/gym_config.json", + "seed": 42, + "device": "cuda:0", + "headless": true, + "enable_rt": false, + "gpu_id": 0, + "num_envs": 32, + "iterations": 500, + "buffer_size": 2048, + "enable_eval": true, + "num_eval_envs": 8, + "num_eval_episodes": 3, + "eval_freq": 100, + "save_freq": 100, + "use_wandb": true, + "wandb_project_name": "embodychain-vla-training", + "model_type": "vla", + "events": { + "eval": { + "record_camera": { + "func": "record_camera_data_async", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "main_cam", + "resolution": [640, 480], + "eye": [-1.4, 1.4, 2.0], + "target": [0, 0, 0], + "up": [0, 0, 1], + "intrinsics": [600, 600, 320, 240], + "save_path": "./outputs/videos/vla_eval" + } + } + } + } + }, + "policy": { + "name": "vla", + "action_dim": 7, + "vla_config": { + "model_path": "checkpoints/pretrained_vla_model.pth", + "model_class": "vla_models.GPTVLAModel", + "model_config": { + "vision_encoder": "resnet50", + "language_model": "gpt2-medium", + "action_head_hidden_size": 512, + "freeze_vision_encoder": false, + "freeze_language_model": false + } + } + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 1e-5, + "n_epochs": 4, + "batch_size": 2048, + "gamma": 0.99, + "gae_lambda": 0.95, + "rollout_time_first": false, + "clip_coef": 0.2, + "ent_coef": 0.001, + "vf_coef": 0.5, + "max_grad_norm": 1.0 + } + } +} diff --git a/docs/rl_training_guide.md b/docs/rl_training_guide.md new file mode 100644 index 00000000..3db3d072 --- /dev/null +++ b/docs/rl_training_guide.md @@ -0,0 +1,292 @@ +# RL Training Framework Guide + +TensorDict-based RL framework supporting standard PPO and asynchronous VLA training. + +--- + +## Quick Start + +### Configuration + +```json +{ + "trainer": { + "buffer_size": 2048, + "model_type": "standard" // or "vla" + }, + "policy": {"name": "actor_critic"}, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 3e-4, + "gamma": 0.99, + "n_epochs": 10, + "batch_size": 64 + } + } +} +``` + +### Run Training + +```bash +python embodichain/agents/rl/train.py --config configs/agents/rl/my_config.json +``` + +--- + +## Architecture + +``` +Trainer → Collector (sync/async) → Buffer (standard/vla) → Algorithm (PPO) +``` + +**Components**: +- **Collector**: Gather data from environment (SyncCollector / AsyncCollector) +- **Buffer**: Store transitions (RolloutBuffer / VLABuffer) +- **Algorithm**: Update policy (PPO) +- **Trainer**: Coordinate training loop + +--- + +## Training Modes + +### Standard Mode (Default) + +**For**: Normal models (<100ms inference/step) + +``` +SyncCollector → Collect 2048 steps → Train → Clear buffer → Repeat +``` + +**Config**: `{"trainer": {"model_type": "standard"}}` + +**Pros**: Simple, stable, low memory, no staleness + +### VLA Async Mode + +**For**: Large models (>1 sec inference/step) + +``` +Background: AsyncCollector → Continuously collect → VLABuffer +Main: Wait for buffer full → Train → Repeat +``` + +**Config**: `{"trainer": {"model_type": "vla"}}` + +**Pros**: 2-3x speedup via parallel collection +**Cons**: Data staleness, higher memory + +--- + +## Collectors + +### SyncCollector + +Collects complete rollout synchronously: + +```python +from embodichain.agents.rl.collector import SyncCollector + +collector = SyncCollector(env, policy, device, callback) +rollout = collector.collect(num_steps=2048) # [T, N, ...] +``` + +### AsyncCollector + +Runs in background thread: + +```python +from embodichain.agents.rl.collector import AsyncCollector + +collector = AsyncCollector(env, policy, buffer, device, callback) +collector.start() # Begin background collection +# ... buffer fills automatically ... +collector.stop() # Stop collection +``` + +--- + +## Buffers + +### RolloutBuffer (Standard) + +Single-use buffer: + +```python +from embodichain.agents.rl.buffer import RolloutBuffer + +buffer = RolloutBuffer(buffer_size=2048, device=device) +buffer.add(rollout) # [T, N, ...] +data = buffer.get(flatten=True) # [T*N, ...], auto-clears +``` + +### VLABuffer (Async) + +Circular FIFO buffer: + +```python +from embodichain.agents.rl.buffer import VLABuffer + +buffer = VLABuffer(buffer_size=4096, device=device) +buffer.add(transition) # Single step +data = buffer.get(flatten=True) # [buffer_size, ...] when full +``` + +**Circular behavior**: `[T0,T1,T2,T3]` → add T4 → `[T4,T1,T2,T3]` (T0 overwritten) + +--- + +## VLA Integration + +### 1. Implement Model + +```python +class MyVLAModel(nn.Module): + def forward(self, obs: TensorDict) -> TensorDict: + # Add 'action', 'sample_log_prob', 'value' + ... + def get_value(self, obs: TensorDict) -> TensorDict: + # Add 'value' + ... + def evaluate_actions(self, obs: TensorDict) -> TensorDict: + # Add 'sample_log_prob', 'entropy', 'value' + ... +``` + +### 2. Implement Loading + +Edit `embodichain/agents/rl/models/vla_policy.py`: + +```python +def load_vla_model(model_path, model_class, model_config, device): + model = MyVLAModel(**model_config) + model.load_state_dict(torch.load(model_path)) + return model.to(device) +``` + +### 3. Configure + +```json +{ + "trainer": {"model_type": "vla"}, + "policy": { + "name": "vla", + "vla_config": { + "model_path": "checkpoints/vla.pt", + "model_class": "MyVLAModel", + "model_config": {} + } + } +} +``` + +--- + +## Common APIs + +### Trainer + +```python +from embodichain.agents.rl.utils import Trainer + +trainer = Trainer( + policy, env, algorithm, + buffer_size=2048, + model_type="standard", # or "vla" + ... +) +trainer.train(total_timesteps=1000000) +``` + +### Buffer Methods + +```python +buffer.add(data) # Add data +data = buffer.get(flatten=True) # Retrieve data +buffer.is_full() # Check ready status +buffer.clear() # Clear buffer +buffer.get_stats() # Statistics +``` + +### Algorithm + +```python +from embodichain.agents.rl.algo import PPO, PPOCfg + +algorithm = PPO(PPOCfg(...), policy) +losses = algorithm.update(rollout) # Returns loss dict +``` + +--- + +## FAQ + +**Q: When use VLA mode?** +A: Inference >100ms/step AND GPU training fast + +**Q: Buffer size?** +A: Standard: 2048-4096 (rollout size). VLA: 2048-4096 (buffer capacity) + +**Q: Data staleness impact?** +A: Minor. PPO robust to staleness. 2-3x speedup >> small penalty + +**Q: Debug data flow?** +A: `buffer.get_stats()` or `_print_tensordict_tree(rollout)` in ppo.py + +--- + +## Workflows + +### Standard + +```python +collector = SyncCollector(env, policy, device, callback) +while step < total: + rollout = collector.collect(num_steps=2048) + buffer.add(rollout) + data = buffer.get(flatten=True) + losses = algorithm.update(data) +``` + +### VLA + +```python +collector = AsyncCollector(env, policy, buffer, device, callback) +collector.start() +while step < total: + while not buffer.is_full(): + time.sleep(0.1) + data = buffer.get(flatten=True) + losses = algorithm.update(data) +collector.stop() +``` + +--- + +## File Structure + +``` +embodichain/agents/rl/ +├── train.py # Entry point +├── algo/ppo.py # PPO algorithm +├── buffer/ +│ ├── standard_buffer.py # RolloutBuffer +│ └── vla_buffer.py # VLABuffer +├── collector/ +│ ├── base.py # BaseCollector +│ ├── sync_collector.py # SyncCollector +│ └── async_collector.py # AsyncCollector +├── models/ +│ ├── actor_critic.py # Standard policy +│ └── vla_policy.py # VLA wrapper +└── utils/trainer.py # Training coordinator +``` + +--- + +## References + +- [TensorDict Docs](https://pytorch.org/tensordict/) +- [PPO Paper](https://arxiv.org/abs/1707.06347) +- Example configs: `configs/agents/rl/` diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index b0126dde..24b979e3 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -64,7 +64,7 @@ The ``runtime`` section controls experiment setup: - **cuda**: Whether to use GPU (default: true) - **headless**: Whether to run simulation in headless mode - **iterations**: Number of training iterations -- **rollout_steps**: Steps per rollout (e.g., 1024) +- **buffer_size**: Steps per rollout (e.g., 1024) - **eval_freq**: Frequency of evaluation (in steps) - **save_freq**: Frequency of checkpoint saving (in steps) - **use_wandb**: Whether to enable Weights & Biases logging (set in JSON config) diff --git a/embodichain/agents/rl/ARCHITECTURE.md b/embodichain/agents/rl/ARCHITECTURE.md new file mode 100644 index 00000000..c83e2ff1 --- /dev/null +++ b/embodichain/agents/rl/ARCHITECTURE.md @@ -0,0 +1,216 @@ +# RL训练框架架构 + +## 总体流程 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Trainer │ +│ (训练总协调者) │ +│ │ +│ ┌─────────────────┐ ┌──────────────────┐ │ +│ │ 初始化阶段 │ │ 训练循环 │ │ +│ │ │ │ │ │ +│ │ 1. 创建Policy │───────▶│ while epoch: │ │ +│ │ 2. 创建Algo │ │ ├─ 收集数据 │ │ +│ │ 3. 创建Collector│ │ ├─ 更新策略 │ │ +│ │ 4. 创建Env │ │ └─ 评估性能 │ │ +│ └─────────────────┘ └──────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────┘ + │ + ┌──────────────┼──────────────┐ + │ │ │ + ▼ ▼ ▼ + ┌──────────┐ ┌──────────┐ ┌──────────┐ + │ Collector│ │Algorithm │ │ Policy │ + └──────────┘ └──────────┘ └──────────┘ +``` + +## 核心组件 + +### 1. Trainer(训练器) +**职责**:总协调者,串联所有组件 +``` +训练循环: + for epoch in range(n_epochs): + ├─ rollout = collector.collect(n_steps) # 收集数据 + ├─ metrics = algorithm.update(rollout) # 更新策略 + └─ eval_reward = evaluate(policy) # 评估性能 +``` + +### 2. Collector(数据收集器) +**职责**:与环境交互,收集经验数据 + +``` +┌─────────────────────────────────────────────┐ +│ Collector 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────────┐ ┌─────────────────┐ │ +│ │ SyncCollector │ │ AsyncCollector │ │ +│ │ (同步收集) │ │ (异步收集) │ │ +│ │ │ │ │ │ +│ │ 用于标准RL算法 │ │ 用于VLA模型 │ │ +│ │ - PPO │ │ - 后台持续收集 │ │ +│ │ - SAC │ │ - 独立线程 │ │ +│ └──────────────────┘ └─────────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +工作流程: + obs = env.reset() + for step in range(n_steps): + ├─ policy.forward(obs, deterministic=False) # 采样动作 + ├─ next_obs, reward, done = env.step(action) + └─ 存储到 TensorDict: (obs, action, reward, done, value) + return rollout_tensordict # [T, N] 格式 +``` + +### 3. Algorithm(算法) +**职责**:策略更新逻辑 + +``` +┌─────────────────────────────────────────────┐ +│ Algorithm 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ PPO │ │ SAC │ ... │ +│ │ │ │ │ │ +│ │ - GAE计算 │ │ - Q学习 │ │ +│ │ - Clip损失 │ │ - Soft更新 │ │ +│ │ - 价值损失 │ │ - 熵正则化 │ │ +│ └──────────────┘ └──────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +工作流程: + def update(rollout: TensorDict) -> dict: + ├─ 计算优势函数 (GAE) + ├─ 多轮优化循环 + │ ├─ policy.evaluate_actions(batch) # 重新计算log_prob + │ ├─ 计算loss (clip + value + entropy) + │ └─ optimizer.step() + └─ return metrics +``` + +### 4. Policy(策略) +**职责**:神经网络,输出动作和价值 + +``` +┌─────────────────────────────────────────────┐ +│ Policy 类型 │ +├─────────────────────────────────────────────┤ +│ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ ActorCritic │ │ VLAPolicy │ │ +│ │ │ │ │ │ +│ │ - MLP网络 │ │ - 视觉语言 │ │ +│ │ - 高斯策略 │ │ - 预训练模型 │ │ +│ └──────────────┘ └──────────────┘ │ +│ │ +└─────────────────────────────────────────────┘ + +接口方法: + 1. forward(obs, deterministic=False) + ├─ 训练时:采样动作 (deterministic=False) + ├─ 评估时:确定性动作 (deterministic=True) + └─ 返回:action, log_prob, value + + 2. evaluate_actions(obs, action) + └─ 重新计算给定动作的log_prob和entropy + + 3. get_value(obs) + └─ 仅返回价值估计 +``` + +## 数据流动(TensorDict) + +``` +Environment ──▶ Collector ──▶ Algorithm ──▶ Policy + │ │ │ │ + │ TensorDict TensorDict Parameters + │ [T, N] [batch] Update + │ │ │ │ + └───────────────┴──────────────┴────────────┘ + +TensorDict 结构: +{ + "observation": Tensor or nested TensorDict, + "action": Tensor[T, N, action_dim], + "reward": Tensor[T, N, 1], + "done": Tensor[T, N, 1], + "value": Tensor[T, N, 1], + "sample_log_prob": Tensor[T, N, 1], + "advantage": Tensor[T, N, 1], # GAE计算后添加 + "return": Tensor[T, N, 1], # GAE计算后添加 +} +``` + +## 完整训练流程示例 + +```python +# 1. 初始化组件 +trainer = Trainer( + env=env, + policy=ActorCritic(...), + algorithm=PPO(...), +) + +# 2. 创建Collector +collector = SyncCollector( + env=env, + policy=policy, + device=device, +) + +# 3. 训练循环 +for epoch in range(n_epochs): + + # 3.1 收集数据 + rollout = collector.collect( + n_steps=2048, + reset=True, + ) + # rollout: TensorDict[T=2048, N=num_envs] + + # 3.2 更新策略 + metrics = algorithm.update(rollout) + # metrics: {"loss": ..., "clip_frac": ..., ...} + + # 3.3 评估性能 + eval_reward = trainer.evaluate( + n_episodes=10, + deterministic=True, # 评估时使用确定性动作 + ) + + # 3.4 日志记录 + print(f"Epoch {epoch}: reward={eval_reward}, loss={metrics['loss']}") +``` + +## 关键设计原则 + +### 1. 职责分离 +- **Trainer**: 协调者,不涉及具体实现 +- **Collector**: 只负责数据收集,不做策略更新 +- **Algorithm**: 只负责策略更新,不做数据收集 +- **Policy**: 只负责网络前向,不涉及训练逻辑 + +### 2. 统一接口 +- 所有组件使用 **TensorDict** 进行数据传递 +- Policy暴露统一接口:`forward()`, `evaluate_actions()`, `get_value()` +- 易于切换不同实现(ActorCritic ↔ VLAPolicy) + +### 3. 灵活扩展 +- 添加新算法:继承 `BaseAlgorithm`,实现 `update()` +- 添加新策略:继承 `Policy`,实现三个抽象方法 +- 添加新收集器:继承 `BaseCollector`,实现 `collect()` + +### 4. 确定性评估 +```python +# 训练时(随机采样,探索) +policy.forward(obs, deterministic=False) # 使用 dist.sample() + +# 评估时(确定性,稳定) +policy.forward(obs, deterministic=True) # 使用 dist.mean +``` diff --git a/embodichain/agents/rl/algo/base.py b/embodichain/agents/rl/algo/base.py index fcb3fc00..914c1dba 100644 --- a/embodichain/agents/rl/algo/base.py +++ b/embodichain/agents/rl/algo/base.py @@ -18,35 +18,27 @@ from typing import Dict, Any, Callable import torch +from tensordict import TensorDict class BaseAlgorithm: - """Base class for RL algorithms. + """Base class for RL algorithms following TorchRL conventions. - Algorithms must implement buffer initialization, rollout collection, and - policy update. Trainer depends only on this interface to remain - algorithm-agnostic. + Algorithms implement policy updates using TensorDict. + Data collection is handled separately by Collector classes (SyncCollector/AsyncCollector). """ device: torch.device - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ) -> None: - """Initialize internal buffer(s) required by the algorithm.""" - raise NotImplementedError + def update(self, rollout: TensorDict) -> Dict[str, float]: + """Update policy using collected rollout data. - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect trajectories and return logging info (e.g., reward components).""" - raise NotImplementedError + Args: + rollout: TensorDict containing collected rollout data from Collector + Expected batch_size format: [T, N] for on-policy algorithms + where T is trajectory length and N is number of environments - def update(self) -> Dict[str, float]: - """Update policy using collected data and return training losses.""" + Returns: + Dictionary of training metrics (losses, learning stats, etc.) + """ raise NotImplementedError diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index b1256ce0..b2ef6349 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -14,17 +14,56 @@ # limitations under the License. # ---------------------------------------------------------------------------- -import torch -from typing import Dict, Any, Tuple, Callable +from __future__ import annotations +import torch from tensordict import TensorDict -from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation -from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.agents.rl.utils import AlgorithmCfg, compute_gae from embodichain.utils import configclass from .base import BaseAlgorithm +def _print_tensordict_tree(td, prefix="", is_last=True, name="TensorDict"): + """Recursively print TensorDict structure in tree format.""" + connector = "└── " if is_last else "├── " + + # Print current node + batch_info = ( + f"batch_size={list(td.batch_size)}" if hasattr(td, "batch_size") else "" + ) + device_info = f"device={td.device}" if hasattr(td, "device") else "" + meta_info = ", ".join(filter(None, [batch_info, device_info])) + print(f"{prefix}{connector}{name}: TensorDict ({meta_info})") + + # Prepare prefix for children + extension = " " if is_last else "│ " + new_prefix = prefix + extension + + # Get all keys + keys = sorted(td.keys()) if hasattr(td, "keys") else [] + + for i, key in enumerate(keys): + is_last_child = i == len(keys) - 1 + value = td[key] + + if isinstance(value, TensorDict): + # Recursively print nested TensorDict + _print_tensordict_tree(value, new_prefix, is_last_child, name=key) + elif isinstance(value, torch.Tensor): + # Print tensor info + child_connector = "└── " if is_last_child else "├── " + shape_str = "x".join(map(str, value.shape)) + dtype_str = str(value.dtype).replace("torch.", "") + print( + f"{new_prefix}{child_connector}{key}: Tensor([{shape_str}], {dtype_str})" + ) + else: + # Print other types + child_connector = "└── " if is_last_child else "├── " + print(f"{new_prefix}{child_connector}{key}: {type(value).__name__}") + + @configclass class PPOCfg(AlgorithmCfg): """Configuration for the PPO algorithm.""" @@ -33,132 +72,106 @@ class PPOCfg(AlgorithmCfg): clip_coef: float = 0.2 ent_coef: float = 0.01 vf_coef: float = 0.5 + rollout_time_first: bool = False class PPO(BaseAlgorithm): - """PPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design).""" + """PPO algorithm using TensorDict for all data flow. + Data collection is handled by Collector classes (SyncCollector/AsyncCollector). + """ def __init__(self, cfg: PPOCfg, policy): self.cfg = cfg self.policy = policy self.device = torch.device(cfg.device) self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) - self.buffer: RolloutBuffer | None = None - # no per-rollout aggregation for dense logging - - def _compute_gae( - self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Internal method to compute GAE. Only called by collect_rollout.""" - T, N = rewards.shape - advantages = torch.zeros_like(rewards, device=self.device) - last_adv = torch.zeros(N, device=self.device) - for t in reversed(range(T)): - next_value = values[t + 1] if t < T - 1 else torch.zeros_like(values[0]) - not_done = (~dones[t]).float() - delta = rewards[t] + self.cfg.gamma * next_value * not_done - values[t] - last_adv = ( - delta + self.cfg.gamma * self.cfg.gae_lambda * not_done * last_adv - ) - advantages[t] = last_adv - returns = advantages + values - return advantages, returns - - def initialize_buffer( - self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int - ): - """Initialize the rollout buffer. Called by trainer before first rollout.""" - self.buffer = RolloutBuffer( - num_steps, num_envs, obs_dim, action_dim, self.device - ) - - def collect_rollout( - self, - env, - policy, - obs: torch.Tensor, - num_steps: int, - on_step_callback: Callable | None = None, - ) -> Dict[str, Any]: - """Collect a rollout. Algorithm controls the data collection process.""" - if self.buffer is None: - raise RuntimeError( - "Buffer not initialized. Call initialize_buffer() first." - ) - - policy.train() - self.buffer.step = 0 - current_obs = obs - for t in range(num_steps): - # Get action from policy - actions, log_prob, value = policy.get_action( - current_obs, deterministic=False - ) - - # Wrap action as dict for env processing - am = getattr(env, "action_manager", None) - action_type = ( - am.action_type if am else getattr(env, "action_type", "delta_qpos") - ) - action_dict = {action_type: actions} - - # Step environment - result = env.step(action_dict) - next_obs, reward, terminated, truncated, env_info = result - done = terminated | truncated - # Light dtype normalization - reward = reward.float() - done = done.bool() - - # Flatten TensorDict observation from ObservationManager if needed - if isinstance(next_obs, TensorDict): - next_obs = flatten_dict_observation(next_obs) + def update(self, rollout: TensorDict) -> dict: + """Update the policy using collected rollout TensorDict (TorchRL style). - # Add to buffer - self.buffer.add(current_obs, actions, reward, done, value, log_prob) + Args: + rollout: TensorDict with batch_size=[T, N] from collect_rollout() + OR [size] from VLA buffer - # Dense logging is handled in Trainer.on_step via info; no aggregation here - # Call callback for statistics and logging - if on_step_callback is not None: - on_step_callback(current_obs, actions, reward, done, env_info, next_obs) + Returns: + Dictionary of training metrics + """ + # Ensure 2D format [T, N] for GAE computation + if len(rollout.batch_size) == 1: + rollout = rollout.unsqueeze(1) # [size] -> [size, 1] - current_obs = next_obs + # GAE layout: use config (default False = [N, T] batch-first) + time_first = self.cfg.rollout_time_first - # Compute advantages/returns and attach to buffer extras - adv, ret = self._compute_gae( - self.buffer.rewards, self.buffer.values, self.buffer.dones + rollout = compute_gae( + rollout, + gamma=self.cfg.gamma, + gae_lambda=self.cfg.gae_lambda, + time_first=time_first, ) - self.buffer.set_extras({"advantages": adv, "returns": ret}) - # No aggregated logging results; Trainer performs dense per-step logging - return {} + # Flatten to [T*N, ...] for training + flat_data = rollout.reshape(-1) + total_samples = flat_data.batch_size[0] - def update(self) -> dict: - """Update the policy using the collected rollout buffer.""" - if self.buffer is None: - raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") - - # Normalize advantages (optional, common default) - adv = self.buffer._extras.get("advantages") - adv = (adv - adv.mean()) / (adv.std() + 1e-8) + # Normalize advantages globally + advantages = flat_data["advantage"] + advantages_normalized = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8 + ) + flat_data["advantage"] = advantages_normalized total_actor_loss = 0.0 total_value_loss = 0.0 total_entropy = 0.0 total_steps = 0 - - for _ in range(self.cfg.n_epochs): - for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): - obs = batch["obs"] - actions = batch["actions"] - old_logprobs = batch["logprobs"] - returns = batch["returns"] - advantages = ( - (batch["advantages"] - adv.mean()) / (adv.std() + 1e-8) - ).detach() - - logprobs, entropy, values = self.policy.evaluate_actions(obs, actions) + total_clip_fraction = 0.0 + total_approx_kl = 0.0 + + for epoch in range(self.cfg.n_epochs): + # Shuffle data each epoch + indices = torch.randperm(total_samples, device=self.device) + shuffled_data = flat_data[indices] + + # Iterate over minibatches + num_minibatches = ( + total_samples + self.cfg.batch_size - 1 + ) // self.cfg.batch_size + for i in range(num_minibatches): + start_idx = i * self.cfg.batch_size + end_idx = min(start_idx + self.cfg.batch_size, total_samples) + batch_td = shuffled_data[start_idx:end_idx] + + # Extract data from TensorDict batch + old_logprobs = batch_td["sample_log_prob"] + returns = batch_td["value_target"] + advantages = batch_td[ + "advantage" + ] # Note: advantages are already normalized globally before shuffling + + # Evaluate actions with current policy + self.policy.evaluate_actions(batch_td) + + # Get updated values + logprobs = batch_td["sample_log_prob"] + entropy = batch_td["entropy"] + values = batch_td["value"] + + # Ensure shapes match (squeeze if needed) + if old_logprobs.dim() > 1: + old_logprobs = old_logprobs.squeeze(-1) + if logprobs.dim() > 1: + logprobs = logprobs.squeeze(-1) + if values.dim() > 1: + values = values.squeeze(-1) + if returns.dim() > 1: + returns = returns.squeeze(-1) + if advantages.dim() > 1: + advantages = advantages.squeeze(-1) + if entropy.dim() > 1: + entropy = entropy.squeeze(-1) + + # PPO loss computation ratio = (logprobs - old_logprobs).exp() surr1 = ratio * advantages surr2 = ( @@ -171,6 +184,13 @@ def update(self) -> dict: value_loss = torch.nn.functional.mse_loss(values, returns) entropy_loss = -entropy.mean() + # Diagnostics + with torch.no_grad(): + clip_fraction = ( + ((ratio - 1.0).abs() > self.cfg.clip_coef).float().mean() + ) + approx_kl = ((ratio - 1.0) - (logprobs - old_logprobs)).mean() + loss = ( actor_loss + self.cfg.vf_coef * value_loss @@ -184,14 +204,18 @@ def update(self) -> dict: ) self.optimizer.step() - bs = obs.shape[0] + bs = batch_td.batch_size[0] total_actor_loss += actor_loss.item() * bs total_value_loss += value_loss.item() * bs total_entropy += (-entropy_loss.item()) * bs + total_clip_fraction += clip_fraction.item() * bs + total_approx_kl += approx_kl.item() * bs total_steps += bs return { "actor_loss": total_actor_loss / max(1, total_steps), "value_loss": total_value_loss / max(1, total_steps), "entropy": total_entropy / max(1, total_steps), + "clip_fraction": total_clip_fraction / max(1, total_steps), + "approx_kl": total_approx_kl / max(1, total_steps), } diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py index 5080d251..bceb1fde 100644 --- a/embodichain/agents/rl/buffer/__init__.py +++ b/embodichain/agents/rl/buffer/__init__.py @@ -14,6 +14,15 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .rollout_buffer import RolloutBuffer +""" +Buffer module for RL training. -__all__ = ["RolloutBuffer"] +Provides two buffer implementations: +- RolloutBuffer: Standard PPO buffer (single rollout, use and discard) +- VLABuffer: VLA buffer (FIFO multi-rollout accumulation for slow inference) +""" + +from .vla_buffer import VLABuffer +from .standard_buffer import RolloutBuffer + +__all__ = ["RolloutBuffer", "VLABuffer"] diff --git a/embodichain/agents/rl/buffer/rollout_buffer.py b/embodichain/agents/rl/buffer/rollout_buffer.py deleted file mode 100644 index cbd66f4e..00000000 --- a/embodichain/agents/rl/buffer/rollout_buffer.py +++ /dev/null @@ -1,106 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -from __future__ import annotations - -from typing import Dict, Iterator - -import torch - - -class RolloutBuffer: - """On-device rollout buffer for on-policy algorithms. - - Stores (obs, actions, rewards, dones, values, logprobs) over time. - After finalize(), exposes advantages/returns and minibatch iteration. - """ - - def __init__( - self, - num_steps: int, - num_envs: int, - obs_dim: int, - action_dim: int, - device: torch.device, - ): - self.num_steps = num_steps - self.num_envs = num_envs - self.obs_dim = obs_dim - self.action_dim = action_dim - self.device = device - - T, N = num_steps, num_envs - self.obs = torch.zeros(T, N, obs_dim, dtype=torch.float32, device=device) - self.actions = torch.zeros(T, N, action_dim, dtype=torch.float32, device=device) - self.rewards = torch.zeros(T, N, dtype=torch.float32, device=device) - self.dones = torch.zeros(T, N, dtype=torch.bool, device=device) - self.values = torch.zeros(T, N, dtype=torch.float32, device=device) - self.logprobs = torch.zeros(T, N, dtype=torch.float32, device=device) - - self.step = 0 - # Container for algorithm-specific extra fields (e.g., advantages, returns) - self._extras: dict[str, torch.Tensor] = {} - - def add( - self, - obs: torch.Tensor, - action: torch.Tensor, - reward: torch.Tensor, - done: torch.Tensor, - value: torch.Tensor, - logprob: torch.Tensor, - ) -> None: - t = self.step - self.obs[t].copy_(obs) - self.actions[t].copy_(action) - self.rewards[t].copy_(reward) - self.dones[t].copy_(done) - self.values[t].copy_(value) - self.logprobs[t].copy_(logprob) - self.step += 1 - - def set_extras(self, extras: dict[str, torch.Tensor]) -> None: - """Attach algorithm-specific tensors (shape [T, N, ...]) for batching. - - Examples: - {"advantages": adv, "returns": ret} - """ - self._extras = extras or {} - - def iterate_minibatches(self, batch_size: int) -> Iterator[Dict[str, torch.Tensor]]: - T, N = self.num_steps, self.num_envs - total = T * N - indices = torch.randperm(total, device=self.device) - for start in range(0, total, batch_size): - idx = indices[start : start + batch_size] - t_idx = idx // N - n_idx = idx % N - batch = { - "obs": self.obs[t_idx, n_idx], - "actions": self.actions[t_idx, n_idx], - "rewards": self.rewards[t_idx, n_idx], - "dones": self.dones[t_idx, n_idx], - "values": self.values[t_idx, n_idx], - "logprobs": self.logprobs[t_idx, n_idx], - } - # Slice extras if present and shape aligned to [T, N, ...] - for name, tensor in self._extras.items(): - try: - batch[name] = tensor[t_idx, n_idx] - except Exception: - # Skip misaligned extras silently - continue - yield batch diff --git a/embodichain/agents/rl/buffer/standard_buffer.py b/embodichain/agents/rl/buffer/standard_buffer.py new file mode 100644 index 00000000..a838e1b2 --- /dev/null +++ b/embodichain/agents/rl/buffer/standard_buffer.py @@ -0,0 +1,117 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict +from typing import Optional + + +class RolloutBuffer: + """Standard on-policy rollout buffer for PPO (matches mainstream implementations). + + Unlike VLA buffer which accumulates multiple rollouts with FIFO eviction, + this buffer follows standard PPO pattern: + - Stores exactly ONE rollout at a time + - After training, buffer is cleared (on-policy: use once and discard) + - Simple and efficient for normal-sized models + + Interface compatible with VLABuffer for easy switching. + """ + + def __init__(self, buffer_size: int, device: torch.device): + """Initialize standard rollout buffer. + + Args: + buffer_size: Buffer size from config (for interface compatibility with VLABuffer) + device: Device to store tensors on + """ + self.buffer_size = buffer_size + self.device = device + self._rollout: Optional[TensorDict] = None + + def add(self, rollout: TensorDict) -> None: + """Add a rollout to buffer, replacing any existing rollout. + + Args: + rollout: TensorDict with batch_size=[T, N, ...] + """ + # Standard PPO: replace existing rollout (not accumulate) + self._rollout = rollout.to(self.device) + + def get(self, flatten: bool = True) -> TensorDict: + """Get rollout from buffer and clear it (standard PPO behavior). + + Args: + flatten: If True, flatten to [N*T, ...]. + If False, return as [N, T, ...] (batch-first). + + Returns: + TensorDict with rollout data + """ + if self._rollout is None: + raise ValueError("Buffer is empty") + + rollout = self._rollout + + # Clear after retrieval (on-policy: use once) + self._rollout = None + + if flatten: + # Flatten [N, T, ...] -> [N*T, ...] + return rollout.reshape(-1) + else: + return rollout + + def clear(self) -> None: + """Clear buffer.""" + self._rollout = None + + def is_full(self) -> bool: + """Check if buffer has a rollout ready for training. + + Returns: + True if buffer contains a rollout + """ + return self._rollout is not None + + def __len__(self) -> int: + """Return 1 if buffer has data, 0 otherwise.""" + return 1 if self._rollout is not None else 0 + + def get_num_rollouts(self) -> int: + """Return current number of rollouts in buffer (0 or 1).""" + return 1 if self._rollout is not None else 0 + + def get_num_transitions(self) -> int: + """Return total number of transitions stored.""" + if self._rollout is None: + return 0 + return self._rollout.batch_size[0] * self._rollout.batch_size[1] + + def get_stats(self) -> dict: + """Get buffer statistics for logging. + + Returns: + Dict with buffer stats + """ + return { + "buffer_size": 1 if self._rollout is not None else 0, + "buffer_capacity": self.buffer_size, + "total_transitions": self.get_num_transitions(), + "buffer_usage": 1.0 if self._rollout is not None else 0.0, + } diff --git a/embodichain/agents/rl/buffer/vla_buffer.py b/embodichain/agents/rl/buffer/vla_buffer.py new file mode 100644 index 00000000..4d254ba0 --- /dev/null +++ b/embodichain/agents/rl/buffer/vla_buffer.py @@ -0,0 +1,144 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import threading +import torch +from tensordict import TensorDict +from typing import Optional + + +class VLABuffer: + """Rollout buffer for VLA RL with (B, T) batch-first layout. + + Stores complete rollouts to ensure correct GAE computation (GAE requires + sequential timesteps within the same trajectory). Async collector accumulates + T steps per env, then adds the full rollout. + + Key characteristics: + - Rollout-level storage: Collect full rollout [T, N] before adding + - Batch-first layout: Stores and returns [N, T, ...] for VLA training + - Thread-safe: Async collector writes, main thread reads + - Single rollout: When full, one rollout ready for training + + Storage layout: [N, T, ...] - batch (env) first, time second. + """ + + def __init__( + self, + buffer_size: int, + device: torch.device, + num_envs: int, + ): + """Initialize VLA buffer. + + Args: + buffer_size: Total transitions per rollout (T * N) + device: Device for tensors + num_envs: Number of parallel environments (N) + """ + self.buffer_size = buffer_size + self.device = device + self.num_envs = num_envs + self.rollout_length = buffer_size // num_envs # T + if self.rollout_length * num_envs != buffer_size: + raise ValueError( + f"buffer_size ({buffer_size}) must be divisible by num_envs ({num_envs})" + ) + + self._rollout: Optional[TensorDict] = None # [N, T, ...] + self._lock = threading.Lock() + + def add_rollout(self, rollout: TensorDict) -> None: + """Add a complete rollout. Fixed layout: [N, T] (batch-first). + + GAE requires same-trajectory timesteps; we only accept full rollouts. + + Args: + rollout: TensorDict with batch_size=[N, T, ...] + """ + with self._lock: + if ( + rollout.batch_size[0] != self.num_envs + or rollout.batch_size[1] != self.rollout_length + ): + raise ValueError( + f"Rollout shape {rollout.batch_size} does not match " + f"expected (N={self.num_envs}, T={self.rollout_length})" + ) + self._rollout = rollout.to(self.device) + + def add_batch(self, transitions: TensorDict) -> None: + """Deprecated: Use add_rollout. Batch must be a complete rollout [N, T].""" + if len(transitions.batch_size) >= 2: + self.add_rollout(transitions) + else: + raise NotImplementedError( + "VLABuffer requires full rollout. Use add_rollout(rollout) with [N, T]." + ) + + def get(self, flatten: bool = True) -> TensorDict: + """Get rollout from buffer (thread-safe). + + Args: + flatten: If True, flatten to [N*T, ...] for minibatch sampling. + + Returns: + TensorDict with batch_size=[N, T] or [N*T] when flatten=True + """ + with self._lock: + if self._rollout is None: + raise ValueError("Buffer is empty") + + rollout = self._rollout + self._rollout = None + + if flatten: + return rollout.reshape(-1) + return rollout + + def clear(self) -> None: + """Clear buffer.""" + with self._lock: + self._rollout = None + + def __len__(self) -> int: + """Return 1 if has rollout, 0 otherwise.""" + with self._lock: + return 1 if self._rollout is not None else 0 + + def is_full(self) -> bool: + """True when one complete rollout is ready.""" + with self._lock: + return self._rollout is not None + + def get_num_rollouts(self) -> int: + """Return 1 if has rollout, 0 otherwise.""" + with self._lock: + return 1 if self._rollout is not None else 0 + + def get_stats(self) -> dict: + """Get buffer statistics.""" + with self._lock: + has_data = self._rollout is not None + return { + "buffer_size": self.rollout_length * self.num_envs, + "rollout_length": self.rollout_length, + "num_envs": self.num_envs, + "layout": "batch_first", + "has_rollout": has_data, + } diff --git a/embodichain/agents/rl/collector/__init__.py b/embodichain/agents/rl/collector/__init__.py new file mode 100644 index 00000000..eede4937 --- /dev/null +++ b/embodichain/agents/rl/collector/__init__.py @@ -0,0 +1,26 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .base import BaseCollector +from .sync_collector import SyncCollector +from .async_collector import AsyncCollector, AsyncCollectorStats + +__all__ = [ + "BaseCollector", + "SyncCollector", + "AsyncCollector", + "AsyncCollectorStats", +] diff --git a/embodichain/agents/rl/collector/async_collector.py b/embodichain/agents/rl/collector/async_collector.py new file mode 100644 index 00000000..465accb9 --- /dev/null +++ b/embodichain/agents/rl/collector/async_collector.py @@ -0,0 +1,291 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import threading +from typing import Callable, Optional +import torch +from tensordict import TensorDict +from collections import deque + +from ..utils.helper import dict_to_tensordict +from .base import BaseCollector + + +class AsyncCollector(BaseCollector): + """Asynchronous data collector for VLA RL scenarios. + + Runs in a background thread to continuously collect transitions while + the main thread performs model updates. Designed for scenarios where + model inference is slow (e.g., VLA) but training is fast. + + Key features: + - Background thread: Continuous data collection + - Thread-safe buffer: Lock-protected writes + - Step-level collection: Individual transitions added to buffer + - Episode statistics tracking: Rewards and lengths + + Usage: + collector = AsyncCollector(env, policy, buffer, device, ...) + collector.start() # Begin background collection + # ... main thread does training ... + collector.stop() # Stop collection + """ + + def __init__( + self, + env, + policy, + buffer, + device: torch.device, + on_step_callback: Optional[Callable] = None, + ): + """Initialize async collector. + + Args: + env: Environment to collect from + policy: Policy for action selection + buffer: VLABuffer instance (shared with Trainer) + device: Device for tensor operations + on_step_callback: Optional callback(transition, env_info) called after each step + """ + super().__init__(env, policy, device, on_step_callback) + self.buffer = buffer + + # Thread control + self._running = False + self._thread: Optional[threading.Thread] = None + self._lock = threading.Lock() + + # Episode statistics + self._episode_count = 0 + self._step_count = 0 + + def start(self): + """Start background collection thread.""" + if self._running: + raise RuntimeError("Collector is already running") + + self._running = True + self._thread = threading.Thread(target=self._collect_loop, daemon=True) + self._thread.start() + print("[AsyncCollector] Background collection started") + + def collect(self, **kwargs) -> TensorDict: + """For AsyncCollector, data is collected continuously in background. + + This method is just for interface compatibility with BaseCollector. + Actual data retrieval happens through buffer.get(). + + Returns: + Empty TensorDict (not used in async mode) + """ + raise NotImplementedError( + "AsyncCollector collects data in background thread. " + "Use buffer.get() to retrieve data instead." + ) + + def stop(self): + """Stop background collection thread.""" + if not self._running: + return + + self._running = False + if self._thread is not None: + self._thread.join(timeout=5.0) + if self._thread.is_alive(): + print("[AsyncCollector] Warning: Thread did not stop cleanly") + + print( + f"[AsyncCollector] Stopped (collected {self._step_count} steps, {self._episode_count} episodes)" + ) + + def is_running(self) -> bool: + """Check if collector is currently running.""" + return self._running + + def get_stats(self) -> dict: + """Get collection statistics.""" + with self._lock: + return { + "steps_collected": self._step_count, + "episodes_collected": self._episode_count, + } + + def _collect_loop(self): + """Background thread main loop: collect full rollout, then add to buffer. + + GAE requires sequential timesteps within the same trajectory. We accumulate + T steps (one rollout) locally, then add the complete rollout to buffer. + This ensures correct per-env trajectory ordering for GAE computation. + """ + rollout_length = self.buffer.rollout_length + current_td = self.obs_tensordict + + while self._running: + try: + rollout_list = [] + + for t in range(rollout_length): + # Policy forward (no_grad for inference) + with torch.no_grad(): + self.policy.train() + self.policy.forward(current_td) + + action = ( + current_td["env_action"] + if "env_action" in current_td.keys() + else current_td["action"] + ) + env_action = self._format_env_action(action) + + next_obs_dict, reward, terminated, truncated, env_info = ( + self.env.step(env_action) + ) + + next_obs_td = dict_to_tensordict(next_obs_dict, self.device) + done = terminated | truncated + next_obs_for_td = next_obs_td["observation"] + if hasattr(self.policy, "reset_envs"): + self.policy.reset_envs(done, next_obs_for_td) + batch_size = next_obs_td.batch_size[0] + + next_td = TensorDict( + { + "observation": next_obs_for_td, + "reward": ( + reward.float().unsqueeze(-1) + if reward.dim() == 1 + else reward.float() + ), + "done": ( + done.bool().unsqueeze(-1) + if done.dim() == 1 + else done.bool() + ), + "terminated": ( + terminated.bool().unsqueeze(-1) + if terminated.dim() == 1 + else terminated.bool() + ), + "truncated": ( + truncated.bool().unsqueeze(-1) + if truncated.dim() == 1 + else truncated.bool() + ), + }, + batch_size=torch.Size([batch_size]), + device=self.device, + ) + + with torch.no_grad(): + next_value_td = TensorDict( + {"observation": next_obs_for_td}, + batch_size=next_td.batch_size, + device=self.device, + ) + self.policy.get_value(next_value_td) + next_td["value"] = next_value_td["value"] + + current_td["next"] = next_td + rollout_list.append(current_td.clone()) + + if self.on_step_callback is not None: + self.on_step_callback(current_td, env_info) + + if done.any(): + with self._lock: + self._episode_count += done.sum().item() + + current_td = next_obs_td + + # Stack along dim=1: list of [N,...] -> [N, T, ...] (batch-first) + rollout = torch.stack(rollout_list, dim=1) + self.obs_tensordict = current_td + + with self._lock: + self.buffer.add_rollout(rollout) + self._step_count += rollout.batch_size[0] * rollout.batch_size[1] + + except Exception as e: + print(f"[AsyncCollector] Error in collection loop: {e}") + import traceback + + traceback.print_exc() + break + + print("[AsyncCollector] Collection loop exited") + + +class AsyncCollectorStats: + """Helper class to track async collection statistics safely.""" + + def __init__(self, num_envs: int, device: torch.device): + self.device = device + self.num_envs = num_envs + + # Episode tracking (on device) + self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=device) + self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=device) + + # Completed episodes (CPU) + self.ret_window = deque(maxlen=100) + self.len_window = deque(maxlen=100) + self._lock = threading.Lock() + + def update(self, reward: torch.Tensor, done: torch.Tensor): + """Update episode statistics (thread-safe). + + Args: + reward: Reward tensor [N] or [N, 1] + done: Done tensor [N] or [N, 1] + """ + # Ensure correct shape + if reward.dim() > 1: + reward = reward.squeeze(-1) + if done.dim() > 1: + done = done.squeeze(-1) + + with self._lock: + # Update cumulative stats + self.curr_ret += reward + self.curr_len += 1 + + # Handle completed episodes + done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) + if done_idx.numel() > 0: + finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() + finished_len = self.curr_len[done_idx].detach().cpu().tolist() + self.ret_window.extend(finished_ret) + self.len_window.extend(finished_len) + + # Reset for finished episodes + self.curr_ret[done_idx] = 0 + self.curr_len[done_idx] = 0 + + def get_avg_stats(self) -> tuple[float, float]: + """Get average episode return and length (thread-safe). + + Returns: + (avg_return, avg_length) or (nan, nan) if no episodes completed + """ + with self._lock: + if len(self.ret_window) == 0: + return float("nan"), float("nan") + return float(sum(self.ret_window) / len(self.ret_window)), float( + sum(self.len_window) / len(self.len_window) + ) diff --git a/embodichain/agents/rl/collector/base.py b/embodichain/agents/rl/collector/base.py new file mode 100644 index 00000000..9664f2e2 --- /dev/null +++ b/embodichain/agents/rl/collector/base.py @@ -0,0 +1,82 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable, Optional +import torch +from tensordict import TensorDict + +from ..utils.helper import dict_to_tensordict + + +class BaseCollector(ABC): + """Abstract base class for data collectors. + + Defines the interface that all collectors must implement. + """ + + def __init__( + self, + env, + policy, + device: torch.device, + on_step_callback: Optional[Callable] = None, + ): + """Initialize base collector. + + Args: + env: Environment to collect from + policy: Policy for action selection + device: Device for tensor operations + on_step_callback: Optional callback(tensordict, env_info) called after each step + """ + self.env = env + self.policy = policy + self.device = device + self.on_step_callback = on_step_callback + if hasattr(self.policy, "bind_env"): + self.policy.bind_env(self.env) + + # Initialize observation + obs_dict, _ = self.env.reset() + self.obs_tensordict = dict_to_tensordict(obs_dict, self.device) + + def _format_env_action(self, action: torch.Tensor): + """Format policy action for the target environment. + + When an ActionManager is configured, the environment expects a mapping + keyed by the active action term name. Otherwise, the environment expects + a raw tensor that is applied directly as joint-space command. + """ + action_manager = getattr(self.env, "action_manager", None) + if action_manager is not None: + return {action_manager.action_type: action} + return action + + @abstractmethod + def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: + """Collect data from environment. + + Args: + num_steps: Number of steps to collect (required by SyncCollector, + ignored by AsyncCollector which collects continuously). + + Returns: + TensorDict with collected data + """ + pass diff --git a/embodichain/agents/rl/collector/sync_collector.py b/embodichain/agents/rl/collector/sync_collector.py new file mode 100644 index 00000000..ecce819a --- /dev/null +++ b/embodichain/agents/rl/collector/sync_collector.py @@ -0,0 +1,136 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from tensordict import TensorDict + +from ..utils.helper import dict_to_tensordict +from .base import BaseCollector + + +class SyncCollector(BaseCollector): + """Synchronous data collector for standard RL training. + + Collects a complete rollout of specified length, then returns it. + Used with RolloutBuffer for standard PPO training. + + Usage: + collector = SyncCollector(env, policy, device) + rollout = collector.collect(num_steps=2048) + buffer.add(rollout) + """ + + def collect(self, num_steps: int | None = None, **kwargs) -> TensorDict: + """Collect a synchronous rollout. + + Args: + num_steps: Number of steps to collect (required) + + Returns: + TensorDict with batch_size=[N, T] (batch-first) containing full rollout + """ + if num_steps is None: + raise TypeError("SyncCollector.collect() requires num_steps") + self.policy.train() + current_td = self.obs_tensordict + rollout_list = [] + + for t in range(num_steps): + # Policy forward: adds "action", "sample_log_prob", "value" to tensordict + self.policy.forward(current_td) + + # Extract action for environment step + action = ( + current_td["env_action"] + if "env_action" in current_td.keys() + else current_td["action"] + ) + env_action = self._format_env_action(action) + + # Environment step - returns tuple (env returns dict, not TensorDict) + next_obs, reward, terminated, truncated, env_info = self.env.step( + env_action + ) + + # Convert env dict observation to TensorDict at boundary + next_obs_td = dict_to_tensordict(next_obs, self.device) + + # Build "next" TensorDict + done = terminated | truncated + next_obs_for_td = next_obs_td["observation"] + if hasattr(self.policy, "reset_envs"): + self.policy.reset_envs(done, next_obs_for_td) + + # Ensure batch_size consistency - use next_obs_td's batch_size + batch_size = next_obs_td.batch_size[0] + + next_td = TensorDict( + { + "observation": next_obs_for_td, + "reward": ( + reward.float().unsqueeze(-1) + if reward.dim() == 1 + else reward.float() + ), + "done": ( + done.bool().unsqueeze(-1) if done.dim() == 1 else done.bool() + ), + "terminated": ( + terminated.bool().unsqueeze(-1) + if terminated.dim() == 1 + else terminated.bool() + ), + "truncated": ( + truncated.bool().unsqueeze(-1) + if truncated.dim() == 1 + else truncated.bool() + ), + }, + batch_size=torch.Size([batch_size]), + device=self.device, + ) + + # Compute next value for GAE (bootstrap value) + with torch.no_grad(): + next_value_td = TensorDict( + {"observation": next_obs_for_td}, + batch_size=next_td.batch_size, + device=self.device, + ) + self.policy.get_value(next_value_td) + next_td["value"] = next_value_td["value"] + + # Add "next" to current tensordict + current_td["next"] = next_td + + # Store complete transition + rollout_list.append(current_td.clone()) + + # Callback for statistics and logging + if self.on_step_callback is not None: + self.on_step_callback(current_td, env_info) + + # Prepare next iteration - use the converted TensorDict + current_td = next_obs_td + + # Update observation for next collection + self.obs_tensordict = current_td + + # Stack along dim=1: list of [N,...] -> [N, T, ...] (batch-first) + rollout = torch.stack(rollout_list, dim=1) + return rollout diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 4b0c0a0b..b3903dcb 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -18,12 +18,12 @@ from typing import Dict, Type import torch -from gymnasium import spaces from .actor_critic import ActorCritic from .actor_only import ActorOnly from .policy import Policy from .mlp import MLP +from .vla_policy import VLAPolicy, build_vla_policy, load_vla_model # In-module policy registry _POLICY_REGISTRY: Dict[str, Type[Policy]] = {} @@ -45,32 +45,54 @@ def get_policy_class(name: str) -> Type[Policy] | None: def build_policy( policy_block: dict, - obs_space: spaces.Space, - action_space: spaces.Space, + action_dim: int, device: torch.device, actor: torch.nn.Module | None = None, critic: torch.nn.Module | None = None, ) -> Policy: - """Build policy strictly from json-like block: { name: ..., cfg: {...} }""" + """Build policy from json-like block. + + With TensorDict architecture, we only need action_dim. + Observations are handled via TensorDict structure. + + Args: + policy_block: Config dict with 'name' key + action_dim: Dimension of action space + device: Device to place policy on + actor: Actor network (required for actor_critic) + critic: Critic network (required for actor_critic) + + Returns: + Initialized Policy instance + """ name = policy_block["name"].lower() if name not in _POLICY_REGISTRY: available = ", ".join(get_registered_policy_names()) raise ValueError( f"Policy '{name}' is not registered. Available policies: {available}" ) - policy_cls = _POLICY_REGISTRY[name] if name == "actor_critic": + policy_cls = _POLICY_REGISTRY[name] if actor is None or critic is None: raise ValueError( "ActorCritic policy requires external 'actor' and 'critic' modules." ) - return policy_cls(obs_space, action_space, device, actor=actor, critic=critic) - elif name == "actor_only": + return policy_cls( + action_dim=action_dim, + device=device, + actor=actor, + critic=critic, + ) + if name == "actor_only": + policy_cls = _POLICY_REGISTRY[name] if actor is None: raise ValueError("ActorOnly policy requires external 'actor' module.") - return policy_cls(obs_space, action_space, device, actor=actor) - else: - return policy_cls(obs_space, action_space, device) + return policy_cls(action_dim=action_dim, device=device, actor=actor) + if name == "vla": + return build_vla_policy(policy_block, action_dim=action_dim, device=device) + + policy_cls = _POLICY_REGISTRY[name] + return policy_cls(action_dim=action_dim, device=device) def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: @@ -93,14 +115,18 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: # default registrations register_policy("actor_critic", ActorCritic) +register_policy("vla", VLAPolicy) register_policy("actor_only", ActorOnly) __all__ = [ "ActorCritic", + "VLAPolicy", "ActorOnly", "register_policy", "get_registered_policy_names", "build_policy", + "build_vla_policy", + "load_vla_model", "build_mlp_from_cfg", "get_policy_class", "Policy", diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py index 35f9a961..5a089bb1 100644 --- a/embodichain/agents/rl/models/actor_critic.py +++ b/embodichain/agents/rl/models/actor_critic.py @@ -16,11 +16,12 @@ from __future__ import annotations -from typing import Dict, Any, Tuple +from typing import Dict, Any import torch import torch.nn as nn from torch.distributions.normal import Normal +from tensordict import TensorDict from .mlp import MLP from .policy import Policy @@ -28,31 +29,31 @@ class ActorCritic(Policy): """Actor-Critic with learnable log_std for Gaussian policy. - This is a placeholder implementation of the Policy interface that: - - Encapsulates MLP networks (actor + critic) that need to be trained by RL algorithms + Uses TensorDict for all data I/O following TorchRL conventions. + This implementation: + - Encapsulates MLP networks (actor + critic) trained by RL algorithms - Handles internal computation: MLP output → mean + learnable log_std → Normal distribution - - Provides a uniform interface for RL algorithms (PPO, SAC, etc.) + - Provides a uniform TensorDict-based interface for RL algorithms (PPO, SAC, etc.) This allows seamless swapping with other policy implementations (e.g., VLAPolicy) without modifying RL algorithm code. Implements: - - get_action(obs, deterministic=False) -> (action, log_prob, value) - - get_value(obs) - - evaluate_actions(obs, actions) -> (log_prob, entropy, value) + - forward(tensordict) -> tensordict (adds action, sample_log_prob, value) + - get_value(tensordict) -> tensordict (adds value) + - evaluate_actions(tensordict) -> tensordict (adds sample_log_prob, entropy, value) """ def __init__( self, - obs_space, - action_space, + action_dim: int, device: torch.device, actor: nn.Module, critic: nn.Module, ): super().__init__() - self.obs_dim = obs_space.shape[-1] - self.action_dim = action_space.shape[-1] + # Observation handling done via TensorDict - no need for obs_space + self.action_dim = action_dim self.device = device # Require external injection of actor and critic @@ -66,31 +67,136 @@ def __init__( self.log_std_min = -5.0 self.log_std_max = 2.0 + def _extract_obs_tensor(self, tensordict: TensorDict) -> torch.Tensor: + """Extract observation as flat tensor from TensorDict. + + For nested TensorDict observations, flattens all leaf tensors. + For plain tensor observations, returns as is. + + Args: + tensordict: Input TensorDict with "observation" key + + Returns: + Flattened observation tensor + """ + obs = tensordict["observation"] + + # Handle nested TensorDict by collecting all leaf tensors + obs_list = [] + + def _collect(item): + # Duck typing: if it has keys(), treat as TensorDict + if hasattr(item, "keys"): + for key in sorted(item.keys()): + _collect(item[key]) + else: + # Leaf tensor + obs_list.append(item.flatten(start_dim=1)) + + _collect(obs) + + if len(obs_list) == 0: + raise ValueError("No tensors found in observation") + elif len(obs_list) == 1: + return obs_list[0] + else: + return torch.cat(obs_list, dim=-1) + @torch.no_grad() - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + """Forward pass: sample action and compute value (in-place modification). + + Args: + tensordict: Must contain "observation" key + deterministic: If True, use mean instead of sampling + + Returns: + Same tensordict with added keys: + - "action": Sampled or deterministic action + - "sample_log_prob": Log probability of action + - "value": Value estimate + - "loc": Distribution mean + - "scale": Distribution std + """ + obs_tensor = self._extract_obs_tensor(tensordict) + + # Actor forward + mean = self.actor(obs_tensor) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) + dist = Normal(mean, std) - action = mean if deterministic else dist.sample() - log_prob = dist.log_prob(action).sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return action, log_prob, value + + # Sample action or use mean + if deterministic: + action = mean + else: + action = dist.sample() + + # Compute log probability + log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) + + # Critic forward - keep shape [N, 1] for consistency with reward/done + value = self.critic(obs_tensor) + + # Add to tensordict (in-place) + tensordict["action"] = action + tensordict["sample_log_prob"] = log_prob + tensordict["value"] = value + tensordict["loc"] = mean + tensordict["scale"] = std + + return tensordict @torch.no_grad() - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - return self.critic(obs).squeeze(-1) + def get_value(self, tensordict: TensorDict) -> TensorDict: + """Get value estimate for observations (in-place modification). + + Args: + tensordict: Must contain "observation" key + + Returns: + Same tensordict with added key: + - "value": Value estimate + """ + obs_tensor = self._extract_obs_tensor(tensordict) + value = self.critic(obs_tensor) # Keep shape [N, 1] + tensordict["value"] = value + return tensordict + + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + """Evaluate actions for policy gradient computation (in-place modification). - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) + Args: + tensordict: Must contain "observation" and "action" keys + + Returns: + Same tensordict with added keys: + - "sample_log_prob": Log probability of actions + - "entropy": Entropy of action distribution + - "value": Value estimate + """ + obs_tensor = self._extract_obs_tensor(tensordict) + actions = tensordict["action"] + + # Actor forward + mean = self.actor(obs_tensor) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) dist = Normal(mean, std) - log_prob = dist.log_prob(actions).sum(dim=-1) - entropy = dist.entropy().sum(dim=-1) - value = self.critic(obs).squeeze(-1) - return log_prob, entropy, value + + # Evaluate given actions - keep shape [N, 1] for consistency + log_prob = dist.log_prob(actions).sum(dim=-1, keepdim=True) + entropy = dist.entropy().sum(dim=-1, keepdim=True) + + # Critic forward - keep shape [N, 1] + value = self.critic(obs_tensor) + + # Add to tensordict (in-place) + tensordict["sample_log_prob"] = log_prob + tensordict["entropy"] = entropy + tensordict["value"] = value + + return tensordict diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py index c54fd515..ce4d14bc 100644 --- a/embodichain/agents/rl/models/actor_only.py +++ b/embodichain/agents/rl/models/actor_only.py @@ -16,31 +16,25 @@ from __future__ import annotations -from typing import Tuple - import torch import torch.nn as nn from torch.distributions.normal import Normal +from tensordict import TensorDict + from .policy import Policy class ActorOnly(Policy): - """Actor-only policy for algorithms that do not use a value function (e.g., GRPO). - - Same interface as ActorCritic: get_action and evaluate_actions return (action, log_prob, value), - but value is always zeros since no critic is used. - """ + """Actor-only Gaussian policy with TensorDict I/O.""" def __init__( self, - obs_space, - action_space, + action_dim: int, device: torch.device, actor: nn.Module, ): super().__init__() - self.obs_dim = obs_space.shape[-1] - self.action_dim = action_space.shape[-1] + self.action_dim = action_dim self.device = device self.actor = actor @@ -50,31 +44,70 @@ def __init__( self.log_std_min = -5.0 self.log_std_max = 2.0 + def _extract_obs_tensor(self, tensordict: TensorDict) -> torch.Tensor: + """Extract a flattened observation tensor from nested TensorDict leaves.""" + obs = tensordict["observation"] + obs_list: list[torch.Tensor] = [] + + def _collect(item) -> None: + if hasattr(item, "keys"): + for key in sorted(item.keys()): + _collect(item[key]) + else: + obs_list.append(item.flatten(start_dim=1)) + + _collect(obs) + + if not obs_list: + raise ValueError("No tensors found in observation") + if len(obs_list) == 1: + return obs_list[0] + return torch.cat(obs_list, dim=-1) + @torch.no_grad() - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + obs_tensor = self._extract_obs_tensor(tensordict) + mean = self.actor(obs_tensor) log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) std = log_std.exp().expand(mean.shape[0], -1) dist = Normal(mean, std) action = mean if deterministic else dist.sample() - log_prob = dist.log_prob(action).sum(dim=-1) - value = torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) - return action, log_prob, value + log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) + value = torch.zeros( + (obs_tensor.shape[0], 1), device=self.device, dtype=obs_tensor.dtype + ) + + tensordict["action"] = action + tensordict["sample_log_prob"] = log_prob + tensordict["value"] = value + tensordict["loc"] = mean + tensordict["scale"] = std + return tensordict @torch.no_grad() - def get_value(self, obs: torch.Tensor) -> torch.Tensor: - return torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) + def get_value(self, tensordict: TensorDict) -> TensorDict: + obs_tensor = self._extract_obs_tensor(tensordict) + tensordict["value"] = torch.zeros( + (obs_tensor.shape[0], 1), device=self.device, dtype=obs_tensor.dtype + ) + return tensordict - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mean = self.actor(obs) + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + obs_tensor = self._extract_obs_tensor(tensordict) + actions = tensordict["action"] log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) + mean = self.actor(obs_tensor) std = log_std.exp().expand(mean.shape[0], -1) dist = Normal(mean, std) - log_prob = dist.log_prob(actions).sum(dim=-1) - entropy = dist.entropy().sum(dim=-1) - value = torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) - return log_prob, entropy, value + log_prob = dist.log_prob(actions).sum(dim=-1, keepdim=True) + entropy = dist.entropy().sum(dim=-1, keepdim=True) + value = torch.zeros( + (obs_tensor.shape[0], 1), device=self.device, dtype=obs_tensor.dtype + ) + + tensordict["sample_log_prob"] = log_prob + tensordict["entropy"] = entropy + tensordict["value"] = value + return tensordict diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py index 21c13a96..ef04f50a 100644 --- a/embodichain/agents/rl/models/policy.py +++ b/embodichain/agents/rl/models/policy.py @@ -19,13 +19,15 @@ This module defines an abstract Policy base class that all RL policies must inherit from. A Policy encapsulates the neural networks and exposes a uniform interface for RL algorithms (e.g., PPO, SAC) to interact with. + +All data I/O now uses TensorDict for structured, extensible data flow. """ from __future__ import annotations -from typing import Tuple from abc import ABC, abstractmethod import torch.nn as nn +from tensordict import TensorDict import torch @@ -37,6 +39,7 @@ class Policy(nn.Module, ABC): - Encapsulates neural networks that are trained by RL algorithms - Handles internal computations (e.g., network output → distribution) - Provides a uniform interface for algorithms (PPO, SAC, etc.) + - Uses TensorDict for all inputs and outputs (no tensor fallback) """ device: torch.device @@ -46,49 +49,54 @@ def __init__(self) -> None: super().__init__() @abstractmethod - def get_action( - self, obs: torch.Tensor, deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sample an action from the policy. + def forward(self, tensordict: TensorDict) -> TensorDict: + """Forward pass that adds action to the input tensordict (in-place). + + This is the main inference method following TorchRL conventions. Args: - obs: Observation tensor of shape (batch_size, obs_dim) - deterministic: If True, return the mean action; otherwise sample + tensordict: Input TensorDict containing at minimum: + - "observation": Observation tensor or nested TensorDict Returns: - Tuple of (action, log_prob, value): - - action: Sampled action tensor of shape (batch_size, action_dim) - - log_prob: Log probability of the action, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + The same TensorDict (modified in-place) with added fields: + - "action": Sampled action tensor + - "sample_log_prob": Log probability of the sampled action + - "value": Value estimate (optional, for actor-critic) + - "loc": Distribution mean (optional, for continuous actions) + - "scale": Distribution std (optional, for continuous actions) """ raise NotImplementedError @abstractmethod - def get_value(self, obs: torch.Tensor) -> torch.Tensor: + def get_value(self, tensordict: TensorDict) -> TensorDict: """Get value estimate for given observations. Args: - obs: Observation tensor of shape (batch_size, obs_dim) + tensordict: Input TensorDict containing: + - "observation": Observation data Returns: - Value estimate tensor of shape (batch_size,) + TensorDict with added field: + - "value": Value estimate tensor """ raise NotImplementedError @abstractmethod - def evaluate_actions( - self, obs: torch.Tensor, actions: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: """Evaluate actions and compute log probabilities, entropy, and values. + Used during policy updates to recompute action probabilities. + Args: - obs: Observation tensor of shape (batch_size, obs_dim) - actions: Action tensor of shape (batch_size, action_dim) + tensordict: Input TensorDict containing: + - "observation": Observation data + - "action": Actions to evaluate Returns: - Tuple of (log_prob, entropy, value): - - log_prob: Log probability of actions, shape (batch_size,) - - entropy: Entropy of the action distribution, shape (batch_size,) - - value: Value estimate, shape (batch_size,) + TensorDict with added fields: + - "sample_log_prob": Log probability of actions + - "entropy": Entropy of the action distribution + - "value": Value estimate """ raise NotImplementedError diff --git a/embodichain/agents/rl/models/vla_policy.py b/embodichain/agents/rl/models/vla_policy.py new file mode 100644 index 00000000..8f1fd80f --- /dev/null +++ b/embodichain/agents/rl/models/vla_policy.py @@ -0,0 +1,809 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import math +import sys +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tensordict import TensorDict +from torch.distributions.normal import Normal + +from .policy import Policy + +__all__ = ["VLAPolicy", "load_vla_model", "build_vla_policy"] + + +class VLAPolicy(Policy): + """Wrap a pretrained DexForceVLA model with the RL Policy interface.""" + + def __init__( + self, + action_dim: int, + device: torch.device, + vla_model: nn.Module, + instruction: str = "Stack the bowls.", + inference_horizon: int = 32, + action_std_init: float = 0.02, + robot_type: str = "CobotMagic", + gripper_open_value: float = 0.05, + gripper_closed_value: float = 0.0, + action_key_order: Optional[list[str]] = None, + ) -> None: + super().__init__() + self.device = device + self.instruction = instruction + self.inference_horizon = inference_horizon + self.robot_type = robot_type + self.gripper_open_value = gripper_open_value + self.gripper_closed_value = gripper_closed_value + + self.vla_model = vla_model.to(self.device) + self.vla_model.eval() + + self._workspace_root = Path(__file__).resolve().parents[5] + self._dexechain_root = self._workspace_root / "embodichain" + if str(self._workspace_root) not in sys.path: + sys.path.append(str(self._workspace_root)) + if str(self._dexechain_root) not in sys.path: + sys.path.append(str(self._dexechain_root)) + + from dexechain.data.data_engine.indices_unifier import ( # pyright: ignore[reportMissingImports] + ActionIndicesGenerator, + ) + from dexechain.data.enum import ( # pyright: ignore[reportMissingImports] + ActionMode, + ControlParts, + EefNormalizer, + EndEffector, + JointType, + Modality, + ) + from dexechain.data.global_mapping import GlobalMapping # pyright: ignore[reportMissingImports] + from dexechain.lab.gym.utils.gym_utils import ( # pyright: ignore[reportMissingImports] + get_pk_serial_chain_from_robot_type, + ) + from dexechain.lab.gym.utils.misc import ( # pyright: ignore[reportMissingImports] + _data_key_to_control_part, + ) + from dexechain.utils.utility import get_right_name # pyright: ignore[reportMissingImports] + + self.ActionMode = ActionMode + self.ControlParts = ControlParts + self.EefNormalizer = EefNormalizer + self.EndEffector = EndEffector + self.JointType = JointType + self.Modality = Modality + self._data_key_to_control_part = _data_key_to_control_part + self._get_right_name = get_right_name + + self.indices_generator = ActionIndicesGenerator(self.vla_model.arm_dofs) + self.global_mapping = GlobalMapping(self.vla_model.arm_dofs) + self.pk_chain = get_pk_serial_chain_from_robot_type(self.robot_type) + + self.state_history_len = int(self.vla_model.state_history_len) + self.img_history_size = int(self.vla_model.img_history_size) + self.state_token_dim = int(self.vla_model.state_token_dim) + self.camera_used = list(getattr(self.vla_model, "camera_used", [])) + self.action_key_order = self._resolve_action_key_order(action_key_order) + self.action_dim = sum( + len(self.indices_generator.get([key])) for key in self.action_key_order + ) + if action_dim != self.action_dim: + raise ValueError( + f"Configured action_dim={action_dim} does not match decoded VLA " + f"action_dim={self.action_dim} for keys {self.action_key_order}." + ) + self.full_action_indices = self.indices_generator.get(self.vla_model.output) + + self.log_std = nn.Parameter( + torch.full( + (self.action_dim,), + float(math.log(max(action_std_init, 1e-6))), + device=self.device, + ) + ) + self.log_std_min = -5.0 + self.log_std_max = 2.0 + critic_input_dim = self.state_history_len * self.state_token_dim + self.value_head = nn.Sequential( + nn.Linear(critic_input_dim, 256), + nn.ReLU(), + nn.Linear(256, 1), + ).to(self.device) + + self._runtime_env = None + self._runtime_robot = None + self._state_history: torch.Tensor | None = None + self._image_history: torch.Tensor | None = None + self._cached_chunk: torch.Tensor | None = None + self._cached_chunk_context: TensorDict | None = None + self._cached_chunk_step: torch.Tensor | None = None + + def bind_env(self, env) -> None: + self._runtime_env = env + if env is None: + self._runtime_robot = None + return + try: + self._runtime_robot = env.get_wrapper_attr("robot") + except Exception: + self._runtime_robot = None + + def _reset_chunk_cache(self, env_mask: torch.Tensor | None = None) -> None: + if env_mask is None: + self._cached_chunk = None + self._cached_chunk_context = None + self._cached_chunk_step = None + return + if self._cached_chunk_step is not None: + self._cached_chunk_step[env_mask] = self.inference_horizon + + @torch.no_grad() + def reset_envs( + self, done_mask: torch.Tensor, next_observation: TensorDict | None = None + ) -> None: + if done_mask.dim() > 1: + done_mask = done_mask.squeeze(-1) + done_mask = done_mask.to(device=self.device, dtype=torch.bool) + if not done_mask.any(): + return + + self._reset_chunk_cache(done_mask) + + if next_observation is None: + if self._state_history is not None: + self._state_history[done_mask] = 0 + if self._image_history is not None: + self._image_history[done_mask] = 0 + return + + current_state, _ = self._build_state_vector(next_observation) + current_images = self._extract_current_images(next_observation) + reset_state_history = current_state.unsqueeze(1).repeat(1, self.state_history_len, 1) + reset_image_history = current_images.unsqueeze(1).repeat(1, self.img_history_size, 1, 1, 1, 1) + + if self._state_history is None or self._state_history.shape[0] != current_state.shape[0]: + self._state_history = reset_state_history + else: + self._state_history[done_mask] = reset_state_history[done_mask] + + if self._image_history is None or self._image_history.shape[0] != current_images.shape[0]: + self._image_history = reset_image_history + else: + self._image_history[done_mask] = reset_image_history[done_mask] + + def _resolve_action_key_order( + self, action_key_order: Optional[list[str]] + ) -> list[str]: + output_keys = list(self.vla_model.output) + if action_key_order: + return [key for key in action_key_order if key in output_keys] + + preferred_order = [ + self.ControlParts.LEFT_ARM.value + + self.ActionMode.RELATIVE.value + + self.JointType.QPOS.value, + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value, + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value, + self.ControlParts.RIGHT_ARM.value + + self.ActionMode.RELATIVE.value + + self.JointType.QPOS.value, + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value, + self.ControlParts.RIGHT_EEF.value + self.EndEffector.GRIPPER.value, + ] + resolved = [key for key in preferred_order if key in output_keys] + if not resolved: + raise ValueError(f"No supported VLA outputs found in {output_keys}") + return resolved + + def _fit_state_value( + self, key: str, value: torch.Tensor | object, dtype: torch.dtype + ) -> torch.Tensor: + tensor = ( + value.to(self.device, dtype=dtype) + if isinstance(value, torch.Tensor) + else torch.as_tensor(value, device=self.device, dtype=dtype) + ) + if tensor.dim() == 1: + tensor = tensor.unsqueeze(0) + + target_width = len(self.global_mapping.get_indices([key])) + if tensor.shape[-1] != target_width: + if target_width == 1: + tensor = tensor.mean(dim=-1, keepdim=True) + elif tensor.shape[-1] > target_width: + tensor = tensor[..., :target_width] + else: + raise ValueError( + f"State '{key}' width {tensor.shape[-1]} cannot fit target width {target_width}." + ) + return tensor + + def _normalize_gripper(self, qpos: torch.Tensor, key: str) -> torch.Tensor: + if self._runtime_robot is not None: + normalized = self.EefNormalizer.normalize_cobotmagic_gripper( + qpos, key, is_action=False, robot=self._runtime_robot + ) + return self._fit_state_value(key, normalized, qpos.dtype).clamp(0.0, 1.0) + + if qpos.dim() >= 2 and qpos.shape[-1] > 1: + qpos = qpos.mean(dim=-1, keepdim=True) + denom = max(self.gripper_open_value - self.gripper_closed_value, 1e-6) + normalized = 1.0 - (qpos - self.gripper_closed_value) / denom + return self._fit_state_value(key, normalized.clamp(0.0, 1.0), qpos.dtype) + + def _resolve_camera_image( + self, sensor_obs: TensorDict, camera_name: str + ) -> torch.Tensor | None: + if camera_name in sensor_obs: + return sensor_obs[camera_name]["color"][..., :3].to(self.device) + + for base_camera_name in sensor_obs.keys(): + if ( + self._get_right_name(base_camera_name) == camera_name + and "color_right" in sensor_obs[base_camera_name] + ): + return sensor_obs[base_camera_name]["color_right"][..., :3].to( + self.device + ) + + return None + + def _resize_camera_image(self, image: torch.Tensor) -> torch.Tensor: + target_size = int(getattr(self.vla_model, "img_size", 0) or 0) + if target_size <= 0: + return image + if image.shape[-3:-1] == (target_size, target_size): + return image + + resized = F.interpolate( + image.permute(0, 3, 1, 2).float(), + size=(target_size, target_size), + mode="bilinear", + align_corners=False, + ) + return resized.permute(0, 2, 3, 1).to(dtype=image.dtype) + + def _extract_current_images(self, observation: TensorDict) -> torch.Tensor: + sensor_obs = observation["sensor"] + images = [] + for camera_name in self.camera_used: + image = self._resolve_camera_image(sensor_obs, camera_name) + if image is None: + raise KeyError(f"Camera '{camera_name}' not found in observation.") + images.append(self._resize_camera_image(image)) + return torch.stack(images, dim=1) + + def _split_qpos(self, qpos: torch.Tensor) -> dict[str, torch.Tensor]: + arm_dofs_per_side = self.vla_model.arm_dofs // 2 + eef_dofs_total = qpos.shape[-1] - self.vla_model.arm_dofs + eef_dofs_per_side = max(eef_dofs_total // 2, 0) + + left_arm_end = arm_dofs_per_side + left_eef_end = left_arm_end + eef_dofs_per_side + right_arm_end = left_eef_end + arm_dofs_per_side + + return { + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value: qpos[ + :, :left_arm_end + ], + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value: qpos[ + :, left_arm_end:left_eef_end + ], + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value: qpos[ + :, left_eef_end:right_arm_end + ], + self.ControlParts.RIGHT_EEF.value + self.EndEffector.GRIPPER.value: qpos[ + :, right_arm_end: + ], + } + + def _build_state_vector( + self, observation: TensorDict + ) -> tuple[torch.Tensor, torch.Tensor]: + qpos = observation["robot"][self.JointType.QPOS.value].to(self.device) + qpos_chunks = self._split_qpos(qpos) + state_entries: dict[str, torch.Tensor] = {} + + if self._runtime_env is not None and self._runtime_robot is not None: + control_parts = ( + self._runtime_env.metadata.get("dataset", {}) + .get("robot_meta", {}) + .get("control_parts", []) + ) + if not control_parts: + control_parts = [ + self.ControlParts.LEFT_ARM.value, + self.ControlParts.LEFT_EEF.value, + self.ControlParts.RIGHT_ARM.value, + self.ControlParts.RIGHT_EEF.value, + ] + for key in self.vla_model.state_meta: + part = self._data_key_to_control_part( + robot=self._runtime_robot, + control_parts=control_parts, + data_key=key, + ) + if part is None: + continue + indices = self._runtime_robot.get_joint_ids(part, remove_mimic=True) + qpos_data = qpos[:, indices] + if self.EndEffector.GRIPPER.value in key: + state_entries[key] = self._normalize_gripper(qpos_data, key) + else: + normalized = self.EefNormalizer.normalize_eef( + qpos_data, part, robot=self._runtime_robot + ) + state_entries[key] = self._fit_state_value(key, normalized, qpos.dtype) + else: + state_entries = { + self.ControlParts.LEFT_ARM.value + + self.JointType.QPOS.value: qpos_chunks[ + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value + ], + self.ControlParts.RIGHT_ARM.value + + self.JointType.QPOS.value: qpos_chunks[ + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value + ], + self.ControlParts.LEFT_EEF.value + + self.EndEffector.GRIPPER.value: self._normalize_gripper( + qpos_chunks[ + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value + ], + self.ControlParts.LEFT_EEF.value + self.EndEffector.GRIPPER.value, + ), + self.ControlParts.RIGHT_EEF.value + + self.EndEffector.GRIPPER.value: self._normalize_gripper( + qpos_chunks[ + self.ControlParts.RIGHT_EEF.value + + self.EndEffector.GRIPPER.value + ], + self.ControlParts.RIGHT_EEF.value + self.EndEffector.GRIPPER.value, + ), + } + + if self.pk_chain is not None: + from dexechain.lab.gym.utils.gym_utils import ( # pyright: ignore[reportMissingImports] + map_qpos_to_eef_pose, + ) + + arm_dofs_per_side = self.vla_model.arm_dofs // 2 + arm_qpos = torch.cat( + [ + qpos_chunks[ + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value + ], + qpos_chunks[ + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value + ], + ], + dim=-1, + ) + eef_pose_dict = map_qpos_to_eef_pose( + self.pk_chain, + arm_qpos.to("cpu"), + control_parts=[ + self.ControlParts.LEFT_ARM.value, + self.ControlParts.RIGHT_ARM.value, + ], + control_ids=[ + list(range(0, arm_dofs_per_side)), + list(range(arm_dofs_per_side, arm_dofs_per_side * 2)), + ], + ) + eef_pose_dict = { + key: value.to(self.device, dtype=qpos.dtype) + if isinstance(value, torch.Tensor) + else torch.as_tensor(value, device=self.device, dtype=qpos.dtype) + for key, value in eef_pose_dict.items() + } + state_entries.update(eef_pose_dict) + + state_vector = torch.zeros( + (qpos.shape[0], self.state_token_dim), + device=self.device, + dtype=qpos.dtype, + ) + state_indicator = torch.zeros_like(state_vector) + for key in self.vla_model.state_meta: + if key not in state_entries: + continue + indices = self.global_mapping.get_indices([key]) + state_vector[:, indices] = state_entries[key] + state_indicator[:, indices] = 1 + return state_vector, state_indicator + + def _roll_history( + self, + history: torch.Tensor | None, + current: torch.Tensor, + history_len: int, + ) -> torch.Tensor: + if history is None or history.shape[0] != current.shape[0]: + return current.unsqueeze(1).repeat( + [1, history_len] + [1] * (current.dim() - 1) + ) + if history_len == 1: + return current.unsqueeze(1) + return torch.cat([history[:, 1:], current.unsqueeze(1)], dim=1) + + def _build_policy_context( + self, + observation: TensorDict, + update_history: bool, + cached_context: TensorDict | None = None, + ) -> tuple[dict[str, torch.Tensor | list[str]], torch.Tensor, TensorDict]: + current_state, current_state_indicator = self._build_state_vector(observation) + current_images = self._extract_current_images(observation) + + if cached_context is not None: + state_history = cached_context["state_history"].to(self.device) + image_history = cached_context["image_history"].to(self.device) + else: + state_history = self._roll_history( + self._state_history, current_state, self.state_history_len + ) + image_history = self._roll_history( + self._image_history, current_images, self.img_history_size + ) + if update_history: + self._state_history = state_history.detach().clone() + self._image_history = image_history.detach().clone() + + state_indicator = current_state_indicator.unsqueeze(1).repeat( + 1, state_history.shape[1], 1 + ) + action_indicator = torch.zeros( + ( + current_state.shape[0], + self.inference_horizon, + self.state_token_dim, + ), + device=self.device, + dtype=current_state.dtype, + ) + action_indicator[:, :, self.full_action_indices] = 1 + + batch = { + self.Modality.IMAGES.value: image_history, + self.Modality.STATES.value: state_history, + self.Modality.STATE_INDICATOR.value: state_indicator, + self.Modality.ACTION_INDICATOR.value: action_indicator, + "instruction": [self.instruction] * current_state.shape[0], + } + critic_input = state_history.reshape(state_history.shape[0], -1).float() + context = TensorDict( + { + "state_history": state_history.detach(), + "image_history": image_history.detach(), + }, + batch_size=[current_state.shape[0]], + device=self.device, + ) + return batch, critic_input, context + + def _slice_batch( + self, batch: dict[str, torch.Tensor | list[str]], mask: torch.Tensor + ) -> dict[str, torch.Tensor | list[str]]: + mask_list = mask.tolist() + sliced: dict[str, torch.Tensor | list[str]] = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + sliced[key] = value[mask] + elif isinstance(value, list): + sliced[key] = [item for item, keep in zip(value, mask_list) if keep] + else: + sliced[key] = value + return sliced + + def _predict_chunk_actions( + self, batch: dict[str, torch.Tensor | list[str]] + ) -> torch.Tensor: + self.vla_model.eval() + data = self.vla_model.brain_infer( + batch, + action_mask=batch[self.Modality.ACTION_INDICATOR.value], + precomp_lang_embed=True, + use_fix_aug=False, + ) + data = self.vla_model._compute_priviliges(data) + data = self.vla_model._compute_adaptors(data) + data = self.vla_model.cerebellum(data, None) + + from dexechain.agents.dexforce_vla.models.utils import ( # pyright: ignore[reportMissingImports] + post_process, + ) + + data = post_process( + data, + is_training=False, + **self.vla_model.global_collection, + ) + return data[self.Modality.ACTIONS.value] + + def _decode_action_step( + self, step_action: torch.Tensor, observation: TensorDict + ) -> torch.Tensor: + current_qpos = observation["robot"][self.JointType.QPOS.value].to(self.device) + qpos_chunks = self._split_qpos(current_qpos) + decoded_parts: list[torch.Tensor] = [] + + for key in self.action_key_order: + indices = self.indices_generator.get([key]) + value = step_action[:, indices] + if ( + self.ActionMode.RELATIVE.value in key + and self.JointType.QPOS.value in key + ): + if key.startswith(self.ControlParts.LEFT_ARM.value): + value = ( + qpos_chunks[ + self.ControlParts.LEFT_ARM.value + self.JointType.QPOS.value + ] + + value + ) + elif key.startswith(self.ControlParts.RIGHT_ARM.value): + value = ( + qpos_chunks[ + self.ControlParts.RIGHT_ARM.value + self.JointType.QPOS.value + ] + + value + ) + elif self.EndEffector.GRIPPER.value in key: + value = self.gripper_closed_value + ( + 1.0 - value + ) * (self.gripper_open_value - self.gripper_closed_value) + decoded_parts.append(value) + + if not decoded_parts: + raise ValueError( + f"No action keys could be decoded from model outputs: {self.vla_model.output}" + ) + return torch.cat(decoded_parts, dim=-1).to(self.device) + + def _decode_first_action( + self, trajectory: torch.Tensor, observation: TensorDict + ) -> torch.Tensor: + return self._decode_action_step(trajectory[:, 0], observation) + + def _expand_env_action( + self, action: torch.Tensor, observation: TensorDict + ) -> torch.Tensor: + expanded_parts: list[torch.Tensor] = [] + offset = 0 + for key in self.action_key_order: + width = len(self.indices_generator.get([key])) + value = action[:, offset : offset + width] + offset += width + + if ( + self._runtime_robot is not None + and self.EndEffector.GRIPPER.value in key + and value.shape[-1] == 1 + ): + value = self.EefNormalizer.denormalize_cobotmagic_gripper( + value, key, robot=self._runtime_robot + ) + value = ( + value.to(self.device, dtype=action.dtype) + if isinstance(value, torch.Tensor) + else torch.as_tensor(value, device=self.device, dtype=action.dtype) + ) + if value.dim() == 1: + value = value.unsqueeze(0) + control_part = key.replace(self.EndEffector.GRIPPER.value, "") + target_dim = len( + self._runtime_robot.get_joint_ids( + control_part, remove_mimic=False + ) + ) + if target_dim > value.shape[-1]: + value = value.repeat(1, target_dim) + + expanded_parts.append(value) + + return torch.cat(expanded_parts, dim=-1).to(self.device) + + def _action_stats( + self, + mean_action: torch.Tensor, + deterministic: bool, + provided_action: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) + std = log_std.exp().expand(mean_action.shape[0], -1) + dist = Normal(mean_action, std) + if provided_action is not None: + action = provided_action + elif deterministic: + action = mean_action + else: + action = dist.rsample() + log_prob = dist.log_prob(action).sum(dim=-1, keepdim=True) + entropy = dist.entropy().sum(dim=-1, keepdim=True) + return action, log_prob, entropy + + @torch.no_grad() + def forward( + self, tensordict: TensorDict, deterministic: bool = False + ) -> TensorDict: + observation = tensordict["observation"] + batch, critic_input, context = self._build_policy_context( + observation, update_history=True + ) + batch_size = observation.batch_size[0] + + if ( + self._cached_chunk is None + or self._cached_chunk_context is None + or self._cached_chunk_step is None + or self._cached_chunk.shape[0] != batch_size + ): + self._cached_chunk = None + self._cached_chunk_context = None + self._cached_chunk_step = torch.full( + (batch_size,), + self.inference_horizon, + device=self.device, + dtype=torch.long, + ) + + refresh_mask = self._cached_chunk_step >= self.inference_horizon + if refresh_mask.any(): + refresh_batch = self._slice_batch(batch, refresh_mask) + refresh_trajectory = self._predict_chunk_actions(refresh_batch) + refresh_context = context[refresh_mask] + + if self._cached_chunk is None: + chunk_shape = (batch_size,) + tuple(refresh_trajectory.shape[1:]) + self._cached_chunk = torch.zeros( + chunk_shape, + device=refresh_trajectory.device, + dtype=refresh_trajectory.dtype, + ) + if self._cached_chunk_context is None: + self._cached_chunk_context = context.clone() + + self._cached_chunk[refresh_mask] = refresh_trajectory + self._cached_chunk_context["state_history"][refresh_mask] = refresh_context[ + "state_history" + ] + self._cached_chunk_context["image_history"][refresh_mask] = refresh_context[ + "image_history" + ] + self._cached_chunk_step[refresh_mask] = 0 + + step_indices = self._cached_chunk_step.clone() + raw_step_actions = self._cached_chunk[ + torch.arange(batch_size, device=self.device), step_indices + ] + mean_action = self._decode_action_step(raw_step_actions, observation) + action, log_prob, _ = self._action_stats(mean_action, deterministic) + tensordict["action"] = action + tensordict["env_action"] = self._expand_env_action(action, observation) + tensordict["sample_log_prob"] = log_prob + tensordict["value"] = self.value_head(critic_input) + tensordict["policy_context"] = self._cached_chunk_context.clone() + tensordict["chunk_step_idx"] = step_indices.unsqueeze(-1) + tensordict["loc"] = mean_action + tensordict["scale"] = self.log_std.clamp( + self.log_std_min, self.log_std_max + ).exp().expand_as(mean_action) + self._cached_chunk_step = self._cached_chunk_step + 1 + return tensordict + + @torch.no_grad() + def get_value(self, tensordict: TensorDict) -> TensorDict: + _, critic_input, _ = self._build_policy_context( + tensordict["observation"], update_history=False + ) + tensordict["value"] = self.value_head(critic_input) + return tensordict + + def evaluate_actions(self, tensordict: TensorDict) -> TensorDict: + observation = tensordict["observation"] + actions = tensordict["action"] + context = tensordict.get("policy_context", None) + batch, critic_input, _ = self._build_policy_context( + observation, update_history=False, cached_context=context + ) + trajectory = self._predict_chunk_actions(batch) + if "chunk_step_idx" in tensordict.keys(): + step_idx = tensordict["chunk_step_idx"].squeeze(-1).long() + step_action = trajectory[ + torch.arange(trajectory.shape[0], device=trajectory.device), step_idx + ] + mean_action = self._decode_action_step(step_action, observation) + else: + mean_action = self._decode_first_action(trajectory, observation) + _, log_prob, entropy = self._action_stats( + mean_action, deterministic=False, provided_action=actions + ) + tensordict["sample_log_prob"] = log_prob + tensordict["entropy"] = entropy + tensordict["value"] = self.value_head(critic_input) + return tensordict + + +def load_vla_model( + model_path: str, + model_class: Optional[str] = None, + model_config: Optional[dict] = None, + device: torch.device = torch.device("cpu"), +) -> nn.Module: + """Load a pretrained DexForceVLA-compatible model.""" + workspace_root = Path(__file__).resolve().parents[5] + dexechain_root = workspace_root / "embodichain" + if str(workspace_root) not in sys.path: + sys.path.append(str(workspace_root)) + if str(dexechain_root) not in sys.path: + sys.path.append(str(dexechain_root)) + + model_config = {} if model_config is None else dict(model_config) + torch_dtype_name = model_config.pop("torch_dtype", "float32") + weight_dtype = getattr(torch, torch_dtype_name) + + if model_class: + module_name, class_name = model_class.rsplit(".", 1) + module = __import__(module_name, fromlist=[class_name]) + model_cls = getattr(module, class_name) + return model_cls.from_pretrained(model_path, dtype=weight_dtype).to(device) + + from dexechain.agents.dexforce_vla.models.dexforcevla_runner import ( # pyright: ignore[reportMissingImports] + DexForceVLA, + ) + + return DexForceVLA.from_pretrained(model_path, dtype=weight_dtype).to(device) + + +def build_vla_policy( + policy_block: dict, + action_dim: int, + device: torch.device, +) -> VLAPolicy: + """Build a VLAPolicy from configuration.""" + vla_config = policy_block.get("vla_config") + if vla_config is None: + raise ValueError("VLA policy requires 'vla_config' in policy block") + + model_path = vla_config.get("model_path") + if model_path is None: + raise ValueError("VLA config requires 'model_path'") + + vla_model = load_vla_model( + model_path=model_path, + model_class=vla_config.get("model_class"), + model_config=dict(vla_config.get("model_config", {})), + device=device, + ) + return VLAPolicy( + action_dim=action_dim, + device=device, + vla_model=vla_model, + instruction=vla_config.get("instruction", "Stack the bowls."), + inference_horizon=int(vla_config.get("inference_horizon", 32)), + action_std_init=float(vla_config.get("action_std_init", 0.02)), + robot_type=vla_config.get("robot_type", "CobotMagic"), + gripper_open_value=float(vla_config.get("gripper_open_value", 0.05)), + gripper_closed_value=float(vla_config.get("gripper_closed_value", 0.0)), + action_key_order=vla_config.get("action_key_order"), + ) diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index 5c553c94..93682940 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -15,7 +15,9 @@ # ---------------------------------------------------------------------------- import argparse +import importlib import os +import sys import time from pathlib import Path @@ -64,7 +66,10 @@ def train_from_config(config_path: str): seed = int(trainer_cfg.get("seed", 1)) device_str = trainer_cfg.get("device", "cpu") iterations = int(trainer_cfg.get("iterations", 250)) - rollout_steps = int(trainer_cfg.get("rollout_steps", 2048)) + buffer_size = int( + trainer_cfg.get("buffer_size", trainer_cfg.get("rollout_steps", 2048)) + ) + model_type = trainer_cfg.get("model_type", "standard") enable_eval = bool(trainer_cfg.get("enable_eval", False)) eval_freq = int(trainer_cfg.get("eval_freq", 10000)) save_freq = int(trainer_cfg.get("save_freq", 50000)) @@ -74,6 +79,8 @@ def train_from_config(config_path: str): gpu_id = int(trainer_cfg.get("gpu_id", 0)) num_envs = trainer_cfg.get("num_envs", None) wandb_project_name = trainer_cfg.get("wandb_project_name", "embodychain-generic") + filter_dataset_saving = bool(trainer_cfg.get("filter_dataset_saving", True)) + import_modules = list(trainer_cfg.get("import_modules", [])) # Device if not isinstance(device_str, str): @@ -129,18 +136,30 @@ def train_from_config(config_path: str): if use_wandb: wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json) + workspace_root = Path(__file__).resolve().parents[3] + dexechain_root = workspace_root / "embodichain" + if str(workspace_root) not in sys.path: + sys.path.append(str(workspace_root)) + if str(dexechain_root) not in sys.path: + sys.path.append(str(dexechain_root)) + for module_name in import_modules: + importlib.import_module(module_name) + gym_config_path = Path(trainer_cfg["gym_config"]) logger.log_info(f"Current working directory: {Path.cwd()}") gym_config_data = load_json(str(gym_config_path)) + if filter_dataset_saving: + gym_config_data = deepcopy(gym_config_data) + gym_config_data.get("env", {}).pop("dataset", None) gym_env_cfg = config_to_cfg( gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES ) if num_envs is not None: gym_env_cfg.num_envs = int(num_envs) - - if num_envs is not None: - gym_env_cfg.num_envs = num_envs + gym_env_cfg.filter_dataset_saving = filter_dataset_saving + if filter_dataset_saving: + gym_env_cfg.init_rollout_buffer = False # Ensure sim configuration mirrors runtime overrides if gym_env_cfg.sim_cfg is None: @@ -171,6 +190,8 @@ def train_from_config(config_path: str): eval_gym_env_cfg = deepcopy(gym_env_cfg) eval_gym_env_cfg.num_envs = num_eval_envs eval_gym_env_cfg.sim_cfg.headless = True + eval_gym_env_cfg.filter_dataset_saving = True + eval_gym_env_cfg.init_rollout_buffer = False eval_env = build_env(gym_config_data["id"], base_env_cfg=eval_gym_env_cfg) logger.log_info( f"Evaluation environment created (num_envs={num_eval_envs}, headless=True)" @@ -178,13 +199,39 @@ def train_from_config(config_path: str): # Build Policy via registry policy_name = policy_block["name"] - # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only) - if policy_name.lower() == "actor_critic": - # Get observation dimension from flattened observation space - # flattened_observation_space returns Box space for RL training - obs_dim = env.flattened_observation_space.shape[-1] - action_dim = env.action_space.shape[-1] + # Prefer explicit action_dim from config, but keep env-based fallback for older configs. + action_dim = policy_block.get("action_dim") + if action_dim is None: + action_space = getattr(env, "action_space", None) + if action_space is None or not hasattr(action_space, "shape"): + raise ValueError( + "Unable to infer action_dim. Please set 'policy.action_dim' explicitly " + "or expose env.action_space.shape." + ) + action_dim = int(action_space.shape[-1]) + + # Infer obs_dim from environment sampling (no gym space dependency) + # Env returns dict, we process it to infer dimensions + sample_obs, _ = env.reset() + + # Get obs_dim by flattening observation structure (env returns dict) + obs_list = [] + + def _collect(item): + """Recursively collect tensors from dict or direct tensor.""" + if hasattr(item, "keys"): # It's a dict + for key in sorted(item.keys()): + _collect(item[key]) + else: # It's a Tensor + obs_list.append(item.flatten(start_dim=1)) + + _collect(sample_obs) + obs_dim = sum(t.shape[-1] for t in obs_list) + + # Build policy based on type + policy_name_lower = policy_name.lower() + if policy_name_lower == "actor_critic": actor_cfg = policy_block.get("actor") critic_cfg = policy_block.get("critic") if actor_cfg is None or critic_cfg is None: @@ -197,16 +244,12 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.flattened_observation_space, - env.action_space, - device, + action_dim=action_dim, + device=device, actor=actor, critic=critic, ) - elif policy_name.lower() == "actor_only": - obs_dim = env.flattened_observation_space.shape[-1] - action_dim = env.action_space.shape[-1] - + elif policy_name_lower == "actor_only": actor_cfg = policy_block.get("actor") if actor_cfg is None: raise ValueError( @@ -217,15 +260,12 @@ def train_from_config(config_path: str): policy = build_policy( policy_block, - env.flattened_observation_space, - env.action_space, - device, + action_dim=action_dim, + device=device, actor=actor, ) else: - policy = build_policy( - policy_block, env.flattened_observation_space, env.action_space, device - ) + policy = build_policy(policy_block, action_dim=action_dim, device=device) # Build Algorithm via factory algo_name = algo_block["name"].lower() @@ -276,7 +316,7 @@ def train_from_config(config_path: str): policy=policy, env=env, algorithm=algo, - num_steps=rollout_steps, + buffer_size=buffer_size, batch_size=algo_cfg["batch_size"], writer=writer, eval_freq=eval_freq if enable_eval else 0, # Disable eval if not enabled @@ -288,6 +328,7 @@ def train_from_config(config_path: str): event_cfg=train_event_cfg, eval_event_cfg=eval_event_cfg if enable_eval else {}, num_eval_episodes=num_eval_episodes, + model_type=model_type, ) logger.log_info("Generic training initialized") @@ -299,7 +340,7 @@ def train_from_config(config_path: str): f"Algorithm: {algo_name} (available: {get_registered_algo_names()})" ) - total_steps = int(iterations * rollout_steps * env.num_envs) + total_steps = int(iterations * buffer_size * env.num_envs) logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})") try: diff --git a/embodichain/agents/rl/utils/__init__.py b/embodichain/agents/rl/utils/__init__.py index 7bd835e8..1f8befee 100644 --- a/embodichain/agents/rl/utils/__init__.py +++ b/embodichain/agents/rl/utils/__init__.py @@ -15,9 +15,19 @@ # ---------------------------------------------------------------------------- from .config import AlgorithmCfg -from .helper import flatten_dict_observation +from .helper import ( + compute_gae, + dict_to_tensordict, + flatten_dict_observation, + mean_scalar, + pack_log_dict, +) __all__ = [ "AlgorithmCfg", + "dict_to_tensordict", "flatten_dict_observation", + "mean_scalar", + "pack_log_dict", + "compute_gae", ] diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index 42259506..113792e1 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -14,40 +14,202 @@ # limitations under the License. # ---------------------------------------------------------------------------- +"""Helper utilities for RL training.""" + +from __future__ import annotations + +import numpy as np import torch from tensordict import TensorDict -def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: +def dict_to_tensordict(obs_dict: dict | TensorDict, device: torch.device) -> TensorDict: + """Convert a nested observation dict into a TensorDict. + + Args: + obs_dict: Nested observation dictionary returned by the environment. + device: Device to place tensors on. + + Returns: + TensorDict with an outer ``"observation"`` key. """ - Flatten hierarchical TensorDict observations from ObservationManager. - Recursively traverse nested TensorDicts, collect all tensor values, - flatten each to (num_envs, -1), and concatenate in sorted key order. + def _recursive_convert(data: dict | TensorDict) -> dict: + result = {} + for key, value in data.items(): + if isinstance(value, (dict, TensorDict)): + result[key] = _recursive_convert(value) + elif isinstance(value, torch.Tensor): + result[key] = value.to(device) + else: + result[key] = torch.tensor(value, device=device) + return result + + def _get_first_tensor_batch_size(data: dict) -> int | None: + for value in data.values(): + if isinstance(value, torch.Tensor): + return int(value.shape[0]) + if isinstance(value, dict): + batch_size = _get_first_tensor_batch_size(value) + if batch_size is not None: + return batch_size + return None + + if isinstance(obs_dict, TensorDict): + obs_td = obs_dict.to(device) + return TensorDict( + {"observation": obs_td}, + batch_size=obs_td.batch_size, + device=device, + ) + + converted = _recursive_convert(obs_dict) + batch_size = _get_first_tensor_batch_size(converted) + if batch_size is None: + batch_size = 1 + + obs_td = TensorDict(converted, batch_size=[batch_size], device=device) + return TensorDict({"observation": obs_td}, batch_size=[batch_size], device=device) + + +def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: + """Flatten a nested observation TensorDict into a dense tensor. Args: - obs: Nested TensorDict structure, e.g. TensorDict(robot=TensorDict(qpos=..., qvel=...), ...) + obs: Nested observation TensorDict. Returns: - Concatenated flat tensor of shape (num_envs, total_dim) + A tensor shaped ``[num_envs, obs_dim]``. """ - obs_list = [] - def _collect_tensors(d, prefix=""): - """Recursively collect tensors from nested TensorDicts in sorted order.""" - for key in sorted(d.keys()): - full_key = f"{prefix}/{key}" if prefix else key - value = d[key] + obs_list: list[torch.Tensor] = [] + + def _collect_tensors(data: TensorDict) -> None: + for key in sorted(data.keys()): + value = data[key] if isinstance(value, TensorDict): - _collect_tensors(value, full_key) + _collect_tensors(value) elif isinstance(value, torch.Tensor): - # Flatten tensor to (num_envs, -1) shape obs_list.append(value.flatten(start_dim=1)) _collect_tensors(obs) if not obs_list: raise ValueError("No tensors found in observation TensorDict") + if len(obs_list) == 1: + return obs_list[0] + return torch.cat(obs_list, dim=-1) + + +def mean_scalar(x) -> float: + """Convert tensor or array to scalar float (mean if needed). + + Args: + x: Tensor, array, or scalar value + + Returns: + Float scalar value + """ + if hasattr(x, "detach"): + x = x.detach().cpu().numpy() + else: + x = np.asarray(x) + return float(np.mean(x)) + + +def pack_log_dict(prefix: str, data: dict) -> dict: + """Pack data dict into logging dict with prefix. + + Args: + prefix: Prefix for keys (e.g., "train", "eval") + data: Dictionary of values to pack + + Returns: + Dictionary with prefixed keys and scalar values + """ + if not isinstance(data, dict): + return {} + out = {} + for k, v in data.items(): + try: + out[f"{prefix}/{k}"] = mean_scalar(v) + except Exception: + continue + return out + + +def compute_gae( + rollout: TensorDict, + gamma: float, + gae_lambda: float, + time_first: bool = True, +) -> TensorDict: + """Compute Generalized Advantage Estimation (GAE) on rollout TensorDict. + + Supports two layouts: + - time_first=True (default): [T, N, ...] - TorchRL convention + - time_first=False: [N, T, ...] - batch-first, matches VLA training convention + + GAE requires sequential timesteps within the same trajectory. Both layouts + ensure correct per-env trajectory ordering. + + Args: + rollout: TensorDict with batch_size=[T, N] or [N, T] containing: + - "value": state values + - "next": TensorDict with "reward", "done", "value" (bootstrapped) + gamma: Discount factor + gae_lambda: GAE lambda parameter + time_first: If True, rollout is [T, N]; if False, rollout is [N, T] + + Returns: + TensorDict with added keys: "advantage", "value_target" + """ + device = rollout.device + + if time_first: + # [T, N, ...] + T, N = rollout.batch_size[:2] + values = rollout["value"] + rewards = rollout["next"]["reward"] + dones = rollout["next"]["done"].float() + if "value" in rollout["next"]: + bootstrap_values = rollout["next"]["value"] + else: + bootstrap_values = torch.zeros_like(values) + + advantages = torch.zeros_like(values) + gae = torch.zeros(N, 1, device=device) + + for t in reversed(range(T)): + delta = ( + rewards[t] + gamma * bootstrap_values[t] * (1.0 - dones[t]) - values[t] + ) + gae = delta + gamma * gae_lambda * (1.0 - dones[t]) * gae + advantages[t] = gae + else: + # [N, T, ...] - batch-first + N, T = rollout.batch_size[:2] + values = rollout["value"] + rewards = rollout["next"]["reward"] + dones = rollout["next"]["done"].float() + if "value" in rollout["next"]: + bootstrap_values = rollout["next"]["value"] + else: + bootstrap_values = torch.zeros_like(values) + + advantages = torch.zeros_like(values) + gae = torch.zeros(N, 1, device=device) + + for t in reversed(range(T)): + delta = ( + rewards[:, t] + + gamma * bootstrap_values[:, t] * (1.0 - dones[:, t]) + - values[:, t] + ) + gae = delta + gamma * gae_lambda * (1.0 - dones[:, t]) * gae + advantages[:, t] = gae - result = torch.cat(obs_list, dim=-1) - return result + value_targets = advantages + values + rollout["advantage"] = advantages + rollout["value_target"] = value_targets + return rollout diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 5b17a8e0..47e76e2f 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -16,8 +16,9 @@ from __future__ import annotations -from typing import Dict, Any, Tuple, Callable +import threading import time +from typing import Dict, Any, Tuple, Callable import numpy as np import torch from torch.utils.tensorboard import SummaryWriter @@ -26,7 +27,8 @@ from tensordict import TensorDict from embodichain.lab.gym.envs.managers.event_manager import EventManager -from .helper import flatten_dict_observation +from .helper import dict_to_tensordict, mean_scalar, pack_log_dict +from ..collector import SyncCollector, AsyncCollector class Trainer: @@ -37,7 +39,7 @@ def __init__( policy, env, algorithm, - num_steps: int, + buffer_size: int, batch_size: int, writer: SummaryWriter | None, eval_freq: int, @@ -49,12 +51,14 @@ def __init__( event_cfg=None, eval_event_cfg=None, num_eval_episodes: int = 5, + # Model type: "standard" (default PPO) or "vla" + model_type: str = "standard", ): self.policy = policy self.env = env self.eval_env = eval_env self.algorithm = algorithm - self.num_steps = num_steps + self.buffer_size = buffer_size self.batch_size = batch_size self.writer = writer self.eval_freq = eval_freq @@ -64,6 +68,36 @@ def __init__( self.use_wandb = use_wandb self.num_eval_episodes = num_eval_episodes + # Buffer setup (depends on model_type) + self.model_type = model_type + device = ( + algorithm.device + if hasattr(algorithm, "device") + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + # Initialize observation and get num_envs (needed for VLA buffer) + obs, _ = env.reset() + self.obs_tensordict = dict_to_tensordict(obs, device) + num_envs = self.obs_tensordict.batch_size[0] + + if model_type == "vla": + # VLA model: rollout-level buffer with (B,T) layout for correct GAE + from embodichain.agents.rl.buffer import VLABuffer + + self.buffer = VLABuffer( + buffer_size=buffer_size, device=device, num_envs=num_envs + ) + elif model_type == "standard": + # Standard PPO model: single rollout, use and discard + from embodichain.agents.rl.buffer import RolloutBuffer + + self.buffer = RolloutBuffer(buffer_size=buffer_size, device=device) + else: + raise ValueError( + f"Unknown model_type: {model_type}. Use 'standard' or 'vla'." + ) + if event_cfg is not None: self.event_manager = EventManager(event_cfg, env=self.env) if eval_event_cfg is not None: @@ -75,146 +109,194 @@ def __init__( self.start_time = time.time() self.ret_window = deque(maxlen=100) self.len_window = deque(maxlen=100) + self._stats_lock = ( + threading.Lock() + ) # Protects curr_ret, curr_len, ret_window, len_window (async mode) - # initial obs (assume env returns torch tensors already on target device) - obs, _ = self.env.reset() - - # Initialize algorithm's buffer - # Flatten TensorDict observations from ObservationManager to tensor for RL algorithms - if isinstance(obs, TensorDict): - obs_tensor = flatten_dict_observation(obs) - obs_dim = obs_tensor.shape[-1] - num_envs = obs_tensor.shape[0] - # Store flattened observation for RL training - self.obs = obs_tensor - - action_space = getattr(self.env, "action_space", None) - action_dim = action_space.shape[-1] if action_space else None - if action_dim is None: - raise RuntimeError( - "Env must expose action_space with shape for buffer initialization." - ) - - # Algorithm manages its own buffer - self.algorithm.initialize_buffer(num_steps, num_envs, obs_dim, action_dim) - - # episode stats tracked on device to avoid repeated CPU round-trips + # Episode stats tracked on device to avoid repeated CPU round-trips self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device) # ---- lightweight helpers for dense logging ---- - @staticmethod - def _mean_scalar(x) -> float: - if hasattr(x, "detach"): - x = x.detach().cpu().numpy() - else: - x = np.asarray(x) - return float(np.mean(x)) - def _log_scalar_dict(self, prefix: str, data: dict): if not self.writer or not isinstance(data, dict): return for k, v in data.items(): try: self.writer.add_scalar( - f"{prefix}/{k}", self._mean_scalar(v), self.global_step + f"{prefix}/{k}", mean_scalar(v), self.global_step ) except Exception: continue - def _pack_log_dict(self, prefix: str, data: dict) -> dict: - if not isinstance(data, dict): - return {} - out = {} - for k, v in data.items(): - try: - out[f"{prefix}/{k}"] = self._mean_scalar(v) - except Exception: - continue - return out + def _create_step_callback(self) -> Callable: + """Create step callback for collectors. + + Returns: + Callback function compatible with both sync and async collectors + """ + + def on_step(tensordict: TensorDict, env_info: dict): + """Callback called at each step during rollout collection.""" + # Extract reward and done from next subdictionary + reward = tensordict["next"]["reward"] + done = tensordict["next"]["done"] + + # Squeeze if needed + if reward.dim() > 1: + reward = reward.squeeze(-1) + if done.dim() > 1: + done = done.squeeze(-1) + + # Episode stats (thread-safe for async mode: collector writes, main reads) + with self._stats_lock: + self.curr_ret += reward + self.curr_len += 1 + done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) + if done_idx.numel() > 0: + finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() + finished_len = self.curr_len[done_idx].detach().cpu().tolist() + self.ret_window.extend(finished_ret) + self.len_window.extend(finished_len) + self.curr_ret[done_idx] = 0 + self.curr_len[done_idx] = 0 + + # Log environment metrics + if isinstance(env_info, dict): + rewards_dict = env_info.get("rewards") + metrics_dict = env_info.get("metrics") + self._log_scalar_dict("rewards", rewards_dict) + self._log_scalar_dict("metrics", metrics_dict) + log_dict = {} + log_dict.update(pack_log_dict("rewards", rewards_dict)) + log_dict.update(pack_log_dict("metrics", metrics_dict)) + if log_dict and self.use_wandb: + wandb.log(log_dict, step=self.global_step) + + return on_step def train(self, total_timesteps: int): print(f"Start training, total steps: {total_timesteps}") + print(f"Model type: {self.model_type}") + + if self.model_type == "vla": + collector = AsyncCollector( + env=self.env, + policy=self.policy, + buffer=self.buffer, + device=self.device, + on_step_callback=self._create_step_callback(), + ) + self._train_async(collector, total_timesteps) + else: + collector = SyncCollector( + env=self.env, + policy=self.policy, + device=self.device, + on_step_callback=self._create_step_callback(), + ) + self._train_sync(collector, total_timesteps) + + def _train_sync(self, collector: SyncCollector, total_timesteps: int): + """Synchronous training loop (standard PPO).""" while self.global_step < total_timesteps: - self._collect_rollout() - losses = self.algorithm.update() - self._log_train(losses) + # Collect rollout + rollout = collector.collect(num_steps=self.buffer_size) + + # Update global step (rollout is [N, T]) + num_envs = rollout.batch_size[0] + num_steps = rollout.batch_size[1] if len(rollout.batch_size) > 1 else 1 + self.global_step += num_envs * num_steps + + self.buffer.add(rollout) + + # Train when buffer is full (pass [N, T] for correct GAE) + if self.buffer.is_full(): + data = self.buffer.get(flatten=False) + losses = self.algorithm.update(data) + self._log_train(losses) + + # Evaluation if ( self.eval_freq > 0 and self.eval_env is not None and self.global_step % self.eval_freq == 0 ): self._eval_once(num_episodes=self.num_eval_episodes) + + # Checkpoint if self.global_step % self.save_freq == 0: self.save_checkpoint() - @torch.no_grad() - def _collect_rollout(self): - """Collect a rollout. Algorithm controls the data collection process.""" + def _train_async(self, collector: AsyncCollector, total_timesteps: int): + """Asynchronous training loop (VLA mode).""" + collector.start() + print("[Trainer] Async collector started") + + try: + while self.global_step < total_timesteps: + # Wait for buffer to fill + while not self.buffer.is_full(): + time.sleep(0.1) + if not collector.is_running(): + raise RuntimeError("Async collector stopped unexpectedly") + + # Get data (flatten=False so PPO gets [N, T] for GAE) + data = self.buffer.get(flatten=False) + self.buffer.clear() # get() already clears; clear() is redundant + + # Update global step (data is [N, T]) + if len(data.batch_size) >= 2: + self.global_step += data.batch_size[0] * data.batch_size[1] + else: + self.global_step += data.batch_size[0] if data.batch_size else 0 + + losses = self.algorithm.update(data) + self._log_train(losses) + + # Evaluation (pause collector during eval) + if ( + self.eval_freq > 0 + and self.eval_env is not None + and self.global_step % self.eval_freq == 0 + ): + collector.stop() + self._eval_once(num_episodes=self.num_eval_episodes) + collector.start() + + # Checkpoint + if self.global_step % self.save_freq == 0: + self.save_checkpoint() + + finally: + collector.stop() + print("[Trainer] Async collector stopped") - # Callback function for statistics and logging - def on_step(obs, actions, reward, done, info, next_obs): - """Callback called at each step during rollout collection.""" - # Episode stats (stay on device; convert only when episode ends) - self.curr_ret += reward - self.curr_len += 1 - done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) - if done_idx.numel() > 0: - finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() - finished_len = self.curr_len[done_idx].detach().cpu().tolist() - self.ret_window.extend(finished_ret) - self.len_window.extend(finished_len) - self.curr_ret[done_idx] = 0 - self.curr_len[done_idx] = 0 - - # Update global step and observation - # next_obs is already flattened in algorithm's collect_rollout - self.obs = next_obs - self.global_step += next_obs.shape[0] - - if isinstance(info, dict): - rewards_dict = info.get("rewards") - metrics_dict = info.get("metrics") - self._log_scalar_dict("rewards", rewards_dict) - self._log_scalar_dict("metrics", metrics_dict) - log_dict = {} - log_dict.update(self._pack_log_dict("rewards", rewards_dict)) - log_dict.update(self._pack_log_dict("metrics", metrics_dict)) - if log_dict and self.use_wandb: - wandb.log(log_dict, step=self.global_step) + def _log_train(self, losses: Dict[str, float]): + # Snapshot episode stats under lock (async mode: main reads, collector writes) + with self._stats_lock: + ret_list = list(self.ret_window) + len_list = list(self.len_window) - # Algorithm controls data collection - result = self.algorithm.collect_rollout( - env=self.env, - policy=self.policy, - obs=self.obs, - num_steps=self.num_steps, - on_step_callback=on_step, - ) + avgR = float(np.mean(ret_list)) if ret_list else float("nan") + avgL = float(np.mean(len_list)) if len_list else float("nan") - def _log_train(self, losses: Dict[str, float]): if self.writer: for k, v in losses.items(): self.writer.add_scalar(f"train/{k}", v, self.global_step) elapsed = max(1e-6, time.time() - self.start_time) sps = self.global_step / elapsed self.writer.add_scalar("charts/SPS", sps, self.global_step) - if len(self.ret_window) > 0: + if ret_list: self.writer.add_scalar( - "charts/episode_reward_avg_100", - float(np.mean(self.ret_window)), - self.global_step, + "charts/episode_reward_avg_100", avgR, self.global_step ) - if len(self.len_window) > 0: + if len_list: self.writer.add_scalar( - "charts/episode_length_avg_100", - float(np.mean(self.len_window)), - self.global_step, + "charts/episode_length_avg_100", avgL, self.global_step ) # console sps = self.global_step / max(1e-6, time.time() - self.start_time) - avgR = np.mean(self.ret_window) if len(self.ret_window) > 0 else float("nan") - avgL = np.mean(self.len_window) if len(self.len_window) > 0 else float("nan") print( f"[train] step={self.global_step} sps={sps:.0f} avgReward(100)={avgR:.3f} avgLength(100)={avgL:.1f}" ) @@ -240,14 +322,18 @@ def _eval_once(self, num_episodes: int = 5): num_episodes: Number of episodes to evaluate """ self.policy.eval() + if hasattr(self.policy, "bind_env"): + self.policy.bind_env(self.eval_env) + if hasattr(self.policy, "_reset_chunk_cache"): + self.policy._reset_chunk_cache() episode_returns = [] episode_lengths = [] for _ in range(num_episodes): - # Reset and initialize episode tracking + # Reset and initialize episode tracking - env returns dict, convert at boundary obs, _ = self.eval_env.reset() - obs = flatten_dict_observation(obs) - num_envs = obs.shape[0] if obs.ndim == 2 else 1 + obs = dict_to_tensordict(obs, self.device) + num_envs = obs.batch_size[0] done_mask = torch.zeros(num_envs, dtype=torch.bool, device=self.device) cumulative_reward = torch.zeros( @@ -257,28 +343,28 @@ def _eval_once(self, num_episodes: int = 5): # Run episode until all environments complete while not done_mask.all(): - # Get deterministic actions from policy - actions, _, _ = self.policy.get_action(obs, deterministic=True) - am = getattr(self.eval_env, "action_manager", None) - action_type = ( - am.action_type - if am - else getattr(self.eval_env, "action_type", "delta_qpos") + # Get deterministic actions for evaluation + obs_copy = obs.clone() + self.policy.forward(obs_copy, deterministic=True) + actions = ( + obs_copy["env_action"] + if "env_action" in obs_copy.keys() + else obs_copy["action"] ) - action_dict = {action_type: actions} + am = getattr(self.eval_env, "action_manager", None) + env_action = {am.action_type: actions} if am else actions - # Environment step - obs, reward, terminated, truncated, info = self.eval_env.step( - action_dict - ) - obs = ( - flatten_dict_observation(obs) - if isinstance(obs, TensorDict) - else obs + # Environment step - env returns dict, convert to TensorDict at boundary + next_obs, reward, terminated, truncated, info = self.eval_env.step( + env_action ) + next_obs = dict_to_tensordict(next_obs, self.device) + done = terminated | truncated + if hasattr(self.policy, "reset_envs"): + self.policy.reset_envs(done, next_obs["observation"]) + obs = next_obs # Update statistics only for still-running environments - done = terminated | truncated still_running = ~done_mask cumulative_reward[still_running] += reward[still_running].float() step_count[still_running] += 1 diff --git a/pyproject.toml b/pyproject.toml index ccc73cbb..a9b045bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", - "tensordict" + "tensordict>=0.5.0", # For TensorDict-based RL data structures ] [project.optional-dependencies] diff --git a/tests/agents/test_rl.py b/tests/agents/test_rl.py index e0a5beaf..f356502c 100644 --- a/tests/agents/test_rl.py +++ b/tests/agents/test_rl.py @@ -68,7 +68,7 @@ def setup_method(self): test_train_config = train_config.copy() test_train_config["trainer"]["gym_config"] = self.temp_gym_config_path test_train_config["trainer"]["iterations"] = 2 - test_train_config["trainer"]["rollout_steps"] = 32 + test_train_config["trainer"]["buffer_size"] = 32 test_train_config["trainer"]["eval_freq"] = 1000000 # Disable eval test_train_config["trainer"]["save_freq"] = 1000000 # Disable save test_train_config["trainer"]["headless"] = True