diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 1760cb2072..0d8295e4c7 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -61,13 +61,13 @@ from ..dflash.conversion import DFlashDMRegistry from ..dflash.dflash_model import DFlashModel +from .hf_spec_mixin import HFSpecDecMixin from .modeling_dflash import ( # noqa: F401 DFlashAttention, DFlashBaseModelOutput, DFlashModule, build_target_layer_ids, ) -from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS logger = logging.getLogger(__name__) @@ -75,50 +75,9 @@ @DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) -class HFDFlashModel(DFlashModel): +class HFDFlashModel(HFSpecDecMixin, DFlashModel): """DFlash Model for HuggingFace transformers.""" - @property - def _base_model(self): - return self.get_submodule(self.base_model_path) - - @property - def _base_model_embeddings(self): - return self.get_submodule(self.base_model_embeddings_path) - - @property - def _base_model_lm_head(self): - return self.get_submodule(self.base_model_lm_head_path) - - @property - def _base_llm_config(self): - return ( - getattr(self.config, "text_config", None) - or getattr(self.config, "llm_config", None) - or self.config - ) - - def _find_base_model_parts(self): - """Locate base model submodules (backbone, embeddings, lm_head) by probing known paths. - - Reuses the shared path constants from modeling_fakebase (same as EAGLE). - """ - for name, paths in { - "base_model_path": _BASE_MODEL_PATHS, - "base_model_embeddings_path": _EMBED_TOKENS_PATHS, - "base_model_lm_head_path": _LM_HEAD_PATHS, - }.items(): - for path in paths: - try: - submodule = self.get_submodule(path) - assert isinstance(submodule, torch.nn.Module) - setattr(self, name, path) - break - except Exception: - continue - else: - raise ValueError(f"Part {name} not found in model") - def modify(self, config): """Initialize DFlash draft module.""" super().modify(config) @@ -178,8 +137,6 @@ def modify(self, config): for param in self.parameters(): param.requires_grad = False - self._find_base_model_parts() - self.dflash_module = DFlashModule(self.dflash_config) # Match base model dtype/device. Skip if base is on meta (during from_pretrained # restore — the model will be moved to the correct device after weight loading). @@ -203,6 +160,14 @@ def get_exporter(self): return DFlashExporter(self) + def get_dummy_inputs(self) -> dict: + """Not yet implemented for DFlash.""" + raise NotImplementedError( + "HFDFlashModel.get_dummy_inputs() is not yet implemented. " + "Required by unified HF quantization export (modelopt.torch.export.unified_export_hf). " + "Implement to enable quantization-export of DFlash speculative decoding models." + ) + def _sample_anchor_positions(self, seq_len, loss_mask, device): """Randomly sample anchor positions per sample. diff --git a/modelopt/torch/speculative/plugins/hf_eagle.py b/modelopt/torch/speculative/plugins/hf_eagle.py index f2040d9d96..b7facc833d 100644 --- a/modelopt/torch/speculative/plugins/hf_eagle.py +++ b/modelopt/torch/speculative/plugins/hf_eagle.py @@ -36,8 +36,8 @@ get_ttt_msk_func, temporary_set_config_value, ) +from .hf_spec_mixin import HFSpecDecMixin from .modeling_eagle import EagleBaseModelOutput, EagleModule -from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS __all__ = ["HFARValidation", "HFEagleModel", "default_eagle_aux_layer_ids"] @@ -55,80 +55,14 @@ def default_eagle_aux_layer_ids(num_layers: int) -> list[int]: @EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) -class HFEagleModel(EagleModel): +class HFEagleModel(HFSpecDecMixin, EagleModel): """Eagle Model Class for huggingface models.""" - @property - def _base_model(self): - return self.get_submodule(self.base_model_path) - - @property - def _base_model_embeddings(self): - return self.get_submodule(self.base_model_embeddings_path) - - @property - def _base_model_lm_head(self): - return self.get_submodule(self.base_model_lm_head_path) - - @property - def _base_llm_config(self): - """Return the llm config for the base model, from LLM or VLM.""" - return ( - getattr(self.config, "text_config", None) - or getattr(self.config, "llm_config", None) - or self.config - ) - - def _nvtx_range(self, name): - """Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set.""" - if not self.eagle_enable_nvtx: - return contextlib.nullcontext() - try: - import torch.cuda.nvtx as nvtx - - return nvtx.range(name) - except Exception as e: - print(f"Failed to create NVTX range {name}: {e}") - return contextlib.nullcontext() - - def _find_base_model_parts(self): - """Find model parts from different models and set base_{part}_path attributes.""" - base_model_parts_mapping = { - "base_model_path": _BASE_MODEL_PATHS, - "base_model_embeddings_path": _EMBED_TOKENS_PATHS, - "base_model_lm_head_path": _LM_HEAD_PATHS, - } - - for name, paths in base_model_parts_mapping.items(): - found_submodule = False - for path in paths: - try: - submodule = self.get_submodule(path) - assert isinstance(submodule, torch.nn.Module) - print(f"Found {name} at {path}") - found_submodule = True - setattr(self, name, path) - break - except Exception: - continue - if not found_submodule: - raise ValueError(f"Part {name} not found in model") - - def _activate_torch_compile(self): - import torch._dynamo - - torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode - - compile_targets = [ - ("_prepare_eagle_inputs", {}), - ("_eagle_forward", {"mode": "max-autotune"}), - ("_eagle_loss", {"fullgraph": True}), - ] - for name, kwargs in compile_targets: - try: - setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) - except Exception: # noqa: PERF203 - print(f"Disabling torch.compile for {name} due to compilation error.") + _compile_targets = [ + ("_prepare_eagle_inputs", {}), + ("_eagle_forward", {"mode": "max-autotune"}), + ("_eagle_loss", {"fullgraph": True}), + ] def get_dummy_inputs(self) -> dict: """Construct dummy inputs for export forward pass.""" @@ -290,6 +224,9 @@ def modify( if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" + # Mixin interface attribute + self._enable_nvtx = self.eagle_enable_nvtx + # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state @@ -307,8 +244,6 @@ def modify( decoder_cls, ) - # find base model, lm head, and embeddings paths - self._find_base_model_parts() self.eagle_module.to(self._base_model.dtype).to(self._get_eagle_device()) # EAGLE-3 auxiliary hidden_states diff --git a/modelopt/torch/speculative/plugins/hf_spec_mixin.py b/modelopt/torch/speculative/plugins/hf_spec_mixin.py new file mode 100644 index 0000000000..f02d687700 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_spec_mixin.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared mixin for HuggingFace speculative decoding model classes.""" + +import contextlib +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, TypeAlias + +import torch + +from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + # At type-check time, pretend the mixin is an nn.Module so attribute access + # like self.get_submodule(...) type-checks. At runtime it remains a plain + # mixin (object) and gets nn.Module via the sibling base class in the MRO. + _Host: TypeAlias = torch.nn.Module +else: + _Host = object + +__all__ = ["HFSpecDecMixin"] + + +class HFSpecDecMixin(_Host, ABC): + """Mixin providing HuggingFace base-model discovery for speculative decoding plugins. + + Provides shared properties and methods for locating base-model submodules + (backbone, embeddings, lm_head), plus NVTX profiling and torch.compile helpers. + + Must be used with multiple inheritance alongside an algorithm-specific base + (EagleModel, DFlashModel, etc.) that inherits from DynamicModule. + + Lifecycle: + Base-model paths are discovered automatically inside ``modify()`` via the + MRO hook below — subclasses only need to call ``super().modify(config)`` + and the ``_base_model`` / ``_base_model_embeddings`` / ``_base_model_lm_head`` + properties are ready to use in the rest of the subclass's ``modify()`` body. + + Example:: + + @EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) + class HFEagleModel(HFSpecDecMixin, EagleModel): ... + """ + + # -- Host-supplied attributes (declared for type checkers) -- + + # Provided by the host (e.g., PreTrainedModel.config). + config: Any + # Set by ``_find_base_model_parts()``. + base_model_path: str + base_model_embeddings_path: str + base_model_lm_head_path: str + + # -- Class attributes (subclasses may override) -- + + # List of (method_name, compile_kwargs) for _activate_torch_compile(). + # Example: [("_eagle_forward", {"mode": "max-autotune"}), ("_eagle_loss", {"fullgraph": True})] + _compile_targets: list[tuple[str, dict]] = [] + + # Set to True in subclass ``modify()`` to enable NVTX ranges. + _enable_nvtx: bool = False + + # -- Properties: base model access -- + + @property + def _base_model(self) -> torch.nn.Module: + return self.get_submodule(self.base_model_path) + + @property + def _base_model_embeddings(self) -> torch.nn.Module: + return self.get_submodule(self.base_model_embeddings_path) + + @property + def _base_model_lm_head(self) -> torch.nn.Module: + return self.get_submodule(self.base_model_lm_head_path) + + @property + def _base_llm_config(self): + """Return the LLM config for the base model, handling VLM nesting.""" + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) + + # -- Lifecycle hook -- + + def modify(self, config): + """Run base-class ``modify``, then auto-discover base-model paths. + + Subclasses only need to call ``super().modify(config)`` first; the base-model + properties are then ready to use in the rest of the subclass's ``modify()`` body. + """ + super().modify(config) + self._find_base_model_parts() + + # -- Methods: model discovery -- + + def _find_base_model_parts(self): + """Find model parts from different models and set base_{part}_path attributes. + + Iterates over candidate submodule paths from modeling_fakebase to locate the + base model backbone, embedding layer, and LM head. + + Raises: + ValueError: If any required model part cannot be found. + """ + for name, paths in { + "base_model_path": _BASE_MODEL_PATHS, + "base_model_embeddings_path": _EMBED_TOKENS_PATHS, + "base_model_lm_head_path": _LM_HEAD_PATHS, + }.items(): + for path in paths: + try: + self.get_submodule(path) + setattr(self, name, path) + logger.debug("Found %s at %s", name, path) + break + except Exception: + continue + else: + raise ValueError(f"Part {name} not found in model") + + # -- Methods: profiling & compilation -- + + def _nvtx_range(self, name): + """Optionally create an NVTX range for profiling. + + Enabled when the subclass sets ``self._enable_nvtx = True`` in ``modify()``. + """ + if not self._enable_nvtx: + return contextlib.nullcontext() + try: + import torch.cuda.nvtx as nvtx + + return nvtx.range(name) + except Exception as e: + print(f"Failed to create NVTX range {name}: {e}") + return contextlib.nullcontext() + + def _activate_torch_compile(self): + """Apply ``torch.compile`` to methods listed in ``_compile_targets``. + + Each entry is ``(method_name, extra_kwargs)`` passed to ``torch.compile(..., dynamic=False)``. + Failures fall back to eager mode silently. + """ + import torch._dynamo + + torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode + + for name, kwargs in self._compile_targets: + try: + setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) + except Exception: # noqa: PERF203 + print(f"Disabling torch.compile for {name} due to compilation error.") + + # -- Required interface -- + + @abstractmethod + def get_exporter(self): + """Return the exporter for the draft model.""" + + @abstractmethod + def get_dummy_inputs(self) -> dict: + """Construct dummy inputs for the export forward pass. + + Used by unified HF quantization export to drive a fake forward when the + model's ``forward`` signature is non-standard (e.g. takes ``base_model_outputs``). + Subclasses that don't yet support this path should raise ``NotImplementedError`` + with a clear message so callers fail loudly rather than silently. + """ diff --git a/modelopt/torch/speculative/plugins/modeling_fakebase.py b/modelopt/torch/speculative/plugins/modeling_fakebase.py index 4ed06ed649..1f0151e361 100644 --- a/modelopt/torch/speculative/plugins/modeling_fakebase.py +++ b/modelopt/torch/speculative/plugins/modeling_fakebase.py @@ -32,7 +32,7 @@ PreTrainedModel, ) -# Candidate module paths searched in order — shared with HFEagleModel._find_base_model_parts +# Candidate module paths searched in order — used by HFSpecDecMixin._find_base_model_parts _EMBED_TOKENS_PATHS = [ "embed_tokens", "language_model.model.embed_tokens",