Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
from typing import TYPE_CHECKING, Any, BinaryIO

import torch
import torch.distributed as tdist
import transformers
from safetensors.torch import save_file as safe_save_file
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

import modelopt.torch.utils.distributed as dist_utils
from modelopt.torch.utils import json_dumps

from ..block_config import maybe_cast_block_configs
Expand All @@ -51,6 +53,7 @@
"load_model_config",
"init_model_from_config",
"save_checkpoint",
"save_checkpoint_from_shards",
"save_subblocks",
"save_model_config",
]
Expand Down Expand Up @@ -200,6 +203,52 @@ def save_checkpoint(
_save_checkpoint(model.config, model.state_dict(), checkpoint_dir, descriptor)


def save_checkpoint_from_shards(
model: PreTrainedModel, checkpoint_dir: Path | str, descriptor: "ModelDescriptor"
) -> None:
"""
Save a checkpoint when the model's weights are sharded across distributed ranks.

Gathers each rank's partial state dictionary onto rank 0 and writes a complete checkpoint
(including the safetensors index and subblocks) from the merged weights. On a single-process
run, saves directly from the local state dict. Only rank 0 performs the filesystem write;
non-master ranks only participate in the gather.

Parameters:
model (PreTrainedModel): The model instance whose local state_dict contains this rank's
shard of weights.
checkpoint_dir (Path | str): Destination directory for the checkpoint files.
descriptor (ModelDescriptor): Descriptor used to partition weights into subblocks and build
the safetensors index.
"""

local_sd = {k: v.cpu() for k, v in model.state_dict().items()}
if dist_utils.size() > 1:
save_err: str | None = None
if dist_utils.is_master():
gathered: list[dict] = [None] * dist_utils.size()
tdist.gather_object(local_sd, gathered, dst=0)
full_sd: dict[str, torch.Tensor] = {}
for shard_sd in gathered:
if shard_sd is None:
continue
full_sd.update(shard_sd)
try:
_save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor)
except Exception as e:
save_err = repr(e)
else:
tdist.gather_object(local_sd, dst=0)
err_box = [save_err]
tdist.broadcast_object_list(err_box, src=0)
# Barrier ensures all ranks wait until file I/O completes before continuing
dist_utils.barrier()
if err_box[0] is not None:
raise RuntimeError(f"Checkpoint save failed on rank 0: {err_box[0]}")
else:
Comment thread
grzegorz-k-karch marked this conversation as resolved.
_save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _save_checkpoint(
model_config: PretrainedConfig,
state_dict: dict[str, torch.Tensor],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from ..utils.validate_runtime_pipeline import perform_pipeline_stitches
from . import validate_model
from .checkpoint_utils import copy_tokenizer
from .checkpoint_utils_hf import save_checkpoint
from .checkpoint_utils_hf import save_checkpoint_from_shards
from .common import resolve_torch_dtype
from .sharded_checkpoint_utils import load_and_shard_model
from .validation_utils import (
Expand Down Expand Up @@ -189,7 +189,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None:
# TODO: Loo into internal Puzzleron code to see how to save as symlinks
# save_checkpoint_as_symlinks is currently not supported
pass
save_checkpoint(model, checkpoint_dir, descriptor)
save_checkpoint_from_shards(model, checkpoint_dir, descriptor)

copy_tokenizer(
args.tokenizer_name,
Expand Down
163 changes: 163 additions & 0 deletions tests/gpu/torch/puzzletron/tools/test_save_ckpt_from_shards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for save_checkpoint_from_shards in checkpoint_utils_hf."""

import json
from functools import partial

import pytest
import torch
from _test_utils.torch.distributed.utils import spawn_multiprocess_job
from safetensors.torch import load_file as safe_load_file
from transformers import AutoModelForCausalLM, LlamaConfig

from modelopt.torch.puzzletron.anymodel.models.llama.llama_model_descriptor import (
LlamaModelDescriptor,
)
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_SUBBLOCKS_DIR_NAME,
save_checkpoint_from_shards,
)

TINY_LLAMA_CONFIG = {
"hidden_size": 32,
"intermediate_size": 64,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"max_position_embeddings": 32,
"vocab_size": 32,
"tie_word_embeddings": False,
}


def _make_tiny_llama(**overrides) -> AutoModelForCausalLM:
cfg = {**TINY_LLAMA_CONFIG, **overrides}
return AutoModelForCausalLM.from_config(LlamaConfig(**cfg))
Comment on lines +36 to +50
Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TINY_LLAMA_CONFIG = {
"hidden_size": 32,
"intermediate_size": 64,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"max_position_embeddings": 32,
"vocab_size": 32,
"tie_word_embeddings": False,
}
def _make_tiny_llama(**overrides) -> AutoModelForCausalLM:
cfg = {**TINY_LLAMA_CONFIG, **overrides}
return AutoModelForCausalLM.from_config(LlamaConfig(**cfg))
from _test_utils.torch.transformers_models import get_tiny_llama
...
model = get_tiny_llama(num_hidden_layers=2, tie_word_embeddings=False)



class TestSaveCheckpointFromShardsSingleProcess:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think many of these tests can be combined into one to avoid creating and saving model again and again separately

Comment thread
kevalmorabia97 marked this conversation as resolved.
"""Tests that run without torch.distributed (world_size=1 path)."""

def test_creates_index_and_subblocks(self, tmp_path):
model = _make_tiny_llama()
save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor)

index_path = tmp_path / SAFE_WEIGHTS_INDEX_NAME
assert index_path.exists(), "safetensors index file was not written"
index = json.loads(index_path.read_text())
assert "weight_map" in index

subblocks_dir = tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME
assert subblocks_dir.is_dir(), "subblocks directory was not created"

shard_files = list(subblocks_dir.glob("*.safetensors"))
assert len(shard_files) > 0, "no safetensors shard files were saved"

def test_weight_map_covers_all_state_dict_keys(self, tmp_path):
model = _make_tiny_llama()
expected_keys = set(model.state_dict().keys())

save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor)

index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text())
mapped_keys = set(index["weight_map"].keys())
assert mapped_keys == expected_keys

def test_saved_weights_match_original(self, tmp_path):
model = _make_tiny_llama()
original_sd = {k: v.clone().cpu() for k, v in model.state_dict().items()}

save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor)

reloaded_sd = {}
for shard in (tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"):
reloaded_sd.update(safe_load_file(str(shard)))

assert set(reloaded_sd.keys()) == set(original_sd.keys())
for key in original_sd:
torch.testing.assert_close(reloaded_sd[key], original_sd[key])

def test_config_json_saved(self, tmp_path):
model = _make_tiny_llama()
save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor)

config_path = tmp_path / "config.json"
assert config_path.exists(), "config.json was not saved"
cfg = json.loads(config_path.read_text())
assert cfg["num_hidden_layers"] == TINY_LLAMA_CONFIG["num_hidden_layers"]

def test_tie_word_embeddings_excluded(self, tmp_path):
model = _make_tiny_llama(tie_word_embeddings=True)
save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor)

index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text())
assert "lm_head.weight" not in index["weight_map"]

reloaded_sd = {}
for shard in (tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"):
reloaded_sd.update(safe_load_file(str(shard)))
assert "lm_head.weight" not in reloaded_sd

def test_subblock_filenames_follow_descriptor_groups(self, tmp_path):
model = _make_tiny_llama()
save_checkpoint_from_shards(model, tmp_path, LlamaModelDescriptor)

index = json.loads((tmp_path / SAFE_WEIGHTS_INDEX_NAME).read_text())
filenames = set(index["weight_map"].values())

expected_substrings = {"embeddings", "lm_head", "block_0_ffn", "block_0_attention"}
for substr in expected_substrings:
assert any(substr in f for f in filenames), f"no shard filename contains '{substr}'"


def _distributed_save_worker(rank, world_size, checkpoint_dir):
"""Worker that shards a model's state dict across ranks and saves."""
model = _make_tiny_llama()
full_sd = model.state_dict()
keys = sorted(full_sd.keys())
per_rank = len(keys) // world_size
start = rank * per_rank
end = start + per_rank if rank < world_size - 1 else len(keys)
shard_keys = keys[start:end]

# Zero out keys not owned by this rank so gather reconstructs the full dict.
for k in keys:
if k not in shard_keys:
full_sd[k] = torch.zeros_like(full_sd[k])

model.load_state_dict(full_sd)
save_checkpoint_from_shards(model, checkpoint_dir, LlamaModelDescriptor)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="need >=2 GPUs for multi-rank test")
class TestSaveCheckpointFromShardsMultiProcess:
"""Tests that exercise the distributed gather path (world_size > 1)."""

def test_distributed_save_creates_valid_checkpoint(self, tmp_path):
spawn_multiprocess_job(2, partial(_distributed_save_worker, checkpoint_dir=tmp_path))

index_path = tmp_path / SAFE_WEIGHTS_INDEX_NAME
assert index_path.exists()
index = json.loads(index_path.read_text())

model = _make_tiny_llama()
expected_keys = set(model.state_dict().keys())
assert set(index["weight_map"].keys()) == expected_keys

shard_files = list((tmp_path / SAFETENSORS_SUBBLOCKS_DIR_NAME).glob("*.safetensors"))
assert len(shard_files) > 0
Loading