@@ -42,7 +42,10 @@ impl Qwen3Attention {
4242 "weight" ,
4343 ) ?;
4444 let query_bias = if config. attention_bias {
45- Some ( vb. pp ( "q_proj" ) . get ( hidden_size, "bias" ) ?)
45+ Some (
46+ vb. pp ( "q_proj" )
47+ . get ( num_attention_heads * attention_head_size, "bias" ) ?,
48+ )
4649 } else {
4750 None
4851 } ;
@@ -85,7 +88,7 @@ impl Qwen3Attention {
8588 let q_norm = RMSNorm :: load ( vb. pp ( "q_norm" ) , attention_head_size, config. rms_norm_eps ) ?;
8689 let k_norm = RMSNorm :: load ( vb. pp ( "k_norm" ) , attention_head_size, config. rms_norm_eps ) ?;
8790
88- let softmax_scale = ( 1. / ( attention_head_size as f64 ) . sqrt ( ) ) as f32 ;
91+ let softmax_scale = 1.0 / ( attention_head_size as f64 ) . sqrt ( ) as f32 ;
8992
9093 Ok ( Self {
9194 q_proj,
@@ -148,6 +151,28 @@ impl Qwen3Attention {
148151
149152 apply_rotary_inplace ( & q, & k, & cos, & sin, true ) ?;
150153
154+ let ( k, v) = if self . num_key_value_heads != self . num_attention_heads {
155+ if self . num_attention_heads % self . num_key_value_heads != 0 {
156+ candle:: bail!( "num_attention_heads must be a multiple of num_key_value_heads" ) ;
157+ }
158+ let repeat = self . num_attention_heads / self . num_key_value_heads ;
159+
160+ let ( total_tokens, n_kv_heads, head_dim) = k. dims3 ( ) ?;
161+
162+ let k = k
163+ . unsqueeze ( 2 ) ?
164+ . expand ( ( total_tokens, n_kv_heads, repeat, head_dim) ) ?
165+ . reshape ( ( total_tokens, n_kv_heads * repeat, head_dim) ) ?;
166+
167+ let v = v
168+ . unsqueeze ( 2 ) ?
169+ . expand ( ( total_tokens, n_kv_heads, repeat, head_dim) ) ?
170+ . reshape ( ( total_tokens, n_kv_heads * repeat, head_dim) ) ?;
171+ ( k, v)
172+ } else {
173+ ( k, v)
174+ } ;
175+
151176 let attention = flash_attn_varlen (
152177 & q,
153178 & k,
@@ -277,101 +302,20 @@ impl Qwen3Layer {
277302
278303 let mlp_output = self . mlp . forward ( & normed_attn_res_output) ?;
279304
280- Ok ( ( mlp_output, attn_res) )
281- }
282- }
283-
284- // Define ClassificationHead trait locally (following TEI pattern)
285- trait ClassificationHead {
286- fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > ;
287- }
288-
289- // Qwen3 Classification Head implementation
290- #[ derive( Debug ) ]
291- struct Qwen3ClassificationHead {
292- dense : Linear ,
293- out_proj : Linear ,
294- activation : HiddenAct ,
295- span : tracing:: Span ,
296- }
297-
298- impl Qwen3ClassificationHead {
299- pub fn load ( vb : VarBuilder , config : & Qwen3Config ) -> Result < Self > {
300- let ( dense, out_proj) = if vb. contains_tensor ( "score.dense.weight" ) {
301- tracing:: info!( "Loading Qwen3 classifier with score layers" ) ;
302-
303- let dense_weight = vb
304- . pp ( "score.dense" )
305- . get ( ( config. hidden_size , config. hidden_size ) , "weight" ) ?;
306- let dense_bias = vb. pp ( "score.dense" ) . get ( config. hidden_size , "bias" ) ?;
307- let dense = Linear :: new ( dense_weight, Some ( dense_bias) , None ) ;
308-
309- let out_proj_weight = vb
310- . pp ( "score.out_proj" )
311- . get ( ( 1 , config. hidden_size ) , "weight" ) ?;
312- let out_proj_bias = vb. pp ( "score.out_proj" ) . get ( 1 , "bias" ) ?;
313- let out_proj = Linear :: new ( out_proj_weight, Some ( out_proj_bias) , None ) ;
314-
315- ( dense, out_proj)
316- } else if vb. contains_tensor ( "classifier.dense.weight" ) {
317- tracing:: info!( "Loading Qwen3 classifier with classifier layers" ) ;
318-
319- let dense_weight = vb
320- . pp ( "classifier.dense" )
321- . get ( ( config. hidden_size , config. hidden_size ) , "weight" ) ?;
322- let dense_bias = vb. pp ( "classifier.dense" ) . get ( config. hidden_size , "bias" ) ?;
323- let dense = Linear :: new ( dense_weight, Some ( dense_bias) , None ) ;
324-
325- let out_proj_weight = vb
326- . pp ( "classifier.out_proj" )
327- . get ( ( 1 , config. hidden_size ) , "weight" ) ?;
328- let out_proj_bias = vb. pp ( "classifier.out_proj" ) . get ( 1 , "bias" ) ?;
329- let out_proj = Linear :: new ( out_proj_weight, Some ( out_proj_bias) , None ) ;
330-
331- ( dense, out_proj)
332- } else {
333- candle:: bail!(
334- "Classification layers not found in model weights. \
335- Expected 'score.dense.weight' or 'classifier.dense.weight' for reranker models. \
336- This model may not be a trained reranker."
337- ) ;
338- } ;
339-
340- Ok ( Self {
341- dense,
342- out_proj,
343- activation : config. hidden_act . clone ( ) ,
344- span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
345- } )
346- }
347- }
348-
349- impl ClassificationHead for Qwen3ClassificationHead {
350- fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
351- let _enter = self . span . enter ( ) ;
352-
353- // Input is already pooled
354-
355- // Apply dense layer with activation
356- let hidden = self . dense . forward ( hidden_states) ?;
357- let hidden = self . activation . forward ( & hidden) ?;
358-
359- // Project to single score
360- let score = self . out_proj . forward ( & hidden) ?;
361-
362- // Squeeze to remove the last dimension if it's 1
363- score. squeeze ( candle:: D :: Minus1 )
305+ let output = ( & mlp_output + & attn_res) ?;
306+ Ok ( ( output, attn_res) )
364307 }
365308}
366309
367310pub struct FlashQwen3Model {
368311 embeddings : Embedding ,
312+ lm_head_weight : Tensor ,
369313 layers : Vec < Qwen3Layer > ,
370314 norm : RMSNorm ,
371315 cos_cache : Tensor ,
372316 sin_cache : Tensor ,
317+ model_type : ModelType ,
373318 pool : Pool ,
374- classifier : Option < Box < dyn ClassificationHead + Send > > ,
375319 pub device : Device ,
376320
377321 span : tracing:: Span ,
@@ -388,19 +332,12 @@ impl FlashQwen3Model {
388332 candle:: bail!( "FlashQwen3 requires DType::F16" )
389333 }
390334
391- let ( pool, classifier ) = match model_type {
335+ let pool = match & model_type {
392336 ModelType :: Classifier => {
393- let pool = Pool :: LastToken ;
394- let classifier: Box < dyn ClassificationHead + Send > =
395- Box :: new ( Qwen3ClassificationHead :: load ( vb. clone ( ) , config) ?) ;
396- ( pool, Some ( classifier) )
397- }
398- ModelType :: Embedding ( pool) => {
399- if pool == Pool :: Splade {
400- candle:: bail!( "`splade` is not supported for Qwen3" )
401- }
402- ( pool, None )
337+ candle:: bail!( "`classifier` model type is not supported for Qwen3" )
403338 }
339+ ModelType :: Embedding ( pool) => pool. clone ( ) ,
340+ ModelType :: ListwiseReranker => Pool :: LastToken ,
404341 } ;
405342
406343 // The Qwen3-Reranker models contain the `model` key
@@ -411,11 +348,13 @@ impl FlashQwen3Model {
411348 vb
412349 } ;
413350
414- let embeddings = Embedding :: new (
415- vb. pp ( "embed_tokens" )
416- . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?,
417- config. hidden_size ,
418- ) ;
351+ let embed_weight = vb
352+ . pp ( "embed_tokens" )
353+ . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?;
354+
355+ let embeddings = Embedding :: new ( embed_weight. clone ( ) , config. hidden_size ) ;
356+
357+ let lm_head_weight = embed_weight;
419358
420359 let layers = ( 0 ..config. num_hidden_layers )
421360 . map ( |index| Qwen3Layer :: load ( vb. pp ( format ! ( "layers.{index}" ) ) , config) )
@@ -438,12 +377,13 @@ impl FlashQwen3Model {
438377
439378 Ok ( Self {
440379 embeddings,
380+ lm_head_weight,
441381 layers,
442382 norm,
443383 cos_cache,
444384 sin_cache,
385+ model_type,
445386 pool,
446- classifier,
447387 device : vb. device ( ) . clone ( ) ,
448388 span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
449389 } )
@@ -469,21 +409,19 @@ impl FlashQwen3Model {
469409 let cos = self . cos_cache . index_select ( & position_ids, 0 ) ?;
470410 let sin = self . sin_cache . index_select ( & position_ids, 0 ) ?;
471411
472- let mut residual = None ;
473412 for layer in & self . layers {
474- let ( h, r ) = layer. forward (
413+ let ( h, _r ) = layer. forward (
475414 & hidden_states,
476- residual . as_ref ( ) ,
415+ None ,
477416 & cu_seqlens,
478417 & cos,
479418 & sin,
480419 batch. max_length as usize ,
481420 ) ?;
482421 hidden_states = h;
483- residual = Some ( r) ;
484422 }
485423
486- let ( outputs, _) = self . norm . forward ( & hidden_states, residual . as_ref ( ) ) ?;
424+ let ( outputs, _) = self . norm . forward ( & hidden_states, None ) ?;
487425
488426 let has_pooling_requests = !batch. pooled_indices . is_empty ( ) ;
489427 let has_raw_requests = !batch. raw_indices . is_empty ( ) ;
@@ -553,7 +491,8 @@ impl FlashQwen3Model {
553491 // Concatenate all results
554492 Some ( Tensor :: cat ( & results?, 0 ) ?)
555493 } else {
556- Some ( ( outputs. sum_keepdim ( 0 ) ? / ( batch. max_length as f64 ) ) ?)
494+ let actual_len = batch. cumulative_seq_lengths [ 1 ] as f64 ;
495+ Some ( ( outputs. sum_keepdim ( 0 ) ? / actual_len) ?)
557496 }
558497 }
559498 Pool :: Splade => {
@@ -607,21 +546,64 @@ impl Model for FlashQwen3Model {
607546 }
608547
609548 fn predict ( & self , batch : Batch ) -> Result < Tensor > {
610- match & self . classifier {
611- None => candle:: bail!( "`predict` is not implemented for this model" ) ,
612- Some ( classifier) => {
613- // Run forward pass to get hidden states
614- let ( pooled_embeddings, _) = self . forward ( batch) ?;
615- match pooled_embeddings {
616- Some ( embeddings) => {
617- let scores = classifier. forward ( & embeddings) ?;
618- // Apply sigmoid to convert logits to probabilities
619- let probabilities = candle_nn:: ops:: sigmoid ( & scores) ?;
620- Ok ( probabilities)
621- }
622- None => candle:: bail!( "No pooled embeddings returned for classification" ) ,
549+ match & self . model_type {
550+ ModelType :: ListwiseReranker => {
551+ let _enter = self . span . enter ( ) ;
552+
553+ let batch_size = batch. cumulative_seq_lengths . len ( ) - 1 ;
554+ let shape = batch. input_ids . len ( ) ;
555+
556+ let input_ids = Tensor :: from_vec ( batch. input_ids , shape, & self . device ) ?;
557+ let position_ids = Tensor :: from_vec ( batch. position_ids , shape, & self . device ) ?;
558+ let cu_seqlens = Tensor :: from_vec (
559+ batch. cumulative_seq_lengths . clone ( ) ,
560+ batch_size + 1 ,
561+ & self . device ,
562+ ) ?;
563+
564+ let mut hidden_states = self . embeddings . forward ( & input_ids) ?;
565+
566+ let cos = self . cos_cache . index_select ( & position_ids, 0 ) ?;
567+ let sin = self . sin_cache . index_select ( & position_ids, 0 ) ?;
568+
569+ for layer in & self . layers {
570+ let ( h, _r) = layer. forward (
571+ & hidden_states,
572+ None ,
573+ & cu_seqlens,
574+ & cos,
575+ & sin,
576+ batch. max_length as usize ,
577+ ) ?;
578+ hidden_states = h;
623579 }
580+
581+ let ( outputs, _) = self . norm . forward ( & hidden_states, None ) ?;
582+
583+ let mut last_hidden_states = Vec :: with_capacity ( batch_size) ;
584+
585+ for i in 0 ..batch_size {
586+ let seq_end = batch. cumulative_seq_lengths [ i + 1 ] as usize ;
587+ let last_token_idx = seq_end - 1 ;
588+
589+ let h_last = outputs. i ( last_token_idx) ?; // [hidden_size]
590+ last_hidden_states. push ( h_last) ;
591+ }
592+
593+ let h_last = Tensor :: stack ( & last_hidden_states, 0 ) ?; // [bs, hidden_size]
594+
595+ let true_id = 9693u32 ;
596+ let false_id = 2152u32 ;
597+
598+ let ids = Tensor :: from_vec ( vec ! [ false_id, true_id] , 2 , & self . device ) ?;
599+ let w = self . lm_head_weight . index_select ( & ids, 0 ) ?; // [2, hidden_size]
600+ let logits = h_last. matmul ( & w. t ( ) ?) ?; // [bs, 2] (no, yes)
601+ let log_probs = candle_nn:: ops:: log_softmax ( & logits, D :: Minus1 ) ?;
602+ let scores = log_probs. i ( ( .., 1 ) ) ?. exp ( ) ?; // P("yes") ∈ (0,1)
603+
604+ Ok ( scores)
624605 }
606+ _ => candle:: bail!( "`predict` is only available for ModelType::ListwiseReranker" ) ,
625607 }
626608 }
627609}
0 commit comments