Skip to content

Commit ba6cd4c

Browse files
mabundayMark Bunday
andauthored
Exclude dotfiles from model files (#313)
Co-authored-by: Mark Bunday <mabunday@amazon.com>
1 parent c4f3b84 commit ba6cd4c

File tree

2 files changed

+42
-16
lines changed

2 files changed

+42
-16
lines changed

src/sagemaker_xgboost_container/algorithm_mode/serve_utils.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,31 +130,46 @@ def parse_content_data(input_data, input_content_type):
130130
return dtest, content_type
131131

132132

133+
def _get_full_model_paths(model_dir):
134+
for data_file in os.listdir(model_dir):
135+
full_model_path = os.path.join(model_dir, data_file)
136+
if os.path.isfile(full_model_path):
137+
if data_file.startswith("."):
138+
logging.warning(
139+
f"Ignoring dotfile '{full_model_path}' found in model directory"
140+
" - please exclude dotfiles from model archives"
141+
)
142+
else:
143+
yield full_model_path
144+
145+
133146
def get_loaded_booster(model_dir, ensemble=False):
134-
model_files = [data_file for data_file in os.listdir(model_dir)
135-
if os.path.isfile(os.path.join(model_dir, data_file))]
136-
model_files = model_files if ensemble else model_files[0:1]
147+
full_model_paths = list(_get_full_model_paths(model_dir))
148+
full_model_paths = full_model_paths if ensemble else full_model_paths[0:1]
137149

138150
models = []
139-
formats = []
140-
for model_file in model_files:
141-
path = os.path.join(model_dir, model_file)
142-
logging.info(f"Loading the model from {path}")
151+
model_formats = []
152+
for full_model_path in full_model_paths:
153+
logging.info(f"Loading the model from {full_model_path}")
143154
try:
144-
booster = pkl.load(open(path, 'rb'))
145-
format = PKL_FORMAT
155+
booster = pkl.load(open(full_model_path, "rb"))
156+
model_format = PKL_FORMAT
146157
except Exception as exp_pkl:
147158
try:
148159
booster = xgb.Booster()
149-
booster.load_model(path)
150-
format = XGB_FORMAT
160+
booster.load_model(full_model_path)
161+
model_format = XGB_FORMAT
151162
except Exception as exp_xgb:
152-
raise RuntimeError("Model at {} cannot be loaded:\n{}\n{}".format(path, str(exp_pkl), str(exp_xgb)))
153-
booster.set_param('nthread', 1)
163+
raise RuntimeError(
164+
f"Model {full_model_path} cannot be loaded:"
165+
f"\nPickle load error={str(exp_pkl)}"
166+
f"\nXGB load model error={str(exp_xgb)}"
167+
)
168+
booster.set_param("nthread", 1)
154169
models.append(booster)
155-
formats.append(format)
170+
model_formats.append(model_format)
156171

157-
return (models, formats) if ensemble and len(models) > 1 else (models[0], formats[0])
172+
return (models, model_formats) if ensemble and len(models) > 1 else (models[0], model_formats[0])
158173

159174

160175
def predict(model, model_format, dtest, input_content_type, objective=None):

test/unit/algorithm_mode/test_serve_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import os
1818

19-
from mock import MagicMock
19+
from mock import MagicMock, patch
2020
import numpy as np
2121
import pytest
2222
from sagemaker_containers.record_pb2 import Record
@@ -284,3 +284,14 @@ def test_encode_predictions_as_json_empty_list():
284284
def test_encode_predictions_as_json_non_empty_list():
285285
expected_response = json.dumps({"predictions": [{"score": 0.43861907720565796}, {"score": 0.4533972144126892}]})
286286
assert expected_response == serve_utils.encode_predictions_as_json([0.43861907720565796, 0.4533972144126892])
287+
288+
289+
@patch.object(os, 'listdir')
290+
@patch.object(os.path, 'isfile')
291+
def test_get_full_model_paths(test_isfile, test_listdir):
292+
test_isfile.return_value = True
293+
mock_directory_contents = ["xgboost-model", ".DS_STORE", ".xgboost-model", "model2"]
294+
test_listdir.return_value = mock_directory_contents
295+
model_dir = "path/to/models"
296+
model_paths = serve_utils._get_full_model_paths(model_dir)
297+
assert [f"{model_dir}/xgboost-model", f"{model_dir}/model2"] == list(model_paths)

0 commit comments

Comments
 (0)