diff --git a/tests/rl/test_multi_task_agent_loop_manager.py b/tests/rl/test_multi_task_agent_loop_manager.py index 124e923e6..8db40aa22 100644 --- a/tests/rl/test_multi_task_agent_loop_manager.py +++ b/tests/rl/test_multi_task_agent_loop_manager.py @@ -167,7 +167,7 @@ def _fake_agent_loop(): rollout_ctl = MagicMock() rollout_ctl.continue_generation.remote = AsyncMock() rollout_ctl.pause_generation.remote = AsyncMock() - rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=()) agent_loop = MagicMock() agent_loop.rollout_ctl = rollout_ctl return agent_loop @@ -177,7 +177,7 @@ def _fake_rollout_controller(): rollout_controller = MagicMock() rollout_controller.continue_generation.remote = AsyncMock() rollout_controller.pause_generation.remote = AsyncMock() - rollout_controller.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + rollout_controller.get_weight_update_targets.remote = AsyncMock(return_value=()) return rollout_controller diff --git a/tests/rl/test_producer.py b/tests/rl/test_producer.py index 879a3ca79..3514542f8 100644 --- a/tests/rl/test_producer.py +++ b/tests/rl/test_producer.py @@ -100,7 +100,7 @@ def _build_agent_loop(self, sleep_by_id: dict[int, float] | None = None): mock_agent_loop = MagicMock() mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) - mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + mock_agent_loop.rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=()) async def mock_pause(): await mock_agent_loop.rollout_ctl.pause_generation.remote() diff --git a/tests/rl/test_rl_colocate_trainer.py b/tests/rl/test_rl_colocate_trainer.py index 389879db2..69f902d6d 100644 --- a/tests/rl/test_rl_colocate_trainer.py +++ b/tests/rl/test_rl_colocate_trainer.py @@ -75,7 +75,7 @@ def _build_fake_rollout_controller(): rollout_ctl = MagicMock() rollout_ctl.continue_generation.remote = AsyncMock(return_value=None) rollout_ctl.pause_generation.remote = AsyncMock(return_value=None) - rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}}) + rollout_ctl.get_weight_update_targets.remote = AsyncMock(return_value=()) return rollout_ctl diff --git a/tests/rl/test_rl_disaggregated_trainer.py b/tests/rl/test_rl_disaggregated_trainer.py index 2db43ceb8..80b96c6da 100644 --- a/tests/rl/test_rl_disaggregated_trainer.py +++ b/tests/rl/test_rl_disaggregated_trainer.py @@ -121,6 +121,7 @@ def _make_trainer(self, agent_loop_manager): trainer._benchmark_training_tokens = 0 trainer._cpu_resource_manager = None trainer._train_worker_cfg = SimpleNamespace(pack_max_length=16) + trainer._rollout_config = SimpleNamespace(weight_update_host=None, weight_update_port=30000) trainer._meta = SimpleNamespace( latest_exp=SimpleNamespace(exp_dir=str(Path(self.temp_dir.name) / "exp")), ) @@ -255,6 +256,28 @@ def test_fit_trains_non_empty_expired_batch_then_syncs_current_step(self): self.assertIn(("continue_produce", 1), manager.calls) self.assertEqual(trainer._cur_step, 1) + def test_fit_rebinds_weight_update_with_rollout_update_address(self): + # 验证非共卡后续同步权重时继续沿用 rollout config 中的 NCCL update 地址。 + train_sample = SimpleNamespace(group_id=1, rollout_id=1) + manager = _FakeManager([ProduceBatchResult(rollout_states=[[train_sample]])]) + trainer = self._make_trainer(manager) + trainer._rollout_config = SimpleNamespace(weight_update_host="10.0.0.1", weight_update_port=23456) + + with ( + patch("xtuner.v1.train.rl_trainer.asyncio_run", side_effect=asyncio.run), + patch("xtuner.v1.train.rl_trainer.bind_train_rollout") as bind_train_rollout_mock, + ): + trainer.fit() + + bind_train_rollout_mock.assert_called_once_with( + train_controller=trainer.train_controller, + rollout_controller=trainer.rollout_controller, + rollout_config=trainer._rollout_config, + weight_transport_type="nccl", + weight_update_host="10.0.0.1", + weight_update_port=23456, + ) + def test_fit_keeps_background_producer_running_while_training_blocks(self): # 验证非共卡训练阻塞在同步训练 batch 时,后台 producer 仍能继续调度。 train_sample = SimpleNamespace(group_id=1, rollout_id=1) diff --git a/tests/rl/test_rl_trainer_checkpoint.py b/tests/rl/test_rl_trainer_checkpoint.py index b3f2db0bc..cb2977b6c 100644 --- a/tests/rl/test_rl_trainer_checkpoint.py +++ b/tests/rl/test_rl_trainer_checkpoint.py @@ -104,7 +104,7 @@ def __init__(self): self.restart_inactive_workers = _RemoteMethod(return_value="rollout_restarted") self.onload_weights = _RemoteMethod(return_value="weights_loaded") self.onload_kvcache = _RemoteMethod(return_value="kvcache_loaded") - self.get_rollout_metadata = _RemoteMethod(return_value={"server_url_dict": {}}) + self.get_weight_update_targets = _RemoteMethod(return_value=()) self.set_enable_partial_rollout = _RemoteMethod(return_value=None) self.validate_registered_workers_to_proxy = _RemoteMethod(return_value=_AwaitableValue(None)) @@ -123,19 +123,24 @@ def __init__(self): self.fit_steps: list[int] = [] self.saved_checkpoints: list[Path] = [] self.resume_checkpoint_paths: list[Path] = [] - self.train_rollout_mode = None + self.weight_transport_type = None self.update_weights_count = 0 self.rollout_info = None - def update_rollout_info( - self, - info, - train_rollout_mode, - weight_update_host, - weight_update_port - ): - self.rollout_info = info - self.train_rollout_mode = train_rollout_mode + def bind_rollout_weight_update( + self, + *, + targets, + rollout_config, + weight_transport_type, + weight_update_host=None, + weight_update_port=None, + ): + self.rollout_info = { + "targets": targets, + "rollout_config": rollout_config, + } + self.weight_transport_type = weight_transport_type self.weight_update_host = weight_update_host self.weight_update_port = weight_update_port diff --git a/tests/rl/test_rollout_logic.py b/tests/rl/test_rollout_logic.py index cb3f73a47..ee2cb7052 100644 --- a/tests/rl/test_rollout_logic.py +++ b/tests/rl/test_rollout_logic.py @@ -23,12 +23,19 @@ from xtuner.v1.rl.agent_loop import AgentLoopConfig from xtuner.v1.rl.rollout.controller import RolloutController from xtuner.v1.rl.rollout.health_manager import RolloutHealthManager +from xtuner.v1.rl.rollout.lmdeploy import LMDeployWorker +from xtuner.v1.rl.rollout.rollout_topology import RolloutEngine, RolloutTopology, RolloutServerProcess from xtuner.v1.rl.rollout.proxy_manager import RolloutProxyManager -from xtuner.v1.rl.rollout.worker_registry import RolloutWorkerRegistry, WorkerLifecycleState, WorkerSnapshot +from xtuner.v1.rl.rollout.worker_registry import ( + RolloutWorkerRegistry, + WorkerLifecycleState, + WorkerSnapshot, +) from xtuner.v1.rl.rollout.sglang import SGLangWorker from xtuner.v1.rl.rollout.utils import PartialRolloutHandler, SessionRouter -from xtuner.v1.rl.rollout.worker import RolloutWorker +from xtuner.v1.rl.rollout.worker import RolloutWorker, RolloutWorkerInitResult from xtuner.v1.rl.utils.misc import delete_from_routedapiproxy +from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateInfo from xtuner.v1.train.rl_trainer import BaseRLTrainer, _agent_loop_manager_requires_rollout_proxy from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult @@ -38,8 +45,11 @@ def __init__(self, result): self.result = result self.calls = [] - def remote(self): - self.calls.append(()) + def remote(self, *args, **kwargs): + if kwargs: + self.calls.append((args, kwargs)) + else: + self.calls.append(args) async def _result(): if isinstance(self.result, Exception): @@ -49,6 +59,31 @@ async def _result(): return _result() +def _register_started_servers( + registry, + entries, + *, + lifecycle_state=WorkerLifecycleState.ACTIVE, +): + entries = tuple(entries) + workers_by_rank = [None] * (max((rank for rank, _actor, _server_url, _session_url in entries), default=-1) + 1) + init_results = [] + for rank, actor, server_url, session_url in entries: + workers_by_rank[rank] = actor + init_results.append( + RolloutWorkerInitResult( + rank=rank, + server_url=server_url, + session_url=session_url, + ) + ) + registry.register_started_servers( + init_results=tuple(init_results), + workers_by_rank=tuple(workers_by_rank), + lifecycle_state=lifecycle_state, + ) + + class _FakeRolloutRouter: def __init__(self, worker): self.worker = worker @@ -122,6 +157,164 @@ def test_trainer_auto_enables_rollout_proxy_when_agent_loop_requires_it(self): self.assertTrue(trainer._rollout_config.enable_proxy) +class TestRolloutTopologyAPI(unittest.TestCase): + def _rollout_config( + self, + *, + tp: int, + ep: int, + num_gpus_per_engine: int, + gpus_per_node: int = 8, + ): + return SimpleNamespace( + api_key="test-key", + tensor_parallel_size=tp, + expert_parallel_size=ep, + num_gpus_per_engine=num_gpus_per_engine, + gpus_per_node=gpus_per_node, + extra_rollout_config={"lmdeploy_backend": "pytorch"}, + ) + + def _rank_bundle_idx_list(self, num_workers: int): + return [(rank, rank) for rank in range(num_workers)] + + def _rank_to_dist_init_addr(self, num_workers: int): + return {rank: f"host{rank}:25{rank:03d}" for rank in range(num_workers)} + + def _weight_update_targets(self, topology: RolloutTopology): + registry = RolloutWorkerRegistry(rollout_topology=topology) + _register_started_servers( + registry, + ( + ( + spec.worker_rank, + object(), + f"http://worker-{spec.worker_rank}", + f"http://session-{spec.worker_rank}", + ) + for spec in topology.server_launch_specs() + ), + ) + return registry.weight_update_targets() + + def _rollout_info(self, *, config, targets, train_rank: int): + return RolloutWeightUpdateInfo.from_targets( + rollout_config=config, + weight_update_targets=targets, + train_rank=train_rank, + weight_transport_type="ipc", + ) + + def test_rollout_topology_resolves_engine_dist_init_addr_when_created(self): + rank_to_dist_init_addr = {0: "host0:25000", 1: "host1:25004"} + dist_init_addr_owner_rank = 0 + engine = RolloutEngine( + engine_ranks=(0, 1), + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + accepts_rollout_requests=True, + weight_update_ranks=(0, 1), + ), + RolloutServerProcess( + worker_rank=1, + placement_group_bundle_idxs=(1,), + weight_update_ranks=(), + accepts_rollout_requests=False, + ), + ), + ) + + topology = RolloutTopology( + engines=(engine,), + ) + + launch_specs = topology.server_launch_specs() + self.assertEqual(tuple(spec.worker_rank for spec in launch_specs), (0, 1)) + rank_0_launch_spec, rank_1_launch_spec = launch_specs + self.assertEqual(rank_0_launch_spec.dist_init_addr, "host0:25000") + self.assertEqual(rank_1_launch_spec.dist_init_addr, "host0:25000") + self.assertEqual(rank_0_launch_spec.engine_rank, 0) + self.assertEqual(rank_1_launch_spec.engine_rank, 1) + self.assertEqual(rank_1_launch_spec.placement_group_bundle_idxs, (1,)) + self.assertTrue(topology.is_request_entrypoint_rank(0)) + self.assertFalse(topology.is_request_entrypoint_rank(1)) + self.assertEqual(topology.lifecycle_group_for_server_rank(1), (0, 1)) + self.assertEqual( + tuple( + (server.worker_rank, server.weight_update_ranks) + for server in topology.weight_update_endpoint_processes() + ), + ((0, (0, 1)),), + ) + + def test_lmdeploy_tp16_weight_update_targets_match_legacy_mesh_and_url_semantics(self): + config = self._rollout_config(tp=16, ep=1, num_gpus_per_engine=16) + topology = LMDeployWorker.build_rollout_topology( + config, + self._rank_bundle_idx_list(16), + self._rank_to_dist_init_addr(16), + ) + targets = self._weight_update_targets(topology) + + self.assertEqual( + tuple((target.endpoint_rank, target.update_ranks) for target in targets), + ((0, tuple(range(16))),), + ) + self.assertEqual(self._rollout_info(config=config, targets=targets, train_rank=0).rollout_url, "http://worker-0") + self.assertIsNone(self._rollout_info(config=config, targets=targets, train_rank=1).rollout_url) + self.assertEqual( + self._rollout_info(config=config, targets=targets, train_rank=1).ipc_rank_mesh, + (tuple(range(16)),), + ) + + def test_lmdeploy_ep16_weight_update_targets_match_legacy_mesh_and_url_semantics(self): + config = self._rollout_config(tp=1, ep=16, num_gpus_per_engine=16) + topology = LMDeployWorker.build_rollout_topology( + config, + self._rank_bundle_idx_list(16), + self._rank_to_dist_init_addr(16), + ) + targets = self._weight_update_targets(topology) + + self.assertEqual( + tuple((target.endpoint_rank, target.update_ranks) for target in targets), + tuple((rank, (rank,)) for rank in range(16)), + ) + self.assertEqual(self._rollout_info(config=config, targets=targets, train_rank=0).rollout_url, "http://worker-0") + self.assertEqual( + self._rollout_info(config=config, targets=targets, train_rank=15).rollout_url, + "http://worker-15", + ) + self.assertEqual( + self._rollout_info(config=config, targets=targets, train_rank=0).ipc_rank_mesh, + tuple((rank,) for rank in range(16)), + ) + + def test_sglang_tp16_cross_node_weight_update_targets_match_legacy_mesh_and_url_semantics(self): + config = self._rollout_config(tp=16, ep=1, num_gpus_per_engine=16, gpus_per_node=8) + topology = SGLangWorker.build_rollout_topology( + config, + self._rank_bundle_idx_list(16), + self._rank_to_dist_init_addr(16), + ) + targets = self._weight_update_targets(topology) + + self.assertEqual(tuple(spec.worker_rank for spec in topology.server_launch_specs()), (0, 8)) + self.assertEqual( + tuple((target.endpoint_rank, target.update_ranks) for target in targets), + ((0, tuple(range(16))),), + ) + self.assertEqual(self._rollout_info(config=config, targets=targets, train_rank=0).rollout_url, "http://worker-0") + self.assertIsNone(self._rollout_info(config=config, targets=targets, train_rank=8).rollout_url) + self.assertEqual( + self._rollout_info(config=config, targets=targets, train_rank=8).ipc_rank_mesh, + (tuple(range(16)),), + ) + + class TestRolloutController(unittest.IsolatedAsyncioTestCase): def _state(self, uid: int, session_id: int) -> RolloutState: return RolloutState( @@ -142,6 +335,27 @@ def _build_controller(self, router): controller.logger = MagicMock() return controller + def _build_registry(self, ranks): + rollout_topology = RolloutTopology( + engines=tuple( + RolloutEngine( + engine_ranks=(rank,), + dist_init_addr=f"addr{rank}", + server_processes=( + RolloutServerProcess( + worker_rank=rank, + placement_group_bundle_idxs=(rank,), + weight_update_ranks=(rank,), + ), + ), + ) + for rank in ranks + ), + ) + return RolloutWorkerRegistry( + rollout_topology=rollout_topology, + ) + async def test_generate_fails_fast_when_no_active_worker(self): # router 找不到 active worker 时,controller 应直接把原样本标成 FAILED,避免请求悬挂。 state = self._state(uid=1, session_id=123) @@ -175,20 +389,13 @@ async def test_generate_routes_to_active_worker(self): def test_register_active_workers_to_proxy_delegates_active_session_urls(self): controller = RolloutController.__new__(RolloutController) - controller.registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0], [1]], rollout_config=SimpleNamespace()) - controller.registry.register_started_server( - rank=0, - actor=object(), - server_url="http://worker-0", - session_url="http://session-0", - is_request_entrypoint=True, - ) - controller.registry.register_started_server( - rank=1, - actor=object(), - server_url="http://worker-1", - session_url="http://session-1", - is_request_entrypoint=True, + controller.registry = self._build_registry((0, 1)) + _register_started_servers( + controller.registry, + ( + (0, object(), "http://worker-0", "http://session-0"), + (1, object(), "http://worker-1", "http://session-1"), + ), ) controller.registry.mark_unhealthy_ranks({1}) controller.proxy_manager = MagicMock() @@ -199,7 +406,7 @@ def test_register_active_workers_to_proxy_delegates_active_session_urls(self): def test_register_active_workers_to_proxy_noops_without_proxy_manager(self): controller = RolloutController.__new__(RolloutController) - controller.registry = RolloutWorkerRegistry(engine_rank_mesh_array=[], rollout_config=SimpleNamespace()) + controller.registry = MagicMock() controller.proxy_manager = None controller.register_active_workers_to_proxy() @@ -360,34 +567,59 @@ class TestRolloutWorkerRegistry(unittest.TestCase): def _worker_by_rank(self, registry, rank): return next(worker for worker in registry.all_workers() if worker.rank == rank) - def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): - config = SimpleNamespace() - registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0, 1]], rollout_config=config) - registry.register_started_server( - rank=0, - actor=object(), - server_url="http://worker-0", - session_url="http://session-0", - lifecycle_group_ranks=(0, 1), - is_request_entrypoint=True, - ) - registry.register_started_server( - rank=1, - actor=object(), - server_url="http://worker-1", - session_url=None, - lifecycle_group_ranks=(0, 1), - is_request_entrypoint=False, + def _runtime_layout( + self, + *, + engine_ranks=(0,), + server_processes=None, + ): + if server_processes is None: + server_processes = ( + RolloutServerProcess( + worker_rank=engine_ranks[0], + placement_group_bundle_idxs=tuple(range(len(engine_ranks))), + accepts_rollout_requests=True, + weight_update_ranks=tuple(engine_ranks), + ), + ) + dist_init_addr_owner_rank = server_processes[0].worker_rank + return RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=tuple(engine_ranks), + dist_init_addr=f"addr{dist_init_addr_owner_rank}", + server_processes=tuple(server_processes), + ), + ), ) - metadata = registry.training_metadata_snapshot() + def test_registry_filters_entrypoints_and_tracks_lifecycle(self): + runtime_layout = self._runtime_layout( + engine_ranks=(0, 1), + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + accepts_rollout_requests=True, + weight_update_ranks=(0, 1), + ), + RolloutServerProcess( + worker_rank=1, + placement_group_bundle_idxs=(1,), + weight_update_ranks=(), + accepts_rollout_requests=False, + ), + ), + ) + registry = RolloutWorkerRegistry(rollout_topology=runtime_layout) + _register_started_servers( + registry, + ( + (0, object(), "http://worker-0", "http://session-0"), + (1, object(), "http://worker-1", None), + ), + ) - self.assertEqual(metadata["engine_rank_mesh_array"], [[0, 1]]) - self.assertIs(metadata["rollout_config"], config) - self.assertEqual(metadata["server_url_dict"], {0: "http://worker-0"}) - self.assertEqual(metadata["worker_server_urls_status"], {"http://worker-0": True}) - self.assertEqual(metadata["worker_session_url_dict"], {0: "http://session-0"}) - self.assertEqual(metadata["worker_session_urls_status"], {"http://session-0": True}) active_entrypoint = registry.active_entrypoints()[0] self.assertIsInstance(active_entrypoint, WorkerSnapshot) self.assertEqual(active_entrypoint.rank, 0) @@ -395,11 +627,8 @@ def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): active_entrypoint.lifecycle_state = WorkerLifecycleState.INACTIVE unhealthy_groups = registry.mark_unhealthy_ranks({0}) - metadata = registry.training_metadata_snapshot() self.assertEqual(unhealthy_groups[0].ranks, (0, 1)) - self.assertEqual(metadata["worker_server_urls_status"], {"http://worker-0": False}) - self.assertEqual(metadata["worker_session_urls_status"], {"http://session-0": False}) self.assertEqual(tuple(worker.rank for worker in registry.inactive_workers()), (0, 1)) self.assertEqual(registry.active_entrypoints(), ()) claimed_groups = registry.claim_inactive_groups_for_recovery() @@ -408,13 +637,65 @@ def test_registry_filters_entrypoints_and_builds_metadata_snapshot(self): registry.set_group_recovery_result(claimed_groups[0], recovered=False) self.assertEqual(self._worker_by_rank(registry, 0).lifecycle_state, WorkerLifecycleState.INACTIVE) + def test_registry_projects_weight_update_targets_from_topology_and_runtime_state(self): + runtime_layout = self._runtime_layout(engine_ranks=(0, 1)) + registry = RolloutWorkerRegistry(rollout_topology=runtime_layout) + _register_started_servers( + registry, + ((0, object(), "http://worker-0", "http://session-0"),), + ) + + targets = registry.weight_update_targets() + + self.assertEqual(len(targets), 1) + target = targets[0] + self.assertEqual(target.endpoint_rank, 0) + self.assertEqual(target.update_ranks, (0, 1)) + self.assertEqual(target.engine_size, 2) + self.assertEqual(target.server_url, "http://worker-0") + self.assertEqual(target.lifecycle_state, WorkerLifecycleState.ACTIVE.value) + self.assertTrue(target.is_active) + class TestSessionRouter(unittest.IsolatedAsyncioTestCase): async def test_sticky_session_reselects_when_previous_entrypoint_is_inactive(self): actor_0 = object() actor_1 = object() - registry = RolloutWorkerRegistry(engine_rank_mesh_array=[[0], [1]], rollout_config=SimpleNamespace()) - registry.register_started_server(rank=0, actor=actor_0, server_url="http://worker-0") - registry.register_started_server(rank=1, actor=actor_1, server_url="http://worker-1") + rollout_topology = RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=(0,), + dist_init_addr="addr0", + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + weight_update_ranks=(0,), + ), + ), + ), + RolloutEngine( + engine_ranks=(1,), + dist_init_addr="addr1", + server_processes=( + RolloutServerProcess( + worker_rank=1, + placement_group_bundle_idxs=(1,), + weight_update_ranks=(1,), + ), + ), + ), + ), + ) + registry = RolloutWorkerRegistry( + rollout_topology=rollout_topology, + ) + _register_started_servers( + registry, + ( + (0, actor_0, "http://worker-0", "http://session-0"), + (1, actor_1, "http://worker-1", "http://session-1"), + ), + ) router = SessionRouter(registry, max_idle_seconds=None) self.assertIs(await router.get_worker(7), actor_0) @@ -519,6 +800,90 @@ async def test_pause_generation_sets_abort_flag(self): self.assertTrue(worker.receive_abort_request.is_set()) worker._send_abort_request.assert_awaited_once_with() + def test_init_binds_launch_spec_and_skips_session_server_for_non_entrypoint(self): + topology = RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=(0, 1), + dist_init_addr="host0:25000", + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + accepts_rollout_requests=True, + weight_update_ranks=(0, 1), + ), + RolloutServerProcess( + worker_rank=1, + placement_group_bundle_idxs=(1,), + accepts_rollout_requests=False, + weight_update_ranks=(), + ), + ), + ), + ), + ) + launch_spec_by_rank = {spec.worker_rank: spec for spec in topology.server_launch_specs()} + worker = RolloutWorker.__new__(RolloutWorker) + worker.rank = 1 + worker.server_launch_spec = None + worker.receive_abort_request = threading.Event() + worker.receive_abort_request.set() + worker.server_url = "http://worker-1" + worker.session_server_url = None + worker._launch_server = MagicMock() + worker.session_server_actor = None + + result = worker.init(launch_spec_by_rank[1]) + + worker._launch_server.assert_called_once_with() + self.assertIsNone(worker.session_server_actor) + self.assertFalse(worker.receive_abort_request.is_set()) + self.assertIs(worker.server_launch_spec, launch_spec_by_rank[1]) + self.assertEqual(result.rank, 1) + self.assertEqual(result.server_url, "http://worker-1") + self.assertIsNone(result.session_url) + + def test_reinit_reuses_bound_launch_spec(self): + topology = RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=(0,), + dist_init_addr="host0:25000", + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0,), + accepts_rollout_requests=True, + weight_update_ranks=(0,), + ), + ), + ), + ), + ) + launch_spec = topology.server_launch_specs()[0] + worker = RolloutWorker.__new__(RolloutWorker) + worker.rank = 0 + worker.server_launch_spec = launch_spec + worker.receive_abort_request = threading.Event() + worker.server_url = "http://worker-0" + worker.session_server_url = None + worker._launch_server = MagicMock() + + def start_session_server(): + worker.session_server_url = "http://session-0" + + worker._start_session_server = MagicMock(side_effect=start_session_server) + + result = worker.reinit() + + worker._launch_server.assert_called_once_with() + worker._start_session_server.assert_called_once_with() + self.assertIs(worker.server_launch_spec, launch_spec) + self.assertEqual(result.rank, 0) + self.assertEqual(result.server_url, "http://worker-0") + self.assertEqual(result.session_url, "http://session-0") + async def test_safe_post_request_returns_aborted_without_sending_when_abort_flag_is_set(self): # safe post 在发送前发现 abort flag 时,应直接返回 REQUEST_ABORTED,不再发 HTTP 请求。 worker = RolloutWorker.__new__(RolloutWorker) @@ -667,20 +1032,41 @@ def _worker_by_rank(self, registry, rank): return next(worker for worker in registry.all_workers() if worker.rank == rank) def _build_registry(self, workers_info): + engines = [] + for rank in sorted(workers_info): + engines.append( + RolloutEngine( + engine_ranks=(rank,), + dist_init_addr=f"addr{rank}", + server_processes=( + RolloutServerProcess( + worker_rank=rank, + placement_group_bundle_idxs=(rank,), + accepts_rollout_requests=True, + weight_update_ranks=(rank,), + ), + ), + ) + ) + rollout_topology = RolloutTopology( + engines=tuple(engines), + ) registry = RolloutWorkerRegistry( - engine_rank_mesh_array=[sorted(workers_info)], - rollout_config=SimpleNamespace(), + rollout_topology=rollout_topology, + ) + _register_started_servers( + registry, + ( + ( + rank, + worker_info.actor, + worker_info.url, + worker_info.session_url or f"http://session-{rank}", + ) + for rank, worker_info in workers_info.items() + ), ) for rank, worker_info in workers_info.items(): - lifecycle_group_ranks = worker_info.lifecycle_group_ranks or (rank,) - registry.register_started_server( - rank=rank, - actor=worker_info.actor, - server_url=worker_info.url, - session_url=worker_info.session_url, - lifecycle_group_ranks=lifecycle_group_ranks, - is_request_entrypoint=worker_info.is_request_entrypoint, - ) if worker_info.lifecycle_state is WorkerLifecycleState.INACTIVE: registry.mark_unhealthy_ranks({rank}) return registry @@ -710,7 +1096,7 @@ def _build_manager( def test_marks_worker_inactive_after_consecutive_health_failures(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") workers_info = {0: worker_info} inactive_groups = [] listener = SimpleNamespace( @@ -738,7 +1124,7 @@ def test_marks_worker_inactive_after_consecutive_health_failures(self): def test_inactive_listener_runs_under_operation_lock(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") lock_acquired_by_listener = [] manager, _ = self._build_manager({0: worker_info}, failure_threshold=1) @@ -764,6 +1150,7 @@ def test_inactive_worker_is_not_cleaned_up_again(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) workers_info = { 0: WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -779,7 +1166,7 @@ def test_inactive_worker_is_not_cleaned_up_again(self): def test_health_check_threshold_zero_disables_periodic_health_check(self): # threshold <= 0 表示关闭周期健康监测,不应把 active worker 直接判 inactive。 actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, registry = self._build_manager({0: worker_info}, failure_threshold=0) checked_count = manager._check_and_deactivate_failed_worker_groups() @@ -790,7 +1177,7 @@ def test_health_check_threshold_zero_disables_periodic_health_check(self): def test_fail_fast_health_check_still_runs_when_periodic_health_check_is_disabled(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(False)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, registry = self._build_manager({0: worker_info}, failure_threshold=0) checked_count = manager._check_and_deactivate_failed_worker_groups(fail_fast=True) @@ -801,7 +1188,7 @@ def test_fail_fast_health_check_still_runs_when_periodic_health_check_is_disable def test_health_check_uses_configured_timeout(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) - worker_info = WorkerSnapshot(actor=actor, url="http://worker-0") + worker_info = WorkerSnapshot(rank=0, actor=actor, url="http://worker-0") manager, _ = self._build_manager({0: worker_info}, check_timeout=2.5) observed_timeouts = [] @@ -817,6 +1204,7 @@ async def fake_wait_for(awaitable, timeout): def test_shutdown_barrier_keeps_failed_shutdown_group_inactive(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -838,6 +1226,7 @@ def test_shutdown_barrier_keeps_failed_shutdown_group_inactive(self): def test_restart_barrier_keeps_failed_recovery_group_inactive(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, @@ -859,6 +1248,7 @@ def test_restart_barrier_keeps_failed_recovery_group_inactive(self): def test_restart_barrier_notifies_recovered_group_after_success(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", session_url="http://session-0", @@ -881,9 +1271,52 @@ def test_restart_barrier_notifies_recovered_group_after_success(self): self.assertEqual([group.ranks for group in recovered_groups], [(0,)]) self.assertTrue(all(worker.is_active() for worker in recovered_groups[0].workers)) + def test_restart_worker_group_uses_reinit(self): + init_result = RolloutWorkerInitResult( + rank=0, + server_url="http://worker-0", + session_url="http://session-0", + ) + actor = SimpleNamespace( + set_skip_load_weights=_FakeAsyncRemoteMethod(None), + init=_FakeAsyncRemoteMethod(init_result), + reinit=_FakeAsyncRemoteMethod(init_result), + check_health=_FakeAsyncRemoteMethod(True), + offload=_FakeAsyncRemoteMethod(None), + restore_skip_load_weights=_FakeAsyncRemoteMethod(None), + ) + worker_info = WorkerSnapshot( + rank=0, + actor=actor, + url="http://worker-0", + session_url="http://session-0", + lifecycle_state=WorkerLifecycleState.INACTIVE, + ) + manager, registry = self._build_manager({0: worker_info}) + group = registry.claim_inactive_groups_for_recovery()[0] + + def fake_ray_get(refs, timeout=None): + del timeout + return [asyncio.run(ref) for ref in refs] + + with ( + patch.object(manager, "_shutdown_worker_group", return_value=True), + patch("xtuner.v1.rl.rollout.health_manager.ray.get", side_effect=fake_ray_get), + ): + result = manager._restart_worker_group(group) + + self.assertTrue(result) + self.assertEqual(actor.set_skip_load_weights.calls, [(True,)]) + self.assertEqual(actor.reinit.calls, [()]) + self.assertEqual(actor.init.calls, []) + self.assertEqual(actor.check_health.calls, [()]) + self.assertEqual(actor.offload.calls, [()]) + self.assertEqual(actor.restore_skip_load_weights.calls, [()]) + def test_recovered_listener_runs_under_operation_lock(self): actor = SimpleNamespace(check_health=_FakeAsyncRemoteMethod(True)) worker_info = WorkerSnapshot( + rank=0, actor=actor, url="http://worker-0", lifecycle_state=WorkerLifecycleState.INACTIVE, diff --git a/tests/rl/test_update_weight_disaggregated.py b/tests/rl/test_update_weight_disaggregated.py index b53c9121d..7eea95977 100644 --- a/tests/rl/test_update_weight_disaggregated.py +++ b/tests/rl/test_update_weight_disaggregated.py @@ -118,12 +118,8 @@ def init_config(self): ) def _check_sglang_weights(self, rollout_controller, action): - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - active_urls = [ - url - for url, is_active in info_dict["worker_server_urls_status"].items() - if is_active - ] + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) + active_urls = [target.server_url for target in targets if target.is_active] self.assertGreater(len(active_urls), 0) results = [] for url in active_urls: @@ -159,8 +155,12 @@ def test_sglang_disaggregated_update_weight_and_generate(self): input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) + train_controller.bind_rollout_weight_update( + targets=targets, + rollout_config=self.rollout_cfg, + weight_transport_type="nccl", + ) train_controller.update_weights() res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) @@ -194,8 +194,12 @@ def test_sglang_disaggregated_update_weight_equal_after_reset(self): self._check_sglang_weights(rollout_controller, action="snapshot_parameters") self._check_sglang_weights(rollout_controller, action="reset_parameters") - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) + train_controller.bind_rollout_weight_update( + targets=targets, + rollout_config=self.rollout_cfg, + weight_transport_type="nccl", + ) train_controller.update_weights() self._check_sglang_weights(rollout_controller, action="compare_parameters") @@ -229,8 +233,12 @@ def test_lmdeploy_disaggregated_update_weight_and_generate(self): input_state = RolloutState(message=TEST_TEXT_MESSAGES, sample_params=sample_params) res_baseline = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) - train_controller.update_rollout_info(info_dict, train_rollout_mode="disaggregated") + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) + train_controller.bind_rollout_weight_update( + targets=targets, + rollout_config=self.rollout_cfg, + weight_transport_type="nccl", + ) train_controller.update_weights() res_update_weight = ray.get(rollout_controller.generate.remote(rollout_state=input_state)) diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index ae4572334..b4309b285 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -1,5 +1,5 @@ import asyncio -from typing import TypeAlias +from typing import Any, TypeAlias from uuid import uuid4 import ray @@ -8,18 +8,20 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status from xtuner.v1.rl.utils import AutoAcceleratorWorkers +from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateTarget from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger from .constants import ROLLOUT_RAY_GENERATE_MAX_CONCURRENCY from .health_manager import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthManager from .proxy_manager import RolloutProxyManager +from .rollout_topology import RolloutTopology from .utils import SessionRouter from .worker import ( ROLLOUT_CONCURRENCY_GROUP_GENERATE, RolloutConfig, get_rollout_worker_base_cls, ) -from .worker_registry import RolloutWorkerMetadata, RolloutWorkerRegistry +from .worker_registry import RolloutWorkerRegistry # Keep this as a Ray actor because Ray AgentLoop actors need a shared, cross-process handle to the same controller @@ -63,17 +65,9 @@ def __init__( ) self.health_manager.start() - def get_rollout_metadata(self) -> RolloutWorkerMetadata: - """Get information about the current rollout setup. - - Returns: - dict: A dictionary containing the engine mesh list, server URL - dictionary, and the rollout configuration. - """ - rollout_metadata = self.registry.training_metadata_snapshot() - self.logger.info(f"Rollout worker server URLs: {rollout_metadata['server_url_dict']}") - self.logger.info(f"Rollout worker session server URLs: {rollout_metadata['worker_session_url_dict']}") - return rollout_metadata + def get_weight_update_targets(self) -> tuple[RolloutWeightUpdateTarget, ...]: + """Return rollout endpoints that can receive weight update requests.""" + return self.registry.weight_update_targets() def register_active_workers_to_proxy(self) -> None: if self.proxy_manager is None: @@ -203,98 +197,87 @@ def _build_remote_worker_cls(self, worker_base_cls): }, )(worker_base_cls) - def _init_workers(self, placement_group: PlacementGroup) -> RolloutWorkerRegistry: - """Initializes and configures the pool of RolloutWorker actors. - - This method follows the same high-level flow as the legacy implementation: - create workers, initialize worker-local ports, build engine groups, - select workers that launch rollout servers, launch servers, and - expose request-entrypoint server URLs to rollout traffic. + def _create_worker_actors( + self, + placement_group: PlacementGroup, + ) -> tuple[tuple[Any, ...], tuple[tuple[int, int], ...]]: + """Create rollout worker actors. - Returns: - A registry containing all server-process workers and the public - training metadata mesh. + Returns workers_by_rank, which is indexed by rollout worker rank, and rank_bundle_indices, which maps worker + ranks to placement-group bundles. """ worker_base_cls = get_rollout_worker_base_cls(self.config) worker_cls = self._build_remote_worker_cls(worker_base_cls) - - # Create workers from placement group. - workers, rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( + workers, rank_bundle_indices = AutoAcceleratorWorkers.from_placement_group( worker_cls, self.config, placement_group ) - rank_to_actor = {rank: worker for (rank, _), worker in zip(rank_bundle_idx_list, workers)} - - # Reserve worker-local ports for all actors first. build_engine_launch_specs - # uses the returned addresses to bind each ServerProcessSpec to its - # logical engine rendezvous address; only server-process owners call init(). - rank_to_dist_init_addr = { - rank: dist_init_addr - for (rank, _), dist_init_addr in zip( - rank_bundle_idx_list, - ray.get([worker.init_dist_port.remote() for worker in workers]), # type: ignore[attr-defined] - ) - } + workers_by_rank = tuple(workers) + return workers_by_rank, tuple(rank_bundle_indices) + + def _initialize_worker_ports_and_build_rollout_topology( + self, + workers_by_rank: tuple[Any, ...], + rank_bundle_indices: tuple[tuple[int, int], ...], + ) -> RolloutTopology: + """Initialize worker-local dist ports and build rollout topology. - # Build engine groups and server-process specs from the rank/bundle mapping. - engine_launch_specs = worker_base_cls.build_engine_launch_specs( + This performs the Ray init_dist_port handshake before building the topology, so the returned layout is bound to + runtime worker addresses. + """ + dist_init_results = ray.get( + [ + worker.init_dist_port.remote() # type: ignore[attr-defined] + for worker in workers_by_rank + ] + ) + worker_base_cls = get_rollout_worker_base_cls(self.config) + return worker_base_cls.build_rollout_topology( self.config, - rank_bundle_idx_list, - rank_to_dist_init_addr, + list(rank_bundle_indices), + dict(dist_init_results), ) - # Keep the public metadata mesh compatible with origin/main. Backends - # may expose a different update-weight mesh than their internal launch - # topology, e.g. LMDeploy EP has one logical engine but one public entry - # per request-serving EP rank. - engine_rank_mesh_array = worker_base_cls.build_metadata_engine_rank_mesh_array(engine_launch_specs) - - # Launch every server process described by the backend-specific specs. - server_rank_to_url = dict( - ray.get( - [ - rank_to_actor[server_process.worker_rank].init.remote( # type: ignore[attr-defined] - engine_launch_spec=engine_spec, - ) - for engine_spec in engine_launch_specs - for server_process in engine_spec.server_processes - ] - ) + + def _init_workers( + self, + placement_group: PlacementGroup, + ) -> RolloutWorkerRegistry: + """Initializes and configures the pool of RolloutWorker actors. + + This method follows the same high-level flow as the legacy implementation: + create workers, initialize worker-local ports, build the bound rollout + topology, launch rollout servers, and expose request-entrypoint server + URLs to rollout traffic. + + Returns: + A registry containing all server-process workers and runtime state. + """ + workers_by_rank, rank_bundle_indices = self._create_worker_actors(placement_group) + rollout_topology = self._initialize_worker_ports_and_build_rollout_topology( + workers_by_rank, + rank_bundle_indices, ) - session_url_by_rank = dict( + init_results = tuple( ray.get( [ - ( - rank_to_actor[server_process.worker_rank].get_session_server_info.remote() # type: ignore[attr-defined] - ) - for engine_spec in engine_launch_specs - for server_process in engine_spec.server_processes + workers_by_rank[launch_spec.worker_rank].init.remote(launch_spec) # type: ignore[attr-defined] + for launch_spec in rollout_topology.server_launch_specs() ] ) ) - registry = RolloutWorkerRegistry( - engine_rank_mesh_array=engine_rank_mesh_array, - rollout_config=self.config, + registry = RolloutWorkerRegistry(rollout_topology=rollout_topology) + registry.register_started_servers( + init_results=init_results, + workers_by_rank=workers_by_rank, + ) + + self.logger.info( + "Rollout worker registry snapshot: " + f"weight_update_targets={registry.weight_update_targets()}, " + f"active_entrypoints={registry.active_entrypoints()}, " + f"server_process_urls={[worker.url for worker in registry.all_workers()]}, " + f"lifecycle_groups={registry.lifecycle_groups()}" ) - for engine_spec in engine_launch_specs: - for server_process in engine_spec.server_processes: - rank = server_process.worker_rank - url = server_rank_to_url[rank] - session_url = session_url_by_rank.get(rank) - if server_process.accepts_rollout_requests and session_url is None: - raise RuntimeError(f"Rollout worker rank={rank} did not return session server URL during init.") - registry.register_started_server( - rank=rank, - actor=rank_to_actor[rank], - server_url=url, - session_url=session_url, - lifecycle_group_ranks=engine_spec.server_worker_ranks, - is_request_entrypoint=server_process.accepts_rollout_requests, - ) - - server_process_workers_info = registry.all_workers() - self.logger.info(f"Rollout server-process worker URLs: {[info.url for info in server_process_workers_info]}") - lifecycle_groups = sorted({info.lifecycle_group_ranks for info in server_process_workers_info}) - self.logger.info(f"Rollout worker lifecycle groups: {lifecycle_groups}") return registry diff --git a/xtuner/v1/rl/rollout/health_manager.py b/xtuner/v1/rl/rollout/health_manager.py index 846cd37bd..66d4009e5 100644 --- a/xtuner/v1/rl/rollout/health_manager.py +++ b/xtuner/v1/rl/rollout/health_manager.py @@ -486,9 +486,9 @@ def _restart_worker_group( ) init_results = ray.get( [ - # init() reuses the immutable launch spec cached on each actor - # during controller startup, including placement bundles and dist addr. - worker.actor.init.remote() # type: ignore[attr-defined] + # reinit() reuses the server launch spec bound during + # controller startup. + worker.actor.reinit.remote() # type: ignore[attr-defined] for worker in group.workers ], timeout=ROLLOUT_RAY_GET_TIMEOUT, @@ -505,11 +505,11 @@ def _restart_worker_group( return False for worker, init_result in zip(group.workers, init_results): - init_rank, init_url = init_result - if init_rank != worker.rank or init_url != worker.url: + if init_result.rank != worker.rank or init_result.server_url != worker.url: logger.error( f"Rollout worker restart returned unexpected endpoint: rank={worker.rank}, " - f"init_rank={init_rank}, expected_url={worker.url}, init_url={init_url}." + f"init_rank={init_result.rank}, expected_url={worker.url}, " + f"init_url={init_result.server_url}." ) self._shutdown_worker_group(group, wait_server_down=False, best_effort=True) return False diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index 7c4506238..94d8f1ddd 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -1,6 +1,6 @@ import os from argparse import Namespace -from typing import Any, Dict, List +from typing import Any, Dict, List, Mapping import numpy as np import ray @@ -10,7 +10,8 @@ from transformers import AutoTokenizer from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams -from .worker import EngineLaunchSpec, EngineLaunchSpecs, RolloutConfig, RolloutWorker, ServerProcessSpec +from .rollout_topology import RolloutEngine, RolloutServerProcess, RolloutTopology +from .worker import RolloutConfig, RolloutWorker SHARED_STORE = "shared_store" @@ -80,118 +81,133 @@ def __init__( self.lmdeploy_actor = None @classmethod - def build_engine_launch_specs( + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build LMDeploy server launch layout. - - LMDeploy EP starts one request-serving server per EP rank. - - Example with expert_parallel_size=2: - rank_bundle_idx_list is [(0, 0), (1, 1), (2, 2), (3, 3)]. - rank identifies the rollout worker; bundle idx identifies the Ray - placement-group bundle that owns the GPU resource. - - If rank_to_dist_init_addr is: - {0: "addr0", 1: "addr1", 2: "addr2", 3: "addr3"} - - The launch specs are: - EngineLaunchSpec( - engine_ranks=(0, 1), - server_processes=( - ServerProcessSpec( - worker_rank=0, - placement_group_bundle_idxs=(0,), - dist_init_addr="addr0", - ), - ServerProcessSpec( - worker_rank=1, - placement_group_bundle_idxs=(1,), - dist_init_addr="addr0", - ), - ), - ) - EngineLaunchSpec( - engine_ranks=(2, 3), - server_processes=( - ServerProcessSpec( - worker_rank=2, - placement_group_bundle_idxs=(2,), - dist_init_addr="addr2", - ), - ServerProcessSpec( - worker_rank=3, - placement_group_bundle_idxs=(3,), - dist_init_addr="addr2", - ), - ), - ) - - Each EP rank launches a server process, so server_worker_ranks is the - same as engine_ranks, and every server accepts rollout requests. + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + """Build LMDeploy rollout topology with bound engine dist-init + addresses. + + ``rank_bundle_idx_list`` stores ``(worker_rank, bundle_idx)`` pairs. + + Example with ranks [(0, 0), (1, 1), (2, 2), (3, 3)] and addrs + {0: "addr0", 1: "addr1", 2: "addr2", 3: "addr3"}: + + +------+------------------------------------------------------------------+ + | Mode | RolloutEngine topology | + +------+------------------------------------------------------------------+ + | TP | RolloutEngine( | + | | engine_ranks=(0, 1), | + | | dist_init_addr="addr0", | + | | server_processes=( | + | | RolloutServerProcess( | + | | worker_rank=0, | + | | placement_group_bundle_idxs=(0, 1), | + | | weight_update_ranks=(0, 1), | + | | ), | + | | ), | + | | ) | + | | RolloutEngine( | + | | engine_ranks=(2, 3), | + | | dist_init_addr="addr2", | + | | server_processes=( | + | | RolloutServerProcess( | + | | worker_rank=2, | + | | placement_group_bundle_idxs=(2, 3), | + | | weight_update_ranks=(2, 3), | + | | ), | + | | ), | + | | ) | + +------+------------------------------------------------------------------+ + | EP | RolloutEngine( | + | | engine_ranks=(0, 1), | + | | dist_init_addr="addr0", | + | | server_processes=( | + | | RolloutServerProcess( | + | | worker_rank=0, | + | | placement_group_bundle_idxs=(0,), | + | | weight_update_ranks=(0,), | + | | ), | + | | RolloutServerProcess( | + | | worker_rank=1, | + | | placement_group_bundle_idxs=(1,), | + | | weight_update_ranks=(1,), | + | | ), | + | | ), | + | | ) | + | | RolloutEngine( | + | | engine_ranks=(2, 3), | + | | dist_init_addr="addr2", | + | | server_processes=( | + | | RolloutServerProcess( | + | | worker_rank=2, | + | | placement_group_bundle_idxs=(2,), | + | | weight_update_ranks=(2,), | + | | ), | + | | RolloutServerProcess( | + | | worker_rank=3, | + | | placement_group_bundle_idxs=(3,), | + | | weight_update_ranks=(3,), | + | | ), | + | | ), | + | | ) | + +------+------------------------------------------------------------------+ """ - if config.expert_parallel_size <= 1: - return RolloutWorker.build_engine_launch_specs( - config, - rank_bundle_idx_list, - rank_to_dist_init_addr, - ) - - ep_size = config.expert_parallel_size + engines: list[RolloutEngine] = [] num_workers = len(rank_bundle_idx_list) - if num_workers % ep_size != 0: - raise ValueError(f"num_rollout_workers={num_workers} must be divisible by expert_parallel_size={ep_size}.") - - engine_launch_specs: list[EngineLaunchSpec] = [] - for engine_start in range(0, num_workers, ep_size): - engine_meta = rank_bundle_idx_list[engine_start : engine_start + ep_size] - engine_ranks = tuple(rank for rank, _ in engine_meta) - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[engine_ranks[0]] - # LMDeploy EP launches one server process for each EP rank. Each - # server owns exactly one placement-group bundle, and every server - # can be used as a rollout request entrypoint. - engine_launch_specs.append( - EngineLaunchSpec( - engine_ranks=engine_ranks, - server_processes=tuple( - ServerProcessSpec( - worker_rank=server_rank, - placement_group_bundle_idxs=(bundle_idx,), - dist_init_addr=engine_dist_init_addr, - ) - for server_rank, bundle_idx in engine_meta - ), + if config.expert_parallel_size <= 1: + num_gpus_per_engine = config.num_gpus_per_engine + if num_workers % num_gpus_per_engine != 0: + raise ValueError( + f"num_rollout_workers={num_workers} must be divisible by " + f"num_gpus_per_engine={num_gpus_per_engine}." ) - ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), - ) - - @classmethod - def build_metadata_engine_rank_mesh_array( - cls, - engine_launch_specs: EngineLaunchSpecs, - ) -> list[list[int]]: - """Keep LMDeploy EP metadata compatible with origin/main. - - Pure EP uses one request-serving server per EP rank. The logical engine topology is still stored in - EngineLaunchSpec.engine_ranks for dp_rank and lifecycle operations, but update_weighter expects the public - metadata mesh to contain one single-rank entry per request server. - """ - metadata_engine_rank_mesh_array: list[list[int]] = [] - for engine_spec in engine_launch_specs: - request_entrypoint_servers = engine_spec.request_entrypoint_servers - if len(request_entrypoint_servers) > 1: - metadata_engine_rank_mesh_array.extend( - [server_process.worker_rank] for server_process in request_entrypoint_servers + for engine_start in range(0, num_workers, num_gpus_per_engine): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] + engine_ranks = tuple(rank for rank, _ in engine_meta) + engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) + dist_init_addr_owner_rank = engine_ranks[0] + engines.append( + RolloutEngine( + engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=( + RolloutServerProcess( + worker_rank=engine_ranks[0], + placement_group_bundle_idxs=engine_bundle_idxs, + weight_update_ranks=engine_ranks, + ), + ), + ) + ) + else: + ep_size = config.expert_parallel_size + if num_workers % ep_size != 0: + raise ValueError( + f"num_rollout_workers={num_workers} must be divisible by expert_parallel_size={ep_size}." + ) + for engine_start in range(0, num_workers, ep_size): + engine_meta = rank_bundle_idx_list[engine_start : engine_start + ep_size] + engine_ranks = tuple(rank for rank, _ in engine_meta) + dist_init_addr_owner_rank = engine_ranks[0] + engines.append( + RolloutEngine( + engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], + server_processes=tuple( + RolloutServerProcess( + worker_rank=server_rank, + placement_group_bundle_idxs=(bundle_idx,), + weight_update_ranks=(server_rank,), + ) + for server_rank, bundle_idx in engine_meta + ), + ) ) - else: - metadata_engine_rank_mesh_array.append(list(engine_spec.engine_ranks)) - return metadata_engine_rank_mesh_array + + return RolloutTopology(engines=tuple(engines)) def offload(self): """Offloads the model weights and KV cache.""" @@ -342,7 +358,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: "NPU": "ascend", } - extra_config = self.config.extra_rollout_config or dict() + extra_config = self.config.extra_rollout_config lmdeploy_config_kwargs = { k.replace("lmdeploy_", ""): v for k, v in extra_config.items() if k.startswith("lmdeploy_") } @@ -383,14 +399,13 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: if backend == "pytorch" and self.config.max_prefill_token_num: extra_engine_config["max_prefill_token_num"] = self.config.max_prefill_token_num + assert self.server_launch_spec is not None dp_rank = 0 if backend == "pytorch": # currently only support ep > 1 and tp == 1 / ep == 1 and tp > 1 assert ep_size == 1 or tp_size == 1 if ep_size > 1: - engine_launch_spec = self.engine_launch_spec - assert engine_launch_spec is not None - dp_rank = engine_launch_spec.engine_ranks.index(self.rank) + dp_rank = self.server_launch_spec.engine_rank backend_config = ( PytorchEngineConfig( @@ -413,7 +428,10 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: else TurbomindEngineConfig( tp=tp_size, max_batch_size=self.config.rollout_max_batch_size_per_instance, - devices=[bundle_idxs % self.config.gpus_per_node for bundle_idxs in self.engine_bundle_idxs], + devices=[ + bundle_idx % self.config.gpus_per_node + for bundle_idx in self.server_launch_spec.placement_group_bundle_idxs + ], empty_init=self.config.skip_load_weights, session_len=self.config.context_length, model_format="fp8" if self.config.enable_float8 else None, @@ -431,7 +449,9 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: env = { "LMDEPLOY_RAY_EXTERNAL_NS": ray_runtime_ctx.namespace, "LMDEPLOY_RAY_EXTERNAL_PG_NAME": current_pg_name, - "LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES": ",".join(map(str, self.engine_bundle_idxs)), + "LMDEPLOY_RAY_EXTERNAL_PG_BUNDLES": ",".join( + map(str, self.server_launch_spec.placement_group_bundle_idxs) + ), } if self.accelerator == "NPU": @@ -444,7 +464,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: ) if tp_size > 1: - dist_addr, dist_port = self.dist_init_addr.split(":")[:2] + dist_addr, dist_port = self.server_launch_spec.dist_init_addr.split(":")[:2] env.update( { "LMDEPLOY_DIST_MASTER_ADDR": dist_addr, @@ -452,7 +472,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: } ) elif ep_size > 1: - dist_addr, dist_port = self.dist_init_addr.split(":")[:2] + dist_addr, dist_port = self.server_launch_spec.dist_init_addr.split(":")[:2] if speculative_num_draft_tokens is not None: deepep_max_tokens_per_rank = max_batch_size * (1 + speculative_num_draft_tokens) else: diff --git a/xtuner/v1/rl/rollout/rollout_topology.py b/xtuner/v1/rl/rollout/rollout_topology.py new file mode 100644 index 000000000..0ff672df6 --- /dev/null +++ b/xtuner/v1/rl/rollout/rollout_topology.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + + +__all__ = [ + "RolloutEngine", + "RolloutServerProcess", + "RolloutTopology", + "ServerLaunchSpec", +] + + +@dataclass(frozen=True) +class RolloutServerProcess: + """Static topology for one worker-owned rollout server process.""" + + # Worker rank that owns and starts this server process. + worker_rank: int + # Placement-group bundles assigned to this server process. + placement_group_bundle_idxs: tuple[int, ...] + # Rollout ranks updated through this server process. + weight_update_ranks: tuple[int, ...] + # Whether this server process can receive rollout generation requests. + accepts_rollout_requests: bool = True + # Node index used by backends that launch one server process per node. + node_rank: int = 0 + # Number of nodes participating in this server launch. + nnodes: int = 1 + + @property + def is_weight_update_endpoint(self) -> bool: + return bool(self.weight_update_ranks) + + +@dataclass(frozen=True) +class RolloutEngine: + """Static topology for one logical inference engine.""" + + # Rollout ranks that jointly form this logical inference engine. + engine_ranks: tuple[int, ...] + # Rendezvous address shared by every server process in this engine. + dist_init_addr: str + # Server processes that expose this engine to rollout traffic or control paths. + server_processes: tuple[RolloutServerProcess, ...] + + +@dataclass(frozen=True) +class ServerLaunchSpec: + """Worker-facing launch data projected from rollout topology.""" + + # Worker rank that should receive this launch spec. + worker_rank: int + # Placement-group bundles assigned to the launched server process. + placement_group_bundle_idxs: tuple[int, ...] + # Engine rendezvous address resolved by RolloutTopology. + dist_init_addr: str + # Rank of this worker inside the logical inference engine. + engine_rank: int + # Whether this server process can receive rollout generation requests. + accepts_rollout_requests: bool = True + # Node index for multi-node backend launches. + node_rank: int = 0 + # Number of nodes for multi-node backend launches. + nnodes: int = 1 + + +@dataclass(frozen=True) +class RolloutTopology: + """Immutable rollout topology after dist-init addresses are resolved. + + Actor handles, server URLs, session URLs, and lifecycle state belong to RolloutWorkerRegistry. + """ + + # Logical inference engines and their server-process topology. + engines: tuple[RolloutEngine, ...] + # Server-process lookup keyed by worker rank. + _server_process_by_rank: dict[int, RolloutServerProcess] = field(init=False, repr=False, compare=False) + # Lifecycle group lookup keyed by server-process worker rank. + _lifecycle_group_by_rank: dict[int, tuple[int, ...]] = field(init=False, repr=False, compare=False) + + def __post_init__(self) -> None: + if not self.engines: + raise ValueError("RolloutTopology must define at least one engine.") + + seen_engine_ranks: set[int] = set() + seen_bundle_idxs: set[int] = set() + server_process_by_rank: dict[int, RolloutServerProcess] = {} + lifecycle_group_by_rank: dict[int, tuple[int, ...]] = {} + for engine_index, engine in enumerate(self.engines): + if not engine.engine_ranks: + raise ValueError(f"RolloutTopology engine[{engine_index}] must define at least one engine rank.") + if len(set(engine.engine_ranks)) != len(engine.engine_ranks): + raise ValueError( + f"RolloutTopology engine[{engine_index}] has duplicate engine ranks: {engine.engine_ranks}." + ) + duplicate_engine_ranks = sorted(set(engine.engine_ranks).intersection(seen_engine_ranks)) + if duplicate_engine_ranks: + raise ValueError( + f"RolloutTopology engine[{engine_index}] engine ranks appear in more than one engine: " + f"{duplicate_engine_ranks}." + ) + seen_engine_ranks.update(engine.engine_ranks) + + if not engine.server_processes: + raise ValueError(f"RolloutTopology engine[{engine_index}] must define at least one server process.") + if not any(server.accepts_rollout_requests for server in engine.server_processes): + raise ValueError( + f"RolloutTopology engine[{engine_index}] must expose at least one request entrypoint." + ) + + engine_rank_set = set(engine.engine_ranks) + covered_update_ranks: set[int] = set() + for server in engine.server_processes: + duplicate_bundle_idxs = sorted(set(server.placement_group_bundle_idxs).intersection(seen_bundle_idxs)) + if duplicate_bundle_idxs: + raise ValueError( + f"RolloutTopology engine[{engine_index}] server worker_rank={server.worker_rank} " + f"reuses placement-group bundle indexes: {duplicate_bundle_idxs}." + ) + seen_bundle_idxs.update(server.placement_group_bundle_idxs) + if server.worker_rank not in engine_rank_set: + raise ValueError( + f"RolloutTopology engine[{engine_index}] server worker_rank={server.worker_rank} " + f"is not in engine_ranks={engine.engine_ranks}." + ) + unknown_update_ranks = sorted( + rank for rank in server.weight_update_ranks if rank not in engine_rank_set + ) + if unknown_update_ranks: + raise ValueError( + f"RolloutTopology engine[{engine_index}] server worker_rank={server.worker_rank} " + f"references unknown weight_update_ranks={unknown_update_ranks}." + ) + duplicate_update_ranks = sorted(set(server.weight_update_ranks).intersection(covered_update_ranks)) + if duplicate_update_ranks: + raise ValueError( + f"RolloutTopology engine[{engine_index}] has duplicate weight_update_ranks=" + f"{duplicate_update_ranks}." + ) + covered_update_ranks.update(server.weight_update_ranks) + + if covered_update_ranks != engine_rank_set: + missing_update_ranks = sorted(engine_rank_set.difference(covered_update_ranks)) + raise ValueError( + f"RolloutTopology engine[{engine_index}] weight_update_ranks do not cover engine ranks: " + f"{missing_update_ranks}." + ) + + lifecycle_group = tuple(server.worker_rank for server in engine.server_processes) + for server in engine.server_processes: + server_process_by_rank[server.worker_rank] = server + lifecycle_group_by_rank[server.worker_rank] = lifecycle_group + + object.__setattr__(self, "_server_process_by_rank", server_process_by_rank) + object.__setattr__(self, "_lifecycle_group_by_rank", lifecycle_group_by_rank) + + def server_launch_specs(self) -> tuple[ServerLaunchSpec, ...]: + return tuple( + ServerLaunchSpec( + worker_rank=server.worker_rank, + placement_group_bundle_idxs=server.placement_group_bundle_idxs, + dist_init_addr=engine.dist_init_addr, + engine_rank=engine.engine_ranks.index(server.worker_rank), + accepts_rollout_requests=server.accepts_rollout_requests, + node_rank=server.node_rank, + nnodes=server.nnodes, + ) + for engine in self.engines + for server in engine.server_processes + ) + + def lifecycle_groups(self) -> tuple[tuple[int, ...], ...]: + return tuple(dict.fromkeys(self._lifecycle_group_by_rank.values())) + + def weight_update_endpoint_processes(self) -> tuple[RolloutServerProcess, ...]: + return tuple( + server for engine in self.engines for server in engine.server_processes if server.is_weight_update_endpoint + ) + + def is_request_entrypoint_rank(self, rank: int) -> bool: + server = self._server_process_by_rank.get(rank) + return server is not None and server.accepts_rollout_requests + + def lifecycle_group_for_server_rank(self, rank: int) -> tuple[int, ...]: + try: + return self._lifecycle_group_by_rank[rank] + except KeyError: + raise KeyError(f"rank={rank} does not own a rollout server process.") from None diff --git a/xtuner/v1/rl/rollout/session_server.py b/xtuner/v1/rl/rollout/session_server.py index 666cfc740..92f65265a 100644 --- a/xtuner/v1/rl/rollout/session_server.py +++ b/xtuner/v1/rl/rollout/session_server.py @@ -85,6 +85,8 @@ def _bool_request_value(value: Any, default: bool = False) -> bool: def _request_uses_trace_store(req_body: dict) -> bool: + if req_body.get("session_id") is None or "messages" not in req_body: + return False return _bool_request_value(req_body.get("return_token_ids"), True) diff --git a/xtuner/v1/rl/rollout/sglang.py b/xtuner/v1/rl/rollout/sglang.py index 047fa2d5a..6825fdcfb 100644 --- a/xtuner/v1/rl/rollout/sglang.py +++ b/xtuner/v1/rl/rollout/sglang.py @@ -1,6 +1,6 @@ import base64 import os -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Mapping, Union import numpy as np import ray @@ -11,13 +11,8 @@ from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.utils import XTUNER_DETERMINISTIC -from .worker import ( - EngineLaunchSpec, - EngineLaunchSpecs, - RolloutConfig, - RolloutWorker, - ServerProcessSpec, -) +from .rollout_topology import RolloutEngine, RolloutServerProcess, RolloutTopology +from .worker import RolloutConfig, RolloutWorker class SGLangWorker(RolloutWorker): @@ -49,53 +44,85 @@ def __init__( self.enable_return_routed_experts = self.config.enable_return_routed_experts @classmethod - def build_engine_launch_specs( + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build SGLang server launch layout. - - SGLang starts one server per node in a logical engine. Only node 0 is - used as the rollout request entrypoint. - - Example with expert_parallel_size=16 and gpus_per_node=8: - rank_bundle_idx_list is: - [(0, 0), (1, 1), ..., (15, 15)] - - If rank_to_dist_init_addr is: - {0: "addr0", 1: "addr1", ..., 15: "addr15"} - - The launch spec is: - EngineLaunchSpec( - engine_ranks=(0, 1, 2, 3, 4, 5, 6, 7, - 8, 9, 10, 11, 12, 13, 14, 15), - server_processes=( - ServerProcessSpec( - worker_rank=0, - placement_group_bundle_idxs=(0, 1, 2, 3, 4, 5, 6, 7), + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + """Build SGLang rollout topology with bound engine dist-init addresses. + + The normal SGLang topology starts one server process for each logical + engine. Cross-node engines are the special case: SGLang starts one + server process per node, but only node 0 accepts rollout requests and + owns the weight-update endpoint. + + Example with ``expert_parallel_size=2`` on one node: + RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=(0, 1), dist_init_addr="addr0", - accepts_rollout_requests=True, - node_rank=0, - nnodes=2, + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0, 1), + accepts_rollout_requests=True, + weight_update_ranks=(0, 1), + node_rank=0, + nnodes=1, + ), + ), ), - ServerProcessSpec( - worker_rank=8, - placement_group_bundle_idxs=(8, 9, 10, 11, 12, 13, 14, 15), - dist_init_addr="addr0", - accepts_rollout_requests=False, - node_rank=1, - nnodes=2, + RolloutEngine( + engine_ranks=(2, 3), + dist_init_addr="addr2", + server_processes=( + RolloutServerProcess( + worker_rank=2, + placement_group_bundle_idxs=(2, 3), + accepts_rollout_requests=True, + weight_update_ranks=(2, 3), + node_rank=0, + nnodes=1, + ), + ), ), ), ) - SGLang starts one server per node, so server_worker_ranks is (0, 8). - Only the node-0 server accepts rollout requests. + Example with ``expert_parallel_size=16`` across two 8-GPU nodes: + RolloutTopology( + engines=( + RolloutEngine( + engine_ranks=(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15), + dist_init_addr="addr0", + server_processes=( + RolloutServerProcess( + worker_rank=0, + placement_group_bundle_idxs=(0, 1, 2, 3, 4, 5, 6, 7), + accepts_rollout_requests=True, + weight_update_ranks=(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15), + node_rank=0, + nnodes=2, + ), + RolloutServerProcess( + worker_rank=8, + placement_group_bundle_idxs=(8, 9, 10, 11, 12, 13, 14, 15), + accepts_rollout_requests=False, + weight_update_ranks=(), + node_rank=1, + nnodes=2, + ), + ), + ), + ), + ) """ num_workers = len(rank_bundle_idx_list) - num_gpus_per_engine = cls._get_num_gpus_per_engine(config) + num_gpus_per_engine = config.num_gpus_per_engine if num_workers % num_gpus_per_engine != 0: raise ValueError( f"num_rollout_workers={num_workers} must be divisible by num_gpus_per_engine={num_gpus_per_engine}." @@ -106,7 +133,7 @@ def build_engine_launch_specs( ) nnodes = max(1, num_gpus_per_engine // config.gpus_per_node) - engine_launch_specs: list[EngineLaunchSpec] = [] + engines = [] for engine_start in range(0, num_workers, num_gpus_per_engine): engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] engine_ranks = tuple(rank for rank, _ in engine_meta) @@ -115,30 +142,30 @@ def build_engine_launch_specs( # first rank of each node owns that node's bundles, while only node # 0 is exposed as the rollout request entrypoint. server_ranks = engine_ranks[:: config.gpus_per_node] - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[server_ranks[0]] - server_processes: list[ServerProcessSpec] = [] + dist_init_addr_owner_rank = server_ranks[0] + server_processes = [] for node_rank, server_rank in enumerate(server_ranks): node_bundle_start = node_rank * config.gpus_per_node node_bundle_end = node_bundle_start + config.gpus_per_node server_processes.append( - ServerProcessSpec( + RolloutServerProcess( worker_rank=server_rank, placement_group_bundle_idxs=engine_bundle_idxs[node_bundle_start:node_bundle_end], - dist_init_addr=engine_dist_init_addr, accepts_rollout_requests=node_rank == 0, + weight_update_ranks=engine_ranks if node_rank == 0 else (), node_rank=node_rank, nnodes=nnodes, ) ) - engine_launch_specs.append( - EngineLaunchSpec( + engines.append( + RolloutEngine( engine_ranks=engine_ranks, + dist_init_addr=rank_to_dist_init_addr[dist_init_addr_owner_rank], server_processes=tuple(server_processes), ) ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), + return RolloutTopology( + engines=tuple(engines), ) def _get_request_payload(self, rollout_state: RolloutState) -> dict: @@ -325,7 +352,7 @@ def _transform_rollout_config_to_server_configs(self): os.environ.pop("CUDA_VISIBLE_DEVICES", None) from sglang.srt.server_args import ServerArgs - extra_config = self.config.extra_rollout_config or dict() + extra_config = self.config.extra_rollout_config sglang_config_kwargs = { k.replace("sglang_", ""): v for k, v in extra_config.items() if k.startswith("sglang_") } @@ -338,13 +365,7 @@ def _transform_rollout_config_to_server_configs(self): ) tp_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.tensor_parallel_size ep_size = num_gpus_per_engine if self.config.expert_parallel_size > 1 else self.config.expert_parallel_size - server_process_spec = self._get_current_server_process_spec() - nnodes = ( - server_process_spec.nnodes - if server_process_spec is not None - else max(1, num_gpus_per_engine // self.config.gpus_per_node) - ) - node_rank = server_process_spec.node_rank if server_process_spec is not None else 0 + assert self.server_launch_spec is not None assigned_gpu_id = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) # SGLang 0.5.10 默认启用的 Piecewise CUDA Graph 在启动 warmup compile 阶段会报错。sglang的文档提到这个功能还是实验功能,可能还不太稳定(https://sgl-project-sglang-93.mintlify.app/optimization/cuda-graph#bug-report)。暂时先通过disable_piecewise_cuda_graph=True关掉改功能 @@ -354,11 +375,11 @@ def _transform_rollout_config_to_server_configs(self): host=self.host, port=self.server_port, nccl_port=self.nccl_port, - dist_init_addr=self.dist_init_addr, + dist_init_addr=self.server_launch_spec.dist_init_addr, base_gpu_id=assigned_gpu_id, gpu_id_step=1, - nnodes=nnodes, - node_rank=node_rank, + nnodes=self.server_launch_spec.nnodes, + node_rank=self.server_launch_spec.node_rank, skip_server_warmup=True, mem_fraction_static=self.config.gpu_memory_utilization, enable_memory_saver=True, diff --git a/xtuner/v1/rl/rollout/vllm.py b/xtuner/v1/rl/rollout/vllm.py index 8cbeaef69..a999baf98 100644 --- a/xtuner/v1/rl/rollout/vllm.py +++ b/xtuner/v1/rl/rollout/vllm.py @@ -2,7 +2,7 @@ import os import traceback from argparse import Namespace -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Mapping, Union import numpy as np import ray @@ -16,6 +16,7 @@ from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_status_from_finish_reason from xtuner.v1.utils.device import get_device, get_torch_device_module +from .rollout_topology import RolloutTopology from .worker import RolloutConfig, RolloutWorker @@ -131,6 +132,15 @@ def run_lmdeploy_server_wrapper(server_namespace: Namespace): class vLLMWorker(RolloutWorker): + @classmethod + def build_rollout_topology( + cls, + config: RolloutConfig, + rank_bundle_idx_list: list[tuple[int, int]], + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + raise NotImplementedError("vLLM rollout topology has not been verified after topology refactor.") + def __init__( self, config: RolloutConfig, @@ -323,13 +333,14 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: args["limit_mm_per_prompt"] = {"image": 10, "video": 0} args["enable_log_requests"] = False args["uvicorn_log_level"] = "error" + assert self.server_launch_spec is not None env = { "VLLM_VERSION": "0.11.0", "TASK_QUEUE_ENABLE": "0", "CPU_AFFINITY_CONF": "2", "VLLM_USE_V1": "1", "VLLM_RAY_PER_WORKER_GPUS": "0.1", - "VLLM_RAY_BUNDLE_INDICES": ",".join(map(str, self.engine_bundle_idxs)), + "VLLM_RAY_BUNDLE_INDICES": ",".join(map(str, self.server_launch_spec.placement_group_bundle_idxs)), "VLLM_MONITOR": "1", "VLLM_ACCU_MONITOR": "0", "CUSTOM_SCHEDULE_KV_LIMIT": "0.9", diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index c50ffa6bf..4d0ab1666 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -9,7 +9,7 @@ from abc import abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, TypeAlias, Union, cast +from typing import TYPE_CHECKING, Any, Callable, List, Literal, Mapping, Optional, Union, cast import httpx import ray @@ -40,6 +40,7 @@ from .constants import ROLLOUT_HTTP_MAX_CONNECTIONS, ROLLOUT_RAY_GENERATE_MAX_CONCURRENCY from .health_manager import ROLLOUT_RAY_GET_TIMEOUT +from .rollout_topology import RolloutTopology, ServerLaunchSpec from .session_server import SessionServerActor from .utils import PartialRolloutHandler @@ -53,56 +54,12 @@ @dataclass(frozen=True) -class ServerProcessSpec: - """How to start one rollout server process.""" - - # Worker rank that owns this server process. - worker_rank: int - # Placement-group bundle indexes assigned to this server process. - placement_group_bundle_idxs: tuple[int, ...] - # Distributed init address used by every server process in the same engine. - # Filled after init_dist_port initializes worker-local ports. - dist_init_addr: str | None = None - # Whether this server is exposed as a rollout request entrypoint. Some - # backends launch extra server processes that must participate in - # lifecycle/health operations but must not be added to worker_server_urls_map - # or receive normal rollout traffic. - accepts_rollout_requests: bool = True - # Node index of this server inside a multi-node logical engine. - node_rank: int = 0 - # Number of nodes used by this logical engine. - nnodes: int = 1 +class RolloutWorkerInitResult: + """Result returned by RolloutWorker.init() after its server starts.""" - -@dataclass(frozen=True) -class EngineLaunchSpec: - """How to launch rollout servers for one logical inference engine.""" - - # All worker ranks that form this logical inference engine. - engine_ranks: tuple[int, ...] - # Server processes required by this engine. - server_processes: tuple[ServerProcessSpec, ...] - - @property - def server_worker_ranks(self) -> tuple[int, ...]: - return tuple(server.worker_rank for server in self.server_processes) - - @property - def request_entrypoint_servers(self) -> tuple[ServerProcessSpec, ...]: - return tuple(server for server in self.server_processes if server.accepts_rollout_requests) - - @property - def request_entrypoint_worker_ranks(self) -> tuple[int, ...]: - return tuple(server.worker_rank for server in self.request_entrypoint_servers) - - @property - def placement_group_bundle_idxs(self) -> tuple[int, ...]: - return tuple( - bundle_idx for server in self.server_processes for bundle_idx in server.placement_group_bundle_idxs - ) - - -EngineLaunchSpecs: TypeAlias = tuple[EngineLaunchSpec, ...] + rank: int + server_url: str + session_url: str | None def get_rollout_worker_base_cls(config: "RolloutConfig") -> type["RolloutWorker"]: @@ -579,8 +536,7 @@ def __init__( self.accelerator = accelerator self.server_func: Callable self.endpoints: dict[str, str] = dict() - self.engine_rank_mesh_array: list[list[int]] = [] - self.engine_launch_spec: EngineLaunchSpec | None = None + self.server_launch_spec: ServerLaunchSpec | None = None # Keep this deliberately large so requests do not queue in the # RolloutWorker/httpx client; the inference engine owns rollout request # scheduling and queueing. @@ -588,7 +544,6 @@ def __init__( limits = httpx.Limits(max_connections=http_concurrency, max_keepalive_connections=100) self.client = httpx.AsyncClient(limits=limits, timeout=self.config.rollout_timeout) self.server_task = None - self.engine_bundle_idxs: list[int] = [] self.server_process: Optional[multiprocessing.Process] = None self.session_server_actor: Any | None = None self.session_server_url: str | None = None @@ -602,205 +557,56 @@ def __init__( self.logger.info(f"Using eos_token: {eos_token} for model at {self.config.model_path}") self.eos_token: List[int] = [eos_token] if isinstance(eos_token, int) else eos_token self.receive_abort_request = threading.Event() - self.dist_init_addr: str = "" self.serverl_url: str = "" self.partial_rollout_handler = PartialRolloutHandler() self.enable_partial_rollout: bool = False - @staticmethod - def _get_num_gpus_per_engine(config: RolloutConfig) -> int: - return config.num_gpus_per_engine - - @classmethod - def validate_engine_launch_specs( - cls, - engine_launch_specs: EngineLaunchSpecs, - *, - known_worker_ranks: tuple[int, ...] | None = None, - ) -> EngineLaunchSpecs: - """Validate backend launch layout before the controller launches - servers.""" - if not engine_launch_specs: - raise ValueError("engine_launch_specs must define at least one engine.") - - known_worker_rank_set = set(known_worker_ranks) if known_worker_ranks is not None else None - seen_engine_ranks: set[int] = set() - seen_server_ranks: set[int] = set() - seen_bundle_idxs: set[int] = set() - for engine_index, engine_spec in enumerate(engine_launch_specs): - if not engine_spec.engine_ranks: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must define at least one engine rank.") - engine_rank_set = set(engine_spec.engine_ranks) - if len(engine_rank_set) != len(engine_spec.engine_ranks): - raise ValueError( - f"EngineLaunchSpec[{engine_index}] has duplicate engine ranks: {engine_spec.engine_ranks}." - ) - if known_worker_rank_set is not None: - unknown_engine_ranks = sorted( - rank for rank in engine_spec.engine_ranks if rank not in known_worker_rank_set - ) - if unknown_engine_ranks: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] references unknown engine ranks: {unknown_engine_ranks}." - ) - duplicated_engine_ranks = sorted(rank for rank in engine_spec.engine_ranks if rank in seen_engine_ranks) - if duplicated_engine_ranks: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] engine ranks appear in more than one engine: " - f"{duplicated_engine_ranks}." - ) - seen_engine_ranks.update(engine_spec.engine_ranks) - - if not engine_spec.server_processes: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must define at least one server process.") - - for server_process in engine_spec.server_processes: - server_rank = server_process.worker_rank - if server_rank not in engine_rank_set: - raise ValueError( - f"EngineLaunchSpec[{engine_index}] server worker_rank={server_rank} " - f"must be part of engine_ranks={engine_spec.engine_ranks}." - ) - if server_rank in seen_server_ranks: - raise ValueError(f"Server worker_rank={server_rank} appears in more than one server process.") - seen_server_ranks.add(server_rank) - - if not server_process.placement_group_bundle_idxs: - raise ValueError(f"Server worker_rank={server_rank} must own at least one placement-group bundle.") - if len(set(server_process.placement_group_bundle_idxs)) != len( - server_process.placement_group_bundle_idxs - ): - raise ValueError( - f"Server worker_rank={server_rank} has duplicate placement-group bundles: " - f"{server_process.placement_group_bundle_idxs}." - ) - duplicated_bundle_idxs = sorted( - bundle_idx - for bundle_idx in server_process.placement_group_bundle_idxs - if bundle_idx in seen_bundle_idxs - ) - if duplicated_bundle_idxs: - raise ValueError( - f"Placement-group bundles are assigned to multiple server processes: {duplicated_bundle_idxs}." - ) - seen_bundle_idxs.update(server_process.placement_group_bundle_idxs) - - if server_process.nnodes < 1: - raise ValueError(f"Server worker_rank={server_rank} must have nnodes >= 1.") - if server_process.node_rank < 0 or server_process.node_rank >= server_process.nnodes: - raise ValueError( - f"Server worker_rank={server_rank} has invalid node_rank={server_process.node_rank} " - f"for nnodes={server_process.nnodes}." - ) - - if not engine_spec.request_entrypoint_servers: - raise ValueError(f"EngineLaunchSpec[{engine_index}] must expose at least one request entrypoint.") - - if known_worker_rank_set is not None: - missing_engine_ranks = sorted(known_worker_rank_set - seen_engine_ranks) - if missing_engine_ranks: - raise ValueError( - f"EngineLaunchSpecs do not cover known worker ranks in engine_ranks: {missing_engine_ranks}." - ) - - return engine_launch_specs - @classmethod - def build_engine_launch_specs( + @abstractmethod + def build_rollout_topology( cls, config: RolloutConfig, rank_bundle_idx_list: list[tuple[int, int]], - rank_to_dist_init_addr: dict[int, str] | None = None, - ) -> EngineLaunchSpecs: - """Build default launch spec: one request-serving server per engine.""" - num_gpus_per_engine = cls._get_num_gpus_per_engine(config) - num_workers = len(rank_bundle_idx_list) - if num_workers % num_gpus_per_engine != 0: - raise ValueError( - f"num_rollout_workers={num_workers} must be divisible by num_gpus_per_engine={num_gpus_per_engine}." - ) + rank_to_dist_init_addr: Mapping[int, str], + ) -> RolloutTopology: + raise NotImplementedError("Concrete rollout worker classes must implement build_rollout_topology().") - engine_launch_specs: list[EngineLaunchSpec] = [] - for engine_start in range(0, num_workers, num_gpus_per_engine): - engine_meta = rank_bundle_idx_list[engine_start : engine_start + num_gpus_per_engine] - engine_ranks = tuple(rank for rank, _ in engine_meta) - engine_bundle_idxs = tuple(bundle_idx for _, bundle_idx in engine_meta) - engine_dist_init_addr = None if rank_to_dist_init_addr is None else rank_to_dist_init_addr[engine_ranks[0]] - engine_launch_specs.append( - EngineLaunchSpec( - engine_ranks=engine_ranks, - server_processes=( - ServerProcessSpec( - worker_rank=engine_ranks[0], - placement_group_bundle_idxs=engine_bundle_idxs, - dist_init_addr=engine_dist_init_addr, - ), - ), - ) - ) - return cls.validate_engine_launch_specs( - tuple(engine_launch_specs), - known_worker_ranks=tuple(rank for rank, _ in rank_bundle_idx_list), - ) - - @classmethod - def build_metadata_engine_rank_mesh_array( - cls, - engine_launch_specs: EngineLaunchSpecs, - ) -> list[list[int]]: - """Build the public engine mesh returned in rollout metadata. + def set_enable_partial_rollout(self, enable: bool) -> None: + self.enable_partial_rollout = enable - By default, the public metadata mesh matches the logical engine topology. Backends with multiple request - servers per logical engine can override this to preserve their legacy update-weight mesh semantics. - """ - return [list(engine_spec.engine_ranks) for engine_spec in engine_launch_specs] + def _bind_server_launch_spec(self, server_launch_spec: ServerLaunchSpec) -> None: + if server_launch_spec.worker_rank != self.rank: + raise ValueError( + f"Server launch spec rank={server_launch_spec.worker_rank} does not match worker rank={self.rank}." + ) + self.server_launch_spec = server_launch_spec - def _get_current_server_process_spec( - self, - engine_launch_spec: EngineLaunchSpec | None = None, - ) -> ServerProcessSpec | None: - engine_launch_spec = engine_launch_spec or self.engine_launch_spec - if engine_launch_spec is None: - return None - - for server_process_spec in engine_launch_spec.server_processes: - if server_process_spec.worker_rank == self.rank: - return server_process_spec - raise RuntimeError( - f"Engine launch spec does not include rollout worker rank={self.rank} " - f"in server_worker_ranks={engine_launch_spec.server_worker_ranks}." - ) + def init(self, server_launch_spec: ServerLaunchSpec) -> RolloutWorkerInitResult: + """Bind the worker launch spec and initialize the rollout server.""" + self._bind_server_launch_spec(server_launch_spec) + return self._init_server() - def set_enable_partial_rollout(self, enable: bool) -> None: - self.enable_partial_rollout = enable + def reinit(self) -> RolloutWorkerInitResult: + """Reinitialize the rollout server using the previously bound launch + spec.""" + return self._init_server() - def init( - self, - *, - engine_launch_spec: EngineLaunchSpec | None = None, - ) -> tuple[int, str]: + def _init_server(self) -> RolloutWorkerInitResult: """Initialize the worker and launch the server. Returns: - Tuple[int, str]: A tuple containing the worker's rank and its - server URL. + Startup result containing rank, server URL, and session URL. """ - if engine_launch_spec is not None: - # Initial controller startup passes the immutable launch spec and caches - # it on the actor. Recovery calls init() without arguments after - # shutdown, intentionally reusing this cached placement/dist layout. - self.engine_launch_spec = engine_launch_spec - server_process_spec = cast( - ServerProcessSpec, - self._get_current_server_process_spec(engine_launch_spec), - ) - self.engine_bundle_idxs = list(server_process_spec.placement_group_bundle_idxs) - if server_process_spec.dist_init_addr is not None: - self.dist_init_addr = server_process_spec.dist_init_addr + if self.server_launch_spec is None: + raise RuntimeError("Rollout worker must bind a server launch spec before starting server.") self.receive_abort_request.clear() self._launch_server() self._start_session_server() - return (self.rank, self.server_url) + return RolloutWorkerInitResult( + rank=self.rank, + server_url=self.server_url, + session_url=self.session_server_url, + ) def set_skip_load_weights(self, skip_load_weights: bool) -> None: self.config = self.config.model_copy(update={"skip_load_weights": skip_load_weights}) @@ -808,7 +614,7 @@ def set_skip_load_weights(self, skip_load_weights: bool) -> None: def restore_skip_load_weights(self) -> None: self.config = self.config.model_copy(update={"skip_load_weights": self._default_skip_load_weights}) - def init_dist_port(self) -> str: + def init_dist_port(self) -> tuple[int, str]: """Initialize distributed communication ports. This method initializes four fixed ports for the distributed setup: @@ -816,7 +622,7 @@ def init_dist_port(self) -> str: for NCCL, and one for the session server. Returns: - str: The distributed initialization address (host:port). + Worker rank and distributed initialization address (host:port). """ local_rank = int(ray.get_runtime_context().get_accelerator_ids()[self.accelerator][0]) base_port = self.config.dist_port_base + local_rank * 4 @@ -825,9 +631,9 @@ def init_dist_port(self) -> str: self.server_port = base_port + 1 self.nccl_port = base_port + 2 self.session_server_port = base_port + 3 - self.dist_init_addr = f"{self.host}:{self.dist_port}" + dist_init_addr = f"{self.host}:{self.dist_port}" self.server_url = f"http://{self.host}:{self.server_port}" - return self.dist_init_addr + return self.rank, dist_init_addr def shutdown(self, *, stop_session_server: bool = False): """Shut down the worker, its server task, and any child processes.""" @@ -870,14 +676,15 @@ def shutdown(self, *, stop_session_server: bool = False): def _start_session_server(self) -> None: """Start the per-worker SessionServer proxy.""" - if self.session_server_actor is not None: + assert self.server_launch_spec is not None + if not self.server_launch_spec.accepts_rollout_requests or self.session_server_actor is not None: return current_pg = ray.util.get_current_placement_group() scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=current_pg, placement_group_capture_child_tasks=False, - placement_group_bundle_index=self.engine_bundle_idxs[0], + placement_group_bundle_index=self.server_launch_spec.placement_group_bundle_idxs[0], ) self.session_server_actor = ( ray.remote(SessionServerActor) @@ -897,6 +704,10 @@ def _start_session_server(self) -> None: self.session_server_actor.start.remote(), timeout=ROLLOUT_RAY_GET_TIMEOUT, ) + if self.session_server_url is None: + raise RuntimeError( + f"Request-entrypoint rollout worker rank={self.rank} did not start session server during init." + ) def _stop_session_server(self) -> None: if self.session_server_actor is not None: @@ -907,9 +718,6 @@ def _stop_session_server(self) -> None: self.session_server_actor = None self.session_server_url = None - def get_session_server_info(self) -> tuple[int, str | None]: - return self.rank, self.session_server_url - async def pause_generation(self): """Pause the worker's generation process.""" self.receive_abort_request.set() @@ -1218,11 +1026,12 @@ def _launch_server(self): else: # launch the server as ray task # so that the lmdeploy backend could get externl pg + assert self.server_launch_spec is not None current_pg = ray.util.get_current_placement_group() scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=current_pg, placement_group_capture_child_tasks=True, - placement_group_bundle_index=self.engine_bundle_idxs[0], + placement_group_bundle_index=self.server_launch_spec.placement_group_bundle_idxs[0], ) assert ray.is_initialized() ray_kwargs = ( diff --git a/xtuner/v1/rl/rollout/worker_registry.py b/xtuner/v1/rl/rollout/worker_registry.py index 4770dd59b..5d664fd14 100644 --- a/xtuner/v1/rl/rollout/worker_registry.py +++ b/xtuner/v1/rl/rollout/worker_registry.py @@ -1,16 +1,19 @@ from __future__ import annotations import threading +from collections.abc import Iterable, Sequence from dataclasses import dataclass, replace from enum import Enum -from typing import TYPE_CHECKING, Iterable, TypedDict +from typing import TYPE_CHECKING if TYPE_CHECKING: - from .worker import RolloutConfig, RolloutWorker + from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateTarget + + from .rollout_topology import RolloutTopology + from .worker import RolloutWorker, RolloutWorkerInitResult __all__ = [ - "RolloutWorkerMetadata", "RolloutWorkerRegistry", "WorkerGroup", "WorkerLifecycleState", @@ -31,20 +34,18 @@ class WorkerLifecycleState(str, Enum): class WorkerSnapshot: """Read-only snapshot for one rollout server process.""" + # Worker rank that owns the runtime snapshot. + rank: int + # Ray actor handle for the rollout worker. actor: RolloutWorker + # Base URL of the rollout server process. url: str + # Session server URL used only by proxy/session routing. session_url: str | None = None - lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE - lifecycle_group_ranks: tuple[int, ...] = () + # Whether this worker can receive rollout generation requests. is_request_entrypoint: bool = True - rank: int = -1 - - def __post_init__(self) -> None: - lifecycle_state = ( - WorkerLifecycleState.ACTIVE if self.lifecycle_state is None else WorkerLifecycleState(self.lifecycle_state) - ) - object.__setattr__(self, "lifecycle_state", lifecycle_state) - object.__setattr__(self, "lifecycle_group_ranks", tuple(self.lifecycle_group_ranks)) + # Current lifecycle state observed by registry and health manager. + lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE def is_active(self) -> bool: return self.lifecycle_state is WorkerLifecycleState.ACTIVE @@ -52,37 +53,12 @@ def is_active(self) -> bool: @dataclass(frozen=True) class WorkerGroup: + # Worker ranks that share one lifecycle action. ranks: tuple[int, ...] + # Runtime snapshots for registered workers in this lifecycle group. workers: tuple[WorkerSnapshot, ...] -class RolloutWorkerMetadata(TypedDict): - """Legacy rollout worker metadata consumed by trainer/update-weight - code.""" - - engine_rank_mesh_array: list[list[int]] - server_url_dict: dict[int, str] - rollout_config: RolloutConfig - worker_server_urls_status: dict[str, bool] - worker_session_url_dict: dict[int, str] - worker_session_urls_status: dict[str, bool] - - -def _build_worker_groups(workers: Iterable[WorkerSnapshot]) -> dict[tuple[int, ...], WorkerGroup]: - grouped_workers: dict[tuple[int, ...], list[WorkerSnapshot]] = {} - for worker in workers: - group_ranks = worker.lifecycle_group_ranks or (worker.rank,) - grouped_workers.setdefault(group_ranks, []).append(worker) - - return { - group_ranks: WorkerGroup( - ranks=group_ranks, - workers=tuple(sorted(group_workers, key=lambda worker: worker.rank)), - ) - for group_ranks, group_workers in grouped_workers.items() - } - - class RolloutWorkerRegistry: """Own runtime rollout worker state and expose consistent query snapshots.""" @@ -90,37 +66,37 @@ class RolloutWorkerRegistry: def __init__( self, *, - engine_rank_mesh_array: list[list[int]], - rollout_config: RolloutConfig, + rollout_topology: RolloutTopology, ): - """Initialize an empty registry with the training-side metadata - projection.""" - self._engine_rank_mesh_array = [list(engine_ranks) for engine_ranks in engine_rank_mesh_array] - self._rollout_config = rollout_config + """Initialize an empty registry with the rollout topology.""" + self._rollout_topology = rollout_topology self._workers: dict[int, WorkerSnapshot] = {} self._lock = threading.RLock() - def register_started_server( + def register_started_servers( self, *, - rank: int, - actor: RolloutWorker, - server_url: str, - session_url: str | None = None, - lifecycle_group_ranks: tuple[int, ...] = (), - is_request_entrypoint: bool = True, + init_results: Iterable[RolloutWorkerInitResult], + workers_by_rank: Sequence[RolloutWorker], + lifecycle_state: WorkerLifecycleState = WorkerLifecycleState.ACTIVE, ) -> None: - """Register one worker actor after its rollout server process has - started.""" + """Register worker actors after their rollout server processes have + started. + + workers_by_rank must be indexed by rollout worker rank; each init_result.rank is used to select the + corresponding actor. + """ with self._lock: - self._workers[rank] = WorkerSnapshot( - rank=rank, - actor=actor, - url=server_url, - session_url=session_url, - lifecycle_group_ranks=lifecycle_group_ranks or (rank,), - is_request_entrypoint=is_request_entrypoint, - ) + for init_result in init_results: + rank = init_result.rank + self._workers[rank] = WorkerSnapshot( + rank=rank, + actor=workers_by_rank[rank], + url=init_result.server_url, + session_url=init_result.session_url, + is_request_entrypoint=self._rollout_topology.is_request_entrypoint_rank(rank), + lifecycle_state=lifecycle_state, + ) def all_workers(self) -> tuple[WorkerSnapshot, ...]: """Return a stable rank-ordered snapshot of all registered server- @@ -162,11 +138,28 @@ def active_entrypoint_by_rank(self, rank: int) -> WorkerSnapshot | None: return None return worker + def lifecycle_groups(self) -> tuple[tuple[int, ...], ...]: + """Return registered lifecycle groups in rank order.""" + with self._lock: + return tuple(sorted(self._rollout_topology.lifecycle_groups())) + + def _build_worker_groups(self) -> dict[tuple[int, ...], WorkerGroup]: + grouped_ranks = { + self._rollout_topology.lifecycle_group_for_server_rank(worker.rank) for worker in self._workers.values() + } + return { + group_ranks: WorkerGroup( + ranks=group_ranks, + workers=tuple(self._workers[rank] for rank in group_ranks if rank in self._workers), + ) + for group_ranks in grouped_ranks + } + def claim_inactive_groups_for_recovery(self) -> tuple[WorkerGroup, ...]: """Claim non-active worker groups by moving them to recovering state.""" with self._lock: - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() inactive_groups = [ group for group in worker_groups.values() @@ -184,16 +177,14 @@ def mark_unhealthy_ranks(self, ranks: set[int]) -> tuple[WorkerGroup, ...]: """Mark every lifecycle group containing a failed rank as inactive.""" with self._lock: failed_group_ranks = { - worker.lifecycle_group_ranks or (worker.rank,) - for rank, worker in self._workers.items() - if rank in ranks + self._rollout_topology.lifecycle_group_for_server_rank(rank) for rank in ranks if rank in self._workers } for group_ranks in failed_group_ranks: for rank in group_ranks: worker = self._workers.get(rank) if worker is not None: self._workers[rank] = replace(worker, lifecycle_state=WorkerLifecycleState.INACTIVE) - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() return tuple( worker_groups[group_ranks] for group_ranks in sorted(failed_group_ranks) @@ -214,29 +205,27 @@ def set_group_recovery_result( worker = self._workers.get(rank) if worker is not None: self._workers[rank] = replace(worker, lifecycle_state=lifecycle_state) - worker_groups = _build_worker_groups(self._workers.values()) + worker_groups = self._build_worker_groups() return worker_groups.get(group.ranks) - def training_metadata_snapshot(self) -> RolloutWorkerMetadata: - """Build the legacy trainer/update-weight metadata from one registry - snapshot.""" + def weight_update_targets(self) -> tuple[RolloutWeightUpdateTarget, ...]: + """Return weight-update targets resolved with current runtime state.""" + from xtuner.v1.rl.weight_update.data import RolloutWeightUpdateTarget + with self._lock: - request_entrypoints = {rank: info for rank, info in self._workers.items() if info.is_request_entrypoint} - worker_server_urls_map = {rank: info.url for rank, info in request_entrypoints.items()} - worker_server_urls_status = {info.url: info.is_active() for info in request_entrypoints.values()} - worker_session_url_dict: dict[int, str] = {} - worker_session_urls_status: dict[str, bool] = {} - for rank, info in request_entrypoints.items(): - if info.session_url is None: - continue - worker_session_url_dict[rank] = info.session_url - worker_session_urls_status[info.session_url] = info.is_active() - - return { - "engine_rank_mesh_array": [list(engine_ranks) for engine_ranks in self._engine_rank_mesh_array], - "server_url_dict": worker_server_urls_map, - "rollout_config": self._rollout_config, - "worker_server_urls_status": worker_server_urls_status, - "worker_session_url_dict": worker_session_url_dict, - "worker_session_urls_status": worker_session_urls_status, - } + targets: list[RolloutWeightUpdateTarget] = [] + for server in self._rollout_topology.weight_update_endpoint_processes(): + worker = self._workers.get(server.worker_rank) + if worker is None: + raise RuntimeError( + f"Rollout weight update endpoint rank={server.worker_rank} has not been registered." + ) + targets.append( + RolloutWeightUpdateTarget( + endpoint_rank=server.worker_rank, + update_ranks=server.weight_update_ranks, + server_url=worker.url, + lifecycle_state=worker.lifecycle_state.value, + ) + ) + return tuple(sorted(targets, key=lambda target: target.endpoint_rank)) diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index d0965f216..87638a8ab 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -290,12 +290,21 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"): ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore return - def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host=None, weight_update_port=None): + def bind_rollout_weight_update( + self, + *, + targets, + rollout_config, + weight_transport_type, + weight_update_host=None, + weight_update_port=None, + ): ray.get( [ - worker.update_rollout_info.remote( - **info_dict, - train_rollout_mode=train_rollout_mode, + worker.bind_rollout_weight_update.remote( + targets=targets, + rollout_config=rollout_config, + weight_transport_type=weight_transport_type, weight_update_host=weight_update_host, weight_update_port=weight_update_port, ) diff --git a/xtuner/v1/rl/trainer/worker.py b/xtuner/v1/rl/trainer/worker.py index 00ce2e067..76e9bd372 100644 --- a/xtuner/v1/rl/trainer/worker.py +++ b/xtuner/v1/rl/trainer/worker.py @@ -7,11 +7,9 @@ from pathlib import Path from typing import ( TYPE_CHECKING, - Dict, Iterable, List, Sequence, - TypeAlias, TypedDict, cast, ) @@ -63,8 +61,6 @@ from ..rollout_is import merge_rollout_is_metrics -DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices -ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping service names to their URLs DEVICE = get_device() DEVICE_MODULE = get_torch_device_module() @@ -277,8 +273,8 @@ def __init__( ) @ray_method - def update_rollout_info(self, *args, **kwargs): - return self.update_weighter.update_rollout_info(*args, **kwargs) + def bind_rollout_weight_update(self, *args, **kwargs): + return self.update_weighter.bind_rollout_weight_update(*args, **kwargs) @ray_method def update_weights(self): diff --git a/xtuner/v1/rl/utils/ray_utils.py b/xtuner/v1/rl/utils/ray_utils.py index ad6b1dcc9..2e0273980 100644 --- a/xtuner/v1/rl/utils/ray_utils.py +++ b/xtuner/v1/rl/utils/ray_utils.py @@ -159,6 +159,10 @@ def signal_handler(signum, frame): def bind_train_rollout( train_workers, rollout_controller, + rollout_config, + weight_transport_type, + weight_update_host=None, + weight_update_port=None, ) -> None: """Bind the training and rollout workers for updating weights. @@ -170,6 +174,17 @@ def bind_train_rollout( train_workers: A list of training worker actors. rollout_controller: The rollout controller actor. """ - info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined] - ray.get([worker.update_rollout_info.remote(**info_dict) for worker in train_workers]) # type: ignore[attr-defined] + targets = ray.get(rollout_controller.get_weight_update_targets.remote()) # type: ignore[attr-defined] + ray.get( + [ + worker.bind_rollout_weight_update.remote( + targets=targets, + rollout_config=rollout_config, + weight_transport_type=weight_transport_type, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port, + ) + for worker in train_workers + ] + ) # type: ignore[attr-defined] return diff --git a/xtuner/v1/rl/weight_update/__init__.py b/xtuner/v1/rl/weight_update/__init__.py index 312f779fe..e2268c11f 100644 --- a/xtuner/v1/rl/weight_update/__init__.py +++ b/xtuner/v1/rl/weight_update/__init__.py @@ -1,10 +1,7 @@ from .data import ( - DeviceMeshRaw, RolloutBackend, - RolloutEngineInfo, RolloutWeightUpdateInfo, - ServiceUrlMap, - TrainRolloutMode, + RolloutWeightUpdateTarget, WeightTransportType, WeightUpdateBatch, ) @@ -24,19 +21,16 @@ __all__ = [ - "DeviceMeshRaw", "IPCBackendAdapter", "IPCWeightTransport", "LMDeployIPCBackendAdapter", "NCCLBackendAdapter", "NCCLWeightTransport", "RolloutBackend", - "RolloutEngineInfo", + "RolloutWeightUpdateTarget", "RolloutWeightUpdateInfo", "SGLangIPCBackendAdapter", "SGLangNCCLBackendAdapter", - "ServiceUrlMap", - "TrainRolloutMode", "UpdateWeighter", "WeightIterator", "WeightTransportType", diff --git a/xtuner/v1/rl/weight_update/data.py b/xtuner/v1/rl/weight_update/data.py index 6041ff643..08007b5a4 100644 --- a/xtuner/v1/rl/weight_update/data.py +++ b/xtuner/v1/rl/weight_update/data.py @@ -1,48 +1,213 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Dict, List, Literal, TypeAlias +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast import torch -from torch.distributed.device_mesh import DeviceMesh -DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices. -ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping rollout ranks to their server URLs. -RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count) -TrainRolloutMode: TypeAlias = Literal["colocate", "disaggregated"] # Train and rollout deployment mode. +if TYPE_CHECKING: + from xtuner.v1.rl.rollout.worker import RolloutConfig + + RolloutBackend: TypeAlias = Literal["sglang", "vllm", "pytorch", "turbomind"] # Rollout inference backend. WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types. -@dataclass +def _resolve_rollout_backend(rollout_config: RolloutConfig) -> RolloutBackend: + # Backend selection follows rollout launcher precedence. + if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": + backend = "sglang" + elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": + backend = "vllm" + else: + backend = (rollout_config.extra_rollout_config or dict()).get("lmdeploy_backend", "pytorch") + + backend = backend.lower() + if backend not in ("sglang", "vllm", "pytorch", "turbomind"): + raise ValueError( + f"Unsupported rollout backend: {backend!r}. Expected 'sglang', 'vllm', 'pytorch' or 'turbomind'." + ) + return cast(RolloutBackend, backend) + + +def _validate_transport_type( + *, + weight_transport_type: WeightTransportType | str, + backend: RolloutBackend, +) -> WeightTransportType: + assert weight_transport_type is not None, "bind_rollout_weight_update() must set weight_transport_type." + + transport_type = weight_transport_type.lower() + if transport_type not in ("ipc", "nccl"): + raise ValueError(f"Unsupported weight_transport_type: {weight_transport_type!r}. Expected 'ipc' or 'nccl'.") + transport_type = cast(WeightTransportType, transport_type) + if transport_type == "nccl" and backend in ("vllm", "turbomind"): + raise NotImplementedError(f"NCCL weight transport is not supported for {backend} backend.") + return transport_type + + +@dataclass(frozen=True) +class RolloutWeightUpdateTarget: + """Runtime weight-update endpoint resolved from rollout registry state.""" + + # Server-process worker rank that receives weight update requests. + endpoint_rank: int + # Rollout ranks updated through this endpoint. + update_ranks: tuple[int, ...] + # Runtime rollout server URL resolved from WorkerSnapshot. + server_url: str + # Registry lifecycle state value for this endpoint. + lifecycle_state: str + + @property + def is_active(self) -> bool: + return self.lifecycle_state == "active" + + @property + def engine_size(self) -> int: + return len(self.update_ranks) + + +@dataclass(frozen=True) class RolloutWeightUpdateInfo: - # Common rollout metadata. - api_key: list[str] | str | None = None - rollout_url: str | None = None - backend: RolloutBackend | None = None - tp: int = 1 - ep: int = 1 - train_rollout_mode: TrainRolloutMode | None = None - transport_type: WeightTransportType | None = None - rollout_cfg_info: dict = field(default_factory=dict) - endpoints: dict[str, str] = field(default_factory=lambda: {"update_weights": "update_weights"}) - - # Colocated rollout metadata. - rollout_device_mesh: DeviceMesh | None = None - rollout_engine_rank_mesh_array: DeviceMeshRaw = field(default_factory=list) - - # Disaggregated rollout metadata. - rollout_server_url_dict: ServiceUrlMap = field(default_factory=dict) - worker_server_urls_status: dict[str, bool] = field(default_factory=dict) + # Rollout config owns api_key, backend choice, TP/EP, and default update host/port. + rollout_config: RolloutConfig + # Registry-resolved rollout update targets visible to every train worker. + weight_update_targets: tuple[RolloutWeightUpdateTarget, ...] + # Current train worker rank; used to derive the local weight update target. + train_rank: int + # Weight transport protocol; also determines rollout weight export strategy. + transport_type: WeightTransportType + # Resolved rollout backend used by transports and iterators. + backend: RolloutBackend + # Optional host used by NCCL external weight update groups. weight_update_host: str | None = None + # Optional port used by NCCL external weight update groups. weight_update_port: int | None = None + @classmethod + def from_targets( + cls, + *, + rollout_config: RolloutConfig, + weight_update_targets: tuple[RolloutWeightUpdateTarget, ...], + train_rank: int, + weight_transport_type: WeightTransportType | str, + weight_update_host: str | None = None, + weight_update_port: int | None = None, + ) -> RolloutWeightUpdateInfo: + backend = _resolve_rollout_backend(rollout_config) + tp = rollout_config.tensor_parallel_size + ep = rollout_config.expert_parallel_size + assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." + transport_type = _validate_transport_type( + weight_transport_type=weight_transport_type, + backend=backend, + ) + return cls( + rollout_config=rollout_config, + weight_update_targets=weight_update_targets, + train_rank=train_rank, + transport_type=transport_type, + backend=backend, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port if weight_update_port is not None else 30000, + ) + + @property + def local_update_target(self) -> RolloutWeightUpdateTarget | None: + return next( + (target for target in self.weight_update_targets if self.train_rank == target.endpoint_rank), + None, + ) + + @property + def rollout_url(self) -> str | None: + target = self.local_update_target + if target is None or not target.is_active: + return None + return target.server_url + + @property + def ipc_rank_mesh(self) -> tuple[tuple[int, ...], ...]: + return tuple(target.update_ranks for target in self.weight_update_targets) + + @property + def _ipc_update_target(self) -> RolloutWeightUpdateTarget | None: + return next( + (target for target in self.weight_update_targets if self.train_rank in target.update_ranks), + None, + ) + + @property + def ipc_engine_parallel_rank(self) -> int | None: + target = self._ipc_update_target + if target is None: + return None + return target.update_ranks.index(self.train_rank) + + @property + def ipc_engine_parallel_size(self) -> int | None: + target = self._ipc_update_target + if target is None: + return None + return target.engine_size + + @property + def active_update_targets(self) -> tuple[RolloutWeightUpdateTarget, ...]: + return tuple(target for target in self.weight_update_targets if target.is_active) + + @property + def nccl_engine_infos(self) -> tuple[tuple[int, str, int], ...]: + return tuple( + (target.endpoint_rank, target.server_url, target.engine_size) for target in self.active_update_targets + ) + + @property + def transport_signature(self) -> tuple[Any, ...]: + target_signature = tuple( + ( + target.endpoint_rank, + tuple(int(rank) for rank in target.update_ranks), + target.server_url, + target.lifecycle_state, + ) + for target in self.weight_update_targets + ) + frozen_api_key = tuple(self.api_key) if isinstance(self.api_key, list) else self.api_key + return ( + self.transport_type, + self.backend, + self.tp, + self.ep, + frozen_api_key, + self.weight_update_host, + self.weight_update_port, + target_signature, + ) + + @property + def api_key(self) -> list[str] | str | None: + return self.rollout_config.api_key + + @property + def tp(self) -> int: + return self.rollout_config.tensor_parallel_size + + @property + def ep(self) -> int: + return self.rollout_config.expert_parallel_size + @dataclass class WeightUpdateBatch: """A single bucket of weights to send to rollout workers.""" + # HF-style named tensors or backend-specific tensors for one update bucket. state_dict: dict[str, torch.Tensor] + # Whether the train model uses EP and may need rollout EP slicing. train_enable_ep: bool = False + # Whether this is the final bucket in the current update stream. finished: bool = False diff --git a/xtuner/v1/rl/weight_update/transport.py b/xtuner/v1/rl/weight_update/transport.py index 2ae3639be..637da8648 100644 --- a/xtuner/v1/rl/weight_update/transport.py +++ b/xtuner/v1/rl/weight_update/transport.py @@ -14,6 +14,7 @@ import torch import torch.distributed as dist from packaging.version import parse as parse_version +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.distributed_c10d import ( Backend, PrefixStore, @@ -39,7 +40,9 @@ @dataclass class WeightUpdateRequest: + # HTTP endpoint on the rollout server that should receive this update. endpoint: str + # JSON body sent to the rollout backend adapter endpoint. body: dict[str, Any] @@ -60,8 +63,6 @@ def __init__(self, *, rollout_info: RolloutWeightUpdateInfo, logger: Any, rank: self._adapter: WeightTransportAdapter | None = None self.rollout_url = self.rollout_info.rollout_url - if self.rollout_url is None: - self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") @staticmethod def post_json(url: str, endpoint: str, payload: dict, *, api_key=None) -> dict: @@ -431,9 +432,12 @@ def __init__( self.config = config self._adapter = self._build_adapter() - assert self.rollout_info.rollout_device_mesh is not None - self.rollout_device_mesh = self.rollout_info.rollout_device_mesh - self.cpu_mesh = self.rollout_info.rollout_device_mesh["engine_parallel"] + self.ipc_update_device_mesh = DeviceMesh( + "cpu", + mesh=[list(ranks) for ranks in self.rollout_info.ipc_rank_mesh], + mesh_dim_names=("engine_instance", "engine_parallel"), + ) + self.cpu_mesh = self.ipc_update_device_mesh["engine_parallel"] self.cpu_group = self.cpu_mesh.get_group() self.head_rank = int(self.cpu_mesh.mesh[0].item()) @@ -461,9 +465,11 @@ def after_update_per_group(self) -> None: dist.barrier() def send(self, batch: WeightUpdateBatch) -> None: - if self.rollout_url is None: - self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") + ipc_update_target = self.rollout_info._ipc_update_target + assert ipc_update_target is not None, "IPC rollout target for current train rank is not resolved." + if not ipc_update_target.is_active: return + rollout_url = ipc_update_target.server_url DEVICE_MODULE.empty_cache() try: @@ -475,7 +481,7 @@ def send(self, batch: WeightUpdateBatch) -> None: if dist.get_rank() == self.head_rank: request = self._adapter.build_request(batch, serialized_data) self.post_json( - self.rollout_url, + rollout_url, request.endpoint, request.body, api_key=self.rollout_info.api_key, @@ -652,36 +658,8 @@ def ensure_nccl_weight_update_group(self): if self.group is not None: return - # Map rollout rank to its engine size. - rank_to_engine_size = { - int(rank): len(engine_ranks) - for engine_ranks in self.rollout_info.rollout_engine_rank_mesh_array - for rank in engine_ranks - } - - # Deduplicate rollout engine URLs while keeping the first rank associated - # with each URL as the representative rank for that engine. - url_to_rank: dict[str, int] = {} - for rank, url in sorted( - self.rollout_info.rollout_server_url_dict.items(), - key=lambda item: int(item[0]), - ): - if url: - url_to_rank.setdefault(url, int(rank)) - - # Collect the representative rank, URL, and engine size needed to create - # the NCCL weight update process group. - engine_info = [ - ( - rank, - url, - rank_to_engine_size.get( - rank, - max(self.rollout_info.tp, self.rollout_info.ep), - ), - ) - for url, rank in url_to_rank.items() - ] + # RolloutWeightUpdateInfo owns the runtime target projection. + engine_info = self.rollout_info.nccl_engine_infos if not engine_info: self.logger.error("No active rollout engine url, cannot init sglang weight update group") diff --git a/xtuner/v1/rl/weight_update/update_weighter.py b/xtuner/v1/rl/weight_update/update_weighter.py index 7adb7ba8d..db4558845 100644 --- a/xtuner/v1/rl/weight_update/update_weighter.py +++ b/xtuner/v1/rl/weight_update/update_weighter.py @@ -1,18 +1,13 @@ from __future__ import annotations -import os -from typing import Any, cast - -from torch.distributed.device_mesh import DeviceMesh +from typing import Any from xtuner.v1.rl.rollout.worker import RolloutConfig from .data import ( - DeviceMeshRaw, - RolloutBackend, RolloutWeightUpdateInfo, - ServiceUrlMap, - TrainRolloutMode, + RolloutWeightUpdateTarget, + WeightTransportType, ) from .transport import IPCWeightTransport, NCCLWeightTransport, WeightTransport from .weight_iterator import WeightIterator @@ -24,86 +19,37 @@ def __init__(self, *, rank: int, logger: Any, config: Any, engine: Any): self.logger = logger self.config = config self._engine = engine - # Used to update weight to rollout engine. - self.rollout_info = RolloutWeightUpdateInfo() + # Bound rollout weight-update metadata, available after bind_rollout_weight_update(). + self.rollout_info: RolloutWeightUpdateInfo | None = None + # Lazily constructed iterator bound to the current rollout_info. + self.weight_iterator: WeightIterator | None = None self._global_hf_keys_mapping_cache: dict[str, list[str]] = {} - # Transport is initialized after update_rollout_info() is called. + # Transport is initialized after bind_rollout_weight_update() is called. self._transport: WeightTransport | None = None # Used to detect changes in rollout metadata that require resetting the transport. self._transport_signature: tuple[Any, ...] | None = None - @staticmethod - def _normalize_rollout_backend(rollout_config: RolloutConfig) -> RolloutBackend: - # Backend selection follows rollout launcher precedence: explicit SGLang/vLLM env vars win, - # otherwise the LMDeploy backend decides between pytorch and turbomind. - if os.environ.get("XTUNER_USE_SGLANG", "0") == "1": - backend = "sglang" - elif os.environ.get("XTUNER_USE_VLLM", "0") == "1": - backend = "vllm" - else: - backend = (rollout_config.extra_rollout_config or dict()).get("lmdeploy_backend", "pytorch") - - backend = backend.lower() - if backend not in ("sglang", "vllm", "pytorch", "turbomind"): - raise ValueError( - f"Unsupported rollout backend: {backend!r}. Expected 'sglang', 'vllm', 'pytorch' or 'turbomind'." - ) - return cast(RolloutBackend, backend) - - def update_rollout_info( + def bind_rollout_weight_update( self, - engine_rank_mesh_array: DeviceMeshRaw, - server_url_dict: ServiceUrlMap, + *, + targets: tuple[RolloutWeightUpdateTarget, ...], rollout_config: RolloutConfig, - worker_server_urls_status: dict[str, bool], - train_rollout_mode: TrainRolloutMode, + weight_transport_type: WeightTransportType, weight_update_host: str | None = None, weight_update_port: int | None = None, - worker_session_url_dict: ServiceUrlMap | None = None, - worker_session_urls_status: dict[str, bool] | None = None, ): - """Update the rollout information for the training worker.""" - - self.rollout_info.backend = self._normalize_rollout_backend(rollout_config) - self.set_train_rollout_mode(train_rollout_mode=train_rollout_mode) - - # Common rollout metadata. - tp = rollout_config.tensor_parallel_size - ep = rollout_config.expert_parallel_size - assert tp == 1 or ep == 1, "Either tensor parallel size or engine parallel size must be 1." - self.rollout_info.tp = tp - self.rollout_info.ep = ep - self.rollout_info.api_key = rollout_config.api_key - rollout_server_url = server_url_dict.get(self.rank, "") - if not worker_server_urls_status.get(rollout_server_url, False): - self.logger.error(f"Rollout server url {rollout_server_url} is not available.") - self.rollout_info.rollout_url = None - else: - self.rollout_info.rollout_url = rollout_server_url - - if self.rollout_info.transport_type == "ipc": - # Colocated rollout metadata. - # rollout_device_mesh is created after train_rollout_mode is set. - self.rollout_info.rollout_engine_rank_mesh_array = [ - [int(rank) for rank in ranks] for ranks in engine_rank_mesh_array - ] - self._ensure_rollout_device_mesh() - elif self.rollout_info.transport_type == "nccl": - # Disaggregated rollout metadata. - self.rollout_info.rollout_server_url_dict = {int(rank): url for rank, url in server_url_dict.items()} - self.rollout_info.worker_server_urls_status = worker_server_urls_status - self.rollout_info.weight_update_host = weight_update_host - self.rollout_info.weight_update_port = weight_update_port if weight_update_port is not None else 30000 - - new_transport_signature = self._build_transport_signature( - engine_rank_mesh_array=engine_rank_mesh_array, - server_url_dict=server_url_dict, - worker_server_urls_status=worker_server_urls_status, - train_rollout_mode=train_rollout_mode, - backend=self.rollout_info.backend, - tp=tp, - ep=ep, + """Bind this train worker to rollout weight-update targets.""" + + self.rollout_info = RolloutWeightUpdateInfo.from_targets( + rollout_config=rollout_config, + weight_update_targets=targets, + train_rank=self.rank, + weight_transport_type=weight_transport_type, + weight_update_host=weight_update_host, + weight_update_port=weight_update_port, ) + + new_transport_signature = self.rollout_info.transport_signature # Weight transports may cache resources derived from rollout metadata. # Since rollout workers can fail and recover with new URL/status/mesh metadata, # reset the cached transport whenever that metadata changes. @@ -121,90 +67,32 @@ def update_rollout_info( if self._transport is None: self._set_transport() - def _ensure_rollout_device_mesh(self): - if self.rollout_info.rollout_device_mesh is None: - # 非共卡 SGLang 不使用这个 mesh;只有共卡/旧权重同步路径需要 - # 用 rollout rank 构造 torch DeviceMesh。 - self.rollout_info.rollout_device_mesh = DeviceMesh( - "cpu", - mesh=self.rollout_info.rollout_engine_rank_mesh_array, - mesh_dim_names=("engine_instance", "engine_parallel"), - ) - - def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode | str): - assert train_rollout_mode is not None, "update_rollout_info() must set train_rollout_mode." - - if self.rollout_info.backend is None: - raise RuntimeError("rollout backend is not set. Please set rollout backend in update_rollout_info().") - - mode = train_rollout_mode.lower() - if mode not in ("colocate", "disaggregated"): - raise ValueError( - f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." - ) - mode = cast(TrainRolloutMode, mode) - self.rollout_info.train_rollout_mode = mode - if mode == "colocate": - self.rollout_info.transport_type = "ipc" - elif mode == "disaggregated": - self.rollout_info.transport_type = "nccl" - - backend = self.rollout_info.backend - if backend == "vllm" or backend == "turbomind": - raise NotImplementedError(f"Disaggregated train-rollout mode is not supported for {backend} backend.") - def update_weights(self): """Update the model weights.""" + assert self.rollout_info is not None, "bind_rollout_weight_update() must be called before update_weights()." assert self._transport is not None, ( f"Weight transport is not initialized. transport_type={self.rollout_info.transport_type!r}, " f"backend={self.rollout_info.backend!r}." ) + assert self.weight_iterator is not None, "Weight iterator is not initialized." self._transport.update(self.weight_iterator) def _set_transport(self) -> None: - if self.rollout_info.transport_type == "ipc": + rollout_info = self.rollout_info + assert rollout_info is not None, "bind_rollout_weight_update() must be called before setting transport." + if rollout_info.transport_type == "ipc": self._transport = IPCWeightTransport( rank=self.rank, logger=self.logger, config=self.config, - rollout_info=self.rollout_info, + rollout_info=rollout_info, ) - elif self.rollout_info.transport_type == "nccl": - self._transport = NCCLWeightTransport(rank=self.rank, logger=self.logger, rollout_info=self.rollout_info) + elif rollout_info.transport_type == "nccl": + self._transport = NCCLWeightTransport(rank=self.rank, logger=self.logger, rollout_info=rollout_info) else: raise NotImplementedError - def _build_transport_signature( - self, - *, - engine_rank_mesh_array: DeviceMeshRaw, - server_url_dict: ServiceUrlMap, - worker_server_urls_status: dict[str, bool], - train_rollout_mode: TrainRolloutMode, - backend: RolloutBackend, - tp: int, - ep: int, - ) -> tuple[Any, ...]: - mesh = tuple(tuple(int(rank) for rank in ranks) for ranks in engine_rank_mesh_array) - - active_urls = tuple( - sorted( - (int(rank), url) - for rank, url in server_url_dict.items() - if url and worker_server_urls_status.get(url, False) - ) - ) - - return ( - train_rollout_mode, - backend, - tp, - ep, - mesh, - active_urls, - ) - def _reset_transport(self) -> None: if self._transport is not None: self._transport.teardown() diff --git a/xtuner/v1/rl/weight_update/weight_iterator.py b/xtuner/v1/rl/weight_update/weight_iterator.py index 9e8c5b783..55f236ba1 100644 --- a/xtuner/v1/rl/weight_update/weight_iterator.py +++ b/xtuner/v1/rl/weight_update/weight_iterator.py @@ -38,7 +38,7 @@ def __init__( def iter_batch_groups(self): # Export path depends on rollout protocol: turbomind consumes layer-wise batches, # compose models update submodules in order, and plain models use HF-style batches. - if self.rollout_info.train_rollout_mode == "colocate" and self.rollout_info.backend == "turbomind": + if self.rollout_info.transport_type == "ipc" and self.rollout_info.backend == "turbomind": yield self.iter_layer_batches() return @@ -190,18 +190,20 @@ def iter_hf_batches(self, submodule=None, final_update=False): ) train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 - should_gather_train_ep_shards = self.rollout_info.train_rollout_mode == "disaggregated" and train_enable_ep + should_gather_train_ep_shards = self.rollout_info.transport_type == "nccl" and train_enable_ep if train_enable_ep: - if self.rollout_info.train_rollout_mode == "colocate" and self.rollout_info.ep > 1: - rollout_device_mesh = self.rollout_info.rollout_device_mesh - assert rollout_device_mesh is not None + if self.rollout_info.transport_type == "ipc" and self.rollout_info.ep > 1: + target_ep_rank = self.rollout_info.ipc_engine_parallel_rank + target_ep_size = self.rollout_info.ipc_engine_parallel_size + assert target_ep_rank is not None, "IPC rollout target for current train rank is not resolved." + assert target_ep_size is not None, "IPC rollout target size for current train rank is not resolved." # Colocated IPC can send only the expert slice needed by the local rollout # EP rank fused_gen = self._rl_get_fused_ep_hf_param( model, - target_ep_rank=rollout_device_mesh["engine_parallel"].get_coordinate()[0], - target_ep_size=rollout_device_mesh["engine_parallel"].size(), + target_ep_rank=target_ep_rank, + target_ep_size=target_ep_size, bucket_size=bucket_size, should_gather_train_ep_shards=should_gather_train_ep_shards, ) @@ -251,7 +253,6 @@ def iter_hf_batches(self, submodule=None, final_update=False): @torch.no_grad() def iter_layer_batches(self): """Update the model weights.""" - assert self.rollout_info.rollout_device_mesh is not None model = self._engine.model DEVICE_MODULE.empty_cache() diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index fdb49f3c7..57300d065 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -52,7 +52,7 @@ set_cpu_resource_manager, sort_rollout_state_for_deterministic, ) -from xtuner.v1.rl.weight_update.data import TrainRolloutMode +from xtuner.v1.rl.weight_update.data import WeightTransportType from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, set_deterministic, timer from xtuner.v1.utils.device import get_device, get_torch_device_module @@ -109,18 +109,20 @@ def check_fa3(): def bind_train_rollout( train_controller: TrainingController, rollout_controller: RolloutControllerProxy, - train_rollout_mode: TrainRolloutMode | str, + rollout_config: RolloutConfig, + weight_transport_type: WeightTransportType | str, weight_update_host: str | None = None, weight_update_port: int | None = None, ) -> None: """Bind the training and rollout workers for update weights.""" - info_dict = ray.get( - rollout_controller.get_rollout_metadata.remote(), # type: ignore[attr-defined] + targets = ray.get( + rollout_controller.get_weight_update_targets.remote(), # type: ignore[attr-defined] timeout=RL_TRAINER_RAY_GET_TIMEOUT, ) - train_controller.update_rollout_info( - info_dict, - train_rollout_mode=train_rollout_mode, + train_controller.bind_rollout_weight_update( + targets=targets, + rollout_config=rollout_config, + weight_transport_type=weight_transport_type, weight_update_host=weight_update_host, weight_update_port=weight_update_port, ) @@ -1561,7 +1563,8 @@ def __init__(self, cfg: RLColocateTrainerConfig): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, - train_rollout_mode="colocate", + rollout_config=self._rollout_config, + weight_transport_type="ipc", ) replay_buffer = cfg.replay_buffer_config.build() @@ -1721,7 +1724,8 @@ def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, - train_rollout_mode="colocate", + rollout_config=self._rollout_config, + weight_transport_type="ipc", ) ray.get( self.rollout_controller.onload_weights.remote(), @@ -1768,7 +1772,8 @@ def __init__(self, cfg: RLDisaggregatedTrainerConfig): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, - train_rollout_mode="disaggregated", + rollout_config=self._rollout_config, + weight_transport_type="nccl", weight_update_host=self._rollout_config.weight_update_host, weight_update_port=self._rollout_config.weight_update_port, ) @@ -1962,7 +1967,10 @@ async def _sync_weights_and_save(self, model_step: int, step_timer_dict: dict): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, - train_rollout_mode="disaggregated", + rollout_config=self._rollout_config, + weight_transport_type="nccl", + weight_update_host=self._rollout_config.weight_update_host, + weight_update_port=self._rollout_config.weight_update_port, ) self.update_weights()