Skip to content

Commit d449f01

Browse files
authored
moved all global functionality into load function dictionary
1 parent 128da0a commit d449f01

File tree

1 file changed

+32
-32
lines changed

1 file changed

+32
-32
lines changed

examples/pytorch_image_classification/src/Algorithm.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,47 +4,38 @@
44
import json
55
from torchvision import models, transforms
66

7-
CLIENT = client()
8-
SMID_ALGO = "algo://util/SmartImageDownloader/0.2.x"
9-
LABEL_PATH = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
10-
MODEL_PATHS = {
11-
"squeezenet": 'data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth',
12-
'alexnet': 'data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth',
13-
}
14-
15-
16-
def load_labels():
17-
local_path = CLIENT.file(LABEL_PATH).getFile().name
7+
def load_labels(label_path, client):
8+
local_path = client.file(label_path).getFile().name
189
with open(local_path) as f:
1910
labels = json.load(f)
2011
labels = [labels[str(k)][1] for k in range(len(labels))]
2112
return labels
2213

2314

24-
def load_model(name):
15+
def load_model(name, model_paths, client):
2516
if name == "squeezenet":
2617
model = models.squeezenet1_1()
2718
models.densenet121()
28-
weights = torch.load(CLIENT.file(MODEL_PATHS['squeezenet']).getFile().name)
19+
weights = torch.load(client.file(model_paths["squeezenet"]).getFile().name)
2920
else:
3021
model = models.alexnet()
31-
weights = torch.load(CLIENT.file(MODEL_PATHS['alexnet']).getFile().name)
22+
weights = torch.load(client.file(model_paths["alexnet"]).getFile().name)
3223
model.load_state_dict(weights)
3324
return model.float().eval()
3425

3526

36-
def get_image(image_url):
37-
input = {"image": image_url, "resize": {'width': 224, 'height': 224}}
38-
result = CLIENT.algo(SMID_ALGO).pipe(input).result["savePath"][0]
39-
local_path = CLIENT.file(result).getFile().name
27+
def get_image(image_url, smid_algo, client):
28+
input = {"image": image_url, "resize": {"width": 224, "height": 224}}
29+
result = client.algo(smid_algo).pipe(input).result["savePath"][0]
30+
local_path = client.file(result).getFile().name
4031
img_data = Image.open(local_path)
4132
return img_data
4233

4334

4435
def infer_image(image_url, n, globals):
45-
model = globals['model']
46-
labels = globals['labels']
47-
image_data = get_image(image_url)
36+
model = globals["model"]
37+
labels = globals["labels"]
38+
image_data = get_image(image_url, globals["SMID_ALGO"], globals["CLIENT"])
4839
transformed = transforms.Compose([
4940
transforms.ToTensor(),
5041
transforms.Normalize(mean=[0.485, 0.456, 0.406],
@@ -63,30 +54,39 @@ def infer_image(image_url, n, globals):
6354

6455

6556
def load():
66-
globals = {'model': load_model("squeezenet"), 'labels': load_labels()}
57+
globals = {}
58+
globals["MODEL_PATHS"] = {
59+
"squeezenet": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
60+
"alexnet": "data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth",
61+
}
62+
globals["LABEL_PATHS"] = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
63+
globals["CLIENT"] = client()
64+
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
65+
globals["model"] = load_model("squeezenet", globals["MODEL_PATHS"], globals["CLIENT"])
66+
globals["labels"] = load_labels(globals["LABEL_PATHS"], globals["CLIENT"])
6767
return globals
6868

6969

7070
def apply(input, globals):
7171
if isinstance(input, dict):
7272
if "n" in input:
73-
n = input['n']
73+
n = input["n"]
7474
else:
7575
n = 3
7676
if "data" in input:
77-
if isinstance(input['data'], str):
78-
output = infer_image(input['data'], n, globals)
79-
elif isinstance(input['data'], list):
80-
for row in input['data']:
81-
row['predictions'] = infer_image(row['image_url'], n, globals)
82-
output = input['data']
77+
if isinstance(input["data"], str):
78+
output = infer_image(input["data"], n, globals)
79+
elif isinstance(input["data"], list):
80+
for row in input["data"]:
81+
row["predictions"] = infer_image(row["image_url"], n, globals)
82+
output = input["data"]
8383
else:
84-
raise Exception("'data' must be a image url or a list of image urls (with labels)")
84+
raise Exception(""data" must be a image url or a list of image urls (with labels)")
8585
return output
8686
else:
87-
raise Exception("'data' must be defined")
87+
raise Exception(""data" must be defined")
8888
else:
89-
raise Exception('input must be a json object')
89+
raise Exception("input must be a json object")
9090

9191

9292
algorithm = ADK(apply_func=apply, load_func=load)

0 commit comments

Comments
 (0)