diff --git a/modelopt/torch/_compress/tools/validate_model.py b/modelopt/torch/_compress/tools/validate_model.py index d3d71a419..456f9fab8 100644 --- a/modelopt/torch/_compress/tools/validate_model.py +++ b/modelopt/torch/_compress/tools/validate_model.py @@ -18,7 +18,7 @@ the loss, and optionally registers hooks to capture the inputs and the outputs of pytorch modules that are used for activation scoring for pruning. -TODO: Consider moving this a separate module dedicated for scoring. +TODO: Consider moving this a separate module dedicated for scoring """ import textwrap @@ -130,11 +130,6 @@ def validate_model( - hidden_states_per_batch: Hidden states and LM head outputs if return_hidden_states is True, else None. Returns (None, None) if not on master rank. """ - # convert model_dtype and autocast_dtype from string to torch.dtype - if isinstance(args.model_dtype, str): - args.model_dtype = getattr(torch, args.model_dtype.strip("torch.")) - if isinstance(args.autocast_dtype, str): - args.autocast_dtype = getattr(torch, args.autocast_dtype.strip("torch.")) if val_dataloader is None: val_dataloader = prepare_dataloader(args, tokenizer) if dist.is_master() else None @@ -199,7 +194,7 @@ def validate_model( calc_on_cpu=args.calc_losses_on_cpu, just_model_forward=just_model_forward, checkpoint_manager=checkpoint_manager, - autocast_dtype=args.autocast_dtype, + autocast_dtype=getattr(torch, args.autocast_dtype.strip("torch.")), ) if losses is not None: @@ -232,7 +227,7 @@ def prepare_model( model = load_and_shard_model( args.model_name_or_path, model_config_overrides={"block_size": args.block_size}, - model_dtype=args.model_dtype, + model_dtype=getattr(torch, args.model_dtype.strip("torch.")), ) else: try: diff --git a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py index ca0299868..fa021640a 100644 --- a/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py +++ b/modelopt/torch/_compress/tools/validate_puzzle_with_multi_replacements.py @@ -15,7 +15,7 @@ """Validates puzzle solutions by applying layer replacements and evaluating model performance. -TODO: Consider moving this a separate module dedicated for scoring. +TODO: Consider moving this a separate module dedicated for scoring """ # mypy: ignore-errors @@ -42,6 +42,7 @@ copy_tokenizer, ) from modelopt.torch._compress.tools.checkpoint_utils_hf import ( + copy_deci_lm_hf_code, save_checkpoint, save_safetensors_index, ) @@ -182,7 +183,7 @@ def validate_puzzle_solutions(args: DictConfig) -> None: save_checkpoint(model, checkpoint_dir) copy_tokenizer(args.tokenizer_name, checkpoint_dir) - copy_hf_code(checkpoint_dir) + copy_deci_lm_hf_code(checkpoint_dir) dist.barrier() @@ -246,13 +247,6 @@ def save_checkpoint_as_symlinks( ) -def copy_hf_code(checkpoint_dir: Path) -> None: - code_dir = Path(__file__).parent / "deci_lm_hf_code" - print(f"copying hf code from {code_dir} ") - for file in code_dir.glob("*.py"): - shutil.copy(file, checkpoint_dir / file.name) - - def _load_tokenizer(args: DictConfig) -> PreTrainedTokenizerBase: tokenizer = None if (tokenizer_name := getattr(args, "tokenizer_name", None)) is not None: