@@ -130,31 +130,46 @@ def parse_content_data(input_data, input_content_type):
130130 return dtest , content_type
131131
132132
133+ def _get_full_model_paths (model_dir ):
134+ for data_file in os .listdir (model_dir ):
135+ full_model_path = os .path .join (model_dir , data_file )
136+ if os .path .isfile (full_model_path ):
137+ if data_file .startswith ("." ):
138+ logging .warning (
139+ f"Ignoring dotfile '{ full_model_path } ' found in model directory"
140+ " - please exclude dotfiles from model archives"
141+ )
142+ else :
143+ yield full_model_path
144+
145+
133146def get_loaded_booster (model_dir , ensemble = False ):
134- model_files = [data_file for data_file in os .listdir (model_dir )
135- if os .path .isfile (os .path .join (model_dir , data_file ))]
136- model_files = model_files if ensemble else model_files [0 :1 ]
147+ full_model_paths = list (_get_full_model_paths (model_dir ))
148+ full_model_paths = full_model_paths if ensemble else full_model_paths [0 :1 ]
137149
138150 models = []
139- formats = []
140- for model_file in model_files :
141- path = os .path .join (model_dir , model_file )
142- logging .info (f"Loading the model from { path } " )
151+ model_formats = []
152+ for full_model_path in full_model_paths :
153+ logging .info (f"Loading the model from { full_model_path } " )
143154 try :
144- booster = pkl .load (open (path , 'rb' ))
145- format = PKL_FORMAT
155+ booster = pkl .load (open (full_model_path , "rb" ))
156+ model_format = PKL_FORMAT
146157 except Exception as exp_pkl :
147158 try :
148159 booster = xgb .Booster ()
149- booster .load_model (path )
150- format = XGB_FORMAT
160+ booster .load_model (full_model_path )
161+ model_format = XGB_FORMAT
151162 except Exception as exp_xgb :
152- raise RuntimeError ("Model at {} cannot be loaded:\n {}\n {}" .format (path , str (exp_pkl ), str (exp_xgb )))
153- booster .set_param ('nthread' , 1 )
163+ raise RuntimeError (
164+ f"Model { full_model_path } cannot be loaded:"
165+ f"\n Pickle load error={ str (exp_pkl )} "
166+ f"\n XGB load model error={ str (exp_xgb )} "
167+ )
168+ booster .set_param ("nthread" , 1 )
154169 models .append (booster )
155- formats .append (format )
170+ model_formats .append (model_format )
156171
157- return (models , formats ) if ensemble and len (models ) > 1 else (models [0 ], formats [0 ])
172+ return (models , model_formats ) if ensemble and len (models ) > 1 else (models [0 ], model_formats [0 ])
158173
159174
160175def predict (model , model_format , dtest , input_content_type , objective = None ):
0 commit comments