From f74ab048f836efda290664e7a479ded1a543f254 Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Thu, 23 Apr 2026 02:22:32 -0700 Subject: [PATCH 01/14] fix incomplete mapping of safetensors in generated puzzletron checkpoint Signed-off-by: Grzegorz Karch --- .../puzzletron/tools/checkpoint_utils_hf.py | 30 +++++++++++++++++++ ...validate_puzzle_with_multi_replacements.py | 4 +-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 69b8e5e29d..84650ebd2e 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -51,6 +51,7 @@ "load_model_config", "init_model_from_config", "save_checkpoint", + "save_checkpoint_from_shards", "save_subblocks", "save_model_config", ] @@ -200,6 +201,35 @@ def save_checkpoint( _save_checkpoint(model.config, model.state_dict(), checkpoint_dir, descriptor) +def save_checkpoint_from_shards( + model: PreTrainedModel, checkpoint_dir: Path | str, descriptor: "ModelDescriptor" +) -> None: + """Save a checkpoint whose weights are split across distributed ranks. + + Each rank holds only a subset of the model's layers (via ``load_and_shard_model``). + This function gathers every rank's partial state dict onto rank 0 so that + ``model.safetensors.index.json`` is built from the *complete* weight map. + Falls back to :func:`save_checkpoint` when running on a single process. + """ + import modelopt.torch.utils.distributed as dist_utils + + local_sd = {k: v.cpu() for k, v in model.state_dict().items()} + if dist_utils.size() > 1: + import torch.distributed as tdist + + if dist_utils.is_master(): + gathered: list[dict] = [{}] * dist_utils.size() + tdist.gather_object(local_sd, gathered, dst=0) + full_sd: dict[str, torch.Tensor] = {} + for shard_sd in gathered: + full_sd.update(shard_sd) + _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor) + else: + tdist.gather_object(local_sd, dst=0) + else: + _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) + + def _save_checkpoint( model_config: PretrainedConfig, state_dict: dict[str, torch.Tensor], diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index d8471aee23..feda1f8aeb 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -41,7 +41,7 @@ from ..utils.validate_runtime_pipeline import perform_pipeline_stitches from . import validate_model from .checkpoint_utils import copy_tokenizer -from .checkpoint_utils_hf import save_checkpoint +from .checkpoint_utils_hf import save_checkpoint, save_checkpoint_from_shards from .common import resolve_torch_dtype from .sharded_checkpoint_utils import load_and_shard_model from .validation_utils import ( @@ -189,7 +189,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None: # TODO: Loo into internal Puzzleron code to see how to save as symlinks # save_checkpoint_as_symlinks is currently not supported pass - save_checkpoint(model, checkpoint_dir, descriptor) + save_checkpoint_from_shards(model, checkpoint_dir, descriptor) copy_tokenizer( args.tokenizer_name, From 1e8e25afaa01387ca274506aae6de4579b56b01d Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Thu, 23 Apr 2026 05:00:14 -0700 Subject: [PATCH 02/14] moved imports to top of file Signed-off-by: Grzegorz Karch --- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 84650ebd2e..adb639c2b9 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -29,12 +29,14 @@ from typing import TYPE_CHECKING, Any, BinaryIO import torch +import torch.distributed as tdist import transformers from safetensors.torch import save_file as safe_save_file from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel from transformers.dynamic_module_utils import get_class_from_dynamic_module from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +import modelopt.torch.utils.distributed as dist_utils from modelopt.torch.utils import json_dumps from ..block_config import maybe_cast_block_configs @@ -211,12 +213,9 @@ def save_checkpoint_from_shards( ``model.safetensors.index.json`` is built from the *complete* weight map. Falls back to :func:`save_checkpoint` when running on a single process. """ - import modelopt.torch.utils.distributed as dist_utils local_sd = {k: v.cpu() for k, v in model.state_dict().items()} if dist_utils.size() > 1: - import torch.distributed as tdist - if dist_utils.is_master(): gathered: list[dict] = [{}] * dist_utils.size() tdist.gather_object(local_sd, gathered, dst=0) From 84d68a36d811c7930e41f9e705bce7e8f5a25a81 Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 22:41:57 +0200 Subject: [PATCH 03/14] =?UTF-8?q?=F0=9F=93=9D=20Add=20docstrings=20to=20`g?= =?UTF-8?q?karch/fix-incomplete-tensor-mapping`=20(#1331)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Docstrings generation was requested by @grzegorz-k-karch. * https://github.com/NVIDIA/Model-Optimizer/pull/1330#issuecomment-4303244743 The following files were modified: * `modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py` * `modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py`
ℹ️ Note
CodeRabbit cannot perform edits on its own pull requests yet.
--------- Signed-off-by: Grzegorz Karch Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Grzegorz Karch --- .../puzzletron/tools/checkpoint_utils_hf.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index adb639c2b9..dfa6eb8233 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -206,12 +206,20 @@ def save_checkpoint( def save_checkpoint_from_shards( model: PreTrainedModel, checkpoint_dir: Path | str, descriptor: "ModelDescriptor" ) -> None: - """Save a checkpoint whose weights are split across distributed ranks. - - Each rank holds only a subset of the model's layers (via ``load_and_shard_model``). - This function gathers every rank's partial state dict onto rank 0 so that - ``model.safetensors.index.json`` is built from the *complete* weight map. - Falls back to :func:`save_checkpoint` when running on a single process. + """ + Save a checkpoint when the model's weights are sharded across distributed ranks. + + Gathers each rank's partial state dictionary onto rank 0 and writes a complete checkpoint + (including the safetensors index and subblocks) from the merged weights. On a single-process + run, saves directly from the local state dict. Only rank 0 performs the filesystem write; + non-master ranks only participate in the gather. + + Parameters: + model (PreTrainedModel): The model instance whose local state_dict contains this rank's + shard of weights. + checkpoint_dir (Path | str): Destination directory for the checkpoint files. + descriptor (ModelDescriptor): Descriptor used to partition weights into subblocks and build + the safetensors index. """ local_sd = {k: v.cpu() for k, v in model.state_dict().items()} From ff2afe18b74d11fa46167836767a086411940e91 Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Fri, 24 Apr 2026 01:13:46 +0200 Subject: [PATCH 04/14] fix: CodeRabbit auto-fixes for PR #1330 (#1339) This stacked PR contains CodeRabbit auto-fixes for #1330. **Files modified:** - `modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py` Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: CodeRabbit --- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index dfa6eb8233..bc9c61f06a 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -225,14 +225,18 @@ def save_checkpoint_from_shards( local_sd = {k: v.cpu() for k, v in model.state_dict().items()} if dist_utils.size() > 1: if dist_utils.is_master(): - gathered: list[dict] = [{}] * dist_utils.size() + gathered: list[dict] = [None] * dist_utils.size() tdist.gather_object(local_sd, gathered, dst=0) full_sd: dict[str, torch.Tensor] = {} for shard_sd in gathered: + if shard_sd is None: + continue full_sd.update(shard_sd) _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor) else: tdist.gather_object(local_sd, dst=0) + # Barrier ensures all ranks wait until file I/O completes before continuing + dist_utils.barrier() else: _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) @@ -484,4 +488,4 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str dataclasses.asdict(conf) if dataclasses.is_dataclass(conf) else conf for conf in model_config.block_configs ] - model_config.save_pretrained(checkpoint_dir) + model_config.save_pretrained(checkpoint_dir) \ No newline at end of file From f10fc53a7c410dcec090e984658b3dacac7f3925 Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Fri, 24 Apr 2026 01:20:31 +0200 Subject: [PATCH 05/14] new line at eof Signed-off-by: Grzegorz Karch --- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index bc9c61f06a..204d7a7737 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -488,4 +488,4 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str dataclasses.asdict(conf) if dataclasses.is_dataclass(conf) else conf for conf in model_config.block_configs ] - model_config.save_pretrained(checkpoint_dir) \ No newline at end of file + model_config.save_pretrained(checkpoint_dir) From 2136ab7bbbf7bd1f9d7d7b175bfcbeb36733562c Mon Sep 17 00:00:00 2001 From: "Grzegorz K. Karch" Date: Fri, 24 Apr 2026 09:41:56 +0200 Subject: [PATCH 06/14] Update modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Grzegorz K. Karch --- .../torch/puzzletron/tools/checkpoint_utils_hf.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 204d7a7737..bc19cdf307 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -224,6 +224,8 @@ def save_checkpoint_from_shards( local_sd = {k: v.cpu() for k, v in model.state_dict().items()} if dist_utils.size() > 1: + if dist_utils.size() > 1: + save_err: str | None = None if dist_utils.is_master(): gathered: list[dict] = [None] * dist_utils.size() tdist.gather_object(local_sd, gathered, dst=0) @@ -232,11 +234,18 @@ def save_checkpoint_from_shards( if shard_sd is None: continue full_sd.update(shard_sd) - _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor) + try: + _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor) + except Exception as e: + save_err = repr(e) else: tdist.gather_object(local_sd, dst=0) + err_box = [save_err] + tdist.broadcast_object_list(err_box, src=0) # Barrier ensures all ranks wait until file I/O completes before continuing dist_utils.barrier() + if err_box[0] is not None: + raise RuntimeError(f"Checkpoint save failed on rank 0: {err_box[0]}") else: _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) From 77b094dc92b03d914a23fa0c4777b00a81d30c2e Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Mon, 27 Apr 2026 02:20:31 -0700 Subject: [PATCH 07/14] fixed double condition line Signed-off-by: Grzegorz Karch --- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index bc19cdf307..1240d1c9b6 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -223,7 +223,6 @@ def save_checkpoint_from_shards( """ local_sd = {k: v.cpu() for k, v in model.state_dict().items()} - if dist_utils.size() > 1: if dist_utils.size() > 1: save_err: str | None = None if dist_utils.is_master(): From 80a41ca2e5d3927b9f7e918d67ebc2ddfaa84037 Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Mon, 27 Apr 2026 03:08:04 -0700 Subject: [PATCH 08/14] added test for save_checkpoint_from_shads Signed-off-by: Grzegorz Karch --- ...validate_puzzle_with_multi_replacements.py | 2 +- .../tools/test_save_ckpt_from_shards.py | 167 ++++++++++++++++++ 2 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py diff --git a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py index feda1f8aeb..a46fba52d0 100644 --- a/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/puzzletron/tools/validate_puzzle_with_multi_replacements.py @@ -41,7 +41,7 @@ from ..utils.validate_runtime_pipeline import perform_pipeline_stitches from . import validate_model from .checkpoint_utils import copy_tokenizer -from .checkpoint_utils_hf import save_checkpoint, save_checkpoint_from_shards +from .checkpoint_utils_hf import save_checkpoint_from_shards from .common import resolve_torch_dtype from .sharded_checkpoint_utils import load_and_shard_model from .validation_utils import ( diff --git a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py new file mode 100644 index 0000000000..7bff957ca4 --- /dev/null +++ b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for save_checkpoint_from_shards in checkpoint_utils_hf.""" + +import json +from functools import partial + +import pytest +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from safetensors.torch import load_file as safe_load_file +from transformers import AutoModelForCausalLM, LlamaConfig + +from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import ( + LlamaModelDescriptor, +) +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFETENSORS_SUBBLOCKS_DIR_NAME, + save_checkpoint_from_shards, +) + +TINY_LLAMA_CONFIG = dict( + hidden_size=32, + intermediate_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + max_position_embeddings=32, + vocab_size=32, + tie_word_embeddings=False, +) + + +def _make_tiny_llama(**overrides) -> AutoModelForCausalLM: + cfg = {**TINY_LLAMA_CONFIG, **overrides} + return AutoModelForCausalLM.from_config(LlamaConfig(**cfg)) + + +class TestSaveCheckpointFromShardsSingleProcess: + """Tests that run without torch.distributed (world_size=1 path).""" + + def test_creates_index_and_subblocks(self, tmp_path): + model = _make_tiny_llama() + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + index_path = tmp_path / SAFE_WEIGHTS_INDEX_NAME + assert index_path.exists(), "safetensors index file was not written" + index = json.loads(index_path.read_text()) + assert "weight_map" in index + + subblocks_dir = tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME + assert subblocks_dir.is_dir(), "subblocks directory was not created" + + shard_files = list(subblocks_dir.glob("*.safetensors")) + assert len(shard_files) > 0, "no safetensors shard files were saved" + + def test_weight_map_covers_all_state_dict_keys(self, tmp_path): + model = _make_tiny_llama() + expected_keys = set(model.state_dict().keys()) + + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text()) + mapped_keys = set(index["weight_map"].keys()) + assert mapped_keys == expected_keys + + def test_saved_weights_match_original(self, tmp_path): + model = _make_tiny_llama() + original_sd = {k: v.clone().cpu() for k, v in model.state_dict().items()} + + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + reloaded_sd = {} + for shard in (tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"): + reloaded_sd.update(safe_load_file(str(shard))) + + assert set(reloaded_sd.keys()) == set(original_sd.keys()) + for key in original_sd: + torch.testing.assert_close(reloaded_sd[key], original_sd[key]) + + def test_config_json_saved(self, tmp_path): + model = _make_tiny_llama() + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + config_path = tmp_path / "config.json" + assert config_path.exists(), "config.json was not saved" + cfg = json.loads(config_path.read_text()) + assert cfg["num_hidden_layers"] == TINY_LLAMA_CONFIG["num_hidden_layers"] + + def test_tie_word_embeddings_excluded(self, tmp_path): + model = _make_tiny_llama(tie_word_embeddings=True) + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text()) + assert "lm_head.weight" not in index["weight_map"] + + reloaded_sd = {} + for shard in (tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"): + reloaded_sd.update(safe_load_file(str(shard))) + assert "lm_head.weight" not in reloaded_sd + + def test_subblock_filenames_follow_descriptor_groups(self, tmp_path): + model = _make_tiny_llama() + save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + + index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text()) + filenames = set(index["weight_map"].values()) + + expected_substrings = {"embeddings", "lm_head", "block_0_ffn", "block_0_attention"} + for substr in expected_substrings: + assert any( + substr in f for f in filenames + ), f"no shard filename contains '{substr}'" + + +def _distributed_save_worker(rank, world_size, checkpoint_dir): + """Worker that shards a model's state dict across ranks and saves.""" + model = _make_tiny_llama() + full_sd = model.state_dict() + keys = sorted(full_sd.keys()) + per_rank = len(keys) // world_size + start = rank * per_rank + end = start + per_rank if rank < world_size - 1 else len(keys) + shard_keys = keys[start:end] + + # Zero out keys not owned by this rank so gather reconstructs the full dict. + for k in keys: + if k not in shard_keys: + full_sd[k] = torch.zeros_like(full_sd[k]) + + model.load_state_dict(full_sd) + save_checkpoint_from_shards(model, checkpoint_dir, LlamaModelDescriptor) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="need >=2 GPUs for multi-rank test" +) +class TestSaveCheckpointFromShardsMultiProcess: + """Tests that exercise the distributed gather path (world_size > 1).""" + + def test_distributed_save_creates_valid_checkpoint(self, tmp_path): + spawn_multiprocess_job(2, partial(_distributed_save_worker, checkpoint_dir=tmp_path)) + + index_path = tmp_path / SAFE_WEIGHTS_INDEX_NAME + assert index_path.exists() + index = json.loads(index_path.read_text()) + + model = _make_tiny_llama() + expected_keys = set(model.state_dict().keys()) + assert set(index["weight_map"].keys()) == expected_keys + + shard_files = list((tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors")) + assert len(shard_files) > 0 From 830d2588965333cb21281c81597fd464d0a0c31a Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Mon, 27 Apr 2026 03:15:42 -0700 Subject: [PATCH 09/14] formatting Signed-off-by: Grzegorz Karch --- .../torch/puzzletron/tools/test_save_ckpt_from_shards.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py index 7bff957ca4..7306692ace 100644 --- a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py +++ b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py @@ -122,9 +122,7 @@ def test_subblock_filenames_follow_descriptor_groups(self, tmp_path): expected_substrings = {"embeddings", "lm_head", "block_0_ffn", "block_0_attention"} for substr in expected_substrings: - assert any( - substr in f for f in filenames - ), f"no shard filename contains '{substr}'" + assert any(substr in f for f in filenames), f"no shard filename contains '{substr}'" def _distributed_save_worker(rank, world_size, checkpoint_dir): @@ -146,9 +144,7 @@ def _distributed_save_worker(rank, world_size, checkpoint_dir): save_checkpoint_from_shards(model, checkpoint_dir, LlamaModelDescriptor) -@pytest.mark.skipif( - torch.cuda.device_count() < 2, reason="need >=2 GPUs for multi-rank test" -) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="need >=2 GPUs for multi-rank test") class TestSaveCheckpointFromShardsMultiProcess: """Tests that exercise the distributed gather path (world_size > 1).""" From 604ea3b089a4b99ccb7b4fd2e2615d94ce8aef5a Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Mon, 27 Apr 2026 03:22:40 -0700 Subject: [PATCH 10/14] ruff feedback Signed-off-by: Grzegorz Karch --- .../tools/test_save_ckpt_from_shards.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py index 7306692ace..60ca086265 100644 --- a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py +++ b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py @@ -33,16 +33,16 @@ save_checkpoint_from_shards, ) -TINY_LLAMA_CONFIG = dict( - hidden_size=32, - intermediate_size=64, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, - max_position_embeddings=32, - vocab_size=32, - tie_word_embeddings=False, -) +TINY_LLAMA_CONFIG = { + "hidden_size": 32, + "intermediate_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "max_position_embeddings": 32, + "vocab_size": 32, + "tie_word_embeddings": False, +} def _make_tiny_llama(**overrides) -> AutoModelForCausalLM: From 53a8f752cec7f4af9185dfbf683eeae57a60674f Mon Sep 17 00:00:00 2001 From: "Grzegorz K. Karch" Date: Mon, 27 Apr 2026 14:13:35 +0200 Subject: [PATCH 11/14] Update tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: Grzegorz K. Karch --- .../tools/test_save_ckpt_from_shards.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py index 60ca086265..ebb2c4a781 100644 --- a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py +++ b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py @@ -33,21 +33,9 @@ save_checkpoint_from_shards, ) -TINY_LLAMA_CONFIG = { - "hidden_size": 32, - "intermediate_size": 64, - "num_hidden_layers": 2, - "num_attention_heads": 4, - "num_key_value_heads": 2, - "max_position_embeddings": 32, - "vocab_size": 32, - "tie_word_embeddings": False, -} - - -def _make_tiny_llama(**overrides) -> AutoModelForCausalLM: - cfg = {**TINY_LLAMA_CONFIG, **overrides} - return AutoModelForCausalLM.from_config(LlamaConfig(**cfg)) +from _test_utils.torch.transformers_models import get_tiny_llama +... +model = get_tiny_llama(num_hidden_layers=2) class TestSaveCheckpointFromShardsSingleProcess: From 4d016be9033d775937cb4bf0e97bdc32d29ed46c Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Mon, 27 Apr 2026 05:16:07 -0700 Subject: [PATCH 12/14] response to feedback Signed-off-by: Grzegorz Karch --- .../tools/test_save_ckpt_from_shards.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py index ebb2c4a781..f9c57632e5 100644 --- a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py +++ b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py @@ -21,8 +21,8 @@ import pytest import torch from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.transformers_models import get_tiny_llama from safetensors.torch import load_file as safe_load_file -from transformers import AutoModelForCausalLM, LlamaConfig from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import ( LlamaModelDescriptor, @@ -33,16 +33,12 @@ save_checkpoint_from_shards, ) -from _test_utils.torch.transformers_models import get_tiny_llama -... -model = get_tiny_llama(num_hidden_layers=2) - class TestSaveCheckpointFromShardsSingleProcess: """Tests that run without torch.distributed (world_size=1 path).""" def test_creates_index_and_subblocks(self, tmp_path): - model = _make_tiny_llama() + model = get_tiny_llama() save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) index_path = tmp_path / SAFE_WEIGHTS_INDEX_NAME @@ -57,7 +53,7 @@ def test_creates_index_and_subblocks(self, tmp_path): assert len(shard_files) > 0, "no safetensors shard files were saved" def test_weight_map_covers_all_state_dict_keys(self, tmp_path): - model = _make_tiny_llama() + model = get_tiny_llama() expected_keys = set(model.state_dict().keys()) save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) @@ -67,7 +63,7 @@ def test_weight_map_covers_all_state_dict_keys(self, tmp_path): assert mapped_keys == expected_keys def test_saved_weights_match_original(self, tmp_path): - model = _make_tiny_llama() + model = get_tiny_llama() original_sd = {k: v.clone().cpu() for k, v in model.state_dict().items()} save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) @@ -81,16 +77,16 @@ def test_saved_weights_match_original(self, tmp_path): torch.testing.assert_close(reloaded_sd[key], original_sd[key]) def test_config_json_saved(self, tmp_path): - model = _make_tiny_llama() + model = get_tiny_llama() save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) config_path = tmp_path / "config.json" assert config_path.exists(), "config.json was not saved" cfg = json.loads(config_path.read_text()) - assert cfg["num_hidden_layers"] == TINY_LLAMA_CONFIG["num_hidden_layers"] + assert cfg["num_hidden_layers"] == get_tiny_llama().config.num_hidden_layers def test_tie_word_embeddings_excluded(self, tmp_path): - model = _make_tiny_llama(tie_word_embeddings=True) + model = get_tiny_llama(tie_word_embeddings=True) save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text()) @@ -102,7 +98,7 @@ def test_tie_word_embeddings_excluded(self, tmp_path): assert "lm_head.weight" not in reloaded_sd def test_subblock_filenames_follow_descriptor_groups(self, tmp_path): - model = _make_tiny_llama() + model = get_tiny_llama() save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text()) @@ -115,7 +111,7 @@ def test_subblock_filenames_follow_descriptor_groups(self, tmp_path): def _distributed_save_worker(rank, world_size, checkpoint_dir): """Worker that shards a model's state dict across ranks and saves.""" - model = _make_tiny_llama() + model = get_tiny_llama() full_sd = model.state_dict() keys = sorted(full_sd.keys()) per_rank = len(keys) // world_size @@ -143,7 +139,7 @@ def test_distributed_save_creates_valid_checkpoint(self, tmp_path): assert index_path.exists() index = json.loads(index_path.read_text()) - model = _make_tiny_llama() + model = get_tiny_llama() expected_keys = set(model.state_dict().keys()) assert set(index["weight_map"].keys()) == expected_keys From 74d8ebc2e8b9d65b9838e71e17e325801fdbe2f0 Mon Sep 17 00:00:00 2001 From: Grzegorz Karch Date: Mon, 27 Apr 2026 05:48:02 -0700 Subject: [PATCH 13/14] reduced number of tests for saving shards Signed-off-by: Grzegorz Karch --- .../tools/test_save_ckpt_from_shards.py | 61 +++++++------------ 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py index f9c57632e5..a31c687cc1 100644 --- a/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py +++ b/tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py @@ -37,54 +37,37 @@ class TestSaveCheckpointFromShardsSingleProcess: """Tests that run without torch.distributed (world_size=1 path).""" - def test_creates_index_and_subblocks(self, tmp_path): + def test_creates_config_index_and_subblocks(self, tmp_path): model = get_tiny_llama() + expected_keys = set(model.state_dict().keys()) save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) + # test safetensors index file exists and contains weight map index_path = tmp_path / SAFE_WEIGHTS_INDEX_NAME assert index_path.exists(), "safetensors index file was not written" index = json.loads(index_path.read_text()) assert "weight_map" in index + assert set(index["weight_map"].keys()) == expected_keys + # test subblocks directory exists and contains shard files subblocks_dir = tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME assert subblocks_dir.is_dir(), "subblocks directory was not created" + assert len(list(subblocks_dir.glob("*.safetensors"))) > 0, ( + "no safetensors shard files were saved" + ) - shard_files = list(subblocks_dir.glob("*.safetensors")) - assert len(shard_files) > 0, "no safetensors shard files were saved" - - def test_weight_map_covers_all_state_dict_keys(self, tmp_path): - model = get_tiny_llama() - expected_keys = set(model.state_dict().keys()) - - save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) - - index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text()) - mapped_keys = set(index["weight_map"].keys()) - assert mapped_keys == expected_keys - - def test_saved_weights_match_original(self, tmp_path): - model = get_tiny_llama() - original_sd = {k: v.clone().cpu() for k, v in model.state_dict().items()} - - save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) - - reloaded_sd = {} - for shard in (tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"): - reloaded_sd.update(safe_load_file(str(shard))) - - assert set(reloaded_sd.keys()) == set(original_sd.keys()) - for key in original_sd: - torch.testing.assert_close(reloaded_sd[key], original_sd[key]) - - def test_config_json_saved(self, tmp_path): - model = get_tiny_llama() - save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) - + # test config.json saved config_path = tmp_path / "config.json" assert config_path.exists(), "config.json was not saved" cfg = json.loads(config_path.read_text()) assert cfg["num_hidden_layers"] == get_tiny_llama().config.num_hidden_layers + # test subblock filenames follow descriptor groups + filenames = set(index["weight_map"].values()) + expected_substrings = {"embeddings", "lm_head", "block_0_ffn", "block_0_attention"} + for substr in expected_substrings: + assert any(substr in f for f in filenames), f"no shard filename contains '{substr}'" + def test_tie_word_embeddings_excluded(self, tmp_path): model = get_tiny_llama(tie_word_embeddings=True) save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) @@ -97,16 +80,18 @@ def test_tie_word_embeddings_excluded(self, tmp_path): reloaded_sd.update(safe_load_file(str(shard))) assert "lm_head.weight" not in reloaded_sd - def test_subblock_filenames_follow_descriptor_groups(self, tmp_path): + def test_saved_weights_match_original(self, tmp_path): model = get_tiny_llama() + original_sd = {k: v.clone().cpu() for k, v in model.state_dict().items()} save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor) - index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text()) - filenames = set(index["weight_map"].values()) + reloaded_sd = {} + for shard in (tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"): + reloaded_sd.update(safe_load_file(str(shard))) - expected_substrings = {"embeddings", "lm_head", "block_0_ffn", "block_0_attention"} - for substr in expected_substrings: - assert any(substr in f for f in filenames), f"no shard filename contains '{substr}'" + assert set(reloaded_sd.keys()) == set(original_sd.keys()) + for key in original_sd: + torch.testing.assert_close(reloaded_sd[key], original_sd[key]) def _distributed_save_worker(rank, world_size, checkpoint_dir): From 238158d8c8e697b7a11c5250b2a03bac074985b2 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Mon, 27 Apr 2026 06:51:09 -0700 Subject: [PATCH 14/14] Add other minor puzzletron fixes Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .../anymodel/model_descriptor/model_descriptor_factory.py | 2 +- tests/gpu/torch/puzzletron/test_puzzletron.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py index 74aaf311bf..cff972a51e 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor_factory.py @@ -33,7 +33,7 @@ "qwen3": "qwen3", "nemotron_h": "nemotron_h", "nemotron_h_v2": "nemotron_h_v2", - "gpt_oss_20b": "gpt_oss_20b", + "gpt_oss": "gpt_oss", } diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index a393e1e086..d44cbc71e9 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -25,11 +25,6 @@ from _test_utils.torch.puzzletron.utils import setup_test_model_and_data from packaging.version import Version -# The puzzletron pipeline imports mip unconditionally at module level. In NeMo containers -# the [puzzletron] extras are not pre-installed, so importing the test file fails with a -# deep ModuleNotFoundError. Skip early with an actionable message instead. -pytest.importorskip("mip", reason="pip install -e '.[puzzletron]' to install MIP solver") - import modelopt.torch.puzzletron as mtpz import modelopt.torch.utils.distributed as dist