@@ -11,17 +11,17 @@ use crate::compute_cap::{
1111 compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212} ;
1313use crate :: models:: {
14- BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaConfig , JinaBertModel , JinaCodeConfig , JinaCodeBertModel ,
15- Model , NomicBertModel , NomicConfig ,
14+ BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
15+ JinaCodeConfig , JinaConfig , Model , NomicBertModel , NomicConfig ,
1616} ;
1717#[ cfg( feature = "cuda" ) ]
1818use crate :: models:: {
19- FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel , FlashNomicBertModel ,
19+ FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel ,
20+ FlashNomicBertModel ,
2021} ;
2122use anyhow:: Context ;
2223use candle:: { DType , Device } ;
2324use candle_nn:: VarBuilder ;
24- use models:: BertConfig ;
2525use nohash_hasher:: BuildNoHashHasher ;
2626use serde:: Deserialize ;
2727use std:: collections:: HashMap ;
@@ -133,7 +133,9 @@ impl CandleBackend {
133133 }
134134 ( Config :: JinaCodeBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
135135 tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
136- Ok ( Box :: new ( JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
136+ Ok ( Box :: new (
137+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
138+ ) )
137139 }
138140 (
139141 Config :: XlmRoberta ( config) | Config :: Camembert ( config) | Config :: Roberta ( config) ,
@@ -171,8 +173,9 @@ impl CandleBackend {
171173 Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
172174 }
173175 }
174- #[ cfg( feature = "cuda" ) ]
175- ( Config :: JinaBert ( config) , Device :: Cuda ( _) ) => {
176+ }
177+ #[ cfg( feature = "cuda" ) ]
178+ ( Config :: JinaBert ( config) , Device :: Cuda ( _) ) => {
176179 if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
177180 && dtype == DType :: F16
178181 && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
@@ -181,25 +184,32 @@ impl CandleBackend {
181184 && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
182185 {
183186 tracing:: info!( "Starting FlashJinaBertModel model on {:?}" , device) ;
184- Ok ( Box :: new ( FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?, ) )
187+ Ok ( Box :: new (
188+ FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?,
189+ ) )
185190 } else {
186191 tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
187192 Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
188193 }
189- #[ cfg( feature = "cuda" ) ]
190- ( Config :: JinaCodeBert ( config) , Device :: Cuda ( _) ) => {
191- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
192- && dtype == DType :: F16
193- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
194- // Allow disabling because of flash attention v1 precision problems
195- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
196- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
197- {
198- tracing:: info!( "Starting FlashJinaCodeBertModel model on {:?}" , device) ;
199- Ok ( Box :: new ( FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?, ) )
200- } else {
201- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
202- Ok ( Box :: new ( JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
194+ }
195+ #[ cfg( feature = "cuda" ) ]
196+ ( Config :: JinaCodeBert ( config) , Device :: Cuda ( _) ) => {
197+ if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
198+ && dtype == DType :: F16
199+ && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
200+ // Allow disabling because of flash attention v1 precision problems
201+ // See: https://github.com/huggingface/text-embeddings-inference/issues/37
202+ && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
203+ {
204+ tracing:: info!( "Starting FlashJinaCodeBertModel model on {:?}" , device) ;
205+ Ok ( Box :: new (
206+ FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
207+ ) )
208+ } else {
209+ tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
210+ Ok ( Box :: new (
211+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
212+ ) )
203213 }
204214 }
205215 #[ cfg( feature = "cuda" ) ]
0 commit comments