From 772de6ab8e784cc532972f52e1a7dcdd3d8adda9 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 29 Jan 2026 18:54:43 +0800 Subject: [PATCH 1/3] optimize rollout-related ut execution time --- tests/ray/test_evaluator.py | 4 +- tests/ray/test_mock_rollout.py | 56 +++--- tests/ray/test_rl_trainer.py | 4 +- tests/ray/test_rollout.py | 347 +++++++++++++++++++------------- tests/ray/test_update_weight.py | 4 +- tests/ray/test_vl_rollout.py | 4 +- 6 files changed, 244 insertions(+), 175 deletions(-) diff --git a/tests/ray/test_evaluator.py b/tests/ray/test_evaluator.py index 7915d96b9..321070f87 100644 --- a/tests/ray/test_evaluator.py +++ b/tests/ray/test_evaluator.py @@ -22,10 +22,12 @@ class TestEvaluator(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.environ["XTUNER_USE_FA3"] = "1" - + os.environ["LMD_SKIP_WARMUP"] = "1" + @classmethod def tearDownClass(cls) -> None: del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] def init_config(self): self.resources_cfg = AcceleratorResourcesConfig( diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index 97c51ed00..c57ece907 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -1,4 +1,5 @@ import os +import asyncio import unittest import ray from transformers import AutoTokenizer @@ -47,7 +48,6 @@ class TestMockRollout(unittest.TestCase): def setUpClass(cls): os.environ["XTUNER_USE_FA3"] = "1" - @classmethod def tearDownClass(cls): del os.environ["XTUNER_USE_FA3"] @@ -61,15 +61,6 @@ def setUp(self): self.max_retry_times = 3 self.temp_dir = tempfile.TemporaryDirectory() self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - - self.resources_cfg = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - self.rollout_cfg = RolloutConfig( env="test_mock_rollout", model_path=MODEL_PATH, @@ -108,33 +99,40 @@ def tearDown(self): ray.shutdown() self.temp_dir.cleanup() - def _run_mock_test(self, mock_controller_cls, error_name: str): - rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg) - self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller) + async def _run_mock_test(self, mock_controller_cls, error_name, pg): + rollout_controller = mock_controller_cls.remote(self.rollout_cfg, pg) + self.test_env = SingleTurnEnvironment.remote("env", pg, self.rollout_cfg, rollout_controller=rollout_controller) self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) - completed_rollouts = ray.get(self.test_dataflow.run.remote(num=3))["data_groups"] - status = ray.get(self.test_dataflow.get_replaybuffer_status.remote()) + result = await self.test_dataflow.run.remote(num=3) + completed_rollouts = result["data_groups"] + status = await self.test_dataflow.get_replaybuffer_status.remote() self.assertEqual(len(completed_rollouts), 0, f"[{error_name}] Expected no rollouts to complete successfully.") self.assertEqual(status["remain_completed_samples_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") self.assertEqual(status["remain_aborted_samples_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.") - ray.get(self.test_env.shutdown.remote()) + await self.test_env.shutdown.remote() @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_rollout_with_timeout_mock(self): - self._run_mock_test(MockTimeoutRolloutController, "timeout") + def test_parallel_mock_rollout(self): + async def run_parallel(): + res_cfg_small = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=2, + num_cpus_per_worker=2, + ) + + pgs = [AutoAcceleratorWorkers.build_placement_group(res_cfg_small, name=f"pg_{i}") for i in range(4)] + await asyncio.gather(*[pg.ready() for pg in pgs]) - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_rollout_with_request_error_mock(self): - self._run_mock_test(MockRequestErrorRolloutController, "request error") - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_rollout_with_client_error_mock(self): - self._run_mock_test(MockClientErrorRolloutController, "client error") - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_rollout_with_server_error_mock(self): - self._run_mock_test(MockServerErrorRolloutController, "server error") + tasks = [ + self._run_mock_test(MockTimeoutRolloutController, "timeout", pgs[0]), + self._run_mock_test(MockRequestErrorRolloutController, "request_error", pgs[1]), + self._run_mock_test(MockClientErrorRolloutController, "client_error", pgs[2]), + self._run_mock_test(MockServerErrorRolloutController, "server_error", pgs[3]), + ] + await asyncio.gather(*tasks) + + asyncio.run(run_parallel()) if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/tests/ray/test_rl_trainer.py b/tests/ray/test_rl_trainer.py index 930cf7b14..113c94fd8 100644 --- a/tests/ray/test_rl_trainer.py +++ b/tests/ray/test_rl_trainer.py @@ -38,11 +38,13 @@ class TestRLTrainer(unittest.TestCase): @classmethod def setUpClass(cls): os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" @classmethod def tearDownClass(cls): del os.environ["XTUNER_USE_FA3"] - + del os.environ["LMD_SKIP_WARMUP"] + def init_traine_worker_config(self, train_optimizer_steps, pack_max_length): model_cfg = get_model_config_from_hf(Path(MODEL_PATH)) optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py index e03af21ac..72d4600a9 100644 --- a/tests/ray/test_rollout.py +++ b/tests/ray/test_rollout.py @@ -21,6 +21,7 @@ DataloaderConfig, DatasetConfig, ) +import asyncio TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] @@ -36,11 +37,13 @@ class TestRollout(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.environ["XTUNER_USE_FA3"] = "1" - + os.environ["LMD_SKIP_WARMUP"] = "1" + @classmethod def tearDownClass(cls) -> None: del os.environ["XTUNER_USE_FA3"] - + del os.environ["LMD_SKIP_WARMUP"] + def init_config(self): self.resources_cfg = AcceleratorResourcesConfig( accelerator=resource_map[torch.accelerator.current_accelerator().type], @@ -50,20 +53,7 @@ def init_config(self): ) self.max_prompt_length = 512 self.max_response_length = 1024 - self.rollout_cfg = RolloutConfig( - env="test_rollout", - model_path=self.model_path, - model_name=os.path.basename(self.model_path).lower(), - tokenizer_path=self.model_path, - rollout_cross_node_comm=False, - tensor_parallel_size=1, - expert_parallel_size=1, - gpus_per_node=8, # gpu: 8, npu: 16 - dtype="bfloat16", - launch_server_method="ray", - context_length=self.max_prompt_length + self.max_response_length, - worker_log_dir=self.worker_log_dir, - ) + self.context_length = self.max_prompt_length + self.max_response_length from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") self.judger_cfg = JudgerConfig( @@ -72,8 +62,8 @@ def init_config(self): ) self.dataflow_cfg = DataFlowConfig( env="test", - prompt_repeat_k=2, - global_batch_size=2, + prompt_repeat_k=1, + global_batch_size=1, enable_partial_rollout=0, max_retry_times=1, worker_log_dir=self.worker_log_dir, @@ -102,12 +92,10 @@ def init_config(self): def setUp(self): ray.init(num_cpus=80, ignore_reinit_error=True) self.data_path = TRAIN_DATA_PATH - self.model_path = MODEL_PATH self.temp_dir = tempfile.TemporaryDirectory() self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") self.init_config() - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) def tearDown(self): ray.shutdown() @@ -117,6 +105,53 @@ def tearDown(self): self._cleanup_lmdeploy_ray_worker_wrapper() self.temp_dir.cleanup() + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_parallel_rollout(self): + resource_config = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, # 8 GB + ) + pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="tp_pg") + pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="ep_pg") + dense_model_path = MODEL_PATH + moe_model_path = MOE_MODEL_PATH + + async def run_both(): + return await asyncio.gather( + self._run_rollout(dense_model_path, 4, 1, pg1), # tp + self._run_rollout(moe_model_path, 1, 4, pg2), # tp + return_exceptions=False + ) + + asyncio.run(run_both()) + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") + def test_parallel_model_save_and_resume(self): + resource_config = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=4, + cpu_memory_per_worker=8 * 1024**3, # 8 GB + ) + pg1 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="dense_pg") + pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="moe_pg") + + async def run_both(): + return await asyncio.wait_for( + asyncio.gather( + self._run_dense_save_resume_sync_async(pg1), + self._run_moe_save_resume_with_r3(pg2), + return_exceptions=False + ), + timeout=300 + ) + try: + asyncio.run(run_both()) + except asyncio.TimeoutError: + self.fail("test_parallel_model_save_and_resume timed out after 300s") + def _cleanup_lmdeploy_ray_worker_wrapper(self): try: result = subprocess.run(["pkill", "-f", "ray::RayWorkerWrapper*"], capture_output=True, text=True, timeout=10) @@ -126,114 +161,70 @@ def _cleanup_lmdeploy_ray_worker_wrapper(self): except Exception as e: print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_generate(self): - rollout_cfg = self.rollout_cfg.model_copy( - deep=True, - update=dict(tensor_parallel_size=2), - ) - rollout_cfg.model_post_init(None) - - sample_params = SampleParams(temperature=0.0) - rollout_controller = ray.remote(RolloutController).remote(rollout_cfg, self.pg) # type: ignore[attr-defined] - res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - - self.assertEqual(res1.finish_reason, "stop") - print("Response from LMDeploy infer:", res1) - ray.get(rollout_controller.shutdown.remote(), timeout=300) - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_dataflow(self): - rollout_cfg = self.rollout_cfg.model_copy( - deep=True, - update=dict( - expert_parallel_size=2, - model_path=self.model_path, - model_name=os.path.basename(MOE_MODEL_PATH).lower(), - tokenizer_path=MOE_MODEL_PATH, - ), - ) - rollout_cfg.model_post_init(None) + async def _run_rollout(self, model_path, tp_size, ep_size, pg): + rollout_config = RolloutConfig( + env="test_rollout", + model_path=model_path, + model_name=os.path.basename(model_path).lower(), + tokenizer_path=model_path, + tensor_parallel_size=tp_size, + expert_parallel_size=ep_size, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + dist_port_base=38000, - self.dataflow_cfg.enable_partial_rollout = 0 - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - rollout_cfg=rollout_cfg, ) - self.test_flow = DataFlow.remote("test_env", - self.dataflow_cfg, - self.replay_buffer_cfg, - self.test_env - ) - responses = ray.get(self.test_flow.run.remote(), timeout=300)["data_groups"] - finished_samples_count = sum(1 for data in responses for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") - self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) - ray.get(self.test_env.shutdown.remote(), timeout=300) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + try: + result = await asyncio.wait_for(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES), timeout=300) + self.assertEqual(result.finish_reason, "stop") + except asyncio.TimeoutError: + self.fail("TP Rollout timed out!") + finally: + await asyncio.wait_for(rollout_controller.shutdown.remote(), timeout=300) - def _get_sorted_input_ids(self, responses): - """Helper to extract and sort input_ids from responses.""" - all_ids = [] - for data_items in responses: - for data_item in data_items: - all_ids.extend(data_item.data.input_ids) - all_ids.sort() - return all_ids - - def _run_dataflow_save_resume_test(self, rollout_cfg, dataflow_cfg): + async def _run_dataflow_save_resume_test(self, test_env, dataflow_cfg: DataFlowConfig, replay_buffer_cfg: ReplayBufferConfig): """ Generic driver for dataflow save/resume tests. """ # 1. Initialize Environment and DataFlow is_partial_rollout = dataflow_cfg.enable_partial_rollout == 1 - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - rollout_cfg=rollout_cfg, - ) - self.test_flow = DataFlow.remote( - "test_env", - dataflow_cfg, - self.replay_buffer_cfg, - self.test_env - ) + test_flow = DataFlow.remote("test_env", dataflow_cfg, replay_buffer_cfg, test_env) # 2. Initial Run - ray.get(self.test_flow.run.remote(), timeout=300) + await test_flow.run.remote() # Capture status before saving (critical for partial rollout consistency check) - rl_status_before_save = ray.get(self.test_flow.get_replaybuffer_status.remote()) + rl_status_before_save = await test_flow.get_replaybuffer_status.remote() # 3. Save save_dir = Path(self.temp_dir.name) / 'checkpoints' / f'ckpt-step-2' save_dir.mkdir(parents=True, exist_ok=True) - ray.get(self.test_flow.save.remote(save_dir)) + await test_flow.save.remote(save_dir) # Define run logic based on mode - def run_continuation(status_ref): + async def run_continuation(status_ref): if is_partial_rollout: remain = status_ref["remain_aborted_samples_count"] + status_ref["remain_completed_samples_count"] # Finish the remaining paused samples - return ray.get(self.test_flow.run.remote(num=remain, staleness_threshold=0), timeout=300)["data_groups"] + result = await test_flow.run.remote(num=remain, enable_partial_rollout=0) + return result["data_groups"] else: # Normal run - return ray.get(self.test_flow.run.remote(), timeout=300)["data_groups"] + result = await test_flow.run.remote() + return result["data_groups"] # continue running after save - responses_old = run_continuation(rl_status_before_save) - rb_status_old = ray.get(self.test_flow.get_replaybuffer_status.remote()) + responses_old = await run_continuation(rl_status_before_save) + rb_status_old = await test_flow.get_replaybuffer_status.remote() # resume from saved checkpoint - ray.get(self.test_flow.resume.remote(save_dir)) - rl_status_resume = ray.get(self.test_flow.get_replaybuffer_status.remote()) - responses_new = run_continuation(rl_status_resume) - rb_status_new = ray.get(self.test_flow.get_replaybuffer_status.remote()) - - # 6. Cleanup - ray.get(self.test_env.shutdown.remote(), timeout=300) + await test_flow.resume.remote(save_dir) + rl_status_resume = await test_flow.get_replaybuffer_status.remote() + responses_new = await run_continuation(rl_status_resume) + rb_status_new = await test_flow.get_replaybuffer_status.remote() - # 7. Assertions # Compare Data ids_old = self._get_sorted_input_ids(responses_old) ids_new = self._get_sorted_input_ids(responses_new) @@ -247,57 +238,110 @@ def run_continuation(status_ref): if is_partial_rollout: for key in rl_status_before_save: self.assertEqual(rl_status_before_save[key], rl_status_resume[key]) - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_dataflow_save_resume(self): - rollout_cfg = self.rollout_cfg - dataflow_cfg = self.dataflow_cfg - dataflow_cfg.staleness_threshold = 0 - dataflow_cfg.enable_partial_rollout = 0 - self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg) - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_dataflow_save_resume_with_partial_rollout(self): - rollout_cfg = self.rollout_cfg - dataflow_cfg = self.dataflow_cfg - dataflow_cfg.staleness_threshold = 1 - dataflow_cfg.enable_partial_rollout = 1 - self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg) - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_dataflow_save_resume_with_partial_rollout_r3(self): - model_path = MOE_MODEL_PATH - rollout_cfg = RolloutConfig( + async def _run_dense_save_resume_sync_async(self, pg): + model_path = MODEL_PATH + worker_log_dir = os.path.join(self.worker_log_dir, "test_dense") + rollout_config = RolloutConfig( env="test_rollout", model_path=model_path, model_name=os.path.basename(model_path).lower(), tokenizer_path=model_path, - rollout_cross_node_comm=False, - tensor_parallel_size=1, - expert_parallel_size=1, - gpus_per_node=8, - dtype="bfloat16", - launch_server_method="ray", - context_length=self.max_prompt_length + self.max_response_length, - worker_log_dir=self.worker_log_dir, - enable_return_routed_experts=True, + context_length=self.context_length, + worker_log_dir=worker_log_dir, + dist_port_base=37000, + ) + test_env = SingleTurnEnvironment.remote( + "test_env", + pg, + rollout_cfg=rollout_config, + ) + sync_dataflow_cfg = DataFlowConfig( + env="test", + prompt_repeat_k=2, + global_batch_size=2, + enable_partial_rollout=0, + max_concurrent=2, + max_retry_times=1, + worker_log_dir=worker_log_dir, ) - dataflow_cfg = DataFlowConfig( + async_dataflow_cfg = DataFlowConfig( env="test", prompt_repeat_k=2, global_batch_size=2, enable_partial_rollout=1, staleness_threshold=1, + max_retry_times=1, worker_log_dir=self.worker_log_dir, ) - self._run_dataflow_save_resume_test(rollout_cfg, dataflow_cfg) + replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=self.train_dataset_cfg, + dataloader_cfg=self.dataloader_cfg, + tokenizer=self.tokenizer, + worker_log_dir=worker_log_dir, + ) + self._run_dataflow_save_resume_test(test_env, sync_dataflow_cfg, replay_buffer_cfg) + self._run_dataflow_save_resume_test(test_env, async_dataflow_cfg, replay_buffer_cfg) + + async def _run_moe_save_resume_with_r3(self, pg): + model_path = MOE_MODEL_PATH + worker_log_dir = os.path.join(self.worker_log_dir, "test_moe_r3") + rollout_config = RolloutConfig( + env="test_rollout", + model_path=model_path, + model_name=os.path.basename(model_path).lower(), + tokenizer_path=model_path, + expert_parallel_size=2, + context_length=self.context_length, + worker_log_dir=worker_log_dir, + dist_port_base=36000, + ) + test_env = SingleTurnEnvironment.remote( + "test_env", + pg, + rollout_cfg=rollout_config, + ) + async_dataflow_cfg = DataFlowConfig( + env="test", + prompt_repeat_k=2, + global_batch_size=2, + enable_partial_rollout=1, + max_concurrent=4, + max_retry_times=1, + worker_log_dir=worker_log_dir, + ) + replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=self.train_dataset_cfg, + dataloader_cfg=self.dataloader_cfg, + tokenizer=self.tokenizer, + worker_log_dir=worker_log_dir, + ) + self._run_dataflow_save_resume_test(test_env, async_dataflow_cfg, replay_buffer_cfg) + + def _get_sorted_input_ids(self, responses): + """Helper to extract and sort input_ids from responses.""" + all_ids = [] + for data_items in responses[0]: + for data_item in data_items: + all_ids.extend(data_item.data.input_ids) + all_ids.sort() + return all_ids @unittest.skip("skip lmdeploy turbomind generate test due to ci environment issue") def test_lmdeploy_turbomind_generate(self): from xtuner.v1.ray.rollout import LMDeployWorker - self.rollout_cfg.extra_rollout_config["lmdeploy_backend"] = "turbomind" + rollout_config = RolloutConfig( + env="test_rollout", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + extra_rollout_config={"lmdeploy_backend": "turbomind"}, + ) sample_params = SampleParams(temperature=0.0) - rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined] + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) # type: ignore[attr-defined] res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) res2 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) self.assertEqual(res1, res2, f"res1 != res2, res1={res1}, res2={res2}") @@ -307,8 +351,18 @@ def test_lmdeploy_turbomind_generate(self): def test_sglang_generate(self): from xtuner.v1.ray.rollout import SGLangWorker self.rollout_cfg.launch_server_method="multiprocessing" + rollout_config = RolloutConfig( + env="test_rollout", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + launch_server_method="multiprocessing" + ) sample_params = SampleParams(temperature=0.0) - rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined] + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) # type: ignore[attr-defined] res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) self.assertEqual(res1.finish_reason, "stop") print("Response from SGLang infer:", res1) @@ -317,21 +371,30 @@ def test_sglang_generate(self): @unittest.skipIf(os.environ.get("XTUNER_USE_SGLANG", "0") == "0", "lmdeploy backend is not enabled") def test_sglang_dataflow(self): self.dataflow_cfg.enable_partial_rollout = 0 - self.rollout_cfg.launch_server_method="multiprocessing" - self.test_env = SingleTurnEnvironment.remote( + rollout_config = RolloutConfig( + env="test_rollout", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + context_length=self.context_length, + worker_log_dir=self.worker_log_dir, + launch_server_method="multiprocessing" + ) + pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + test_env = SingleTurnEnvironment.remote( "test_env", - self.pg, - rollout_cfg=self.rollout_cfg, + pg, + rollout_cfg=rollout_config, ) - self.test_flow = DataFlow.remote("test_env", + test_flow = DataFlow.remote("test_env", self.dataflow_cfg, self.replay_buffer_cfg, - self.test_env + test_env ) - responses = ray.get(self.test_flow.run.remote(), timeout=300)["data_groups"] - finished_samples_count = sum(1 for data in responses for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") + responses = ray.get(test_flow.run.remote(), timeout=300)["data_groups"] + finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) - ray.get(self.test_env.shutdown.remote(), timeout=300) + ray.get(test_env.shutdown.remote(), timeout=300) print("responses: ", responses) if __name__ == "__main__": diff --git a/tests/ray/test_update_weight.py b/tests/ray/test_update_weight.py index dfc3668ba..a8110039b 100644 --- a/tests/ray/test_update_weight.py +++ b/tests/ray/test_update_weight.py @@ -25,11 +25,13 @@ class TestUpdateWeight(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" @classmethod def tearDownClass(cls) -> None: del os.environ["XTUNER_USE_FA3"] - + del os.environ["LMD_SKIP_WARMUP"] + def setUp(self): ray.init(num_cpus=80, ignore_reinit_error=True) self.model_path = MODEL_PATH diff --git a/tests/ray/test_vl_rollout.py b/tests/ray/test_vl_rollout.py index 4dc03db96..81621e9d2 100644 --- a/tests/ray/test_vl_rollout.py +++ b/tests/ray/test_vl_rollout.py @@ -35,11 +35,13 @@ class TestRollout(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" @classmethod def tearDownClass(cls) -> None: del os.environ["XTUNER_USE_FA3"] - + del os.environ["LMD_SKIP_WARMUP"] + def init_config(self): self.resources_cfg = AcceleratorResourcesConfig( accelerator=resource_map[torch.accelerator.current_accelerator().type], From 4b63e139b0733bebf90a37d70179e8d82344d38d Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 3 Feb 2026 12:11:04 +0800 Subject: [PATCH 2/3] delete llm update_weights and optim vl update_weights --- tests/ray/test_rollout.py | 10 +- tests/ray/test_update_weight.py | 45 +++------ tests/ray/test_vl_update_weight.py | 155 ----------------------------- 3 files changed, 21 insertions(+), 189 deletions(-) delete mode 100644 tests/ray/test_vl_update_weight.py diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py index 72d4600a9..203a4f469 100644 --- a/tests/ray/test_rollout.py +++ b/tests/ray/test_rollout.py @@ -117,11 +117,11 @@ def test_parallel_rollout(self): pg2 = AutoAcceleratorWorkers.build_placement_group(resource_config, name="ep_pg") dense_model_path = MODEL_PATH moe_model_path = MOE_MODEL_PATH - + dist_port_base = 38000 async def run_both(): return await asyncio.gather( - self._run_rollout(dense_model_path, 4, 1, pg1), # tp - self._run_rollout(moe_model_path, 1, 4, pg2), # tp + self._run_rollout(model_path=dense_model_path, tp_size=4, ep_size=1, pg=pg1, dist_port_base=dist_port_base), + self._run_rollout(model_path=moe_model_path, tp_size=1, ep_size=4, pg=pg2, dist_port_base=dist_port_base + 1024 * 4), return_exceptions=False ) @@ -161,7 +161,7 @@ def _cleanup_lmdeploy_ray_worker_wrapper(self): except Exception as e: print(f"Error stopping ray::RayWorkerWrapper cluster: {e}") - async def _run_rollout(self, model_path, tp_size, ep_size, pg): + async def _run_rollout(self, model_path, tp_size, ep_size, pg, dist_port_base): rollout_config = RolloutConfig( env="test_rollout", model_path=model_path, @@ -171,7 +171,7 @@ async def _run_rollout(self, model_path, tp_size, ep_size, pg): expert_parallel_size=ep_size, context_length=self.context_length, worker_log_dir=self.worker_log_dir, - dist_port_base=38000, + dist_port_base=dist_port_base, ) rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) diff --git a/tests/ray/test_update_weight.py b/tests/ray/test_update_weight.py index a8110039b..fa008d3d7 100644 --- a/tests/ray/test_update_weight.py +++ b/tests/ray/test_update_weight.py @@ -3,7 +3,6 @@ import tempfile import ray -from xtuner.v1.ray.base import AutoAcceleratorWorkers from xtuner.v1.ray.rollout import RolloutController from xtuner.v1.data_proto.rl_data import SampleParams from xtuner.v1.config import ( @@ -11,27 +10,25 @@ FSDPConfig, LRConfig, ) -from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig from xtuner.v1.ray.config.worker import RolloutConfig from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig -from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense4BConfig + +TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] +MODEL_PATH = os.environ["QWEN3_VL_DENSE_PATH"] -TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] -MODEL_PATH = os.environ["QWEN3_MOE_PATH"] class TestUpdateWeight(unittest.TestCase): @classmethod def setUpClass(cls) -> None: os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" @classmethod def tearDownClass(cls) -> None: del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - + def setUp(self): ray.init(num_cpus=80, ignore_reinit_error=True) self.model_path = MODEL_PATH @@ -59,7 +56,7 @@ def init_config(self): rollout_cross_node_comm=False, tensor_parallel_size=4, expert_parallel_size=1, - gpus_per_node=8, # gpu: 8, npu: 16 + gpus_per_node=8, # gpu: 8, npu: 16 dtype="bfloat16", skip_load_weights=True, context_length=256, @@ -68,14 +65,9 @@ def init_config(self): ) # training config - model_cfg = get_model_config_from_hf(model_path=MODEL_PATH) - if hasattr(model_cfg, 'z_loss_cfg'): - model_cfg.z_loss_cfg = ZLossConfig() - if hasattr(model_cfg, 'balancing_loss_cfg'): - model_cfg.balancing_loss_cfg = BalancingLossConfig() + model_cfg = Qwen3VLDense4BConfig() optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) fsdp_cfg: FSDPConfig = FSDPConfig(ep_size=4) - model_cfg.ep_size = fsdp_cfg.ep_size lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) self.worker_cfg: WorkerConfig = WorkerConfig( model_cfg=model_cfg, @@ -122,37 +114,32 @@ def test_lmdeploy_update_weight_and_generate(self): # fixed sample params sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) - # init rollout_update + # init rollout_controller and rollout baseline + self.rollout_cfg.skip_load_weights = False rollout_controller = ray.remote(RolloutController).remote( self.rollout_cfg, self.pg, ) + + res_baseline = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) + + # start update weight test info_dict = ray.get(rollout_controller.get_rollout_info.remote()) ray.get(train_controller.update_rollout_info.remote(info_dict)) # update weights ray.get(rollout_controller.offload.remote()) - ray.get(rollout_controller.onload_weights.remote()) + ray.get(train_controller.onload.remote(target="all")) ray.get(train_controller.offload.remote(["optimizer"])) + ray.get(rollout_controller.onload_weights.remote()) ray.get(train_controller.update_weights.remote()) ray.get(train_controller.offload.remote(["model"])) ray.get(rollout_controller.onload_kvcache.remote()) res_update_weight = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) + self.assertEqual(res_update_weight.response, res_baseline.response) ray.get(rollout_controller.shutdown.remote(), timeout=60) - # init rollout_ref - self.rollout_cfg.skip_load_weights = False - rollout_controller_ref = ray.remote(RolloutController).remote( - self.rollout_cfg, - self.pg, - ) - - res_ref = ray.get(rollout_controller_ref.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - ray.get(rollout_controller_ref.shutdown.remote(), timeout=60) - - self.assertEqual(res_update_weight.response, res_ref.response) - if __name__ == "__main__": test_instance = TestUpdateWeight() diff --git a/tests/ray/test_vl_update_weight.py b/tests/ray/test_vl_update_weight.py deleted file mode 100644 index bf0718a98..000000000 --- a/tests/ray/test_vl_update_weight.py +++ /dev/null @@ -1,155 +0,0 @@ -import os -import unittest -import tempfile -import ray - -from xtuner.v1.ray.rollout import RolloutController -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig -from xtuner.v1.model.compose.qwen3_vl import Qwen3VLDense4BConfig - -TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] -MODEL_PATH = os.environ["QWEN3_VL_DENSE_PATH"] - - -class TestUpdateWeight(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.model_path = MODEL_PATH - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=4, - num_cpus_per_worker=12, - cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB - ) - self.rollout_cfg = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - rollout_cross_node_comm=False, - tensor_parallel_size=4, - expert_parallel_size=1, - gpus_per_node=8, # gpu: 8, npu: 16 - dtype="bfloat16", - skip_load_weights=True, - context_length=256, - worker_log_dir=self.worker_log_dir, - gpu_memory_utilization=0.5, - ) - - # training config - model_cfg = Qwen3VLDense4BConfig() - optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) - fsdp_cfg: FSDPConfig = FSDPConfig(ep_size=4) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) - self.worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - optim_cfg=optim_cfg, - loss_cfg=LossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="eager"), - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - load_from=MODEL_PATH, - sp_size=1, - pack_max_length=1024, - ) - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_update_weight_and_generate(self): - # init train - TrainingWorker = ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - } - }, - )(BaseTrainingWorker) - train_workers, _ = AutoAcceleratorWorkers.from_placement_group( - TrainingWorker, self.worker_cfg, self.pg - ) - futures = [ worker.test_all_reduce.remote() for worker in train_workers ] - ray.get(futures) - train_controller = TrainingController.remote( - workers=train_workers, - ) - ray.get(train_controller.__ray_ready__.remote()) - - # fixed sample params - sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) - - # init rollout_update - rollout_controller = ray.remote(RolloutController).remote( - self.rollout_cfg, - self.pg, - ) - info_dict = ray.get(rollout_controller.get_rollout_info.remote()) - ray.get(train_controller.update_rollout_info.remote(info_dict)) - - # update weights - ray.get(rollout_controller.offload.remote()) - ray.get(rollout_controller.onload_weights.remote()) - ray.get(train_controller.offload.remote(["optimizer"])) - ray.get(train_controller.update_weights.remote()) - ray.get(train_controller.offload.remote(["model"])) - ray.get(rollout_controller.onload_kvcache.remote()) - - res_update_weight = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - ray.get(rollout_controller.shutdown.remote(), timeout=60) - - # init rollout_ref - self.rollout_cfg.skip_load_weights = False - rollout_controller_ref = ray.remote(RolloutController).remote( - self.rollout_cfg, - self.pg, - ) - - res_ref = ray.get(rollout_controller_ref.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) - ray.get(rollout_controller_ref.shutdown.remote(), timeout=60) - - self.assertEqual(res_update_weight.response, res_ref.response) - - -if __name__ == "__main__": - test_instance = TestUpdateWeight() - test_instance.setUp() - try: - test_instance.test_lmdeploy_update_weight_and_generate() - finally: - test_instance.tearDown() From 15d267bf6d939d94be25f0d5bccabdec99ba53c6 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 3 Feb 2026 14:55:16 +0800 Subject: [PATCH 3/3] fix sglang ut typo --- tests/ray/test_rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py index 203a4f469..31f8542d3 100644 --- a/tests/ray/test_rollout.py +++ b/tests/ray/test_rollout.py @@ -392,7 +392,7 @@ def test_sglang_dataflow(self): test_env ) responses = ray.get(test_flow.run.remote(), timeout=300)["data_groups"] - finished_samples_count = sum(1 for data in responses[0] for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") + finished_samples_count = sum(1 for data in responses for item in data if item.env.rollout.finish_reason == "stop" or item.env.rollout.finish_reason == "length") self.assertEqual(finished_samples_count // self.dataflow_cfg.prompt_repeat_k, self.dataflow_cfg.global_batch_size) ray.get(test_env.shutdown.remote(), timeout=300) print("responses: ", responses)