1- from Algorithmia import client , ADK
1+ from Algorithmia import ADK
2+ import Algorithmia
23import torch
34from PIL import Image
45import json
@@ -12,14 +13,10 @@ def load_labels(label_path, client):
1213 return labels
1314
1415
15- def load_model (name , model_paths , client ):
16- if name == "squeezenet" :
17- model = models .squeezenet1_1 ()
18- models .densenet121 ()
19- weights = torch .load (client .file (model_paths ["squeezenet" ]).getFile ().name )
20- else :
21- model = models .alexnet ()
22- weights = torch .load (client .file (model_paths ["alexnet" ]).getFile ().name )
16+ def load_model (model_paths , client ):
17+ model = models .squeezenet1_1 ()
18+ local_file = client .file (model_paths ["filepath" ]).getFile ().name
19+ weights = torch .load (local_file )
2320 model .load_state_dict (weights )
2421 return model .float ().eval ()
2522
@@ -53,17 +50,13 @@ def infer_image(image_url, n, globals):
5350 return result
5451
5552
56- def load ():
53+ def load (manifest ):
54+
5755 globals = {}
58- globals ["MODEL_PATHS" ] = {
59- "squeezenet" : "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth" ,
60- "alexnet" : "data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth" ,
61- }
62- globals ["LABEL_PATHS" ] = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
63- globals ["CLIENT" ] = client ()
56+ client = Algorithmia .client ()
6457 globals ["SMID_ALGO" ] = "algo://util/SmartImageDownloader/0.2.x"
65- globals ["model" ] = load_model ("squeezenet" , globals [ "MODEL_PATHS" ], globals [ "CLIENT" ] )
66- globals ["labels" ] = load_labels (globals [ "LABEL_PATHS " ], globals [ "CLIENT" ] )
58+ globals ["model" ] = load_model (manifest [ "squeezenet" ], client )
59+ globals ["labels" ] = load_labels (manifest [ "label_file " ], client )
6760 return globals
6861
6962
@@ -81,10 +74,10 @@ def apply(input, globals):
8174 row ["predictions" ] = infer_image (row ["image_url" ], n , globals )
8275 output = input ["data" ]
8376 else :
84- raise Exception ("" data " must be a image url or a list of image urls (with labels)" )
77+ raise Exception ("\ " data\ " must be a image url or a list of image urls (with labels)" )
8578 return output
8679 else :
87- raise Exception ("" data " must be defined" )
80+ raise Exception ("\ " data\ " must be defined" )
8881 else :
8982 raise Exception ("input must be a json object" )
9083
0 commit comments