diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index 5a4666d46b5..978015e7d18 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -55,6 +55,7 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int): self._capture_model_state() self.rdma_handle = None self.use_gdr_checkpoint_transfer = envs.FD_USE_GDR_CHECKPOINT_TRANSFER + self._gdr_ct_handle = None if self.use_gdr_checkpoint_transfer: self.update_weights_by_gdr() @@ -175,14 +176,8 @@ def update_weights_by_gdr( f"load_strategy:{self.load_config.load_strategy}, step_id:{step_id}" ) - from checkpoint_transfer.transfer import CheckpointTransfer - - transfer_config = self._build_ct_transfer_config(config) - logger.info(f"CheckpointTransfer config:{transfer_config}") - ct_handle = CheckpointTransfer(transfer_config) - total_start = time.perf_counter() - asyncio.run(ct_handle.initialize()) + ct_handle = self._ensure_gdr_handle(config) try: weights_iterator = ct_handle.receive_weights_sync(step_id=step_id, output_framework="paddle") @@ -192,8 +187,9 @@ def update_weights_by_gdr( paddle.empty(target_param.shape, dtype=target_param.dtype)._share_buffer_to(target_param) logger.debug(f"Restored cleared parameter storage before GDR checkpoint transfer load: {name}") update_count, mtp_cache_count = self._load_models_from_weight_iterator(weights_iterator) - finally: - asyncio.run(ct_handle.cleanup()) + except Exception: + self._destroy_gdr_handle() + raise self._capture_model_state(log_params=False) total_cost = time.perf_counter() - total_start logger.info( @@ -210,6 +206,32 @@ def update_weights_by_gdr( "mtp_cache_count": mtp_cache_count, } + def _ensure_gdr_handle(self, config: dict): + """Lazily create and initialize the CheckpointTransfer handle (once).""" + if self._gdr_ct_handle is not None: + return self._gdr_ct_handle + + transfer_config = self._build_ct_transfer_config(config) + logger.info(f"CheckpointTransfer config:{transfer_config}") + + from checkpoint_transfer.transfer import CheckpointTransfer + + ct_handle = CheckpointTransfer(transfer_config) + asyncio.run(ct_handle.initialize()) + + self._gdr_ct_handle = ct_handle + logger.info("[GDR] CheckpointTransfer initialized and cached for reuse") + return ct_handle + + def _destroy_gdr_handle(self): + """Destroy the cached GDR handle (e.g. on error).""" + if self._gdr_ct_handle is not None: + try: + asyncio.run(self._gdr_ct_handle.cleanup()) + except Exception: + pass + self._gdr_ct_handle = None + def _build_ct_transfer_config(self, config: dict): from dataclasses import fields diff --git a/tests/rl/test_dynamic_weight_gdr.py b/tests/rl/test_dynamic_weight_gdr.py index f6b1be0baad..ed9c8dfca2b 100644 --- a/tests/rl/test_dynamic_weight_gdr.py +++ b/tests/rl/test_dynamic_weight_gdr.py @@ -183,6 +183,7 @@ def _make_manager(rsync_config=None, load_strategy="rsync"): manager.model_list = [_FakeModel()] manager.state_dict = {} manager.use_gdr_checkpoint_transfer = True + manager._gdr_ct_handle = None return manager @@ -256,25 +257,33 @@ def test_update_weights_by_gdr_gdr_mode(self): class FakeCheckpointTransfer: def __init__(self, config): self.config = config + self.step_ids = [] created.append(self) def receive_weights_sync(self, step_id, output_framework="paddle"): - self.step_id = step_id + self.step_ids.append(step_id) self.output_framework = output_framework - yield "model.layers.0.weight", object() + yield f"model.layers.{len(self.step_ids)}.weight", object() manager = _make_manager() with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer): result = manager.update_weights_by_gdr(version="step-1") + second_result = manager.update_weights_by_gdr(version="step-2") self.assertEqual(result["version"], "step-1") + self.assertEqual(second_result["version"], "step-2") self.assertEqual(result["update_count"], 1) + self.assertEqual(second_result["update_count"], 1) self.assertIn("total_cost", result) - self.assertEqual(manager.model_list[0].loaded[0][0], "model.layers.0.weight") + self.assertEqual( + [name for name, _ in manager.model_list[0].loaded], ["model.layers.1.weight", "model.layers.2.weight"] + ) + self.assertEqual(len(created), 1) + self.assertIs(manager._gdr_ct_handle, created[0]) self.assertTrue(created[0].initialized) - self.assertTrue(created[0].cleaned) - self.assertEqual(created[0].step_id, "step-1") + self.assertFalse(hasattr(created[0], "cleaned")) + self.assertEqual(created[0].step_ids, ["step-1", "step-2"]) self.assertEqual(created[0].output_framework, "paddle") self.assertEqual(created[0].config.kwargs["role"], _FakeRole.INFERENCE) self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.GPU_DIRECT) @@ -313,9 +322,11 @@ def receive_weights_sync(self, step_id, output_framework="paddle"): self.assertEqual(created[0].config.kwargs["qsize"], 2) def test_gdr_checkpoint_transfer_receive_exception_propagates(self): + created = [] + class FakeCheckpointTransfer: def __init__(self, config): - pass + created.append(self) def receive_weights_sync(self, step_id, output_framework="paddle"): yield "model.layers.0.weight", object() @@ -333,6 +344,9 @@ def load_weights(self, weights_iterator): with self.assertRaisesRegex(RuntimeError, "receive failed"): manager.update_weights_by_gdr(version="step-error") + self.assertTrue(created[0].cleaned) + self.assertIsNone(manager._gdr_ct_handle) + def test_gdr_checkpoint_transfer_refreshes_state_dict_after_model_loader(self): loaded_param = object()