@@ -11,17 +11,17 @@ use crate::compute_cap::{
1111 compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212} ;
1313use crate :: models:: {
14- BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaConfig , JinaBertModel , JinaCodeConfig , JinaCodeBertModel ,
14+ BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
1515 Model , NomicBertModel , NomicConfig ,
1616} ;
1717#[ cfg( feature = "cuda" ) ]
1818use crate :: models:: {
19- FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel , FlashNomicBertModel ,
19+ FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel ,
20+ FlashNomicBertModel ,
2021} ;
2122use anyhow:: Context ;
2223use candle:: { DType , Device } ;
2324use candle_nn:: VarBuilder ;
24- use models:: BertConfig ;
2525use nohash_hasher:: BuildNoHashHasher ;
2626use serde:: Deserialize ;
2727use std:: collections:: HashMap ;
@@ -30,17 +30,28 @@ use text_embeddings_backend_core::{
3030 Backend , BackendError , Batch , Embedding , Embeddings , ModelType , Predictions ,
3131} ;
3232
33+ /// This enum is needed to be able to differentiate between jina models that also use
34+ /// the `bert` model type and valid Bert models.
35+ /// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
36+ /// run but is still better than the other options...
37+ #[ derive( Debug , Clone , PartialEq , Deserialize ) ]
38+ #[ serde( tag = "_name_or_path" ) ]
39+ pub enum BertConfigWrapper {
40+ #[ serde( rename = "jinaai/jina-bert-implementation" ) ]
41+ JinaBert ( BertConfig ) ,
42+ #[ serde( rename = "jinaai/jina-bert-v2-qk-post-norm" ) ]
43+ JinaCodeBert ( BertConfig ) ,
44+ #[ serde( untagged) ]
45+ Bert ( BertConfig ) ,
46+ }
47+
3348#[ derive( Deserialize ) ]
3449#[ serde( tag = "model_type" , rename_all = "kebab-case" ) ]
3550enum Config {
36- Bert ( BertConfig ) ,
51+ Bert ( BertConfigWrapper ) ,
3752 XlmRoberta ( BertConfig ) ,
3853 Camembert ( BertConfig ) ,
3954 Roberta ( BertConfig ) ,
40- #[ serde( rename( deserialize = "jina_bert" ) ) ]
41- JinaBert ( JinaConfig ) ,
42- #[ serde( rename( deserialize = "jina_code_bert" ) ) ]
43- JinaCodeBert ( JinaCodeConfig ) ,
4455 #[ serde( rename( deserialize = "distilbert" ) ) ]
4556 DistilBert ( DistilBertConfig ) ,
4657 #[ serde( rename( deserialize = "nomic_bert" ) ) ]
@@ -76,7 +87,7 @@ impl CandleBackend {
7687 "Runtime compute cap {} is not compatible with compile time compute cap {}" ,
7788 get_runtime_compute_cap( ) . unwrap( ) ,
7889 get_compile_compute_cap( ) . unwrap( )
79- ) ) )
90+ ) ) ) ;
8091 }
8192 Err ( err) => {
8293 tracing:: warn!( "Could not find a compatible CUDA device on host: {err:?}" ) ;
@@ -123,18 +134,22 @@ impl CandleBackend {
123134 ( _, Device :: Cuda ( _) ) => Err ( BackendError :: Start (
124135 "`cuda` feature is not enabled" . to_string ( ) ,
125136 ) ) ,
126- ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
127- tracing:: info!( "Starting Bert model on {:?}" , device) ;
128- Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
129- }
130- ( Config :: JinaBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
131- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
132- Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
133- }
134- ( Config :: JinaCodeBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
135- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
136- Ok ( Box :: new ( JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
137- }
137+ ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => match config {
138+ BertConfigWrapper :: JinaBert ( config) => {
139+ tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
140+ Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
141+ }
142+ BertConfigWrapper :: JinaCodeBert ( config) => {
143+ tracing:: info!( "Starting JinaCodeBert model on {:?}" , device) ;
144+ Ok ( Box :: new (
145+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
146+ ) )
147+ }
148+ BertConfigWrapper :: Bert ( config) => {
149+ tracing:: info!( "Starting Bert model on {:?}" , device) ;
150+ Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
151+ }
152+ } ,
138153 (
139154 Config :: XlmRoberta ( config) | Config :: Camembert ( config) | Config :: Roberta ( config) ,
140155 Device :: Cpu | Device :: Metal ( _) ,
@@ -158,48 +173,45 @@ impl CandleBackend {
158173 ( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
159174 if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
160175 && dtype == DType :: F16
161- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
162176 // Allow disabling because of flash attention v1 precision problems
163177 // See: https://github.com/huggingface/text-embeddings-inference/issues/37
164178 && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
165179 {
166- if config. position_embedding_type == PositionEmbeddingType :: Alibi {
167- tracing:: info!( "Starting FlashBert model on {:?}" , device) ;
168- Ok ( Box :: new ( FlashBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
169- } else {
170- tracing:: info!( "Starting Bert model on {:?}" , device) ;
171- Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
180+ match config {
181+ BertConfigWrapper :: JinaBert ( config) => {
182+ tracing:: info!( "Starting FlashJinaBert model on {:?}" , device) ;
183+ Ok ( Box :: new (
184+ FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?,
185+ ) )
186+ }
187+ BertConfigWrapper :: JinaCodeBert ( config) => {
188+ tracing:: info!( "Starting FlashJinaCodeBert model on {:?}" , device) ;
189+ Ok ( Box :: new (
190+ FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
191+ ) )
192+ }
193+ BertConfigWrapper :: Bert ( config) => {
194+ tracing:: info!( "Starting FlashBert model on {:?}" , device) ;
195+ Ok ( Box :: new ( FlashBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
196+ }
172197 }
173- }
174- #[ cfg( feature = "cuda" ) ]
175- ( Config :: JinaBert ( config) , Device :: Cuda ( _) ) => {
176- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
177- && dtype == DType :: F16
178- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
179- // Allow disabling because of flash attention v1 precision problems
180- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
181- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
182- {
183- tracing:: info!( "Starting FlashJinaBertModel model on {:?}" , device) ;
184- Ok ( Box :: new ( FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?, ) )
185198 } else {
186- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
187- Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
188- }
189- #[ cfg( feature = "cuda" ) ]
190- ( Config :: JinaCodeBert ( config) , Device :: Cuda ( _) ) => {
191- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
192- && dtype == DType :: F16
193- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
194- // Allow disabling because of flash attention v1 precision problems
195- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
196- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
197- {
198- tracing:: info!( "Starting FlashJinaCodeBertModel model on {:?}" , device) ;
199- Ok ( Box :: new ( FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?, ) )
200- } else {
201- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
202- Ok ( Box :: new ( JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
199+ match config {
200+ BertConfigWrapper :: JinaBert ( config) => {
201+ tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
202+ Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
203+ }
204+ BertConfigWrapper :: JinaCodeBert ( config) => {
205+ tracing:: info!( "Starting JinaCodeBert model on {:?}" , device) ;
206+ Ok ( Box :: new (
207+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
208+ ) )
209+ }
210+ BertConfigWrapper :: Bert ( config) => {
211+ tracing:: info!( "Starting Bert model on {:?}" , device) ;
212+ Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
213+ }
214+ }
203215 }
204216 }
205217 #[ cfg( feature = "cuda" ) ]
@@ -209,7 +221,6 @@ impl CandleBackend {
209221 ) => {
210222 if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
211223 && dtype == DType :: F16
212- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
213224 // Allow disabling because of flash attention v1 precision problems
214225 // See: https://github.com/huggingface/text-embeddings-inference/issues/37
215226 && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
0 commit comments