Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/rl/test_multi_task_agent_loop_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _fake_agent_loop():
rollout_ctl = MagicMock()
rollout_ctl.continue_generation.remote = AsyncMock()
rollout_ctl.pause_generation.remote = AsyncMock()
rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=())
agent_loop = MagicMock()
agent_loop.rollout_ctl = rollout_ctl
return agent_loop
Expand All @@ -177,7 +177,7 @@ def _fake_rollout_controller():
rollout_controller = MagicMock()
rollout_controller.continue_generation.remote = AsyncMock()
rollout_controller.pause_generation.remote = AsyncMock()
rollout_controller.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
rollout_controller.get_weight_update_targets.remote = AsyncMock(return_value=())
return rollout_controller


Expand Down
2 changes: 1 addition & 1 deletion tests/rl/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None):
mock_agent_loop = MagicMock()
mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None)
mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None)
mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
mock_agent_loop.rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=())

async def mock_pause():
await mock_agent_loop.rollout_ctl.pause_generation.remote()
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/test_rl_colocate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _build_fake_rollout_controller():
rollout_ctl = MagicMock()
rollout_ctl.continue_generation.remote = AsyncMock(return_value=None)
rollout_ctl.pause_generation.remote = AsyncMock(return_value=None)
rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=())
return rollout_ctl


Expand Down
23 changes: 23 additions & 0 deletions tests/rl/test_rl_disaggregated_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _make_trainer(self, agent_loop_manager):
trainer._benchmark_training_tokens = 0
trainer._cpu_resource_manager = None
trainer._train_worker_cfg = SimpleNamespace(pack_max_length=16)
trainer._rollout_config = SimpleNamespace(weight_update_host=None, weight_update_port=30000)
trainer._meta = SimpleNamespace(
latest_exp=SimpleNamespace(exp_dir=str(Path(self.temp_dir.name) / "exp")),
)
Expand Down Expand Up @@ -255,6 +256,28 @@ def test_fit_trains_non_empty_expired_batch_then_syncs_current_step(self):
self.assertIn(("continue_produce", 1), manager.calls)
self.assertEqual(trainer._cur_step, 1)

def test_fit_rebinds_weight_update_with_rollout_update_address(self):
# 验证非共卡后续同步权重时继续沿用 rollout config 中的 NCCL update 地址。
train_sample = SimpleNamespace(group_id=1, rollout_id=1)
manager = _FakeManager([ProduceBatchResult(rollout_states=[[train_sample]])])
trainer = self._make_trainer(manager)
trainer._rollout_config = SimpleNamespace(weight_update_host="10.0.0.1", weight_update_port=23456)

with (
patch("xtuner.v1.train.rl_trainer.asyncio_run", side_effect=asyncio.run),
patch("xtuner.v1.train.rl_trainer.bind_train_rollout") as bind_train_rollout_mock,
):
trainer.fit()

bind_train_rollout_mock.assert_called_once_with(
train_controller=trainer.train_controller,
rollout_controller=trainer.rollout_controller,
rollout_config=trainer._rollout_config,
weight_transport_type="nccl",
weight_update_host="10.0.0.1",
weight_update_port=23456,
)

def test_fit_keeps_background_producer_running_while_training_blocks(self):
# 验证非共卡训练阻塞在同步训练 batch 时,后台 producer 仍能继续调度。
train_sample = SimpleNamespace(group_id=1, rollout_id=1)
Expand Down
27 changes: 16 additions & 11 deletions tests/rl/test_rl_trainer_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self):
self.restart_inactive_workers = _RemoteMethod(return_value="rollout_restarted")
self.onload_weights = _RemoteMethod(return_value="weights_loaded")
self.onload_kvcache = _RemoteMethod(return_value="kvcache_loaded")
self.get_rollout_metadata = _RemoteMethod(return_value={"server_url_dict": {}})
self.get_weight_update_targets = _RemoteMethod(return_value=())
self.set_enable_partial_rollout = _RemoteMethod(return_value=None)
self.validate_registered_workers_to_proxy = _RemoteMethod(return_value=_AwaitableValue(None))

Expand All @@ -123,19 +123,24 @@ def __init__(self):
self.fit_steps: list[int] = []
self.saved_checkpoints: list[Path] = []
self.resume_checkpoint_paths: list[Path] = []
self.train_rollout_mode = None
self.weight_transport_type = None
self.update_weights_count = 0
self.rollout_info = None

def update_rollout_info(
self,
info,
train_rollout_mode,
weight_update_host,
weight_update_port
):
self.rollout_info = info
self.train_rollout_mode = train_rollout_mode
def bind_rollout_weight_update(
self,
*,
targets,
rollout_config,
weight_transport_type,
weight_update_host=None,
weight_update_port=None,
):
self.rollout_info = {
"targets": targets,
"rollout_config": rollout_config,
}
self.weight_transport_type = weight_transport_type
self.weight_update_host = weight_update_host
self.weight_update_port = weight_update_port

Expand Down
Loading
Loading