Skip to content

Commit 2de51dc

Browse files
committed
added a model_manifest example to the advanced algorithm
1 parent 7b495cf commit 2de51dc

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"label_file": {
3+
"filepath": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json",
4+
"md5_hash": "c2c37ea517e94d9795004a39431a14cb",
5+
"origin_ref": "this file came from imagenet.org",
6+
"uploaded_utc": "2021-05-03-11:05"
7+
},
8+
"squeezenet": {
9+
"filepath": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
10+
"md5_hash": "46a44d32d2c5c07f7f66324bef4c7266",
11+
"origin_ref": "From https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth",
12+
"uploaded_utc": "2021-05-03-11:05"
13+
}
14+
}

examples/pytorch_image_classification/src/Algorithm.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from Algorithmia import client, ADK
1+
from Algorithmia import ADK
2+
import Algorithmia
23
import torch
34
from PIL import Image
45
import json
@@ -12,14 +13,10 @@ def load_labels(label_path, client):
1213
return labels
1314

1415

15-
def load_model(name, model_paths, client):
16-
if name == "squeezenet":
17-
model = models.squeezenet1_1()
18-
models.densenet121()
19-
weights = torch.load(client.file(model_paths["squeezenet"]).getFile().name)
20-
else:
21-
model = models.alexnet()
22-
weights = torch.load(client.file(model_paths["alexnet"]).getFile().name)
16+
def load_model(model_paths, client):
17+
model = models.squeezenet1_1()
18+
local_file = client.file(model_paths["filepath"]).getFile().name
19+
weights = torch.load(local_file)
2320
model.load_state_dict(weights)
2421
return model.float().eval()
2522

@@ -53,17 +50,13 @@ def infer_image(image_url, n, globals):
5350
return result
5451

5552

56-
def load():
53+
def load(manifest):
54+
5755
globals = {}
58-
globals["MODEL_PATHS"] = {
59-
"squeezenet": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
60-
"alexnet": "data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth",
61-
}
62-
globals["LABEL_PATHS"] = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
63-
globals["CLIENT"] = client()
56+
client = Algorithmia.client()
6457
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
65-
globals["model"] = load_model("squeezenet", globals["MODEL_PATHS"], globals["CLIENT"])
66-
globals["labels"] = load_labels(globals["LABEL_PATHS"], globals["CLIENT"])
58+
globals["model"] = load_model(manifest["squeezenet"], client)
59+
globals["labels"] = load_labels(manifest["label_file"], client)
6760
return globals
6861

6962

@@ -81,10 +74,10 @@ def apply(input, globals):
8174
row["predictions"] = infer_image(row["image_url"], n, globals)
8275
output = input["data"]
8376
else:
84-
raise Exception(""data" must be a image url or a list of image urls (with labels)")
77+
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
8578
return output
8679
else:
87-
raise Exception(""data" must be defined")
80+
raise Exception("\"data\" must be defined")
8881
else:
8982
raise Exception("input must be a json object")
9083

0 commit comments

Comments
 (0)