Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions fastdeploy/rl/dynamic_weight_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
self._capture_model_state()
self.rdma_handle = None
self.use_gdr_checkpoint_transfer = envs.FD_USE_GDR_CHECKPOINT_TRANSFER
self._gdr_ct_handle = None

if self.use_gdr_checkpoint_transfer:
self.update_weights_by_gdr()
Expand Down Expand Up @@ -175,14 +176,8 @@ def update_weights_by_gdr(
f"load_strategy:{self.load_config.load_strategy}, step_id:{step_id}"
)

from checkpoint_transfer.transfer import CheckpointTransfer

transfer_config = self._build_ct_transfer_config(config)
logger.info(f"CheckpointTransfer config:{transfer_config}")
ct_handle = CheckpointTransfer(transfer_config)

total_start = time.perf_counter()
asyncio.run(ct_handle.initialize())
ct_handle = self._ensure_gdr_handle(config)
try:
weights_iterator = ct_handle.receive_weights_sync(step_id=step_id, output_framework="paddle")

Expand All @@ -192,8 +187,9 @@ def update_weights_by_gdr(
paddle.empty(target_param.shape, dtype=target_param.dtype)._share_buffer_to(target_param)
logger.debug(f"Restored cleared parameter storage before GDR checkpoint transfer load: {name}")
update_count, mtp_cache_count = self._load_models_from_weight_iterator(weights_iterator)
finally:
asyncio.run(ct_handle.cleanup())
except Exception:
self._destroy_gdr_handle()
raise
self._capture_model_state(log_params=False)
total_cost = time.perf_counter() - total_start
logger.info(
Expand All @@ -210,6 +206,32 @@ def update_weights_by_gdr(
"mtp_cache_count": mtp_cache_count,
}

def _ensure_gdr_handle(self, config: dict):
"""Lazily create and initialize the CheckpointTransfer handle (once)."""
if self._gdr_ct_handle is not None:

This comment was marked as outdated.

return self._gdr_ct_handle

transfer_config = self._build_ct_transfer_config(config)
logger.info(f"CheckpointTransfer config:{transfer_config}")

from checkpoint_transfer.transfer import CheckpointTransfer

ct_handle = CheckpointTransfer(transfer_config)
asyncio.run(ct_handle.initialize())

self._gdr_ct_handle = ct_handle

This comment was marked as outdated.

logger.info("[GDR] CheckpointTransfer initialized and cached for reuse")
return ct_handle

def _destroy_gdr_handle(self):
"""Destroy the cached GDR handle (e.g. on error)."""
if self._gdr_ct_handle is not None:
try:
asyncio.run(self._gdr_ct_handle.cleanup())
except Exception:

This comment was marked as outdated.

pass
self._gdr_ct_handle = None

def _build_ct_transfer_config(self, config: dict):
from dataclasses import fields

Expand Down
26 changes: 20 additions & 6 deletions tests/rl/test_dynamic_weight_gdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _make_manager(rsync_config=None, load_strategy="rsync"):
manager.model_list = [_FakeModel()]
manager.state_dict = {}
manager.use_gdr_checkpoint_transfer = True
manager._gdr_ct_handle = None
return manager


Expand Down Expand Up @@ -256,25 +257,33 @@ def test_update_weights_by_gdr_gdr_mode(self):
class FakeCheckpointTransfer:
def __init__(self, config):
self.config = config
self.step_ids = []
created.append(self)

def receive_weights_sync(self, step_id, output_framework="paddle"):
self.step_id = step_id
self.step_ids.append(step_id)
self.output_framework = output_framework
yield "model.layers.0.weight", object()
yield f"model.layers.{len(self.step_ids)}.weight", object()

manager = _make_manager()

with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer):
result = manager.update_weights_by_gdr(version="step-1")
second_result = manager.update_weights_by_gdr(version="step-2")

self.assertEqual(result["version"], "step-1")
self.assertEqual(second_result["version"], "step-2")
self.assertEqual(result["update_count"], 1)
self.assertEqual(second_result["update_count"], 1)
self.assertIn("total_cost", result)
self.assertEqual(manager.model_list[0].loaded[0][0], "model.layers.0.weight")
self.assertEqual(
[name for name, _ in manager.model_list[0].loaded], ["model.layers.1.weight", "model.layers.2.weight"]
)
self.assertEqual(len(created), 1)
self.assertIs(manager._gdr_ct_handle, created[0])
self.assertTrue(created[0].initialized)
self.assertTrue(created[0].cleaned)
self.assertEqual(created[0].step_id, "step-1")
self.assertFalse(hasattr(created[0], "cleaned"))
self.assertEqual(created[0].step_ids, ["step-1", "step-2"])
self.assertEqual(created[0].output_framework, "paddle")
self.assertEqual(created[0].config.kwargs["role"], _FakeRole.INFERENCE)
self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.GPU_DIRECT)
Expand Down Expand Up @@ -313,9 +322,11 @@ def receive_weights_sync(self, step_id, output_framework="paddle"):
self.assertEqual(created[0].config.kwargs["qsize"], 2)

def test_gdr_checkpoint_transfer_receive_exception_propagates(self):
created = []

class FakeCheckpointTransfer:
def __init__(self, config):
pass
created.append(self)

def receive_weights_sync(self, step_id, output_framework="paddle"):
yield "model.layers.0.weight", object()
Expand All @@ -333,6 +344,9 @@ def load_weights(self, weights_iterator):
with self.assertRaisesRegex(RuntimeError, "receive failed"):
manager.update_weights_by_gdr(version="step-error")

self.assertTrue(created[0].cleaned)
self.assertIsNone(manager._gdr_ct_handle)

def test_gdr_checkpoint_transfer_refreshes_state_dict_after_model_loader(self):
loaded_param = object()

Expand Down
Loading