@@ -122,7 +122,8 @@ algorithm.init("Algorithmia")
122122## [ pytorch based image classification] ( examples/pytorch_image_classification )
123123<!-- embedme examples/pytorch_image_classification/src/Algorithm.py -->
124124``` python
125- from Algorithmia import client, ADK
125+ from Algorithmia import ADK
126+ import Algorithmia
126127import torch
127128from PIL import Image
128129import json
@@ -136,14 +137,10 @@ def load_labels(label_path, client):
136137 return labels
137138
138139
139- def load_model (name , model_paths , client ):
140- if name == " squeezenet" :
141- model = models.squeezenet1_1()
142- models.densenet121()
143- weights = torch.load(client.file(model_paths[" squeezenet" ]).getFile().name)
144- else :
145- model = models.alexnet()
146- weights = torch.load(client.file(model_paths[" alexnet" ]).getFile().name)
140+ def load_model (model_paths , client ):
141+ model = models.squeezenet1_1()
142+ local_file = client.file(model_paths[" filepath" ]).getFile().name
143+ weights = torch.load(local_file)
147144 model.load_state_dict(weights)
148145 return model.float().eval()
149146
@@ -177,17 +174,13 @@ def infer_image(image_url, n, globals):
177174 return result
178175
179176
180- def load ():
177+ def load (manifest ):
178+
181179 globals = {}
182- globals [" MODEL_PATHS" ] = {
183- " squeezenet" : " data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth" ,
184- " alexnet" : " data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth" ,
185- }
186- globals [" LABEL_PATHS" ] = " data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
187- globals [" CLIENT" ] = client()
180+ client = Algorithmia.client()
188181 globals [" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
189- globals [" model" ] = load_model(" squeezenet" , globals [ " MODEL_PATHS " ], globals [ " CLIENT " ] )
190- globals [" labels" ] = load_labels(globals [ " LABEL_PATHS " ], globals [ " CLIENT " ] )
182+ globals [" model" ] = load_model(manifest[ " squeezenet" ], client )
183+ globals [" labels" ] = load_labels(manifest[ " label_file " ], client )
191184 return globals
192185
193186
@@ -205,10 +198,10 @@ def apply(input, globals):
205198 row[" predictions" ] = infer_image(row[" image_url" ], n, globals )
206199 output = input [" data" ]
207200 else :
208- raise Exception (" " data" must be a image url or a list of image urls (with labels)" )
201+ raise Exception (" \ " data\ " must be a image url or a list of image urls (with labels)" )
209202 return output
210203 else :
211- raise Exception (" " data" must be defined" )
204+ raise Exception (" \ " data\ " must be defined" )
212205 else :
213206 raise Exception (" input must be a json object" )
214207
0 commit comments