Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,84 +15,57 @@
# mypy: ignore-errors

"""Provides a function to register activation hooks for a model.
Activation hooks are used to compute activation scores for pruning.
"""
Activation hooks are used to compute activation scores for pruning."""

import re
from typing import Type

from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import (
ForwardHook,
IndependentChannelContributionHook,
IndependentKvHeadContributionHook,
IterativeChannelContributionHook,
LayerNormContributionHook,
)
from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM
from modelopt.torch.nas.plugins.megatron_hooks.base_hooks import ForwardHook as ActivationsHook
from modelopt.torch.puzzletron.tools.logger import aprint


def register_activation_hooks(
model: DeciLMForCausalLM, activation_hooks_kwargs: dict
) -> tuple[dict[str, ForwardHook], type[ForwardHook]]:
hook_class_map = {
"mlp.down_proj": {
"independent": IndependentChannelContributionHook,
"iterative": IterativeChannelContributionHook,
},
"self_attn.o_proj": {
"independent_kv_head_contribution": IndependentKvHeadContributionHook,
},
r"regex:experts\.\d+\.down_proj$": { # For MoE
"independent": IndependentChannelContributionHook,
},
# TODO: maybe this is too generic, and we should have it specifically for
# input_layernorm and post_attention_layernorm; now it might select qk_norms
"layernorm": {
"layer_norm_contribution": LayerNormContributionHook,
},
}

activation_hooks = {}
target_layer = activation_hooks_kwargs.get("target_layer", "mlp.c_proj")

if target_layer.startswith("regex:"):
target_layer_regex = target_layer[len("regex:") :]
pattern = re.compile(target_layer_regex)

def match_predicate(module_name, module):
return pattern.search(module_name)
else:

def match_predicate(module_name, module):
return module_name.endswith(target_layer)

target_layer_hooks_map = hook_class_map.get(target_layer)
if target_layer_hooks_map is None:
raise ValueError(f"no hook classes found for: {target_layer}")

hook_class = target_layer_hooks_map.get(activation_hooks_kwargs["method"])
if hook_class is None:
raise ValueError(f"Unknown hook class: {hook_class}")

if target_layer == "block":
pattern = re.compile(r"^transformer\.h\.\d+$")

def match_predicate(module_name, module):
return pattern.match(module_name)

model,
activation_hooks_kwargs: dict,
pruning_mixin,
hook_class: Type[ActivationsHook],
) -> dict[str, ActivationsHook]:
"""Register activation hooks using the pruning mixin approach.

Args:
model: The model to register hooks on.
activation_hooks_kwargs: Keyword arguments passed to hook constructors.
pruning_mixin: The pruning mixin that defines which modules to hook.
hook_class: The hook class to instantiate for each module.

Returns:
Dictionary mapping module names to hook instances.
"""
activation_hooks_kwargs["model"] = model
for module_name, module in model.named_modules():
if match_predicate(module_name, module):
block_config = None
if block_idx_match := re.search(r"\.(\d+)\.", module_name):
block_idx = int(block_idx_match.group(1))
block_config = model.config.block_configs[block_idx]
curr_activation_hooks_kwargs = {
**activation_hooks_kwargs,
"block_config": block_config,
}

hook = hook_class(module, curr_activation_hooks_kwargs)
module.register_forward_hook(hook)
activation_hooks[module_name] = hook

return activation_hooks, hook_class
if hook_class not in pruning_mixin.supported_hooks():
raise ValueError(
f"Hook class not supported for {pruning_mixin.__class__.__name__}, "
f"must be in {pruning_mixin.supported_hooks()}"
)

module_names_to_hook = pruning_mixin.get_module_names_to_hook(model)
activation_hooks = dict()
for block_idx, module_name in module_names_to_hook:
block_config = None
if block_idx is not None:
block_config = model.config.block_configs[block_idx]
curr_activation_hooks_kwargs = {
**activation_hooks_kwargs,
"block_config": block_config,
}

module = model.get_submodule(module_name)
hook = hook_class(module, curr_activation_hooks_kwargs)
module.register_forward_hook(hook)
activation_hooks[module_name] = hook

if len(activation_hooks) == 0:
raise ValueError("couldn't find any hooks")

aprint(f"Found the following hooks: {activation_hooks.keys()}")
return activation_hooks
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,4 @@ def launch_score_activations(cfg: DictConfig):
mprint("Starting pruning activation scoring...")

# The checkpoint manager inside validate_model handles all progress tracking
validate_model(args=cfg.pruning, pipeline_parallel=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is removed?

validate_model(args=cfg.pruning)
26 changes: 14 additions & 12 deletions modelopt/torch/puzzletron/puzzletron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""This module provides the main compression function for a model using MIP-based NAS search algorithm."""

import hydra
from omegaconf import DictConfig

import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations
Expand Down Expand Up @@ -51,24 +52,25 @@ def puzzletron(
f"dataset_path={dataset_path}",
],
)
hydra_cfg = hydra.utils.instantiate(hydra_cfg)

# Step 1: score_pruning_activations (distributed processing)
score_pruning_activations.launch_score_activations(hydra_cfg)

# Step 2: pruning_ckpts (single process)
if dist.is_master():
pruning_ckpts.launch_prune_ckpt(hydra_cfg)
dist.barrier()
# # Step 2: pruning_ckpts (single process)
# if dist.is_master():
# pruning_ckpts.launch_prune_ckpt(hydra_cfg)
# dist.barrier()

# Step 4: build_library_and_stats (single process)
if dist.is_master():
build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
dist.barrier()
# # Step 4: build_library_and_stats (single process)
# if dist.is_master():
# build_library_and_stats.launch_build_library_and_stats(hydra_cfg)
# dist.barrier()

# Step 5: calc_one_block_scores (distributed processing)
scoring.launch_scoring(hydra_cfg)
# # Step 5: calc_one_block_scores (distributed processing)
# scoring.launch_scoring(hydra_cfg)

# Step 6: mip_and_realize_models (distributed processing)
mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
# # Step 6: mip_and_realize_models (distributed processing)
# mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)

return hydra_cfg
5 changes: 5 additions & 0 deletions modelopt/torch/puzzletron/tools/robust_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,13 @@ def default(self, o):
# User-defined function in main — fallback to just the name
return o.__name__
return f"{o.__module__}.{o.__qualname__}"
if inspect.isclass(o):
return f"{o.__module__}.{o.__qualname__}"
if isinstance(o, datetime.timedelta):
return str(o)
# Fallback for arbitrary objects: return their class path
if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"):
return f"{o.__class__.__module__}.{o.__class__.__qualname__}"
return super().default(o)


Expand Down
Loading
Loading