From 45c4f7008022b52dd9b1b36ae3938fbec0f23fba Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 23:07:56 +0000 Subject: [PATCH] fix: apply CodeRabbit auto-fixes Fixed 1 file(s) based on 1 unresolved review comment. Co-authored-by: CodeRabbit --- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index dfa6eb8233..bc9c61f06a 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -225,14 +225,18 @@ def save_checkpoint_from_shards( local_sd = {k: v.cpu() for k, v in model.state_dict().items()} if dist_utils.size() > 1: if dist_utils.is_master(): - gathered: list[dict] = [{}] * dist_utils.size() + gathered: list[dict] = [None] * dist_utils.size() tdist.gather_object(local_sd, gathered, dst=0) full_sd: dict[str, torch.Tensor] = {} for shard_sd in gathered: + if shard_sd is None: + continue full_sd.update(shard_sd) _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor) else: tdist.gather_object(local_sd, dst=0) + # Barrier ensures all ranks wait until file I/O completes before continuing + dist_utils.barrier() else: _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) @@ -484,4 +488,4 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str dataclasses.asdict(conf) if dataclasses.is_dataclass(conf) else conf for conf in model_config.block_configs ] - model_config.save_pretrained(checkpoint_dir) + model_config.save_pretrained(checkpoint_dir) \ No newline at end of file