Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions modelopt/torch/puzzletron/puzzletron.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def puzzletron(
# Step 1: score_pruning_activations (distributed processing)
score_pruning_activations.launch_score_activations(hydra_cfg)

# # Step 2: pruning_ckpts (single process)
# if dist.is_master():
# pruning_ckpts.launch_prune_ckpt(hydra_cfg)
# dist.barrier()
# Step 2: pruning_ckpts (single process)
if dist.is_master():
pruning_ckpts.launch_prune_ckpt(hydra_cfg)
dist.barrier()

# # Step 4: build_library_and_stats (single process)
# if dist.is_master():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@
# limitations under the License.
# mypy: ignore-errors

"""TODO Add description"""
"""Initialize child models from parent models using AnyModel approach with deci_x_patcher."""

import json
import time
from pathlib import Path
from typing import Optional

import torch
import yaml
from transformers import AutoModelForCausalLM

from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM
from modelopt.torch.puzzletron.anymodel.model_descriptor import (
ModelDescriptor,
ModelDescriptorFactory,
)
from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
from modelopt.torch.puzzletron.tools.bypassed_training.child_init import (
GQAInitMode,
HiddenSizeInitMode,
Expand All @@ -31,85 +38,37 @@
create_child_state_dict,
update_model_config,
)
from modelopt.torch.puzzletron.tools.checkpoint_utils import (
copy_tokenizer,
load_model_config,
load_state_dict,
)
from modelopt.torch.puzzletron.tools.checkpoint_utils import copy_tokenizer, load_state_dict
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import (
_save_checkpoint,
copy_deci_lm_hf_code,
load_model_config,
)
from modelopt.torch.puzzletron.tools.logger import mprint

"""

Usage example - remove all/some routed experts:
===============================================

PARENT_DIR=".../meta-llama/Llama-4-Scout-17B-16E-Instruct--deci-hf"

MLP_INIT_MODE="ConcatExpertsIntoDenseFFN"

## remove all routed experts, turn the shared expert into a dense FFN
# OUTPUT_DIR="/.../micro_scout/Scout-remove-routed-experts"
# MODEL_CONFIG_OVERRIDES_JSON='
# {
# "ffn": [
# {
# "moe": null,
# "intermediate_size": 14336,
# "gated": true,
# "hidden_act": "silu"
# }
# ]
# }
# '

## concat the shared expert with one routed expert into a dense FFN
OUTPUT_DIR=".../scratch/micro_scout/Scout-ConcatExpertsIntoDenseFFN-concat-shared-and-3-routed"
MODEL_CONFIG_OVERRIDES_JSON='
{
"ffn": [
{
"moe": null,
"intermediate_size": 14336,
"gated": true,
"hidden_act": "silu"
}
]
}
'

echo ""
echo "MODEL_CONFIG_OVERRIDES_JSON:"
echo "${MODEL_CONFIG_OVERRIDES_JSON}"

python -m modelopt.torch.puzzletron.tools.bypassed_training.init_child_from_parent \
--parent_checkpoint_dir="$PARENT_DIR" \
--model_config_overrides_json="$MODEL_CONFIG_OVERRIDES_JSON" \
--output_checkpoint_dir="$OUTPUT_DIR" \
--mlp_init_mode="$MLP_INIT_MODE" \
--mlp_init_config_yaml="$MLP_INIT_CONFIG_YAML"
"""
from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import _get_model_class_from_config


def init_child_from_parent(
descriptor: ModelDescriptor,
pruning_mixin,
parent_checkpoint_dir: str,
model_config_overrides_json: str,
model_config_overrides_dict: dict | str,
output_checkpoint_dir: str,
gqa_init_mode: GQAInitMode,
mlp_init_mode: MlpInitMode,
mlp_init_config_yaml: str | None,
mlp_init_config_yaml: Optional[str],
linear_init_mode: LinearInitMode,
hidden_size_init_mode: HiddenSizeInitMode | None = None,
channel_importance_path: str | None = None,
max_workers: int | None = None, # Auto-calculate optimal workers if None
max_layer_workers: int | None = None, # Auto-calculate optimal workers if None
hidden_size_init_mode: Optional[HiddenSizeInitMode] = None,
channel_importance_path: Optional[str] = None,
max_workers: Optional[int] = None, # Auto-calculate optimal workers if None
max_layer_workers: Optional[int] = None, # Auto-calculate optimal workers if None
) -> None:
"""Init child models from parent models in the style of bypass training,
"""
Init child models from parent models in the style of bypass training,
but without having to run the entire bypass pipeline.

Uses AnyModel approach with deci_x_patcher for heterogeneous layer configurations.

I/O Optimization Parameters:
- max_workers: Number of threads for parallel file I/O (default: auto-calculate min(CPU count, num files))
- max_layer_workers: Number of threads for parallel layer processing (default: auto-calculate min(CPU count, num layers))
Expand All @@ -123,16 +82,16 @@ def init_child_from_parent(
"We do not support random init of any subblock in this script to avoid initializing the student model"
)

descriptor = ModelDescriptorFactory.get(descriptor)

copy_tokenizer(parent_checkpoint_dir, output_checkpoint_dir)

parent_model_config = load_model_config(parent_checkpoint_dir)
parent_state_dict = load_state_dict(parent_checkpoint_dir)

# Parse the model config overrides
if isinstance(model_config_overrides_json, str):
model_config_overrides_dict = json.loads(model_config_overrides_json)
else:
model_config_overrides_dict = model_config_overrides_json
# Parse JSON if string
if isinstance(model_config_overrides_dict, str):
model_config_overrides_dict = json.loads(model_config_overrides_dict)

# Separate global config overrides from block-level overrides
global_config_overrides = {}
Expand All @@ -146,7 +105,7 @@ def init_child_from_parent(

# Load child model config with global overrides
child_model_config = load_model_config(
checkpoint_dir=parent_checkpoint_dir,
parent_checkpoint_dir,
model_config_overrides=global_config_overrides,
ignore_unexpected_config_keys=True,
)
Expand All @@ -159,19 +118,32 @@ def init_child_from_parent(
)

with torch.device("meta"):
child_model = DeciLMForCausalLM(child_model_config)
# Pass block_configs explicitly so patcher works for VL models where
# decoder layers receive nested config (e.g., text_config) without block_configs
with deci_x_patcher(
model_descriptor=descriptor, block_configs=child_model_config.block_configs
):
model_class = _get_model_class_from_config(child_model_config)
# AutoModelForCausalLM uses from_config(); concrete model classes use _from_config()
if model_class is AutoModelForCausalLM:
child_model = model_class.from_config(child_model_config, trust_remote_code=True)
else:
child_model = model_class._from_config(child_model_config)

child_state_dict_with_meta_tensors = child_model.state_dict()

mlp_init_config = (
yaml.safe_load(mlp_init_config_yaml)
if isinstance(mlp_init_config_yaml, str) is None
if isinstance(mlp_init_config_yaml, str)
else mlp_init_config_yaml
)

# Profile create_child_state_dict with automatic layer parallelization
mprint("Starting create_child_state_dict...")
start_time = time.time()
child_state_dict = create_child_state_dict(
pruning_mixin=pruning_mixin,
descriptor=descriptor,
original_state_dict=parent_state_dict,
new_state_dict=child_state_dict_with_meta_tensors,
original_config=parent_model_config,
Expand All @@ -182,7 +154,7 @@ def init_child_from_parent(
linear_init_mode=linear_init_mode,
hidden_size_init_mode=hidden_size_init_mode or HiddenSizeInitMode.CopyAsIs,
channel_importance_path=channel_importance_path,
max_layer_workers=max_layer_workers, # Will auto-calculate if None
max_layer_workers=max_layer_workers,
)
create_child_state_dict_time = time.time() - start_time
mprint(f"create_child_state_dict completed in {create_child_state_dict_time:.2f} seconds")
Expand All @@ -196,7 +168,8 @@ def init_child_from_parent(
child_model_config,
child_state_dict,
output_checkpoint_dir,
max_workers=max_workers, # Will auto-calculate if None
descriptor,
max_workers=max_workers,
)
save_checkpoint_time = time.time() - start_time
mprint(f"_save_checkpoint completed in {save_checkpoint_time:.2f} seconds")
Expand All @@ -207,7 +180,7 @@ def init_child_from_parent(
total_core_time = create_child_state_dict_time + save_checkpoint_time
actual_layer_workers = max_layer_workers if max_layer_workers else "auto"
actual_io_workers = max_workers if max_workers else "auto"
mprint("\n=== PROFILING SUMMARY ===")
mprint(f"\n=== PROFILING SUMMARY ===")
mprint(
f"create_child_state_dict: {create_child_state_dict_time:.2f}s ({create_child_state_dict_time / total_core_time * 100:.1f}%)"
)
Expand All @@ -216,4 +189,4 @@ def init_child_from_parent(
)
mprint(f"Total core processing: {total_core_time:.2f}s")
mprint(f"Optimizations: I/O workers={actual_io_workers}, Layer workers={actual_layer_workers}")
mprint("=========================\n")
mprint(f"=========================\n")
Loading