From 79d8da9b1c1be4e777736b907d3605a24378a47e Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 16 Aug 2025 22:24:53 +0900 Subject: [PATCH] fix: wrong pooling method for CUDA --- backends/candle/src/models/flash_qwen3.rs | 29 +++++++++++++---------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/backends/candle/src/models/flash_qwen3.rs b/backends/candle/src/models/flash_qwen3.rs index 48c3a0c2..5fba7665 100644 --- a/backends/candle/src/models/flash_qwen3.rs +++ b/backends/candle/src/models/flash_qwen3.rs @@ -1,6 +1,7 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, Linear, RMSNorm}; -use crate::models::{Model, Qwen3Config, Qwen3ClassificationHead}; +use crate::models::{Model, Qwen3Config}; +use crate::models::qwen3::Qwen3ClassificationHead; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use candle_rotary::apply_rotary_inplace; @@ -309,7 +310,7 @@ impl FlashQwen3Model { ModelType::Classifier => { // Load classification head before the vb is modified let classification_head = Some(Qwen3ClassificationHead::load(vb.clone(), config)?); - (Pool::Cls, classification_head) // Use CLS pooling for classification + (Pool::LastToken, classification_head) // Use LastToken pooling for classification } ModelType::Embedding(pool) => (pool, None), }; @@ -404,8 +405,8 @@ impl FlashQwen3Model { // CLS and LastToken pooling Pool::Cls | Pool::LastToken => { if batch_size > 1 { - // Get token indices form cu_seqlens - let mut indices = match self.pool { + // Get token indices for each sequence + let all_indices = match self.pool { Pool::Cls => cu_seqlens.narrow(0, 0, batch_size)?, Pool::LastToken => { let end = cu_seqlens.narrow(0, 1, batch_size)?; @@ -414,19 +415,21 @@ impl FlashQwen3Model { _ => unreachable!(), }; - // If raw_indices is empty, we don't need to do anything with - // the pooled_indices - if has_raw_requests { - // We need the pooled indices to select the correct cls indices + // Select the appropriate indices based on pooled_indices + let indices = if has_raw_requests { + // Select only the sequences that need pooling + let pooled_indices_vec: Vec = batch.pooled_indices.iter() + .map(|&idx| idx as i64) + .collect(); let pooled_indices = Tensor::from_vec( - batch.pooled_indices.clone(), + pooled_indices_vec, batch.pooled_indices.len(), &self.device, )?; - - // Only select indices that requires pooling - indices = indices.index_select(&pooled_indices, 0)? - } + all_indices.index_select(&pooled_indices, 0)? + } else { + all_indices + }; // Select tokens Some(outputs.index_select(&indices, 0)?)