From e55fff65c6afc17379cceb5bc6c6cc6eb66555da Mon Sep 17 00:00:00 2001 From: Francesco Cariaggi Date: Fri, 7 Nov 2025 17:41:39 +0100 Subject: [PATCH] Fix mel length computation in Qwen2-Audio --- .../models/qwen2_audio/modeling_qwen2_audio.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 736d67b1a2ad..d90324ef990a 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -347,9 +347,15 @@ def forward( ): r""" Args: - attention_mask (`torch.Tensor`)`, *optional*): - Qwen2Audio does not support masking of the `input_features`, this argument is preserved for compatibility, - but it is not used. By default the silence in the input log mel spectrogram are ignored. + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a + `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or + the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`), *optional*): + attention mask used in the encoder stack (after the convolutional layers). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -765,7 +771,7 @@ def forward( feature_attention_mask.sum(-1) ) batch_size, _, max_mel_seq_len = input_features.shape - max_seq_len = (max_mel_seq_len - 2) // 2 + 1 + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) seq_range = ( torch.arange(0, max_seq_len, dtype=audio_feat_lengths.dtype, device=audio_feat_lengths.device)