Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/gpt-oss/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
kernels>=0.9.0
torch>2.7.1
trackio
transformers>=4.55.0
trl>=0.21.0
1 change: 0 additions & 1 deletion examples/llm_distill/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pyarrow
torchao>=0.14.1
transformers<5.0
trl>=0.23.0
3 changes: 1 addition & 2 deletions examples/speculative_decoding/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
accelerate==1.12.0
transformers==5.0.0rc1
transformers>=5.0
3 changes: 0 additions & 3 deletions examples/vlm_ptq/requirements-vila.txt

This file was deleted.

2 changes: 1 addition & 1 deletion modelopt/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
try:
from transformers import __version__ as _transformers_version

if not (_Version("4.53") <= _Version(_transformers_version) < _Version("5.0")):
if not (_Version("4.55") <= _Version(_transformers_version)):
_warnings.warn(
f"transformers version {_transformers_version} is not tested with nvidia-modelopt and may cause issues. "
"Please install recommended version with `pip install nvidia-modelopt[hf]` if working with HF models.",
Expand Down
14 changes: 2 additions & 12 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from typing import Any

import torch
import transformers
from packaging.version import Version
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
Expand Down Expand Up @@ -77,14 +75,6 @@
CACHED_SHARD_TTT_MASKS = {}


def _get_empty_cache(config):
"""Return an empty cache. Handle different versions of transformers for unit tests."""
if Version(transformers.__version__) >= Version("4.54"):
return DynamicCache(config=config)
else:
return DynamicCache()


@MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
class HFMedusaModel(MedusaModel):
"""Medusa Model Class for huggingface models."""
Expand Down Expand Up @@ -908,9 +898,9 @@ def forward(
)

if not isinstance(past_key_values, Cache):
past_key_values = _get_empty_cache(self._base_llm_config)
past_key_values = DynamicCache(config=self._base_llm_config)
if not isinstance(eagle_cache, Cache):
eagle_cache = _get_empty_cache(self.eagle_module.config)
eagle_cache = DynamicCache(config=self.eagle_module.config)
past_key_values.eagle_cache = eagle_cache

# ====Prepare inputs for the first eagle forward pass====
Expand Down
58 changes: 58 additions & 0 deletions modelopt/torch/trace/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

"""Utilities to describe symbols in the dynamic attention module."""

import torch
from packaging.version import Version as _Version
from torch import nn
from transformers import __version__ as _transformers_version
from transformers.models.bert.modeling_bert import BertAttention
from transformers.models.gptj.modeling_gptj import GPTJAttention

Expand Down Expand Up @@ -56,3 +59,58 @@ def get_hf_attn_sym_info_sortable(mod: nn.Module) -> SymInfo:
@SymMap.register([GPTJAttention])
def get_hf_attn_sym_info_unsortable(mod: nn.Module) -> SymInfo:
return get_hf_attn_sym_info(sortable_attn=True)


# In transformers>=5.0, BertLayer.forward uses tuple unpacking on the BertAttention output
# (e.g. `self_attn_out, _ = self.attention(...)`), which FX symbolic tracing cannot handle when
# BertAttention is a registered leaf (the proxy is not iterable). Patch BertLayer.forward to use
# indexing instead, and call feed_forward_chunk directly (equivalent to apply_chunking_to_forward
# with chunk_size=0, which is the default for BERT).
if _Version(_transformers_version) >= _Version("5.0"):
from transformers.models.bert.modeling_bert import BertLayer as _BertLayer

def _fx_friendly_bert_layer_forward(
self,
hidden_states: torch.Tensor,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
cache_position=None,
**kwargs,
):
# Use indexing instead of tuple-unpacking so FX can trace through BertLayer
# when BertAttention is a registered leaf (returns an opaque Proxy).
# Accept **kwargs so that a parent trace (e.g. BertEncoder) passing extra kwargs
# like position_ids does not mark BertLayer as failed. However, do NOT forward
# **kwargs into self.attention: FX represents **kwargs as a Proxy(_kwargs), so
# unpacking it with ** would trigger "Proxy cannot be iterated". Additionally,
# BertSelfAttention ignores these kwargs (e.g. position_ids) in practice.
_attn_outputs = self.attention(
hidden_states,
attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
)
attention_output = _attn_outputs[0]

if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with"
" cross-attention layers by setting `config.add_cross_attention=True`"
)
_cross_outputs = self.crossattention(
attention_output,
None,
encoder_hidden_states,
encoder_attention_mask,
past_key_values=past_key_values,
)
attention_output = _cross_outputs[0]

# Call feed_forward_chunk directly (equivalent to apply_chunking_to_forward when
# chunk_size_feed_forward=0, which is the BERT default).
return self.feed_forward_chunk(attention_output)

_BertLayer.forward = _fx_friendly_bert_layer_forward
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ hf = [
"nltk",
"peft>=0.17.0",
"sentencepiece>=0.2.1", # Also implicitly used in test_unified_export_megatron, test_vllm_fakequant_megatron_export
"transformers>=4.53,<5.0", # Should match modelopt/torch/__init__.py and tox.ini
"transformers>=4.55", # Should match modelopt/torch/__init__.py and tox.ini
"wonderwords",
]
dev-lint = [
Expand Down
9 changes: 2 additions & 7 deletions tests/_test_utils/torch/transformers_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import pytest
import torch
from _test_utils.torch.misc import set_seed
from packaging.version import Version

transformers = pytest.importorskip("transformers")
from transformers import (
AutoTokenizer,
BertConfig,
BertForQuestionAnswering,
GptOssConfig,
GptOssForCausalLM,
LlamaConfig,
LlamaForCausalLM,
Qwen3Config,
Expand All @@ -37,9 +38,6 @@
T5Tokenizer,
)

if Version(transformers.__version__) >= Version("4.55"):
from transformers import GptOssConfig, GptOssForCausalLM

import modelopt.torch.opt as mto

SEED = 1234
Expand Down Expand Up @@ -141,9 +139,6 @@ def get_tiny_t5(**config_kwargs) -> T5ForConditionalGeneration:

def get_tiny_gpt_oss(**config_kwargs) -> "GptOssForCausalLM":
set_seed(SEED)
if Version(transformers.__version__) < Version("4.55"):
pytest.skip("GptOssForCausalLM is not supported in transformers < 4.55")

kwargs = {
"num_hidden_layers": 4,
"num_local_experts": 8,
Expand Down
4 changes: 0 additions & 4 deletions tests/unit/torch/quantization/plugins/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pytest
import torch
from _test_utils.torch.transformers_models import get_tiny_gpt_oss, get_tiny_llama, tf_output_tester
from packaging.version import Version

pytest.importorskip("peft")
transformers = pytest.importorskip("transformers")
Expand Down Expand Up @@ -54,9 +53,6 @@ def test_convert_loralinear():
tf_output_tester(model_ref, model_test)


@pytest.mark.skipif(
Version(transformers.__version__) < Version("4.55"), reason="transformers < 4.55"
)
def test_peft_flow(tmp_path):
model_original = get_tiny_gpt_oss(num_hidden_layers=1)

Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ deps =
-e .[all,dev-test]

# Should match pyproject.toml
tf_min: transformers~=4.53.0
tf_min: transformers~=4.55.0
commands =
python -m pytest tests/unit {env:COV_ARGS:}

Expand Down
Loading