Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,31 @@ async def _experimental_fork_checkpoint(

shutil.copytree(source_checkpoint_dir, dest_checkpoint_dir)

# Make the fork effective for already-created local services. The
# checkpoint copy alone updates disk, but Unsloth may already have a
# cached trainer and a running vLLM server pointed at the fresh step-0
# adapter.
service = await self._get_service(cast(TrainableModel, model))
if hasattr(service, "_state") and "_state" in service.__dict__:
del service.__dict__["_state"]
if verbose:
print("Invalidated service _state cache for forked checkpoint")
service._forked_checkpoint_dir = dest_checkpoint_dir # type: ignore[attr-defined]

server_started = bool(getattr(service, "_vllm_process", None)) or bool(
getattr(service, "_server_task", None)
)
register_lora = getattr(service, "register_lora_for_step", None)
if server_started and callable(register_lora):
await register_lora(selected_step, dest_checkpoint_dir)
if verbose:
print(
f"Registered forked checkpoint {model.name}@{selected_step} "
"with running inference service"
)
elif hasattr(service, "_latest_step"):
service._latest_step = selected_step # type: ignore[attr-defined]

if verbose:
print(
f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}"
Expand Down
25 changes: 25 additions & 0 deletions src/art/pipeline_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
eval_every_n_steps: int = 20,
eval_at_start: bool = True,
save_checkpoint: bool = True,
save_checkpoint_artifact: bool = False,
# Resumption
resume: bool = True,
) -> None:
Expand All @@ -113,6 +114,8 @@ def __init__(
raise ValueError("log_interval_seconds must be > 0")
if discard_queue_multiplier <= 0:
raise ValueError("discard_queue_multiplier must be > 0")
if save_checkpoint_artifact and not save_checkpoint:
raise ValueError("save_checkpoint_artifact=True requires save_checkpoint=True")
self.model = model
self.backend = backend
self.rollout_fn = rollout_fn
Expand All @@ -136,6 +139,7 @@ def __init__(
self.eval_every_n_steps = eval_every_n_steps
self.eval_at_start = eval_at_start
self.save_checkpoint = save_checkpoint
self.save_checkpoint_artifact = save_checkpoint_artifact
self.resume = resume
self.discard_queue_multiplier = discard_queue_multiplier
self._discard_queue: list[TrajectoryGroup] = []
Expand Down Expand Up @@ -469,6 +473,16 @@ async def _training_stage(self) -> None:
batch,
**train_kwargs,
)
checkpoint_path = getattr(result, "checkpoint_path", None)
if (
should_checkpoint
and self.save_checkpoint_artifact
and checkpoint_path is not None
):
self._save_checkpoint_artifact(
checkpoint_path=checkpoint_path,
step=result.step,
)
except Exception:
self._status.note_training_end()
raise
Expand Down Expand Up @@ -810,6 +824,17 @@ def _should_eval_step(self, step: int) -> bool:
return False
return (step - self.state.last_eval_step) >= self.eval_every_n_steps

def _save_checkpoint_artifact(self, *, checkpoint_path: str, step: int) -> None:
from art.utils.deployment import WandbDeploymentConfig, deploy_wandb

deploy_wandb(
model=self.model,
checkpoint_path=checkpoint_path,
step=step,
config=WandbDeploymentConfig(provenance=["local-rl"]),
verbose=True,
)

def _read_pipeline_state(self) -> dict[str, Any]:
state = self.model.read_state() or {}
return state.get(PIPELINE_STATE_KEY, {})
Expand Down
13 changes: 13 additions & 0 deletions src/art/unsloth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class UnslothService:
output_dir: str
_is_sleeping: bool = False
_latest_step: int = 0
_forked_checkpoint_dir: str | None = None
_lora_id_counter: int = 1 # Start from 1 since 0 is reserved
# Dedicated mode subprocess state
_vllm_process: subprocess.Popen | None = field(default=None, repr=False) # type: ignore[type-arg]
Expand Down Expand Up @@ -571,6 +572,14 @@ async def register_lora_for_step(self, step: int, checkpoint_dir: str) -> None:
self._latest_step = step
await llm.resume_generation()

async def _load_forked_checkpoint_if_needed(self) -> None:
forked_dir = self._forked_checkpoint_dir
if forked_dir is None:
return

self._forked_checkpoint_dir = None
await self._state.load_lora_adapter(forked_dir)

async def train(
self,
disk_packed_tensors: DiskPackedTensors,
Expand Down Expand Up @@ -598,6 +607,8 @@ async def _train_dedicated(
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
"""Train in dedicated mode — no sleep/wake, vLLM keeps running on separate GPU."""
await self._load_forked_checkpoint_if_needed()

async for result in run_unsloth_rl_training(
self._state,
disk_packed_tensors=disk_packed_tensors,
Expand Down Expand Up @@ -663,6 +674,8 @@ async def _train_shared(
# Reload training model to GPU (after vLLM is asleep)
self._state.reload_to_gpu()

await self._load_forked_checkpoint_if_needed()

async for result in run_unsloth_rl_training(
self._state,
disk_packed_tensors=disk_packed_tensors,
Expand Down
19 changes: 17 additions & 2 deletions src/art/utils/deployment/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ class WandbDeploymentConfig(DeploymentConfig):
"Qwen/Qwen2.5-14B-Instruct",
]

WANDB_BASE_MODEL_ALIASES = {
"unsloth/Meta-Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"unsloth/Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"meta-llama/Meta-Llama-3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
}


def get_wandb_base_model(base_model: str) -> str:
"""Return the W&B inference base model id for compatible aliases."""
return WANDB_BASE_MODEL_ALIASES.get(base_model, base_model)


def deploy_wandb(
model: "TrainableModel",
Expand All @@ -54,7 +66,8 @@ def deploy_wandb(
"""
import wandb

if model.base_model not in WANDB_SUPPORTED_BASE_MODELS:
wandb_base_model = get_wandb_base_model(model.base_model)
if wandb_base_model not in WANDB_SUPPORTED_BASE_MODELS:
raise UnsupportedBaseModelDeploymentError(
message=f"Base model {model.base_model} is not supported for serverless LoRA deployment by W&B. Supported models: {WANDB_SUPPORTED_BASE_MODELS}"
)
Expand All @@ -77,7 +90,9 @@ def deploy_wandb(
settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]),
)
try:
metadata: dict[str, object] = {"wandb.base_model": model.base_model}
metadata: dict[str, object] = {"wandb.base_model": wandb_base_model}
if wandb_base_model != model.base_model:
metadata["source_base_model"] = model.base_model
if config is not None:
metadata["wandb.provenance"] = config.provenance
artifact = wandb.Artifact(
Expand Down
75 changes: 75 additions & 0 deletions tests/unit/test_pipeline_trainer_local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from art.megatron.train import load_adapter_into_model
from art.pipeline_trainer.trainer import PipelineTrainer
from art.preprocessing.tokenize import TokenizedResult
from art.utils.deployment.wandb import get_wandb_base_model
from art.utils.output_dirs import get_model_dir


Expand Down Expand Up @@ -159,6 +160,80 @@ async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend(
}


@pytest.mark.asyncio
async def test_pipeline_trainer_saves_checkpoint_artifact_on_eval_step(
tmp_path: Path,
) -> None:
model = TrainableModel(
name="pipeline-save-checkpoint-artifact",
project="pipeline-tests",
base_model="test-model",
base_path=str(tmp_path),
)
checkpoint_path = str(tmp_path / "checkpoint-1")
backend = MagicMock()
backend.train = AsyncMock(
return_value=SimpleNamespace(
step=1,
metrics={},
checkpoint_path=checkpoint_path,
)
)

trainer = _make_trainer(
model=model,
backend=backend,
eval_fn=AsyncMock(return_value=[]),
eval_every_n_steps=1,
save_checkpoint_artifact=True,
)
trainer._save_checkpoint_artifact = MagicMock() # type: ignore[method-assign]
trainer._output_queue = asyncio.Queue()
await trainer._output_queue.put(_make_group([0.0, 1.0]))
await trainer._output_queue.put(None)

await trainer._training_stage()

assert backend.train.await_args.kwargs["save_checkpoint"] is True
trainer._save_checkpoint_artifact.assert_called_once_with( # type: ignore[attr-defined]
checkpoint_path=checkpoint_path,
step=1,
)


def test_pipeline_trainer_checkpoint_artifact_requires_checkpoint(
tmp_path: Path,
) -> None:
model = TrainableModel(
name="pipeline-save-checkpoint-artifact-validation",
project="pipeline-tests",
base_model="test-model",
base_path=str(tmp_path),
)
backend = MagicMock()

with pytest.raises(
ValueError, match="save_checkpoint_artifact=True requires save_checkpoint=True"
):
_make_trainer(
model=model,
backend=backend,
save_checkpoint=False,
save_checkpoint_artifact=True,
)


def test_wandb_base_model_aliases_for_unsloth_llama() -> None:
assert (
get_wandb_base_model("unsloth/Meta-Llama-3.1-8B-Instruct")
== "meta-llama/Llama-3.1-8B-Instruct"
)
assert (
get_wandb_base_model("unsloth/Meta-Llama-3.1-70B-Instruct")
== "meta-llama/Llama-3.1-70B-Instruct"
)


@pytest.mark.asyncio
async def test_local_backend_train_translates_loss_fn(tmp_path: Path) -> None:
model = TrainableModel(
Expand Down
Loading