Skip to content

Commit 7b495cf

Browse files
committed
updated readme snippets with template
1 parent d449f01 commit 7b495cf

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

README.md

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def load():
104104
# The return object from this function can be passed directly as input to your apply function.
105105
# A great example would be any model files that need to be available to this algorithm
106106
# during runtime.
107+
107108
# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
108109
globals = {}
109110
globals['payload'] = "Loading has been completed."
@@ -127,47 +128,38 @@ from PIL import Image
127128
import json
128129
from torchvision import models, transforms
129130

130-
CLIENT = client()
131-
SMID_ALGO = "algo://util/SmartImageDownloader/0.2.x"
132-
LABEL_PATH = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
133-
MODEL_PATHS = {
134-
"squeezenet": 'data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth',
135-
'alexnet': 'data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth',
136-
}
137-
138-
139-
def load_labels():
140-
local_path = CLIENT.file(LABEL_PATH).getFile().name
131+
def load_labels(label_path, client):
132+
local_path = client.file(label_path).getFile().name
141133
with open(local_path) as f:
142134
labels = json.load(f)
143135
labels = [labels[str(k)][1] for k in range(len(labels))]
144136
return labels
145137

146138

147-
def load_model(name):
139+
def load_model(name, model_paths, client):
148140
if name == "squeezenet":
149141
model = models.squeezenet1_1()
150142
models.densenet121()
151-
weights = torch.load(CLIENT.file(MODEL_PATHS['squeezenet']).getFile().name)
143+
weights = torch.load(client.file(model_paths["squeezenet"]).getFile().name)
152144
else:
153145
model = models.alexnet()
154-
weights = torch.load(CLIENT.file(MODEL_PATHS['alexnet']).getFile().name)
146+
weights = torch.load(client.file(model_paths["alexnet"]).getFile().name)
155147
model.load_state_dict(weights)
156148
return model.float().eval()
157149

158150

159-
def get_image(image_url):
160-
input = {"image": image_url, "resize": {'width': 224, 'height': 224}}
161-
result = CLIENT.algo(SMID_ALGO).pipe(input).result["savePath"][0]
162-
local_path = CLIENT.file(result).getFile().name
151+
def get_image(image_url, smid_algo, client):
152+
input = {"image": image_url, "resize": {"width": 224, "height": 224}}
153+
result = client.algo(smid_algo).pipe(input).result["savePath"][0]
154+
local_path = client.file(result).getFile().name
163155
img_data = Image.open(local_path)
164156
return img_data
165157

166158

167159
def infer_image(image_url, n, globals):
168-
model = globals['model']
169-
labels = globals['labels']
170-
image_data = get_image(image_url)
160+
model = globals["model"]
161+
labels = globals["labels"]
162+
image_data = get_image(image_url, globals["SMID_ALGO"], globals["CLIENT"])
171163
transformed = transforms.Compose([
172164
transforms.ToTensor(),
173165
transforms.Normalize(mean=[0.485, 0.456, 0.406],
@@ -186,30 +178,39 @@ def infer_image(image_url, n, globals):
186178

187179

188180
def load():
189-
globals = {'model': load_model("squeezenet"), 'labels': load_labels()}
181+
globals = {}
182+
globals["MODEL_PATHS"] = {
183+
"squeezenet": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth",
184+
"alexnet": "data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth",
185+
}
186+
globals["LABEL_PATHS"] = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
187+
globals["CLIENT"] = client()
188+
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
189+
globals["model"] = load_model("squeezenet", globals["MODEL_PATHS"], globals["CLIENT"])
190+
globals["labels"] = load_labels(globals["LABEL_PATHS"], globals["CLIENT"])
190191
return globals
191192

192193

193194
def apply(input, globals):
194195
if isinstance(input, dict):
195196
if "n" in input:
196-
n = input['n']
197+
n = input["n"]
197198
else:
198199
n = 3
199200
if "data" in input:
200-
if isinstance(input['data'], str):
201-
output = infer_image(input['data'], n, globals)
202-
elif isinstance(input['data'], list):
203-
for row in input['data']:
204-
row['predictions'] = infer_image(row['image_url'], n, globals)
205-
output = input['data']
201+
if isinstance(input["data"], str):
202+
output = infer_image(input["data"], n, globals)
203+
elif isinstance(input["data"], list):
204+
for row in input["data"]:
205+
row["predictions"] = infer_image(row["image_url"], n, globals)
206+
output = input["data"]
206207
else:
207-
raise Exception("'data' must be a image url or a list of image urls (with labels)")
208+
raise Exception(""data" must be a image url or a list of image urls (with labels)")
208209
return output
209210
else:
210-
raise Exception("'data' must be defined")
211+
raise Exception(""data" must be defined")
211212
else:
212-
raise Exception('input must be a json object')
213+
raise Exception("input must be a json object")
213214

214215

215216
algorithm = ADK(apply_func=apply, load_func=load)

0 commit comments

Comments
 (0)