@@ -13,13 +13,13 @@ use crate::compute_cap::{
1313use crate :: models:: {
1414 BertConfig , BertModel , DistilBertConfig , DistilBertModel , GTEConfig , GTEModel , JinaBertModel ,
1515 JinaCodeBertModel , MPNetConfig , MPNetModel , MistralConfig , Model , ModernBertConfig ,
16- ModernBertModel , NomicBertModel , NomicConfig , Qwen2Config ,
16+ ModernBertModel , NomicBertModel , NomicConfig , Qwen2Config , Qwen3Config ,
1717} ;
1818#[ cfg( feature = "cuda" ) ]
1919use crate :: models:: {
2020 FlashBertModel , FlashDistilBertModel , FlashGTEModel , FlashJinaBertModel ,
2121 FlashJinaCodeBertModel , FlashMistralModel , FlashModernBertModel , FlashNomicBertModel ,
22- FlashQwen2Model ,
22+ FlashQwen2Model , FlashQwen3Model ,
2323} ;
2424use anyhow:: Context ;
2525use candle:: { DType , Device } ;
@@ -103,6 +103,8 @@ enum Config {
103103 Gte ( GTEConfig ) ,
104104 #[ allow( dead_code) ]
105105 Qwen2 ( Qwen2Config ) ,
106+ #[ allow( dead_code) ]
107+ Qwen3 ( Qwen3Config ) ,
106108 #[ serde( rename = "mpnet" ) ]
107109 MPNet ( MPNetConfig ) ,
108110 #[ serde( rename( deserialize = "modernbert" ) ) ]
@@ -273,6 +275,10 @@ impl CandleBackend {
273275 "Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
274276 . to_string ( ) ,
275277 ) ) ,
278+ ( Config :: Qwen3 ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
279+ "Qwen3 is only supported on Cuda devices in fp16 with flash attention enabled"
280+ . to_string ( ) ,
281+ ) ) ,
276282 ( Config :: MPNet ( config) , _) => {
277283 tracing:: info!( "Starting MPNet model on {:?}" , device) ;
278284 Ok ( Box :: new ( MPNetModel :: load ( vb, & config, model_type) . s ( ) ?) )
@@ -446,6 +452,18 @@ impl CandleBackend {
446452 FlashQwen2Model :: load ( vb, & config, model_type) . s ( ) ?,
447453 ) )
448454 }
455+ #[ cfg( feature = "cuda" ) ]
456+ ( Config :: Qwen3 ( config) , Device :: Cuda ( _) ) => {
457+ if dtype != DType :: F16
458+ || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
459+ {
460+ return Err ( BackendError :: Start ( "Qwen3 is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
461+ }
462+ tracing:: info!( "Starting FlashQwen3 model on {:?}" , device) ;
463+ Ok ( Box :: new (
464+ FlashQwen3Model :: load ( vb, & config, model_type) . s ( ) ?,
465+ ) )
466+ }
449467 } ;
450468
451469 Ok ( Self {
0 commit comments