From 83571360c2a2202dd5521387e6059e943f52400f Mon Sep 17 00:00:00 2001 From: Daniel Korzekwa Date: Thu, 5 Mar 2026 09:52:47 -0800 Subject: [PATCH] Make activation scoring working Signed-off-by: Daniel Korzekwa --- .../activation_hooks/utils.py | 121 ++++------- .../score_pruning_activations.py | 2 +- modelopt/torch/puzzletron/puzzletron.py | 26 ++- .../torch/puzzletron/tools/robust_json.py | 5 + .../tools/sharded_checkpoint_utils.py | 205 +++++++++++++----- .../torch/puzzletron/tools/validate_model.py | 193 ++++++++--------- .../utils/validate_runtime_pipeline.py | 94 ++++++-- tests/gpu/torch/puzzletron/test_puzzletron.py | 51 ++--- 8 files changed, 405 insertions(+), 292 deletions(-) diff --git a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py index ab7eed2ac..1b1485c71 100644 --- a/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py +++ b/modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py @@ -15,84 +15,57 @@ # mypy: ignore-errors """Provides a function to register activation hooks for a model. -Activation hooks are used to compute activation scores for pruning. -""" +Activation hooks are used to compute activation scores for pruning.""" -import re +from typing import Type -from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ( - ForwardHook, - IndependentChannelContributionHook, - IndependentKvHeadContributionHook, - IterativeChannelContributionHook, - LayerNormContributionHook, -) -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook as ActivationsHook +from modelopt.torch.puzzletron.tools.logger import aprint def register_activation_hooks( - model: DeciLMForCausalLM, activation_hooks_kwargs: dict -) -> tuple[dict[str, ForwardHook], type[ForwardHook]]: - hook_class_map = { - "mlp.down_proj": { - "independent": IndependentChannelContributionHook, - "iterative": IterativeChannelContributionHook, - }, - "self_attn.o_proj": { - "independent_kv_head_contribution": IndependentKvHeadContributionHook, - }, - r"regex:experts\.\d+\.down_proj$": { # For MoE - "independent": IndependentChannelContributionHook, - }, - # TODO: maybe this is too generic, and we should have it specifically for - # input_layernorm and post_attention_layernorm; now it might select qk_norms - "layernorm": { - "layer_norm_contribution": LayerNormContributionHook, - }, - } - - activation_hooks = {} - target_layer = activation_hooks_kwargs.get("target_layer", "mlp.c_proj") - - if target_layer.startswith("regex:"): - target_layer_regex = target_layer[len("regex:") :] - pattern = re.compile(target_layer_regex) - - def match_predicate(module_name, module): - return pattern.search(module_name) - else: - - def match_predicate(module_name, module): - return module_name.endswith(target_layer) - - target_layer_hooks_map = hook_class_map.get(target_layer) - if target_layer_hooks_map is None: - raise ValueError(f"no hook classes found for: {target_layer}") - - hook_class = target_layer_hooks_map.get(activation_hooks_kwargs["method"]) - if hook_class is None: - raise ValueError(f"Unknown hook class: {hook_class}") - - if target_layer == "block": - pattern = re.compile(r"^transformer\.h\.\d+$") - - def match_predicate(module_name, module): - return pattern.match(module_name) - + model, + activation_hooks_kwargs: dict, + pruning_mixin, + hook_class: Type[ActivationsHook], +) -> dict[str, ActivationsHook]: + """Register activation hooks using the pruning mixin approach. + + Args: + model: The model to register hooks on. + activation_hooks_kwargs: Keyword arguments passed to hook constructors. + pruning_mixin: The pruning mixin that defines which modules to hook. + hook_class: The hook class to instantiate for each module. + + Returns: + Dictionary mapping module names to hook instances. + """ activation_hooks_kwargs["model"] = model - for module_name, module in model.named_modules(): - if match_predicate(module_name, module): - block_config = None - if block_idx_match := re.search(r"\.(\d+)\.", module_name): - block_idx = int(block_idx_match.group(1)) - block_config = model.config.block_configs[block_idx] - curr_activation_hooks_kwargs = { - **activation_hooks_kwargs, - "block_config": block_config, - } - - hook = hook_class(module, curr_activation_hooks_kwargs) - module.register_forward_hook(hook) - activation_hooks[module_name] = hook - return activation_hooks, hook_class + if hook_class not in pruning_mixin.supported_hooks(): + raise ValueError( + f"Hook class not supported for {pruning_mixin.__class__.__name__}, " + f"must be in {pruning_mixin.supported_hooks()}" + ) + + module_names_to_hook = pruning_mixin.get_module_names_to_hook(model) + activation_hooks = dict() + for block_idx, module_name in module_names_to_hook: + block_config = None + if block_idx is not None: + block_config = model.config.block_configs[block_idx] + curr_activation_hooks_kwargs = { + **activation_hooks_kwargs, + "block_config": block_config, + } + + module = model.get_submodule(module_name) + hook = hook_class(module, curr_activation_hooks_kwargs) + module.register_forward_hook(hook) + activation_hooks[module_name] = hook + + if len(activation_hooks) == 0: + raise ValueError("couldn't find any hooks") + + aprint(f"Found the following hooks: {activation_hooks.keys()}") + return activation_hooks diff --git a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py index ef5e5e9ad..c043c20d5 100644 --- a/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py +++ b/modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py @@ -138,4 +138,4 @@ def launch_score_activations(cfg: DictConfig): mprint("Starting pruning activation scoring...") # The checkpoint manager inside validate_model handles all progress tracking - validate_model(args=cfg.pruning, pipeline_parallel=True) + validate_model(args=cfg.pruning) diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 1051fdbaf..0d9ac068f 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -15,6 +15,7 @@ """This module provides the main compression function for a model using MIP-based NAS search algorithm.""" +import hydra from omegaconf import DictConfig import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations @@ -51,24 +52,25 @@ def puzzletron( f"dataset_path={dataset_path}", ], ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) # Step 1: score_pruning_activations (distributed processing) score_pruning_activations.launch_score_activations(hydra_cfg) - # Step 2: pruning_ckpts (single process) - if dist.is_master(): - pruning_ckpts.launch_prune_ckpt(hydra_cfg) - dist.barrier() + # # Step 2: pruning_ckpts (single process) + # if dist.is_master(): + # pruning_ckpts.launch_prune_ckpt(hydra_cfg) + # dist.barrier() - # Step 4: build_library_and_stats (single process) - if dist.is_master(): - build_library_and_stats.launch_build_library_and_stats(hydra_cfg) - dist.barrier() + # # Step 4: build_library_and_stats (single process) + # if dist.is_master(): + # build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + # dist.barrier() - # Step 5: calc_one_block_scores (distributed processing) - scoring.launch_scoring(hydra_cfg) + # # Step 5: calc_one_block_scores (distributed processing) + # scoring.launch_scoring(hydra_cfg) - # Step 6: mip_and_realize_models (distributed processing) - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + # # Step 6: mip_and_realize_models (distributed processing) + # mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) return hydra_cfg diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py index dbb561b82..3397de639 100644 --- a/modelopt/torch/puzzletron/tools/robust_json.py +++ b/modelopt/torch/puzzletron/tools/robust_json.py @@ -50,8 +50,13 @@ def default(self, o): # User-defined function in main — fallback to just the name return o.__name__ return f"{o.__module__}.{o.__qualname__}" + if inspect.isclass(o): + return f"{o.__module__}.{o.__qualname__}" if isinstance(o, datetime.timedelta): return str(o) + # Fallback for arbitrary objects: return their class path + if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): + return f"{o.__class__.__module__}.{o.__class__.__qualname__}" return super().default(o) diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py index 1cb5e8489..1cf02dc93 100644 --- a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -14,22 +14,30 @@ # limitations under the License. # mypy: ignore-errors -"""Provides utilities for distributed loading, saving, and manipulation of +""" +Provides utilities for distributed loading, saving, and manipulation of large language model checkpoints across multiple GPUs/processes. + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. """ import json from collections.abc import Iterable, Mapping from pathlib import Path -from typing import Literal, cast +from types import SimpleNamespace +from typing import Literal, Type, cast import numpy as np import torch import torch.distributed import torch.nn as nn +import transformers +from huggingface_hub import split_torch_state_dict_into_shards from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME from transformers.utils.hub import cached_file, get_checkpoint_shard_files from typing_extensions import override @@ -43,23 +51,18 @@ ) from modelopt.torch.puzzletron.tools.checkpoint_utils import load_model_config, load_state_dict from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.utils.dummy_modules import ( + DummyBlock, + DummyLMHead, + DummyModule, + DummyWTE, +) from modelopt.torch.puzzletron.utils.utils import EmptyInitOnDevice -class DummyModule(nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.register_load_state_dict_post_hook(self.load_state_dict_post_hook) - - @staticmethod - def load_state_dict_post_hook( - module: torch.nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys - ) -> None: - incompatible_keys.missing_keys.clear() - incompatible_keys.unexpected_keys.clear() +class DeciLMDummyBlock(DummyModule): + """Dummy block for DeciLM models (used by replacement_library).""" - -class DummyBlock(DummyModule): def __init__(self, config: DeciLMConfig, block_index: int): super().__init__() self.config = config @@ -73,7 +76,9 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor | tuple[torc return x, None -class DummyWTE(DummyModule): +class DeciLMDummyWTE(DummyModule): + """Dummy word token embedding for DeciLM models (used by replacement_library).""" + def __init__(self, config: DeciLMConfig, dtype: torch.dtype | None = None): super().__init__() self.n_embd = config.get_hidden_size() @@ -86,7 +91,9 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor: return result -class DummyLMHead(DummyModule): +class DeciLMDummyLMHead(DummyModule): + """Dummy LM head for DeciLM models (used by replacement_library).""" + def __init__(self, config: DeciLMConfig): super().__init__() self.vocab_size = config.vocab_size @@ -98,24 +105,44 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return result -def create_local_shard_(model: DeciLMForCausalLM, owned_block_indexes: set[int]): - all_block_indexes = set(range(len(model.model.layers))) +def set_submodule(model: nn.Module, module_name: str, new_submodule: nn.Module) -> None: + """Set a submodule on a model by dotted path.""" + parts = module_name.split(".") + parent_path = ".".join(parts[:-1]) + attr = parts[-1] + parent_module = model.get_submodule(parent_path) if parent_path else model + setattr(parent_module, attr, new_submodule) + + +def create_local_shard_(model, owned_block_indexes: set[int], descriptor, runtime): + all_block_indexes = set(range(model.config.num_hidden_layers)) has_first_block = 0 in owned_block_indexes has_last_block = max(all_block_indexes) in owned_block_indexes unowned_block_indexes = all_block_indexes - owned_block_indexes for block_index in unowned_block_indexes: - model.model.layers[block_index] = cast( - "DeciLMDecoderLayer", DummyBlock(model.config, block_index) + decoder_layer_name = descriptor.layer_block_name(block_index) + decoder_layer = model.get_submodule(decoder_layer_name) + set_submodule( + model, + decoder_layer_name, + descriptor.create_dummy_block(decoder_layer, block_index=block_index), ) - if not has_first_block: - model.set_input_embeddings(DummyWTE(model.config)) + # If we have the last block with tied embeddings, keep embed_tokens so lm_head works. + # load_sharded_state_dict will load embed_tokens.weight from the first shard's checkpoint file, + # and since they're tied, lm_head.weight gets populated too. + if not has_first_block and not (has_last_block and model.config.tie_word_embeddings): + set_submodule( + model, + descriptor.input_embedding_name(), + DummyWTE(model.config.hidden_size, dtype=runtime.dtype), + ) if not has_last_block: - model.model.set_final_layer_norm(nn.Identity()) + set_submodule(model, descriptor.final_norm_name(), nn.Identity()) if not (model.config.tie_word_embeddings and has_first_block): - model.set_output_embeddings(DummyLMHead(model.config)) + set_submodule(model, descriptor.output_embedding_name(), DummyLMHead(model.config)) return model @@ -130,42 +157,74 @@ def create_dummy_model( rope_cls = rope_type_to_class[model_config.position_embedding_type] model.model.rotary_emb = rope_cls(config=model.config) - model.model.set_input_embeddings(DummyWTE(model.config, dtype)) + model.model.set_input_embeddings(DeciLMDummyWTE(model.config, dtype)) model.model.set_final_layer_norm(nn.Identity()) - model.set_output_embeddings(DummyLMHead(model.config)) + model.set_output_embeddings(DeciLMDummyLMHead(model.config)) for block_index in range(model_config.get_num_hidden_layers()): - model.model.layers[block_index] = DummyBlock(model.config, block_index) + model.model.layers[block_index] = DeciLMDummyBlock(model.config, block_index) return model +def _get_model_class_from_config(config: PretrainedConfig): + """ + Get the model class from config.architectures field. + Works for any model registered in transformers (CausalLM, VL models, etc.). + Falls back to AutoModelForCausalLM if architectures is not available. + """ + if hasattr(config, "architectures") and config.architectures: + model_class_name = config.architectures[0] + if hasattr(transformers, model_class_name): + return getattr(transformers, model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, falling back to AutoModelForCausalLM" + ) + return AutoModelForCausalLM + + def load_and_shard_model( + descriptor, checkpoint_path: str | Path, owned_block_indexes: set[int] | Literal["auto"] = "auto", - model_config: DeciLMConfig | None = None, - model_config_overrides: Mapping | None = None, - model_dtype: torch.dtype = torch.bfloat16, -) -> DeciLMForCausalLM: + model_config: PretrainedConfig | None = None, +): checkpoint_path = Path(checkpoint_path) - with torch.device(dist.local_rank()): + runtime = SimpleNamespace( + device=torch.device(dist.local_rank()), + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + use_autocast=True, # Default: use autocast; descriptor can override + ) + + with runtime.device: if model_config is None: - model_config = load_model_config( - checkpoint_path, model_config_overrides, ignore_unexpected_config_keys=True - ) + model_config = load_model_config(checkpoint_path) if owned_block_indexes == "auto": owned_block_indexes = set( - np.array_split(np.arange(model_config.get_num_hidden_layers()), dist.size())[ - dist.rank() + np.array_split(np.arange(model_config.num_hidden_layers), runtime.world_size)[ + runtime.global_rank ] ) mprint("Initializing model shards") - model_shard = create_sharded_model( - model_config=model_config, - owned_block_indexes=owned_block_indexes, - ) + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + with deci_x_patcher( + model_descriptor=descriptor, block_configs=getattr(model_config, "block_configs", None) + ): + model_shard = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=model_config, + owned_block_indexes=owned_block_indexes, + ) if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( checkpoint_path / SAFE_WEIGHTS_INDEX_NAME @@ -178,27 +237,47 @@ def load_and_shard_model( shard_state_dict = load_sharded_state_dict( model_name_or_path=str(checkpoint_path), keys_to_load=shard_keys, - device=torch.device(dist.local_rank()), + device=runtime.device, ) new_names = set(shard_state_dict.keys()) mprint(f"{new_names=}") - model_shard.load_state_dict(shard_state_dict, assign=True) + # strict=False: allows missing lm_head.weight when tie_word_embeddings=True (e.g., Llama 3.2 3B) + model_shard.load_state_dict(shard_state_dict, strict=False, assign=True) del shard_state_dict - if model_config.tie_word_embeddings and (0 in owned_block_indexes): - # re-tie the weights in case the connection was severed + # Re-tie weights after load_state_dict with assign=True, which severs the tie. + # Needed on first rank (owns embed_tokens) and last rank (owns lm_head). + has_first_block = 0 in owned_block_indexes + has_last_block = (model_config.num_hidden_layers - 1) in owned_block_indexes + if model_config.tie_word_embeddings and (has_first_block or has_last_block): model_shard.tie_weights() + + # On the last rank with tied embeddings, we kept embed_tokens in create_local_shard_() + # just to load the weight and tie it to lm_head. Now replace it with a dummy so it + # doesn't interfere with the pipeline forward pass (only rank 0 should run embed_tokens). + if model_config.tie_word_embeddings and has_last_block and not has_first_block: + set_submodule( + model_shard, + descriptor.input_embedding_name(), + DummyWTE(model_config.hidden_size, dtype=runtime.dtype), + ) else: mprint("Loading state_dict in main process") - state_dict = load_state_dict(checkpoint_path) if dist.is_master() else None + state_dict = load_state_dict(checkpoint_path) if runtime.is_main_process else None mprint("Distributing model to shards") load_state_dict_to_shards(model_shard=model_shard, loaded_state_dict=state_dict) del state_dict - model_shard.type(model_dtype) + descriptor.init_rotary_embedding(model_shard, runtime) + + model_shard.type(runtime.dtype) + + # Configure autocast based on model descriptor (some models like Qwen3-VL MoE + # have dtype bugs under autocast) + runtime.use_autocast = descriptor.uses_autocast() params_on_meta_device = [ param_name @@ -206,14 +285,16 @@ def load_and_shard_model( if param.device == torch.device("meta") ] assert len(params_on_meta_device) == 0, ( - f"[global_rank={dist.rank()}] Couldn't load params {params_on_meta_device}" + f"[global_rank={runtime.global_rank}] Couldn't load params {params_on_meta_device}" ) return model_shard def create_sharded_model( - model_config: DeciLMConfig, + runtime, + descriptor, + model_config: PretrainedConfig, owned_block_indexes: set[int], device: str | torch.device | None = "meta", dtype: torch.dtype | None = torch.float32, @@ -224,14 +305,24 @@ def create_sharded_model( dist.barrier() with EmptyInitOnDevice(device="meta", dtype=dtype): - model = DeciLMForCausalLM(model_config) - create_local_shard_(model=model, owned_block_indexes=owned_block_indexes) + # Get model class from config.architectures (works for CausalLM, VL models, etc.) + model_class = _get_model_class_from_config(model_config) + # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() + if model_class is AutoModelForCausalLM: + model = model_class.from_config(model_config, trust_remote_code=True) + else: + model = model_class._from_config(model_config) + create_local_shard_( + model=model, + owned_block_indexes=owned_block_indexes, + descriptor=descriptor, + runtime=runtime, + ) if device != torch.device("meta"): local_shard_state_dict = { k: torch.empty_like(v, device=device) for k, v in model.state_dict().items() } - model.load_state_dict(local_shard_state_dict, assign=True) return model @@ -288,7 +379,9 @@ def load_state_dict_to_shards( def save_sharded_model( model_shard: torch.nn.Module | dict[str, torch.Tensor], out_path: str | Path ): - """out_path is usually output_checkpoint_path / "model.safetensors" """ + """ + out_path is usually output_checkpoint_path / "model.safetensors" + """ dist.barrier() if isinstance(model_shard, torch.nn.Module): @@ -346,7 +439,9 @@ def load_sharded_state_dict( keys_to_load: Iterable[str] | None = None, device: torch.device | str = "cpu", ) -> dict[str, torch.Tensor]: - """keys_to_load: entire state_dict if None, else partial state_dict containing only these keys""" + """ + keys_to_load: entire state_dict if None, else partial state_dict containing only these keys + """ shard_paths = _resolve_shard_paths(model_name_or_path) # print(f"shard_paths: {shard_paths}") partial_state_dict = {} diff --git a/modelopt/torch/puzzletron/tools/validate_model.py b/modelopt/torch/puzzletron/tools/validate_model.py index 6c3dc3640..cb8eb996d 100644 --- a/modelopt/torch/puzzletron/tools/validate_model.py +++ b/modelopt/torch/puzzletron/tools/validate_model.py @@ -12,42 +12,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""Provides a function to validate a model. Runs a model forward pass on a dataset and calculates +# mypy: ignore-errors +""" +Provides a function to validate a model. Runs a model forward pass on a dataset and calculates the loss, and optionally registers hooks to capture the inputs and the outputs of pytorch modules that are used for activation scoring for pruning. TODO: Consider moving this a separate module dedicated for scoring + +Uses native HuggingFace models with deci_x_patcher for heterogeneous layer configurations. """ import textwrap from pathlib import Path +from typing import Type import torch from omegaconf import DictConfig from torch import nn from torch.utils.data import DataLoader -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizerBase, -) +from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.activation_scoring.activation_hooks.utils import ( register_activation_hooks, ) -from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_checkpoint +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer.no_op import Same from modelopt.torch.puzzletron.tools.logger import aprint, mprint -from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( + load_and_shard_model, + set_submodule, +) from modelopt.torch.puzzletron.utils.data.dataloaders import create_validation_dataloader -from modelopt.torch.puzzletron.utils.parsing import simple_parse_args_string +from modelopt.torch.puzzletron.utils.parsing import ( + simple_parse_args_string, # noqa: F401 (kept for backwards compat) +) from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( HiddenStatesAndLMHead, calculate_losses_pipeline, ) -from modelopt.torch.puzzletron.utils.validation import calculate_losses """ Two goals: @@ -70,7 +77,6 @@ def validate_model( tokenizer: PreTrainedTokenizerBase | None = None, target_hidden_states_per_batch: list[torch.Tensor] | None = None, return_hidden_states: bool = False, - pipeline_parallel: bool = False, calculate_full_score_ablations: bool = False, val_dataloader: DataLoader | None = None, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: @@ -79,86 +85,80 @@ def validate_model( Args: args: Configuration object containing the following attributes: - Model Configuration attributes: - - - ``model_name_or_path`` (str): Path to model checkpoint or HuggingFace model name. - Required unless model is passed directly. - - ``model_dtype`` (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). - - ``autocast_dtype`` (str or torch.dtype): Autocast data type for mixed precision. - - Dataset Configuration attributes: - - - ``dataset_path`` (str): Path to the validation dataset. - - ``tokenizer_name`` (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. - - ``data_column`` (str): Column name in dataset containing text data. - - ``block_size`` (int): Maximum sequence length for tokenization. - - ``eval_samples`` (int, optional): Number of samples to evaluate. Uses all if None. - - ``val_dataset_name`` (str): Name of validation dataset split. - - ``source_datasets_to_discard`` (list[str], optional): List of source datasets to exclude. - - ``load_dataset_fn`` (callable, optional): Custom function to load the dataset. - - Data Processing attributes: - - - ``micro_batch_size`` (int): Batch size for evaluation. - - ``seed`` (int): Random seed for reproducibility. - - ``shuffle_seed`` (int, optional): Seed for shuffling data. Uses seed if None. - - ``varlen`` (bool): Enable variable-length sequences. - - ``bos_rate`` (float): Rate of adding BOS token. - - ``fim_rate`` (float): Fill-in-the-middle rate for code completion tasks. - - ``fim_spm_rate`` (float): SPM-based fill-in-the-middle rate. - - Activation Hooks attributes: - - - ``activations_log_dir`` (str, optional): Directory to log activation scores. - If provided, hooks will be registered to capture activations. - - ``activation_hooks_kwargs`` (str or dict, optional): Arguments for activation hooks. - If string, comma-separated format: "arg1=val1,arg2=val2". - - Execution Options attributes: - - - ``calc_losses_on_cpu`` (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. - - ``write_results`` (bool): Write validation results to file. + Model Configuration: + - model_name_or_path (str): Path to model checkpoint or HuggingFace model name. + Required unless model is passed directly. + - model_dtype (str or torch.dtype): Model data type (e.g., "torch.bfloat16", torch.float16). + - autocast_dtype (str or torch.dtype): Autocast data type for mixed precision. + + Dataset Configuration: + - dataset_path (str): Path to the validation dataset. + - tokenizer_name (str, optional): Tokenizer name/path. Uses model_name_or_path if not specified. + - data_column (str): Column name in dataset containing text data. + - block_size (int): Maximum sequence length for tokenization. + - eval_samples (int, optional): Number of samples to evaluate. Uses all if None. + - val_dataset_name (str): Name of validation dataset split. + - source_datasets_to_discard (list[str], optional): List of source datasets to exclude. + - load_dataset_fn (callable, optional): Custom function to load the dataset. + + Data Processing: + - micro_batch_size (int): Batch size for evaluation. + - seed (int): Random seed for reproducibility. + - shuffle_seed (int, optional): Seed for shuffling data. Uses seed if None. + - varlen (bool): Enable variable-length sequences. + - bos_rate (float): Rate of adding BOS token. + - fim_rate (float): Fill-in-the-middle rate for code completion tasks. + - fim_spm_rate (float): SPM-based fill-in-the-middle rate. + + Activation Hooks: + - activations_log_dir (str, optional): Directory to log activation scores. If provided, + hooks will be registered to capture activations. + - activation_hooks_kwargs (str or dict, optional): Arguments for activation hooks. + If string, comma-separated format: "arg1=val1,arg2=val2". + + Execution Options: + - calc_losses_on_cpu (bool): Calculate losses on CPU to avoid OOM. Very slow, not recommended. + - write_results (bool): Write validation results to file. model: Pre-loaded model. If None, will be loaded from args.model_name_or_path. tokenizer: Pre-loaded tokenizer. If None, will be loaded based on args. target_hidden_states_per_batch: Target hidden states for pipeline parallel evaluation. return_hidden_states: Whether to return hidden states from the model. - pipeline_parallel: Enable pipeline parallelism for large models. calculate_full_score_ablations: Calculate comprehensive teacher similarity scores. - False calculates only a small suite for efficiency. + False calculates only a small suite for efficiency. val_dataloader: Pre-created validation dataloader. If None, will be created from args. Returns: A tuple containing: - - losses: Dictionary mapping loss names to loss statistics (avg, per_sample). - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. - Returns (None, None) if not on master rank. """ + descriptor = ModelDescriptorFactory.get(args.descriptor) + if val_dataloader is None: val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None validation_full_iters = ( args.eval_samples // args.micro_batch_size ) # model pipeline, single data rank - model = prepare_model(args, model, pipeline_parallel) + model = prepare_model(args, descriptor=descriptor, model=model) just_model_forward = False checkpoint_manager = None activation_hooks = None if args.activations_log_dir is not None: - activation_hooks_kwargs = ( - simple_parse_args_string(args.activation_hooks_kwargs) - if isinstance(args.activation_hooks_kwargs, str) - else args.activation_hooks_kwargs - ) + activation_hooks_kwargs = args.activation_hooks_kwargs or {} activation_hooks_kwargs["validation_full_iters"] = validation_full_iters + hook_class = args.hook_class - # Create activation hooks first - activation_hooks, hook_class = register_activation_hooks( - model=model, activation_hooks_kwargs=activation_hooks_kwargs + # Create activation hooks using pruning mixin + activation_hooks = register_activation_hooks( + model=model, + activation_hooks_kwargs=activation_hooks_kwargs, + hook_class=hook_class, + pruning_mixin=args.pruning_mixin, ) # Create checkpoint manager with hooks @@ -181,26 +181,23 @@ def validate_model( else: mprint("No checkpoint found, starting fresh") just_model_forward = True - model.lm_head = nn.Identity() - - if not pipeline_parallel: - losses, hidden_states_per_batch = calculate_losses( - model=model, - dataloader=val_dataloader, - checkpoint_manager=checkpoint_manager, - ) - else: - losses, hidden_states_per_batch = calculate_losses_pipeline( - stitched_model=model, - dataloader=val_dataloader, - target_hidden_states_per_batch=target_hidden_states_per_batch, - return_hidden_states=return_hidden_states, - calculate_full_score_ablations=calculate_full_score_ablations, - calc_on_cpu=args.calc_losses_on_cpu, - just_model_forward=just_model_forward, - checkpoint_manager=checkpoint_manager, - autocast_dtype=getattr(torch, args.autocast_dtype.strip("torch.")), - ) + set_submodule(model, descriptor.output_embedding_name(), Same()) + + losses, hidden_states_per_batch = calculate_losses_pipeline( + stitched_model=model, + dataloader=val_dataloader, + target_hidden_states_per_batch=target_hidden_states_per_batch, + return_hidden_states=return_hidden_states, + calculate_full_score_ablations=calculate_full_score_ablations, + calc_on_cpu=args.calc_losses_on_cpu, + just_model_forward=just_model_forward, + checkpoint_manager=checkpoint_manager, + autocast_dtype=getattr( + torch, getattr(args, "autocast_dtype", "torch.bfloat16").strip("torch.") + ), + descriptor=descriptor, + use_autocast=descriptor.uses_autocast(), + ) if losses is not None: avg_losses = {loss_name: loss_log["avg"] for loss_name, loss_log in losses.items()} @@ -224,31 +221,13 @@ def validate_model( def prepare_model( - args: DictConfig, model: PreTrainedModel | None = None, pipeline_parallel: bool = False + args: DictConfig, + descriptor: Type[ModelDescriptor], + model: PreTrainedModel | None = None, ) -> nn.Module: if model is None: assert args.model_name_or_path is not None - if pipeline_parallel: - model = load_and_shard_model( - args.model_name_or_path, - model_config_overrides={"block_size": args.block_size}, - model_dtype=getattr(torch, args.model_dtype.strip("torch.")), - ) - else: - try: - model = load_checkpoint( - args.model_name_or_path, - model_config_overrides={"block_size": args.block_size}, - ignore_unexpected_config_keys=True, - ) - model.to("cuda") - except FileNotFoundError: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - torch_dtype="auto", - device_map="auto", - trust_remote_code=True, - ) + model = load_and_shard_model(descriptor=descriptor, checkpoint_path=args.model_name_or_path) model.eval() return model diff --git a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py index db1e8f2ce..90fea13c5 100644 --- a/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py +++ b/modelopt/torch/puzzletron/utils/validate_runtime_pipeline.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. +""" +Model evaluation utilities for models split across multiple GPUs in pipeline-parallel mode. Coordinates forward passes and loss computation through model shards distributed across GPUs using sewing_kit's StitchedModule framework. Relies on validation.py for core loss computation. @@ -22,16 +23,18 @@ """ # mypy: ignore-errors +import traceback +from contextlib import nullcontext +from typing import Type + import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( - DeciLMForCausalLM, - LMHead, -) +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import LMHead from modelopt.torch.puzzletron.sewing_kit import ( ExternalTarget, InputArgs, @@ -51,6 +54,23 @@ from modelopt.torch.puzzletron.utils.validation import _organize_outputs, calculate_batch_outputs +def _log_forward_error(e: Exception, rank: int, batch_idx: int, num_batches: int) -> None: + """Log detailed error info for distributed forward pass failures. + + When one rank crashes during distributed forward, others may hang waiting for communication. + This logging helps diagnose which rank failed and why. + """ + error_msg = ( + f"\n{'=' * 60}\n" + f"[Rank {rank}] ERROR in stitched_model forward (batch {batch_idx}/{num_batches})\n" + f"Error: {type(e).__name__}: {e}\n" + f"{'=' * 60}\n" + f"{traceback.format_exc()}" + f"{'=' * 60}\n" + ) + print(error_msg, flush=True) + + class HiddenStatesAndLMHead(list): def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Tensor): super().__init__(hidden_states) @@ -59,7 +79,7 @@ def __init__(self, hidden_states: list[torch.Tensor], lm_head_weights: torch.Ten @torch.no_grad() def calculate_losses_pipeline( - stitched_model: StitchedModule | DeciLMForCausalLM, + stitched_model: StitchedModule, dataloader: DataLoader | None, target_hidden_states_per_batch: HiddenStatesAndLMHead | None = None, return_hidden_states: bool = False, @@ -68,8 +88,11 @@ def calculate_losses_pipeline( just_model_forward: bool = False, checkpoint_manager=None, autocast_dtype: torch.dtype = torch.bfloat16, + descriptor: Type[ModelDescriptor] = None, + use_autocast: bool = True, ) -> tuple[dict[str, dict], HiddenStatesAndLMHead | None] | tuple[None, None]: - """Do model forward on each batch and calculate LM loss. + """ + Do model forward on each batch and calculate LM loss. Optionally also calculate kl_div loss and other metrics from given target_hidden_states_per_batch. Optionally return hidden states per batch. Does not support data-parallel. @@ -87,8 +110,8 @@ def calculate_losses_pipeline( target_hidden_states_per_batch: list[torch.Tensor], returned if return_hidden_states=True """ - if isinstance(stitched_model, DeciLMForCausalLM): - stitched_model = perform_pipeline_stitches(stitched_model) + if not isinstance(stitched_model, StitchedModule): + stitched_model = perform_pipeline_stitches(stitched_model, descriptor) params = list(stitched_model.parameters()) model_device = params[0].device if params else "cpu" @@ -145,14 +168,24 @@ def calculate_losses_pipeline( stitched_model.eval() - with torch.autocast(device_type="cuda", dtype=autocast_dtype): + # Use autocast for mixed precision, or nullcontext if disabled + # (some models like Qwen3-VL MoE have dtype bugs under autocast) + autocast_ctx = ( + torch.autocast(device_type="cuda", dtype=autocast_dtype) if use_autocast else nullcontext() + ) + with autocast_ctx: + fake_input_ids = fake_tensor(1, seq_len, dtype=torch.long, device=model_device) for i_batch in progress_bar: if dist.is_master(): input_ids = all_input_ids[i_batch].to(model_device) else: - input_ids = fake_tensor(1, seq_len, dtype=torch.long) + input_ids = fake_input_ids - output = stitched_model({}, {}, input_ids) + try: + output = stitched_model({}, {}, input_ids) + except Exception as e: + _log_forward_error(e, dist.rank(), i_batch, num_batches) + raise if dist.is_last_process(): logits = output.captured_outputs.get("model_output") @@ -183,6 +216,16 @@ def calculate_losses_pipeline( outputs.append(batch_outputs) + # Free GPU memory after processing each batch + del logits, hidden_states, targets + if target_hidden_states is not None: + del target_hidden_states + if target_logits is not None: + del target_logits + + # Free output tensor memory on all ranks + del output + # Update checkpoint progress periodically if checkpoint_manager: checkpoint_manager.update_progress(i_batch + 1, num_batches) @@ -200,13 +243,28 @@ def calculate_losses_pipeline( return losses, hidden_states_per_batch -def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: +def perform_pipeline_stitches( + model, + descriptor: Type[ModelDescriptor], +) -> StitchedModule: + """Create pipeline stitches for distributed model evaluation. + + Args: + model: The model to stitch (any HuggingFace model with AnyModel descriptor). + descriptor: ModelDescriptor for layer naming. + """ target = ModuleTarget("module", model) stitcher = Needle() + num_layers = model.config.num_hidden_layers + is_real_block = np.flatnonzero( - [not isinstance(block, DummyBlock) for block in model.model.layers] + [ + not isinstance(model.get_submodule(descriptor.layer_block_name(i)), DummyBlock) + for i in range(num_layers) + ] ) + first_block, last_block = is_real_block.min(), is_real_block.max() if dist.rank() != 0: @@ -216,7 +274,7 @@ def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: name="activations", adapter=lambda x: InputArgs(x) ), target.input( - name=f"model.layers.{first_block}", + name=descriptor.layer_block_name(first_block), reducer=InputReducer( lambda acc, override, orig, *args: override + orig.drop_args(0) ), @@ -226,17 +284,17 @@ def perform_pipeline_stitches(model: DeciLMForCausalLM) -> StitchedModule: if not dist.is_last_process(): # send activations to next rank stitcher.stitch( - target.output(f"model.layers.{last_block}"), + target.output(descriptor.layer_block_name(last_block)), RemoteTarget(peer_rank=dist.rank() + 1).value(name="activations"), ) else: # register model output stitcher.stitch( - target.output(name="lm_head"), + target.output(name=descriptor.output_embedding_name()), ExternalTarget().output("model_output"), ) stitcher.stitch( - target.output(name="model.norm"), + target.output(name=descriptor.final_norm_name()), ExternalTarget().output("hidden_states"), ) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index 23a4b61c2..585567715 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -24,6 +24,7 @@ from _test_utils.torch.puzzletron.utils import setup_test_model_and_data import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron import puzzletron from modelopt.torch.puzzletron.anymodel import convert_model # The e2e test to compress a model based on Local Neural Architecture Search (Mixed Integer Programing NAS search) @@ -42,26 +43,26 @@ ), [ ("llama_3_1_8b_instruct", "llama", "llama_3_1_8b_instruct", None, False), - ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), - ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), - ( - "mistral-small-24b-instruct-2501", - "mistral_small", - "mistral-small-24b-instruct-2501", - None, - False, - ), - ("qwen3-8b", "qwen3", "qwen3-8b", None, False), - ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), - ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), - ( - "nemotron-3-nano-30b-a3b-base-bf16", - "nemotron_h", - "nemotron-3-nano-30b-a3b-base-bf16", - "*E", - True, - ), - ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), + # ("llama_3_2_3b_instruct", "llama", "llama_3_1_8b_instruct", None, False), + # ("qwen2_5_7b_instruct", "qwen2", "qwen2_5_7b_instruct", None, False), + # ( + # "mistral-small-24b-instruct-2501", + # "mistral_small", + # "mistral-small-24b-instruct-2501", + # None, + # False, + # ), + # ("qwen3-8b", "qwen3", "qwen3-8b", None, False), + # ("qwen3-vl-30b-a3b-instruct", "qwen3_vl", "qwen3-vl-30b-a3b-instruct", None, True), + # ("nemotron-nano-12b-v2", "nemotron_h_v2", "nemotron-nano-12b-v2", "*-", False), + # ( + # "nemotron-3-nano-30b-a3b-base-bf16", + # "nemotron_h", + # "nemotron-3-nano-30b-a3b-base-bf16", + # "*E", + # True, + # ), + # ("gpt-oss-20b", "gpt_oss_20b", "gpt-oss-20b", None, True), ], ) def test_puzzletron( @@ -106,7 +107,7 @@ def _test_puzzletron_multiprocess_job( puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, hf_config_name, hybrid_override_pattern ) - hydra_config_dir = ( # noqa: F841 + hydra_config_dir = ( project_root_path / f"tests/gpu/torch/puzzletron/resources/configs/{hydra_config_subdir}" ) @@ -120,10 +121,10 @@ def _test_puzzletron_multiprocess_job( dist.barrier() # TODO commented for the duration of merging process from dkorzekwa/any_model to feature/puzzletron - # # Compress the model using a one-click approach - # puzzletron.puzzletron( - # str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) - # ) + # Compress the model using a one-click approach + puzzletron.puzzletron( + str(hydra_config_dir), hydra_config_subdir, str(puzzle_dir), str(dataset_path) + ) # # # # Check assertions