Skip to content

[compat] Support Qwen2-Audio with newer transformers#9453

Open
MWXGOD wants to merge 2 commits into
modelscope:mainfrom
MWXGOD:fix-qwen2-audio-transformers-compat
Open

[compat] Support Qwen2-Audio with newer transformers#9453
MWXGOD wants to merge 2 commits into
modelscope:mainfrom
MWXGOD:fix-qwen2-audio-transformers-compat

Conversation

@MWXGOD
Copy link
Copy Markdown

@MWXGOD MWXGOD commented May 30, 2026

Qwen2-Audio Compatibility with Newer Transformers for RLHF/GRPO

PR type

  • Bug fix
  • Feature enhancement
  • Documentation
  • Test

PR information

Motivation

Qwen2-Audio is currently constrained to the older transformers 4.48-era
encoding path. This is inconvenient for RLHF workflows such as GRPO because
recent trl versions require newer transformers releases.

When Qwen2-Audio is used with a newer transformers stack, the old audio
placeholder encoding path may encode <|AUDIO|> incorrectly, which can lead to
unstable training or corrupted inference outputs. This PR proposes a small,
Qwen2-Audio-only compatibility update so that SFT, inference, and GRPO/RLHF can
run in one environment.

Summary of changes

  • Encode Qwen2-Audio <|AUDIO|> contexts with the Qwen2-Audio processor instead
    of treating them as generic text-only contexts.
  • Load audio with librosa.load(..., sr=processor.feature_extractor.sampling_rate).
  • Keep the existing labels and loss_scale construction after input_ids are
    returned by the processor.
  • Keep all non-Qwen2-Audio models on the existing tokenization path.
  • Relax Qwen2-Audio requirements from transformers>=4.45,<4.49 to
    transformers>=4.48,<6.
  • Add a minimal transformers>=5.0 cache compatibility patch for generation.

Suggested implementation

Line numbers below refer to the current official snapshot and may shift slightly
after upstream changes.

1. swift/template/base.py

Near the top-level imports, around line 1-15, add:

import warnings
import librosa

Replace _encode_context_list, currently starting around line 1064, with an
audio-aware version:

def _encode_context_list(
    self,
    context_list: List[Context],
    loss_scale_list: Optional[List[float]] = None,
    audio_path_list: Optional[List[str]] = None,
) -> Tuple[List[int], List[int], List[float]]:
    is_binary_loss_scale = self.is_binary_loss_scale
    if is_binary_loss_scale is None:
        is_binary_loss_scale = self.loss_scale.is_binary_loss_scale
    input_ids: List[int] = []
    labels: List[int] = []
    loss_scale: List[float] = []
    if loss_scale_list is None:
        loss_scale_list = [0.] * len(context_list)

    audio_ptr = 0
    for context, loss_weight in zip(context_list, loss_scale_list):
        if isinstance(context, str) and '<|AUDIO|>' in context:
            if audio_path_list is None or audio_ptr >= len(audio_path_list):
                warnings.warn(
                    'Found <|AUDIO|> but no matching audio input; fallback to text tokenization',
                    RuntimeWarning)
                token_list = self._tokenize(context)
            else:
                sample_rate = self.processor.feature_extractor.sampling_rate
                wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True)
                encoded = self.processor(
                    text=context,
                    audio=wav,
                    sampling_rate=sample_rate,
                    return_tensors=None,
                    add_special_tokens=False,
                )
                token_list = encoded['input_ids']
                if len(token_list) > 0 and isinstance(token_list[0], list):
                    token_list = token_list[0]
                audio_ptr += 1
        else:
            token_list = self._tokenize(context) if isinstance(context, str) else context

        input_ids += token_list
        if loss_weight > 0.0:
            labels += token_list
        else:
            labels += [-100] * len(token_list)
        if not is_binary_loss_scale:
            loss_scale.extend([loss_weight] * len(token_list))
    if is_binary_loss_scale:
        loss_scale = None
    return input_ids, labels, loss_scale

In _encode, around line 1472 after
self._simplify_context_list(...), call the audio-aware path only for
Qwen2-Audio:

res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs)
if self.tokenizer.model_meta.model_type and self.tokenizer.model_meta.model_type == 'qwen2_audio':
    input_ids, labels, loss_scale = self._encode_context_list(
        res_context_list, loss_scale_list, inputs.audios)
else:
    input_ids, labels, loss_scale = self._encode_context_list(res_context_list, loss_scale_list)

2. swift/model/models/qwen.py

If the target branch does not already import transformers, add it near the
top-level imports:

import transformers

Replace Qwen2AudioLoader, currently starting around line 1815, with:

class Qwen2AudioLoader(ModelLoader):

    @staticmethod
    def _is_transformers5() -> bool:
        return version.parse(transformers.__version__) >= version.parse('5.0.0')

    def _patch_transformers5_model(self, model: PreTrainedModel) -> PreTrainedModel:
        if not self._is_transformers5():
            return model
        generation_config = getattr(model, 'generation_config', None)
        if generation_config is not None and hasattr(generation_config, 'cache_implementation'):
            generation_config.cache_implementation = None
        _patch_hybrid_cache_device_update()
        return model

    def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
        from transformers import Qwen2AudioForConditionalGeneration
        self.auto_model_cls = self.auto_model_cls or Qwen2AudioForConditionalGeneration
        model = super().get_model(model_dir, *args, **kwargs)
        return self._patch_transformers5_model(model)

Update the Qwen2-Audio requirement, currently around line 1835:

requires=['transformers>=4.48', 'librosa'],

Add the cache helper after the Qwen2-Audio registration and before the next
model loader class:

