diff --git a/src/art/local/backend.py b/src/art/local/backend.py index daa490204..4e5bb2546 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -1467,6 +1467,31 @@ async def _experimental_fork_checkpoint( shutil.copytree(source_checkpoint_dir, dest_checkpoint_dir) + # Make the fork effective for already-created local services. The + # checkpoint copy alone updates disk, but Unsloth may already have a + # cached trainer and a running vLLM server pointed at the fresh step-0 + # adapter. + service = await self._get_service(cast(TrainableModel, model)) + if hasattr(service, "_state") and "_state" in service.__dict__: + del service.__dict__["_state"] + if verbose: + print("Invalidated service _state cache for forked checkpoint") + service._forked_checkpoint_dir = dest_checkpoint_dir # type: ignore[attr-defined] + + server_started = bool(getattr(service, "_vllm_process", None)) or bool( + getattr(service, "_server_task", None) + ) + register_lora = getattr(service, "register_lora_for_step", None) + if server_started and callable(register_lora): + await register_lora(selected_step, dest_checkpoint_dir) + if verbose: + print( + f"Registered forked checkpoint {model.name}@{selected_step} " + "with running inference service" + ) + elif hasattr(service, "_latest_step"): + service._latest_step = selected_step # type: ignore[attr-defined] + if verbose: print( f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}" diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py index 99ec77d76..584a40915 100644 --- a/src/art/pipeline_trainer/trainer.py +++ b/src/art/pipeline_trainer/trainer.py @@ -90,6 +90,7 @@ def __init__( eval_every_n_steps: int = 20, eval_at_start: bool = True, save_checkpoint: bool = True, + save_checkpoint_artifact: bool = False, # Resumption resume: bool = True, ) -> None: @@ -113,6 +114,8 @@ def __init__( raise ValueError("log_interval_seconds must be > 0") if discard_queue_multiplier <= 0: raise ValueError("discard_queue_multiplier must be > 0") + if save_checkpoint_artifact and not save_checkpoint: + raise ValueError("save_checkpoint_artifact=True requires save_checkpoint=True") self.model = model self.backend = backend self.rollout_fn = rollout_fn @@ -136,6 +139,7 @@ def __init__( self.eval_every_n_steps = eval_every_n_steps self.eval_at_start = eval_at_start self.save_checkpoint = save_checkpoint + self.save_checkpoint_artifact = save_checkpoint_artifact self.resume = resume self.discard_queue_multiplier = discard_queue_multiplier self._discard_queue: list[TrajectoryGroup] = [] @@ -469,6 +473,16 @@ async def _training_stage(self) -> None: batch, **train_kwargs, ) + checkpoint_path = getattr(result, "checkpoint_path", None) + if ( + should_checkpoint + and self.save_checkpoint_artifact + and checkpoint_path is not None + ): + self._save_checkpoint_artifact( + checkpoint_path=checkpoint_path, + step=result.step, + ) except Exception: self._status.note_training_end() raise @@ -810,6 +824,17 @@ def _should_eval_step(self, step: int) -> bool: return False return (step - self.state.last_eval_step) >= self.eval_every_n_steps + def _save_checkpoint_artifact(self, *, checkpoint_path: str, step: int) -> None: + from art.utils.deployment import WandbDeploymentConfig, deploy_wandb + + deploy_wandb( + model=self.model, + checkpoint_path=checkpoint_path, + step=step, + config=WandbDeploymentConfig(provenance=["local-rl"]), + verbose=True, + ) + def _read_pipeline_state(self) -> dict[str, Any]: state = self.model.read_state() or {} return state.get(PIPELINE_STATE_KEY, {}) diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py index 2a5a60abf..22c0f6c94 100644 --- a/src/art/unsloth/service.py +++ b/src/art/unsloth/service.py @@ -108,6 +108,7 @@ class UnslothService: output_dir: str _is_sleeping: bool = False _latest_step: int = 0 + _forked_checkpoint_dir: str | None = None _lora_id_counter: int = 1 # Start from 1 since 0 is reserved # Dedicated mode subprocess state _vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg] @@ -571,6 +572,14 @@ async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None: self._latest_step = step await llm.resume_generation() + async def _load_forked_checkpoint_if_needed(self) -> None: + forked_dir = self._forked_checkpoint_dir + if forked_dir is None: + return + + self._forked_checkpoint_dir = None + await self._state.load_lora_adapter(forked_dir) + async def train( self, disk_packed_tensors: DiskPackedTensors, @@ -598,6 +607,8 @@ async def _train_dedicated( verbose: bool = False, ) -> AsyncIterator[dict[str, float]]: """Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU.""" + await self._load_forked_checkpoint_if_needed() + async for result in run_unsloth_rl_training( self._state, disk_packed_tensors=disk_packed_tensors, @@ -663,6 +674,8 @@ async def _train_shared( # Reload training model to GPU (after vLLM is asleep) self._state.reload_to_gpu() + await self._load_forked_checkpoint_if_needed() + async for result in run_unsloth_rl_training( self._state, disk_packed_tensors=disk_packed_tensors, diff --git a/src/art/utils/deployment/wandb.py b/src/art/utils/deployment/wandb.py index 9ddf778e8..49202a41a 100644 --- a/src/art/utils/deployment/wandb.py +++ b/src/art/utils/deployment/wandb.py @@ -32,6 +32,18 @@ class WandbDeploymentConfig(DeploymentConfig): "Qwen/Qwen2.5-14B-Instruct", ] +WANDB_BASE_MODEL_ALIASES = { + "unsloth/Meta-Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct", + "unsloth/Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", + "meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct", +} + + +def get_wandb_base_model(base_model: str) -> str: + """Return the W&B inference base model id for compatible aliases.""" + return WANDB_BASE_MODEL_ALIASES.get(base_model, base_model) + def deploy_wandb( model: "TrainableModel", @@ -54,7 +66,8 @@ def deploy_wandb( """ import wandb - if model.base_model not in WANDB_SUPPORTED_BASE_MODELS: + wandb_base_model = get_wandb_base_model(model.base_model) + if wandb_base_model not in WANDB_SUPPORTED_BASE_MODELS: raise UnsupportedBaseModelDeploymentError( message=f"Base model {model.base_model} is not supported for serverless LoRA deployment by W&B. Supported models: {WANDB_SUPPORTED_BASE_MODELS}" ) @@ -77,7 +90,9 @@ def deploy_wandb( settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]), ) try: - metadata: dict[str, object] = {"wandb.base_model": model.base_model} + metadata: dict[str, object] = {"wandb.base_model": wandb_base_model} + if wandb_base_model != model.base_model: + metadata["source_base_model"] = model.base_model if config is not None: metadata["wandb.provenance"] = config.provenance artifact = wandb.Artifact( diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py index 90e2c59d7..f3696d84e 100644 --- a/tests/unit/test_pipeline_trainer_local_backend.py +++ b/tests/unit/test_pipeline_trainer_local_backend.py @@ -15,6 +15,7 @@ from art.megatron.train import load_adapter_into_model from art.pipeline_trainer.trainer import PipelineTrainer from art.preprocessing.tokenize import TokenizedResult +from art.utils.deployment.wandb import get_wandb_base_model from art.utils.output_dirs import get_model_dir @@ -159,6 +160,80 @@ async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend( } +@pytest.mark.asyncio +async def test_pipeline_trainer_saves_checkpoint_artifact_on_eval_step( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-save-checkpoint-artifact", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + checkpoint_path = str(tmp_path / "checkpoint-1") + backend = MagicMock() + backend.train = AsyncMock( + return_value=SimpleNamespace( + step=1, + metrics={}, + checkpoint_path=checkpoint_path, + ) + ) + + trainer = _make_trainer( + model=model, + backend=backend, + eval_fn=AsyncMock(return_value=[]), + eval_every_n_steps=1, + save_checkpoint_artifact=True, + ) + trainer._save_checkpoint_artifact = MagicMock() # type: ignore[method-assign] + trainer._output_queue = asyncio.Queue() + await trainer._output_queue.put(_make_group([0.0, 1.0])) + await trainer._output_queue.put(None) + + await trainer._training_stage() + + assert backend.train.await_args.kwargs["save_checkpoint"] is True + trainer._save_checkpoint_artifact.assert_called_once_with( # type: ignore[attr-defined] + checkpoint_path=checkpoint_path, + step=1, + ) + + +def test_pipeline_trainer_checkpoint_artifact_requires_checkpoint( + tmp_path: Path, +) -> None: + model = TrainableModel( + name="pipeline-save-checkpoint-artifact-validation", + project="pipeline-tests", + base_model="test-model", + base_path=str(tmp_path), + ) + backend = MagicMock() + + with pytest.raises( + ValueError, match="save_checkpoint_artifact=True requires save_checkpoint=True" + ): + _make_trainer( + model=model, + backend=backend, + save_checkpoint=False, + save_checkpoint_artifact=True, + ) + + +def test_wandb_base_model_aliases_for_unsloth_llama() -> None: + assert ( + get_wandb_base_model("unsloth/Meta-Llama-3.1-8B-Instruct") + == "meta-llama/Llama-3.1-8B-Instruct" + ) + assert ( + get_wandb_base_model("unsloth/Meta-Llama-3.1-70B-Instruct") + == "meta-llama/Llama-3.1-70B-Instruct" + ) + + @pytest.mark.asyncio async def test_local_backend_train_translates_loss_fn(tmp_path: Path) -> None: model = TrainableModel(