Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xtuner/v1/rl/trainer/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,12 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"):
ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore
return

def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host=None, weight_update_port=None):
def update_rollout_info(self, info_dict, weight_update_mode, weight_update_host=None, weight_update_port=None):
ray.get(
[
worker.update_rollout_info.remote(
**info_dict,
train_rollout_mode=train_rollout_mode,
weight_update_mode=weight_update_mode,
weight_update_host=weight_update_host,
weight_update_port=weight_update_port,
)
Expand Down
12 changes: 10 additions & 2 deletions xtuner/v1/rl/weight_update/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
from .data import (
DeviceMeshRaw,
DiskUpdateUpstreamTransport,
RolloutBackend,
RolloutEngineInfo,
RolloutWeightUpdateInfo,
ServiceUrlMap,
TrainRolloutMode,
WeightTransportType,
WeightUpdateBatch,
)
from .transport import (
DiskBackendAdapter,
DiskWeightTransport,
IPCBackendAdapter,
IPCWeightTransport,
LMDeployDiskBackendAdapter,
LMDeployIPCBackendAdapter,
NCCLBackendAdapter,
NCCLWeightTransport,
SGLangDiskBackendAdapter,
SGLangIPCBackendAdapter,
SGLangNCCLBackendAdapter,
WeightTransport,
Expand All @@ -24,19 +28,23 @@


__all__ = [
"DiskBackendAdapter",
"DiskUpdateUpstreamTransport",
"DiskWeightTransport",
"DeviceMeshRaw",
"IPCBackendAdapter",
"IPCWeightTransport",
"LMDeployDiskBackendAdapter",
"LMDeployIPCBackendAdapter",
"NCCLBackendAdapter",
"NCCLWeightTransport",
"RolloutBackend",
"RolloutEngineInfo",
"RolloutWeightUpdateInfo",
"SGLangDiskBackendAdapter",
"SGLangIPCBackendAdapter",
"SGLangNCCLBackendAdapter",
"ServiceUrlMap",
"TrainRolloutMode",
"UpdateWeighter",
"WeightIterator",
"WeightTransportType",
Expand Down
9 changes: 6 additions & 3 deletions xtuner/v1/rl/weight_update/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices.
ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping rollout ranks to their server URLs.
RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count)
TrainRolloutMode: TypeAlias = Literal["colocate", "disaggregated"] # Train and rollout deployment mode.
RolloutBackend: TypeAlias = Literal["sglang", "vllm", "pytorch", "turbomind"] # Rollout inference backend.
WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types.
WeightTransportType: TypeAlias = Literal["ipc", "nccl", "disk"] # Supported weight transport types.
DiskUpdateUpstreamTransport: TypeAlias = Literal["ipc", "nccl"] # How disk-loaded weights are delivered to rollout.


@dataclass
Expand All @@ -23,7 +23,6 @@ class RolloutWeightUpdateInfo:
backend: RolloutBackend | None = None
tp: int = 1
ep: int = 1
train_rollout_mode: TrainRolloutMode | None = None
transport_type: WeightTransportType | None = None
rollout_cfg_info: dict = field(default_factory=dict)
endpoints: dict[str, str] = field(default_factory=lambda: {"update_weights": "update_weights"})
Expand All @@ -38,6 +37,10 @@ class RolloutWeightUpdateInfo:
weight_update_host: str | None = None
weight_update_port: int | None = None

# Disk update metadata.
hf_weight_path: str | None = None
disk_update_upstream_transport: DiskUpdateUpstreamTransport | None = None


@dataclass
class WeightUpdateBatch:
Expand Down
155 changes: 154 additions & 1 deletion xtuner/v1/rl/weight_update/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def ensure_nccl_weight_update_group(self):

def send(self, batch: WeightUpdateBatch) -> None:
state_dict = batch.state_dict
if not state_dict:
if not state_dict and not batch.finished:
return

train_sync_group = self.get_train_update_sync_group()
Expand Down Expand Up @@ -860,3 +860,156 @@ def teardown(self) -> None:
self.group_name = None
self.engine_urls = []
self.external_group_world_size = None


class DiskBackendAdapter:
def update(self, weight_iterator: Any) -> None:
raise NotImplementedError

def teardown(self) -> None:
return


class SGLangDiskBackendAdapter(DiskBackendAdapter):
def __init__(self, *, rank: int, rollout_info: RolloutWeightUpdateInfo):
self.rank = rank
self.rollout_info = rollout_info
self.executor: ThreadPoolExecutor | None = None

def build_request(self, hf_weight_path: str) -> WeightUpdateRequest:
# SGLang already owns the disk reload path. XTuner only needs to pass
# the HF checkpoint directory to the rollout server.
return WeightUpdateRequest(
endpoint="update_weights_from_disk",
body={
"model_path": hf_weight_path,
"load_format": "safetensors",
"abort_all_requests": True,
"flush_cache": True,
},
)

def update(self, weight_iterator: Any) -> None:
# SGLang consumes the checkpoint path on the rollout server side.
del weight_iterator

hf_weight_path = self.rollout_info.hf_weight_path
if not hf_weight_path:
raise RuntimeError("Disk weight update requires rollout_info.hf_weight_path from rollout_config.")

try:
if dist.get_rank() != 0:
dist.barrier()
return

target_urls = list(dict.fromkeys(url for url in self.rollout_info.rollout_server_url_dict.values() if url))
if not target_urls:
raise RuntimeError("Disk weight update requires at least one rollout server url.")
request = self.build_request(hf_weight_path)
self.executor = ThreadPoolExecutor(max_workers=max(1, len(target_urls)))
futures = [
self.executor.submit(
WeightTransport.post_json,
url,
request.endpoint,
request.body,
api_key=self.rollout_info.api_key,
)
for url in target_urls
]
for future in futures:
result = future.result()
assert result.get("success", True), f"disk weight update failed: {result.get('message', result)}"
dist.barrier()
finally:
self.teardown()
DEVICE_MODULE.empty_cache()

def teardown(self) -> None:
if self.executor is not None:
self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None


class LMDeployDiskBackendAdapter(DiskBackendAdapter):
def __init__(
self,
*,
rank: int,
logger: Any,
rollout_info: RolloutWeightUpdateInfo,
config: Any | None,
upstream_transport: str,
):
self.upstream_transport = upstream_transport
self._batch_transport = self._build_batch_transport(
rank=rank,
logger=logger,
rollout_info=rollout_info,
config=config,
)

def _build_batch_transport(
self,
*,
rank: int,
logger: Any,
rollout_info: RolloutWeightUpdateInfo,
config: Any | None,
) -> WeightTransport:
if self.upstream_transport == "ipc":
return IPCWeightTransport(
rank=rank,
logger=logger,
config=config,
rollout_info=rollout_info,
)
elif self.upstream_transport == "nccl":
return NCCLWeightTransport(rank=rank, logger=logger, rollout_info=rollout_info)
else:
raise ValueError(f"Unsupported disk weight update upstream transport: {self.upstream_transport!r}")

def update(self, weight_iterator: Any) -> None:
# WeightIterator.iter_batch_groups() switches disk mode to iter_disk_hf_batches().
# The underlying LMDeploy transport then uses the existing tensor update endpoints.
self._batch_transport.update(weight_iterator)

def teardown(self) -> None:
self._batch_transport.teardown()


class DiskWeightTransport(WeightTransport):
_disk_adapter: DiskBackendAdapter

def __init__(self, *, rank: int, logger: Any, rollout_info: RolloutWeightUpdateInfo, config: Any | None = None):
super().__init__(rank=rank, logger=logger, rollout_info=rollout_info)
self.config = config
self._disk_adapter = self._build_adapter()

def _build_adapter(self) -> DiskBackendAdapter:
if self.backend == "sglang":
return SGLangDiskBackendAdapter(rank=self.rank, rollout_info=self.rollout_info)
elif self.backend == "pytorch":
upstream_transport = self.rollout_info.disk_update_upstream_transport or "ipc"
return LMDeployDiskBackendAdapter(
rank=self.rank,
logger=self.logger,
config=self.config,
rollout_info=self.rollout_info,
upstream_transport=upstream_transport,
)
raise ValueError(f"Unsupported disk weight update backend: {self.backend!r}")

def update(self, weight_iterator: Any) -> None:
self._disk_adapter.update(weight_iterator)

def send(self, batch: WeightUpdateBatch) -> None:
raise NotImplementedError("DiskWeightTransport bypasses WeightIterator batches.")

def after_update_all_groups(self) -> None:
self._disk_adapter.teardown()
DEVICE_MODULE.empty_cache()

def teardown(self) -> None:
self._disk_adapter.teardown()
super().teardown()
Loading
Loading