@@ -103,11 +103,34 @@ pub fn sort_embeddings(embeddings: Embeddings) -> (Vec<Vec<f32>>, Vec<Vec<f32>>)
103103 ( pooled_embeddings, raw_embeddings)
104104}
105105
106+ #[ derive( Deserialize , PartialEq ) ]
107+ enum ModuleType {
108+ #[ serde( rename = "sentence_transformers.models.Dense" ) ]
109+ Dense ,
110+ #[ serde( rename = "sentence_transformers.models.Normalize" ) ]
111+ Normalize ,
112+ #[ serde( rename = "sentence_transformers.models.Pooling" ) ]
113+ Pooling ,
114+ #[ serde( rename = "sentence_transformers.models.Transformer" ) ]
115+ Transformer ,
116+ }
117+
118+ #[ derive( Deserialize ) ]
119+ struct ModuleConfig {
120+ #[ allow( dead_code) ]
121+ idx : usize ,
122+ #[ allow( dead_code) ]
123+ name : String ,
124+ path : String ,
125+ #[ serde( rename = "type" ) ]
126+ module_type : ModuleType ,
127+ }
128+
106129pub fn download_artifacts (
107130 model_id : & ' static str ,
108131 revision : Option < & ' static str > ,
109132 dense_path : Option < & ' static str > ,
110- ) -> Result < PathBuf > {
133+ ) -> Result < ( PathBuf , Option < Vec < String > > ) > {
111134 let mut builder = ApiBuilder :: from_env ( ) . with_progress ( false ) ;
112135
113136 if let Some ( cache_dir) = std:: env:: var_os ( "HUGGINGFACE_HUB_CACHE" ) {
@@ -142,41 +165,35 @@ pub fn download_artifacts(
142165 }
143166 } ;
144167
145- // Download dense path files if specified
146- if let Some ( dense_path) = dense_path {
147- let dense_config_path = format ! ( "{}/config.json" , dense_path) ;
148- match api_repo. get ( & dense_config_path) {
149- Ok ( _) => tracing:: info!( "Downloaded dense config: {}" , dense_config_path) ,
150- Err ( err) => tracing:: warn!(
151- "Could not download dense config {}: {}" ,
152- dense_config_path,
153- err
154- ) ,
155- }
156-
157- // Try to download dense model files (safetensors first, then pytorch)
158- let dense_safetensors_path = format ! ( "{}/model.safetensors" , dense_path) ;
159- match api_repo. get ( & dense_safetensors_path) {
160- Ok ( _) => tracing:: info!( "Downloaded dense safetensors: {}" , dense_safetensors_path) ,
161- Err ( _) => {
162- tracing:: warn!( "Dense safetensors not found. Trying pytorch_model.bin" ) ;
163- let dense_pytorch_path = format ! ( "{}/pytorch_model.bin" , dense_path) ;
164- match api_repo. get ( & dense_pytorch_path) {
165- Ok ( _) => {
166- tracing:: info!( "Downloaded dense pytorch model: {}" , dense_pytorch_path)
168+ let dense_paths = if let Ok ( modules_path) = api_repo. get ( "modules.json" ) {
169+ match parse_dense_paths_from_modules ( & modules_path) {
170+ Ok ( paths) => match paths. len ( ) {
171+ 0 => None ,
172+ 1 => {
173+ let path = if let Some ( path) = dense_path {
174+ path. to_string ( )
175+ } else {
176+ paths[ 0 ] . clone ( )
177+ } ;
178+
179+ download_dense_module ( & api_repo, & path) ?;
180+ Some ( vec ! [ path] )
181+ }
182+ _ => {
183+ for path in & paths {
184+ download_dense_module ( & api_repo, & path) ?;
167185 }
168- Err ( err) => tracing:: warn!(
169- "Could not download dense pytorch model {}: {}" ,
170- dense_pytorch_path,
171- err
172- ) ,
186+ Some ( paths)
173187 }
174- }
188+ } ,
189+ _ => None ,
175190 }
176- }
191+ } else {
192+ None
193+ } ;
177194
178195 let model_root = model_files[ 0 ] . parent ( ) . unwrap ( ) . to_path_buf ( ) ;
179- Ok ( model_root)
196+ Ok ( ( model_root, dense_paths ) )
180197}
181198
182199fn download_safetensors ( api : & ApiRepo ) -> Result < Vec < PathBuf > , ApiError > {
@@ -218,6 +235,38 @@ fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
218235 Ok ( safetensors_files)
219236}
220237
238+ fn parse_dense_paths_from_modules ( modules_path : & PathBuf ) -> Result < Vec < String > , std:: io:: Error > {
239+ let content = std:: fs:: read_to_string ( modules_path) ?;
240+ let modules: Vec < ModuleConfig > = serde_json:: from_str ( & content)
241+ . map_err ( |err| std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , err) ) ?;
242+
243+ Ok ( modules
244+ . into_iter ( )
245+ . filter ( |module| module. module_type == ModuleType :: Dense )
246+ . map ( |module| module. path )
247+ . collect :: < Vec < String > > ( ) )
248+ }
249+
250+ fn download_dense_module ( api : & ApiRepo , dense_path : & str ) -> Result < PathBuf , ApiError > {
251+ let config_file = format ! ( "{}/config.json" , dense_path) ;
252+ tracing:: info!( "Downloading `{}`" , config_file) ;
253+ let config_path = api. get ( & config_file) ?;
254+
255+ let safetensors_file = format ! ( "{}/model.safetensors" , dense_path) ;
256+ tracing:: info!( "Downloading `{}`" , safetensors_file) ;
257+ match api. get ( & safetensors_file) {
258+ Ok ( _) => { }
259+ Err ( err) => {
260+ tracing:: warn!( "Could not download `{}`: {}" , safetensors_file, err) ;
261+ let pytorch_file = format ! ( "{}/pytorch_model.bin" , dense_path) ;
262+ tracing:: info!( "Downloading `{}`" , pytorch_file) ;
263+ api. get ( & pytorch_file) ?;
264+ }
265+ }
266+
267+ Ok ( config_path. parent ( ) . unwrap ( ) . to_path_buf ( ) )
268+ }
269+
221270#[ allow( unused) ]
222271pub ( crate ) fn relative_matcher ( ) -> YamlMatcher < SnapshotScores > {
223272 YamlMatcher :: new ( )
0 commit comments