def _patch_hybrid_cache_device_update() -> None:
    try:
        from transformers.cache_utils import HybridCache

        def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, *args,
                   **kwargs) -> Tuple[torch.Tensor]:
            self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
            self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
            return self._update_origin(key_states, value_states, layer_idx, *args, **kwargs)

        if not hasattr(HybridCache, '_update_origin'):
            HybridCache._update_origin = HybridCache.update
            HybridCache.update = update
    except ImportError:
        pass

Compatibility

This change is intentionally scoped to Qwen2-Audio. Other text and multimodal
models should continue to use the existing encode path.

Expected supported environments:

  • transformers>=4.48: preserve the existing Qwen2-Audio baseline.
  • transformers>=5.0: support newer trl/GRPO environments with the cache
    compatibility handling above.

Experiment results

Suggested smoke tests:

  • Qwen2-Audio SFT encode: a sample with <|AUDIO|> produces valid input_ids,
    labels, and loss_scale.
  • Qwen2-Audio inference: generation is readable under both 4.48-series and newer
    transformers versions.
  • GRPO/RLHF: a recent trl version, for example trl>=0.20, can initialize
    Qwen2-Audio and start rollout generation.
  • Regression: non-Qwen2-Audio models still use the original encode path.

Checklist

  • The change is limited to Qwen2-Audio compatibility.
  • No task-specific metric, reward function, or private dataset logic is included.
  • Qwen2-Audio SFT encode has been checked with an audio sample.
  • Qwen2-Audio inference has been checked with newer transformers.
  • A GRPO/RLHF initialization or rollout smoke test has been checked with recent trl.
  • pre-commit run --all-files has been run before submission.

Notes for reviewers

This PR does not introduce task-specific evaluation, reward, or dataset code. It
only addresses Qwen2-Audio compatibility with the newer dependency stack needed
by RLHF/GRPO.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen2-Audio models, including patching for transformers v5 compatibility and handling audio tokenization in the base template. The review comments highlight critical issues: a missing transformers import in qwen.py causing a NameError, a global librosa import in base.py that introduces an unwanted hard dependency, and a potential AttributeError in the audio tokenization path for non-audio models.

Comment on lines +1817 to +1819
@staticmethod
def _is_transformers5() -> bool:
return version.parse(transformers.__version__) >= version.parse('5.0.0')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The transformers module is not imported at the top of swift/model/models/qwen.py. Calling _is_transformers5() will raise a NameError at runtime. Please import transformers locally inside the method or at the module level.

Suggested change
@staticmethod
def _is_transformers5() -> bool:
return version.parse(transformers.__version__) >= version.parse('5.0.0')
@staticmethod
def _is_transformers5() -> bool:
import transformers
return version.parse(transformers.__version__) >= version.parse('5.0.0')

Comment thread swift/template/base.py
Comment on lines 2 to 15
import hashlib
import inspect
import librosa
import math
import os
import random
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing librosa globally at the top of swift/template/base.py makes it a hard dependency for all models in the repository, even if they do not use audio features. This will cause an ImportError for users who do not have librosa installed. Please remove the global import and load librosa lazily where it is actually needed.

Suggested change
import hashlib
import inspect
import librosa
import math
import os
import random
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
import hashlib
import inspect
import math
import os
import random
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy

Comment thread swift/template/base.py
Comment on lines +1082 to +1101
for context, loss_weight in zip(context_list, loss_scale_list):
if isinstance(context, str) and '<|AUDIO|>' in context:
if audio_path_list is None or audio_ptr >= len(audio_path_list):
warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization',
RuntimeWarning)
token_list = self._tokenize(context)
else:
sample_rate = self.processor.feature_extractor.sampling_rate
wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True)
encoded = self.processor(
text=context,
audio=wav,
sampling_rate=sample_rate,
return_tensors=None,
add_special_tokens=False,
)
token_list = encoded['input_ids']
if len(token_list) > 0 and isinstance(token_list[0], list):
token_list = token_list[0]
audio_ptr += 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To prevent other models from accidentally triggering the audio processing path (which would raise an AttributeError since non-audio models do not have a feature_extractor on their processor), we should restrict this block to only run when the model type is qwen2_audio. Additionally, we should import librosa locally here to avoid making it a global hard dependency.

        for context, loss_weight in zip(context_list, loss_scale_list):
            if (isinstance(context, str)
                    and '<|AUDIO|>' in context
                    and getattr(self.tokenizer, 'model_meta', None)
                    and getattr(self.tokenizer.model_meta, 'model_type', None) == 'qwen2_audio'):
                if audio_path_list is None or audio_ptr >= len(audio_path_list):
                    warnings.warn('Found <|AUDIO|> but no matching audio input; fallback to text tokenization',
                                  RuntimeWarning)
                    token_list = self._tokenize(context)
                else:
                    import librosa
                    sample_rate = self.processor.feature_extractor.sampling_rate
                    wav, _ = librosa.load(audio_path_list[audio_ptr], sr=sample_rate, mono=True)
                    encoded = self.processor(
                        text=context,
                        audio=wav,
                        sampling_rate=sample_rate,
                        return_tensors=None,
                        add_special_tokens=False,
                    )
                    token_list = encoded['input_ids']
                    if len(token_list) > 0 and isinstance(token_list[0], list):
                        token_list = token_list[0]
                    audio_ptr += 1

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. I updated the PR to:

  1. import transformers locally in _is_transformers5;
  2. remove the global librosa import and import it lazily in the Qwen2-Audio branch;
  3. restrict the audio placeholder branch to qwen2_audio only.

@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented May 30, 2026

This is inconvenient for RLHF workflows such as GRPO because
recent trl versions require newer transformers releases.

To my knowledge, Swift's GRPO should be compatible with Transformers 4. Could you provide the specific error message?

@MWXGOD
Copy link
Copy Markdown
Author

MWXGOD commented May 30, 2026 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants