Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions modelopt/torch/_compress/tools/validate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
the loss, and optionally registers hooks to capture the inputs and the outputs
of pytorch modules that are used for activation scoring for pruning.

TODO: Consider moving this a separate module dedicated for scoring.
TODO: Consider moving this a separate module dedicated for scoring
"""

import textwrap
Expand Down Expand Up @@ -130,11 +130,6 @@ def validate_model(
- hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None.
Returns (None, None) if not on master rank.
"""
# convert model_dtype and autocast_dtype from string to torch.dtype
if isinstance(args.model_dtype, str):
args.model_dtype = getattr(torch, args.model_dtype.strip("torch."))
if isinstance(args.autocast_dtype, str):
args.autocast_dtype = getattr(torch, args.autocast_dtype.strip("torch."))

if val_dataloader is None:
val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None
Expand Down Expand Up @@ -199,7 +194,7 @@ def validate_model(
calc_on_cpu=args.calc_losses_on_cpu,
just_model_forward=just_model_forward,
checkpoint_manager=checkpoint_manager,
autocast_dtype=args.autocast_dtype,
autocast_dtype=getattr(torch, args.autocast_dtype.strip("torch.")),
)

if losses is not None:
Expand Down Expand Up @@ -232,7 +227,7 @@ def prepare_model(
model = load_and_shard_model(
args.model_name_or_path,
model_config_overrides={"block_size": args.block_size},
model_dtype=args.model_dtype,
model_dtype=getattr(torch, args.model_dtype.strip("torch.")),
)
else:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""Validates puzzle solutions by applying layer replacements and evaluating model performance.

TODO: Consider moving this a separate module dedicated for scoring.
TODO: Consider moving this a separate module dedicated for scoring
"""

# mypy: ignore-errors
Expand All @@ -42,6 +42,7 @@
copy_tokenizer,
)
from modelopt.torch._compress.tools.checkpoint_utils_hf import (
copy_deci_lm_hf_code,
save_checkpoint,
save_safetensors_index,
)
Expand Down Expand Up @@ -182,7 +183,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None:
save_checkpoint(model, checkpoint_dir)

copy_tokenizer(args.tokenizer_name, checkpoint_dir)
copy_hf_code(checkpoint_dir)
copy_deci_lm_hf_code(checkpoint_dir)

dist.barrier()

Expand Down Expand Up @@ -246,13 +247,6 @@ def save_checkpoint_as_symlinks(
)


def copy_hf_code(checkpoint_dir: Path) -> None:
code_dir = Path(__file__).parent / "deci_lm_hf_code"
print(f"copying hf code from {code_dir} ")
for file in code_dir.glob("*.py"):
shutil.copy(file, checkpoint_dir / file.name)


def _load_tokenizer(args: DictConfig) -> PreTrainedTokenizerBase:
tokenizer = None
if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None:
Expand Down