diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 0d9ac068f..94a1de57e 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -57,10 +57,10 @@ def puzzletron( # Step 1: score_pruning_activations (distributed processing) score_pruning_activations.launch_score_activations(hydra_cfg) - # # Step 2: pruning_ckpts (single process) - # if dist.is_master(): - # pruning_ckpts.launch_prune_ckpt(hydra_cfg) - # dist.barrier() + # Step 2: pruning_ckpts (single process) + if dist.is_master(): + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() # # Step 4: build_library_and_stats (single process) # if dist.is_master(): diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py index 46e403c5f..36e41c4b6 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/init_child_from_parent.py @@ -14,15 +14,22 @@ # limitations under the License. # mypy: ignore-errors -"""TODO Add description""" +"""Initialize child models from parent models using AnyModel approach with deci_x_patcher.""" import json import time +from pathlib import Path +from typing import Optional import torch import yaml +from transformers import AutoModelForCausalLM -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( GQAInitMode, HiddenSizeInitMode, @@ -31,85 +38,37 @@ create_child_state_dict, update_model_config, ) -from modelopt.torch.puzzletron.tools.checkpoint_utils import ( - copy_tokenizer, - load_model_config, - load_state_dict, -) +from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import ( _save_checkpoint, copy_deci_lm_hf_code, + load_model_config, ) from modelopt.torch.puzzletron.tools.logger import mprint - -""" - -Usage example - remove all/some routed experts: -=============================================== - -PARENT_DIR=".../meta-llama/Llama-4-Scout-17B-16E-Instruct--deci-hf" - -MLP_INIT_MODE="ConcatExpertsIntoDenseFFN" - -## remove all routed experts, turn the shared expert into a dense FFN -# OUTPUT_DIR="/.../micro_scout/Scout-remove-routed-experts" -# MODEL_CONFIG_OVERRIDES_JSON=' -# { -# "ffn": [ -# { -# "moe": null, -# "intermediate_size": 14336, -# "gated": true, -# "hidden_act": "silu" -# } -# ] -# } -# ' - -## concat the shared expert with one routed expert into a dense FFN -OUTPUT_DIR=".../scratch/micro_scout/Scout-ConcatExpertsIntoDenseFFN-concat-shared-and-3-routed" -MODEL_CONFIG_OVERRIDES_JSON=' -{ - "ffn": [ - { - "moe": null, - "intermediate_size": 14336, - "gated": true, - "hidden_act": "silu" - } - ] -} -' - -echo "" -echo "MODEL_CONFIG_OVERRIDES_JSON:" -echo "${MODEL_CONFIG_OVERRIDES_JSON}" - -python -m modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent \ - --parent_checkpoint_dir="$PARENT_DIR" \ - --model_config_overrides_json="$MODEL_CONFIG_OVERRIDES_JSON" \ - --output_checkpoint_dir="$OUTPUT_DIR" \ - --mlp_init_mode="$MLP_INIT_MODE" \ - --mlp_init_config_yaml="$MLP_INIT_CONFIG_YAML" -""" +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config def init_child_from_parent( + descriptor: ModelDescriptor, + pruning_mixin, parent_checkpoint_dir: str, - model_config_overrides_json: str, + model_config_overrides_dict: dict | str, output_checkpoint_dir: str, gqa_init_mode: GQAInitMode, mlp_init_mode: MlpInitMode, - mlp_init_config_yaml: str | None, + mlp_init_config_yaml: Optional[str], linear_init_mode: LinearInitMode, - hidden_size_init_mode: HiddenSizeInitMode | None = None, - channel_importance_path: str | None = None, - max_workers: int | None = None, # Auto-calculate optimal workers if None - max_layer_workers: int | None = None, # Auto-calculate optimal workers if None + hidden_size_init_mode: Optional[HiddenSizeInitMode] = None, + channel_importance_path: Optional[str] = None, + max_workers: Optional[int] = None, # Auto-calculate optimal workers if None + max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None ) -> None: - """Init child models from parent models in the style of bypass training, + """ + Init child models from parent models in the style of bypass training, but without having to run the entire bypass pipeline. + Uses AnyModel approach with deci_x_patcher for heterogeneous layer configurations. + I/O Optimization Parameters: - max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files)) - max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers)) @@ -123,16 +82,16 @@ def init_child_from_parent( "We do not support random init of any subblock in this script to avoid initializing the student model" ) + descriptor = ModelDescriptorFactory.get(descriptor) + copy_tokenizer(parent_checkpoint_dir, output_checkpoint_dir) parent_model_config = load_model_config(parent_checkpoint_dir) parent_state_dict = load_state_dict(parent_checkpoint_dir) - # Parse the model config overrides - if isinstance(model_config_overrides_json, str): - model_config_overrides_dict = json.loads(model_config_overrides_json) - else: - model_config_overrides_dict = model_config_overrides_json + # Parse JSON if string + if isinstance(model_config_overrides_dict, str): + model_config_overrides_dict = json.loads(model_config_overrides_dict) # Separate global config overrides from block-level overrides global_config_overrides = {} @@ -146,7 +105,7 @@ def init_child_from_parent( # Load child model config with global overrides child_model_config = load_model_config( - checkpoint_dir=parent_checkpoint_dir, + parent_checkpoint_dir, model_config_overrides=global_config_overrides, ignore_unexpected_config_keys=True, ) @@ -159,12 +118,23 @@ def init_child_from_parent( ) with torch.device("meta"): - child_model = DeciLMForCausalLM(child_model_config) + # Pass block_configs explicitly so patcher works for VL models where + # decoder layers receive nested config (e.g., text_config) without block_configs + with deci_x_patcher( + model_descriptor=descriptor, block_configs=child_model_config.block_configs + ): + model_class = _get_model_class_from_config(child_model_config) + # AutoModelForCausalLM uses from_config(); concrete model classes use _from_config() + if model_class is AutoModelForCausalLM: + child_model = model_class.from_config(child_model_config, trust_remote_code=True) + else: + child_model = model_class._from_config(child_model_config) + child_state_dict_with_meta_tensors = child_model.state_dict() mlp_init_config = ( yaml.safe_load(mlp_init_config_yaml) - if isinstance(mlp_init_config_yaml, str) is None + if isinstance(mlp_init_config_yaml, str) else mlp_init_config_yaml ) @@ -172,6 +142,8 @@ def init_child_from_parent( mprint("Starting create_child_state_dict...") start_time = time.time() child_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, original_state_dict=parent_state_dict, new_state_dict=child_state_dict_with_meta_tensors, original_config=parent_model_config, @@ -182,7 +154,7 @@ def init_child_from_parent( linear_init_mode=linear_init_mode, hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs, channel_importance_path=channel_importance_path, - max_layer_workers=max_layer_workers, # Will auto-calculate if None + max_layer_workers=max_layer_workers, ) create_child_state_dict_time = time.time() - start_time mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds") @@ -196,7 +168,8 @@ def init_child_from_parent( child_model_config, child_state_dict, output_checkpoint_dir, - max_workers=max_workers, # Will auto-calculate if None + descriptor, + max_workers=max_workers, ) save_checkpoint_time = time.time() - start_time mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds") @@ -207,7 +180,7 @@ def init_child_from_parent( total_core_time = create_child_state_dict_time + save_checkpoint_time actual_layer_workers = max_layer_workers if max_layer_workers else "auto" actual_io_workers = max_workers if max_workers else "auto" - mprint("\n=== PROFILING SUMMARY ===") + mprint(f"\n=== PROFILING SUMMARY ===") mprint( f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)" ) @@ -216,4 +189,4 @@ def init_child_from_parent( ) mprint(f"Total core processing: {total_core_time:.2f}s") mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}") - mprint("=========================\n") + mprint(f"=========================\n")