diff --git a/xtuner/v1/rl/trainer/controller.py b/xtuner/v1/rl/trainer/controller.py index d0965f216..088fde354 100644 --- a/xtuner/v1/rl/trainer/controller.py +++ b/xtuner/v1/rl/trainer/controller.py @@ -290,12 +290,12 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"): ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore return - def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host=None, weight_update_port=None): + def update_rollout_info(self, info_dict, weight_update_mode, weight_update_host=None, weight_update_port=None): ray.get( [ worker.update_rollout_info.remote( **info_dict, - train_rollout_mode=train_rollout_mode, + weight_update_mode=weight_update_mode, weight_update_host=weight_update_host, weight_update_port=weight_update_port, ) diff --git a/xtuner/v1/rl/weight_update/__init__.py b/xtuner/v1/rl/weight_update/__init__.py index 312f779fe..31f2b1652 100644 --- a/xtuner/v1/rl/weight_update/__init__.py +++ b/xtuner/v1/rl/weight_update/__init__.py @@ -1,19 +1,23 @@ from .data import ( DeviceMeshRaw, + DiskUpdateUpstreamTransport, RolloutBackend, RolloutEngineInfo, RolloutWeightUpdateInfo, ServiceUrlMap, - TrainRolloutMode, WeightTransportType, WeightUpdateBatch, ) from .transport import ( + DiskBackendAdapter, + DiskWeightTransport, IPCBackendAdapter, IPCWeightTransport, + LMDeployDiskBackendAdapter, LMDeployIPCBackendAdapter, NCCLBackendAdapter, NCCLWeightTransport, + SGLangDiskBackendAdapter, SGLangIPCBackendAdapter, SGLangNCCLBackendAdapter, WeightTransport, @@ -24,19 +28,23 @@ __all__ = [ + "DiskBackendAdapter", + "DiskUpdateUpstreamTransport", + "DiskWeightTransport", "DeviceMeshRaw", "IPCBackendAdapter", "IPCWeightTransport", + "LMDeployDiskBackendAdapter", "LMDeployIPCBackendAdapter", "NCCLBackendAdapter", "NCCLWeightTransport", "RolloutBackend", "RolloutEngineInfo", "RolloutWeightUpdateInfo", + "SGLangDiskBackendAdapter", "SGLangIPCBackendAdapter", "SGLangNCCLBackendAdapter", "ServiceUrlMap", - "TrainRolloutMode", "UpdateWeighter", "WeightIterator", "WeightTransportType", diff --git a/xtuner/v1/rl/weight_update/data.py b/xtuner/v1/rl/weight_update/data.py index 6041ff643..327c6f92e 100644 --- a/xtuner/v1/rl/weight_update/data.py +++ b/xtuner/v1/rl/weight_update/data.py @@ -10,9 +10,9 @@ DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices. ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping rollout ranks to their server URLs. RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count) -TrainRolloutMode: TypeAlias = Literal["colocate", "disaggregated"] # Train and rollout deployment mode. RolloutBackend: TypeAlias = Literal["sglang", "vllm", "pytorch", "turbomind"] # Rollout inference backend. -WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types. +WeightTransportType: TypeAlias = Literal["ipc", "nccl", "disk"] # Supported weight transport types. +DiskUpdateUpstreamTransport: TypeAlias = Literal["ipc", "nccl"] # How disk-loaded weights are delivered to rollout. @dataclass @@ -23,7 +23,6 @@ class RolloutWeightUpdateInfo: backend: RolloutBackend | None = None tp: int = 1 ep: int = 1 - train_rollout_mode: TrainRolloutMode | None = None transport_type: WeightTransportType | None = None rollout_cfg_info: dict = field(default_factory=dict) endpoints: dict[str, str] = field(default_factory=lambda: {"update_weights": "update_weights"}) @@ -38,6 +37,10 @@ class RolloutWeightUpdateInfo: weight_update_host: str | None = None weight_update_port: int | None = None + # Disk update metadata. + hf_weight_path: str | None = None + disk_update_upstream_transport: DiskUpdateUpstreamTransport | None = None + @dataclass class WeightUpdateBatch: diff --git a/xtuner/v1/rl/weight_update/transport.py b/xtuner/v1/rl/weight_update/transport.py index 2ae3639be..19b9c38c7 100644 --- a/xtuner/v1/rl/weight_update/transport.py +++ b/xtuner/v1/rl/weight_update/transport.py @@ -739,7 +739,7 @@ def ensure_nccl_weight_update_group(self): def send(self, batch: WeightUpdateBatch) -> None: state_dict = batch.state_dict - if not state_dict: + if not state_dict and not batch.finished: return train_sync_group = self.get_train_update_sync_group() @@ -860,3 +860,156 @@ def teardown(self) -> None: self.group_name = None self.engine_urls = [] self.external_group_world_size = None + + +class DiskBackendAdapter: + def update(self, weight_iterator: Any) -> None: + raise NotImplementedError + + def teardown(self) -> None: + return + + +class SGLangDiskBackendAdapter(DiskBackendAdapter): + def __init__(self, *, rank: int, rollout_info: RolloutWeightUpdateInfo): + self.rank = rank + self.rollout_info = rollout_info + self.executor: ThreadPoolExecutor | None = None + + def build_request(self, hf_weight_path: str) -> WeightUpdateRequest: + # SGLang already owns the disk reload path. XTuner only needs to pass + # the HF checkpoint directory to the rollout server. + return WeightUpdateRequest( + endpoint="update_weights_from_disk", + body={ + "model_path": hf_weight_path, + "load_format": "safetensors", + "abort_all_requests": True, + "flush_cache": True, + }, + ) + + def update(self, weight_iterator: Any) -> None: + # SGLang consumes the checkpoint path on the rollout server side. + del weight_iterator + + hf_weight_path = self.rollout_info.hf_weight_path + if not hf_weight_path: + raise RuntimeError("Disk weight update requires rollout_info.hf_weight_path from rollout_config.") + + try: + if dist.get_rank() != 0: + dist.barrier() + return + + target_urls = list(dict.fromkeys(url for url in self.rollout_info.rollout_server_url_dict.values() if url)) + if not target_urls: + raise RuntimeError("Disk weight update requires at least one rollout server url.") + request = self.build_request(hf_weight_path) + self.executor = ThreadPoolExecutor(max_workers=max(1, len(target_urls))) + futures = [ + self.executor.submit( + WeightTransport.post_json, + url, + request.endpoint, + request.body, + api_key=self.rollout_info.api_key, + ) + for url in target_urls + ] + for future in futures: + result = future.result() + assert result.get("success", True), f"disk weight update failed: {result.get('message', result)}" + dist.barrier() + finally: + self.teardown() + DEVICE_MODULE.empty_cache() + + def teardown(self) -> None: + if self.executor is not None: + self.executor.shutdown(wait=False, cancel_futures=True) + self.executor = None + + +class LMDeployDiskBackendAdapter(DiskBackendAdapter): + def __init__( + self, + *, + rank: int, + logger: Any, + rollout_info: RolloutWeightUpdateInfo, + config: Any | None, + upstream_transport: str, + ): + self.upstream_transport = upstream_transport + self._batch_transport = self._build_batch_transport( + rank=rank, + logger=logger, + rollout_info=rollout_info, + config=config, + ) + + def _build_batch_transport( + self, + *, + rank: int, + logger: Any, + rollout_info: RolloutWeightUpdateInfo, + config: Any | None, + ) -> WeightTransport: + if self.upstream_transport == "ipc": + return IPCWeightTransport( + rank=rank, + logger=logger, + config=config, + rollout_info=rollout_info, + ) + elif self.upstream_transport == "nccl": + return NCCLWeightTransport(rank=rank, logger=logger, rollout_info=rollout_info) + else: + raise ValueError(f"Unsupported disk weight update upstream transport: {self.upstream_transport!r}") + + def update(self, weight_iterator: Any) -> None: + # WeightIterator.iter_batch_groups() switches disk mode to iter_disk_hf_batches(). + # The underlying LMDeploy transport then uses the existing tensor update endpoints. + self._batch_transport.update(weight_iterator) + + def teardown(self) -> None: + self._batch_transport.teardown() + + +class DiskWeightTransport(WeightTransport): + _disk_adapter: DiskBackendAdapter + + def __init__(self, *, rank: int, logger: Any, rollout_info: RolloutWeightUpdateInfo, config: Any | None = None): + super().__init__(rank=rank, logger=logger, rollout_info=rollout_info) + self.config = config + self._disk_adapter = self._build_adapter() + + def _build_adapter(self) -> DiskBackendAdapter: + if self.backend == "sglang": + return SGLangDiskBackendAdapter(rank=self.rank, rollout_info=self.rollout_info) + elif self.backend == "pytorch": + upstream_transport = self.rollout_info.disk_update_upstream_transport or "ipc" + return LMDeployDiskBackendAdapter( + rank=self.rank, + logger=self.logger, + config=self.config, + rollout_info=self.rollout_info, + upstream_transport=upstream_transport, + ) + raise ValueError(f"Unsupported disk weight update backend: {self.backend!r}") + + def update(self, weight_iterator: Any) -> None: + self._disk_adapter.update(weight_iterator) + + def send(self, batch: WeightUpdateBatch) -> None: + raise NotImplementedError("DiskWeightTransport bypasses WeightIterator batches.") + + def after_update_all_groups(self) -> None: + self._disk_adapter.teardown() + DEVICE_MODULE.empty_cache() + + def teardown(self) -> None: + self._disk_adapter.teardown() + super().teardown() diff --git a/xtuner/v1/rl/weight_update/update_weighter.py b/xtuner/v1/rl/weight_update/update_weighter.py index 7adb7ba8d..ece037b91 100644 --- a/xtuner/v1/rl/weight_update/update_weighter.py +++ b/xtuner/v1/rl/weight_update/update_weighter.py @@ -12,9 +12,9 @@ RolloutBackend, RolloutWeightUpdateInfo, ServiceUrlMap, - TrainRolloutMode, + WeightTransportType, ) -from .transport import IPCWeightTransport, NCCLWeightTransport, WeightTransport +from .transport import DiskWeightTransport, IPCWeightTransport, NCCLWeightTransport, WeightTransport from .weight_iterator import WeightIterator @@ -56,7 +56,7 @@ def update_rollout_info( server_url_dict: ServiceUrlMap, rollout_config: RolloutConfig, worker_server_urls_status: dict[str, bool], - train_rollout_mode: TrainRolloutMode, + weight_update_mode: WeightTransportType | str, weight_update_host: str | None = None, weight_update_port: int | None = None, worker_session_url_dict: ServiceUrlMap | None = None, @@ -65,7 +65,7 @@ def update_rollout_info( """Update the rollout information for the training worker.""" self.rollout_info.backend = self._normalize_rollout_backend(rollout_config) - self.set_train_rollout_mode(train_rollout_mode=train_rollout_mode) + self.set_weight_update_mode(weight_update_mode=weight_update_mode) # Common rollout metadata. tp = rollout_config.tensor_parallel_size @@ -74,6 +74,16 @@ def update_rollout_info( self.rollout_info.tp = tp self.rollout_info.ep = ep self.rollout_info.api_key = rollout_config.api_key + extra_rollout_config = rollout_config.extra_rollout_config or dict() + # PSEUDO: the disk transport consumes the HF checkpoint path from rollout_config. + # A future RolloutConfig field can replace the extra_rollout_config fallback. + self.rollout_info.hf_weight_path = cast( + str | None, getattr(rollout_config, "hf_weight_path", None) or extra_rollout_config.get("hf_weight_path") + ) + self.rollout_info.disk_update_upstream_transport = cast( + Any, + extra_rollout_config.get("disk_update_upstream_transport"), + ) rollout_server_url = server_url_dict.get(self.rank, "") if not worker_server_urls_status.get(rollout_server_url, False): self.logger.error(f"Rollout server url {rollout_server_url} is not available.") @@ -83,7 +93,6 @@ def update_rollout_info( if self.rollout_info.transport_type == "ipc": # Colocated rollout metadata. - # rollout_device_mesh is created after train_rollout_mode is set. self.rollout_info.rollout_engine_rank_mesh_array = [ [int(rank) for rank in ranks] for ranks in engine_rank_mesh_array ] @@ -94,12 +103,30 @@ def update_rollout_info( self.rollout_info.worker_server_urls_status = worker_server_urls_status self.rollout_info.weight_update_host = weight_update_host self.rollout_info.weight_update_port = weight_update_port if weight_update_port is not None else 30000 + elif self.rollout_info.transport_type == "disk": + # Disk update needs rollout URL metadata, and LMDeploy still needs to + # know whether the underlying transfer should mirror IPC or NCCL. + disk_update_upstream_transport = (self.rollout_info.disk_update_upstream_transport or "ipc").lower() + if disk_update_upstream_transport not in ("ipc", "nccl"): + raise ValueError( + f"disk_update_upstream_transport must be 'ipc' or 'nccl', got {disk_update_upstream_transport!r}." + ) + self.rollout_info.disk_update_upstream_transport = cast(Any, disk_update_upstream_transport) + self.rollout_info.rollout_engine_rank_mesh_array = [ + [int(rank) for rank in ranks] for ranks in engine_rank_mesh_array + ] + self.rollout_info.rollout_server_url_dict = {int(rank): url for rank, url in server_url_dict.items()} + self.rollout_info.worker_server_urls_status = worker_server_urls_status + self.rollout_info.weight_update_host = weight_update_host + self.rollout_info.weight_update_port = weight_update_port if weight_update_port is not None else 30000 + if disk_update_upstream_transport == "ipc": + self._ensure_rollout_device_mesh() new_transport_signature = self._build_transport_signature( engine_rank_mesh_array=engine_rank_mesh_array, server_url_dict=server_url_dict, worker_server_urls_status=worker_server_urls_status, - train_rollout_mode=train_rollout_mode, + weight_update_mode=self.rollout_info.transport_type, backend=self.rollout_info.backend, tp=tp, ep=ep, @@ -131,24 +158,21 @@ def _ensure_rollout_device_mesh(self): mesh_dim_names=("engine_instance", "engine_parallel"), ) - def set_train_rollout_mode(self, train_rollout_mode: TrainRolloutMode | str): - assert train_rollout_mode is not None, "update_rollout_info() must set train_rollout_mode." + def set_weight_update_mode(self, weight_update_mode: WeightTransportType | str): + assert weight_update_mode is not None, "update_rollout_info() must set weight_update_mode." if self.rollout_info.backend is None: raise RuntimeError("rollout backend is not set. Please set rollout backend in update_rollout_info().") - mode = train_rollout_mode.lower() - if mode not in ("colocate", "disaggregated"): + mode = weight_update_mode.lower() + if mode not in ("ipc", "nccl", "disk"): raise ValueError( - f"Unsupported train_rollout_mode: {train_rollout_mode!r}. Expected 'colocate' or 'disaggregated'." + f"Unsupported weight_update_mode: {weight_update_mode!r}. Expected 'ipc', 'nccl' or 'disk'." ) - mode = cast(TrainRolloutMode, mode) - self.rollout_info.train_rollout_mode = mode - if mode == "colocate": - self.rollout_info.transport_type = "ipc" - elif mode == "disaggregated": - self.rollout_info.transport_type = "nccl" + mode = cast(WeightTransportType, mode) + self.rollout_info.transport_type = mode + if mode == "nccl": backend = self.rollout_info.backend if backend == "vllm" or backend == "turbomind": raise NotImplementedError(f"Disaggregated train-rollout mode is not supported for {backend} backend.") @@ -172,6 +196,13 @@ def _set_transport(self) -> None: ) elif self.rollout_info.transport_type == "nccl": self._transport = NCCLWeightTransport(rank=self.rank, logger=self.logger, rollout_info=self.rollout_info) + elif self.rollout_info.transport_type == "disk": + self._transport = DiskWeightTransport( + rank=self.rank, + logger=self.logger, + config=self.config, + rollout_info=self.rollout_info, + ) else: raise NotImplementedError @@ -181,7 +212,7 @@ def _build_transport_signature( engine_rank_mesh_array: DeviceMeshRaw, server_url_dict: ServiceUrlMap, worker_server_urls_status: dict[str, bool], - train_rollout_mode: TrainRolloutMode, + weight_update_mode: WeightTransportType | None, backend: RolloutBackend, tp: int, ep: int, @@ -197,7 +228,7 @@ def _build_transport_signature( ) return ( - train_rollout_mode, + weight_update_mode, backend, tp, ep, diff --git a/xtuner/v1/rl/weight_update/weight_iterator.py b/xtuner/v1/rl/weight_update/weight_iterator.py index 9e8c5b783..9e58070b3 100644 --- a/xtuner/v1/rl/weight_update/weight_iterator.py +++ b/xtuner/v1/rl/weight_update/weight_iterator.py @@ -1,11 +1,14 @@ from __future__ import annotations +import json from itertools import chain +from pathlib import Path from typing import Any, cast import torch import torch.distributed as dist import tqdm +from safetensors import safe_open from torch.distributed.tensor import DTensor from xtuner.v1.model.compose.base import BaseComposeConfig @@ -38,7 +41,11 @@ def __init__( def iter_batch_groups(self): # Export path depends on rollout protocol: turbomind consumes layer-wise batches, # compose models update submodules in order, and plain models use HF-style batches. - if self.rollout_info.train_rollout_mode == "colocate" and self.rollout_info.backend == "turbomind": + if self.rollout_info.transport_type == "disk" and self.rollout_info.backend != "sglang": + yield self.iter_disk_hf_batches(final_update=True) + return + + if self.rollout_info.transport_type == "ipc" and self.rollout_info.backend == "turbomind": yield self.iter_layer_batches() return @@ -190,10 +197,10 @@ def iter_hf_batches(self, submodule=None, final_update=False): ) train_enable_ep = model.fsdp_config is not None and model.fsdp_config.ep_size > 1 - should_gather_train_ep_shards = self.rollout_info.train_rollout_mode == "disaggregated" and train_enable_ep + should_gather_train_ep_shards = self.rollout_info.transport_type == "nccl" and train_enable_ep if train_enable_ep: - if self.rollout_info.train_rollout_mode == "colocate" and self.rollout_info.ep > 1: + if self.rollout_info.transport_type == "ipc" and self.rollout_info.ep > 1: rollout_device_mesh = self.rollout_info.rollout_device_mesh assert rollout_device_mesh is not None # Colocated IPC can send only the expert slice needed by the local rollout @@ -350,3 +357,62 @@ def get_params(tensor_list, name_list, save_dtype): yield WeightUpdateBatch({}, finished=True) DEVICE_MODULE.empty_cache() + + @torch.no_grad() + def iter_disk_hf_batches(self, final_update: bool = True): + hf_weight_path = self.rollout_info.hf_weight_path + if not hf_weight_path: + raise RuntimeError("Disk weight update requires rollout_info.hf_weight_path.") + + loader = _HFCheckpointLoader(hf_weight_path) + bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) + + state_dict: dict[str, torch.Tensor] = {} + bucket_bytes = 0 + for name, tensor in loader.iter_tensors(): + tensor = tensor.to(device=DEVICE, non_blocking=True).contiguous() + tensor_bytes = tensor.numel() * tensor.element_size() + + if state_dict and bucket_bytes + tensor_bytes > bucket_size: + yield WeightUpdateBatch(state_dict, finished=False) + state_dict = {} + bucket_bytes = 0 + + state_dict[name] = tensor + bucket_bytes += tensor_bytes + + if state_dict: + yield WeightUpdateBatch(state_dict, finished=False) + + if self.rollout_info.backend in ("pytorch", "vllm") and final_update: + yield WeightUpdateBatch({}, finished=True) + + DEVICE_MODULE.empty_cache() + + +class _HFCheckpointLoader: + def __init__(self, model_path: str | Path): + self.model_path = Path(model_path) + index_path = self.model_path / "model.safetensors.index.json" + single_path = self.model_path / "model.safetensors" + + if index_path.exists(): + with open(index_path) as f: + self.weight_map = json.load(f)["weight_map"] + elif single_path.exists(): + with safe_open(single_path, framework="pt", device="cpu") as f: + self.weight_map = dict.fromkeys(f.keys(), single_path.name) + else: + raise FileNotFoundError( + f"Cannot find model.safetensors.index.json or model.safetensors in {self.model_path}" + ) + + def iter_tensors(self): + current_file = None + buffer = None + for name, filename in self.weight_map.items(): + if filename != current_file: + buffer = safe_open(self.model_path / filename, framework="pt", device="cpu") + current_file = filename + assert buffer is not None + yield name, buffer.get_tensor(name) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index eecb796dd..ab40e30d0 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -52,7 +52,6 @@ set_cpu_resource_manager, sort_rollout_state_for_deterministic, ) -from xtuner.v1.rl.weight_update.data import TrainRolloutMode from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_logger, is_hf_model_path, set_deterministic, timer from xtuner.v1.utils.device import get_device, get_torch_device_module @@ -119,7 +118,7 @@ def check_fa3(): def bind_train_rollout( train_controller: TrainingController, rollout_controller: RolloutControllerProxy, - train_rollout_mode: TrainRolloutMode | str, + weight_update_mode: Literal["ipc", "nccl", "disk"] | str, weight_update_host: str | None = None, weight_update_port: int | None = None, ) -> None: @@ -130,7 +129,7 @@ def bind_train_rollout( ) train_controller.update_rollout_info( info_dict, - train_rollout_mode=train_rollout_mode, + weight_update_mode=weight_update_mode, weight_update_host=weight_update_host, weight_update_port=weight_update_port, ) @@ -1599,7 +1598,7 @@ def __init__(self, cfg: RLColocateTrainerConfig): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, - train_rollout_mode="colocate", + weight_update_mode="ipc", ) replay_buffer = cfg.replay_buffer_config.build() @@ -1759,7 +1758,7 @@ def _sync_weights_and_save(self, train_step: int, step_timer_dict: dict) -> bool bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, - train_rollout_mode="colocate", + weight_update_mode="ipc", ) ray.get( self.rollout_controller.onload_weights.remote(), @@ -1806,7 +1805,7 @@ def __init__(self, cfg: RLDisaggregatedTrainerConfig): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, - train_rollout_mode="disaggregated", + weight_update_mode="nccl", weight_update_host=self._rollout_config.weight_update_host, weight_update_port=self._rollout_config.weight_update_port, ) @@ -2000,7 +1999,7 @@ async def _sync_weights_and_save(self, model_step: int, step_timer_dict: dict): bind_train_rollout( train_controller=self.train_controller, rollout_controller=self.rollout_controller, - train_rollout_mode="disaggregated", + weight_update_mode="nccl", ) self.update_weights()