diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py index 74aaf311bf..cff972a51e 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -33,7 +33,7 @@ "qwen3": "qwen3", "nemotron_h": "nemotron_h", "nemotron_h_v2": "nemotron_h_v2", - "gpt_oss_20b": "gpt_oss_20b", + "gpt_oss": "gpt_oss", } diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 69b8e5e29d..1240d1c9b6 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -29,12 +29,14 @@ from typing import TYPE_CHECKING, Any, BinaryIO import torch +import torch.distributed as tdist import transformers from safetensors.torch import save_file as safe_save_file from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +import modelopt.torch.utils.distributed as dist_utils from modelopt.torch.utils import json_dumps from ..block_config import maybe_cast_block_configs @@ -51,6 +53,7 @@ "load_model_config", "init_model_from_config", "save_checkpoint", + "save_checkpoint_from_shards", "save_subblocks", "save_model_config", ] @@ -200,6 +203,52 @@ def save_checkpoint( _save_checkpoint(model.config, model.state_dict(), checkpoint_dir, descriptor) +def save_checkpoint_from_shards( + model: PreTrainedModel, checkpoint_dir: Path | str, descriptor: "ModelDescriptor" +) -> None: + """ + Save a checkpoint when the model's weights are sharded across distributed ranks. + + Gathers each rank's partial state dictionary onto rank 0 and writes a complete checkpoint + (including the safetensors index and subblocks) from the merged weights. On a single-process + run, saves directly from the local state dict. Only rank 0 performs the filesystem write; + non-master ranks only participate in the gather. + + Parameters: + model (PreTrainedModel): The model instance whose local state_dict contains this rank's + shard of weights. + checkpoint_dir (Path | str): Destination directory for the checkpoint files. + descriptor (ModelDescriptor): Descriptor used to partition weights into subblocks and build + the safetensors index. + """ + + local_sd = {k: v.cpu() for k, v in model.state_dict().items()} + if dist_utils.size() > 1: + save_err: str | None = None + if dist_utils.is_master(): + gathered: list[dict] = [None] * dist_utils.size() + tdist.gather_object(local_sd, gathered, dst=0) + full_sd: dict[str, torch.Tensor] = {} + for shard_sd in gathered: + if shard_sd is None: + continue + full_sd.update(shard_sd) + try: + _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor) + except Exception as e: + save_err = repr(e) + else: + tdist.gather_object(local_sd, dst=0) + err_box = [save_err] + tdist.broadcast_object_list(err_box, src=0) + # Barrier ensures all ranks wait until file I/O completes before continuing + dist_utils.barrier() + if err_box[0] is not None: + raise RuntimeError(f"Checkpoint save failed on rank 0: {err_box[0]}") + else: + _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) + + def _save_checkpoint( model_config: PretrainedConfig, state_dict: dict[str, torch.Tensor], diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index d8471aee23..a46fba52d0 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -41,7 +41,7 @@ from ..utils.validate_runtime_pipeline import perform_pipeline_stitches from . import validate_model from .checkpoint_utils import copy_tokenizer -from .checkpoint_utils_hf import save_checkpoint +from .checkpoint_utils_hf import save_checkpoint_from_shards from .common import resolve_torch_dtype from .sharded_checkpoint_utils import load_and_shard_model from .validation_utils import ( @@ -189,7 +189,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None: # TODO: Loo into internal Puzzleron code to see how to save as symlinks # save_checkpoint_as_symlinks is currently not supported pass - save_checkpoint(model, checkpoint_dir, descriptor) + save_checkpoint_from_shards(model, checkpoint_dir, descriptor) copy_tokenizer( args.tokenizer_name, diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index a393e1e086..d44cbc71e9 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -25,11 +25,6 @@ from _test_utils.torch.puzzletron.utils import setup_test_model_and_data from packaging.version import Version -# The puzzletron pipeline imports mip unconditionally at module level. In NeMo containers -# the [puzzletron] extras are not pre-installed, so importing the test file fails with a -# deep ModuleNotFoundError. Skip early with an actionable message instead. -pytest.importorskip("mip", reason="pip install -e '.[puzzletron]' to install MIP solver") - import modelopt.torch.puzzletron as mtpz import modelopt.torch.utils.distributed as dist diff --git a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py new file mode 100644 index 0000000000..a31c687cc1 --- /dev/null +++ b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for save_checkpoint_from_shards in checkpoint_utils_hf.""" + +import json +from functools import partial + +import pytest +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.transformers_models import get_tiny_llama +from safetensors.torch import load_file as safe_load_file + +from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import ( + LlamaModelDescriptor, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_SUBBLOCKS_DIR_NAME, + save_checkpoint_from_shards, +) + + +class TestSaveCheckpointFromShardsSingleProcess: + """Tests that run without torch.distributed (world_size=1 path).""" + + def test_creates_config_index_and_subblocks(self, tmp_path): + model = get_tiny_llama() + expected_keys = set(model.state_dict().keys()) + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + # test safetensors index file exists and contains weight map + index_path = tmp_path / SAFE_WEIGHTS_INDEX_NAME + assert index_path.exists(), "safetensors index file was not written" + index = json.loads(index_path.read_text()) + assert "weight_map" in index + assert set(index["weight_map"].keys()) == expected_keys + + # test subblocks directory exists and contains shard files + subblocks_dir = tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME + assert subblocks_dir.is_dir(), "subblocks directory was not created" + assert len(list(subblocks_dir.glob("*.safetensors"))) > 0, ( + "no safetensors shard files were saved" + ) + + # test config.json saved + config_path = tmp_path / "config.json" + assert config_path.exists(), "config.json was not saved" + cfg = json.loads(config_path.read_text()) + assert cfg["num_hidden_layers"] == get_tiny_llama().config.num_hidden_layers + + # test subblock filenames follow descriptor groups + filenames = set(index["weight_map"].values()) + expected_substrings = {"embeddings", "lm_head", "block_0_ffn", "block_0_attention"} + for substr in expected_substrings: + assert any(substr in f for f in filenames), f"no shard filename contains '{substr}'" + + def test_tie_word_embeddings_excluded(self, tmp_path): + model = get_tiny_llama(tie_word_embeddings=True) + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text()) + assert "lm_head.weight" not in index["weight_map"] + + reloaded_sd = {} + for shard in (tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"): + reloaded_sd.update(safe_load_file(str(shard))) + assert "lm_head.weight" not in reloaded_sd + + def test_saved_weights_match_original(self, tmp_path): + model = get_tiny_llama() + original_sd = {k: v.clone().cpu() for k, v in model.state_dict().items()} + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + reloaded_sd = {} + for shard in (tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"): + reloaded_sd.update(safe_load_file(str(shard))) + + assert set(reloaded_sd.keys()) == set(original_sd.keys()) + for key in original_sd: + torch.testing.assert_close(reloaded_sd[key], original_sd[key]) + + +def _distributed_save_worker(rank, world_size, checkpoint_dir): + """Worker that shards a model's state dict across ranks and saves.""" + model = get_tiny_llama() + full_sd = model.state_dict() + keys = sorted(full_sd.keys()) + per_rank = len(keys) // world_size + start = rank * per_rank + end = start + per_rank if rank < world_size - 1 else len(keys) + shard_keys = keys[start:end] + + # Zero out keys not owned by this rank so gather reconstructs the full dict. + for k in keys: + if k not in shard_keys: + full_sd[k] = torch.zeros_like(full_sd[k]) + + model.load_state_dict(full_sd) + save_checkpoint_from_shards(model, checkpoint_dir, LlamaModelDescriptor) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="need >=2 GPUs for multi-rank test") +class TestSaveCheckpointFromShardsMultiProcess: + """Tests that exercise the distributed gather path (world_size > 1).""" + + def test_distributed_save_creates_valid_checkpoint(self, tmp_path): + spawn_multiprocess_job(2, partial(_distributed_save_worker, checkpoint_dir=tmp_path)) + + index_path = tmp_path / SAFE_WEIGHTS_INDEX_NAME + assert index_path.exists() + index = json.loads(index_path.read_text()) + + model = get_tiny_llama() + expected_keys = set(model.state_dict().keys()) + assert set(index["weight_map"].keys()) == expected_keys + + shard_files = list((tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors")) + assert len(shard_files) > 0