@@ -104,6 +104,7 @@ def load():
104104 # The return object from this function can be passed directly as input to your apply function.
105105 # A great example would be any model files that need to be available to this algorithm
106106 # during runtime.
107+
107108 # Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
108109 globals = {}
109110 globals [' payload' ] = " Loading has been completed."
@@ -127,47 +128,38 @@ from PIL import Image
127128import json
128129from torchvision import models, transforms
129130
130- CLIENT = client()
131- SMID_ALGO = " algo://util/SmartImageDownloader/0.2.x"
132- LABEL_PATH = " data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
133- MODEL_PATHS = {
134- " squeezenet" : ' data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth' ,
135- ' alexnet' : ' data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth' ,
136- }
137-
138-
139- def load_labels ():
140- local_path = CLIENT .file(LABEL_PATH ).getFile().name
131+ def load_labels (label_path , client ):
132+ local_path = client.file(label_path).getFile().name
141133 with open (local_path) as f:
142134 labels = json.load(f)
143135 labels = [labels[str (k)][1 ] for k in range (len (labels))]
144136 return labels
145137
146138
147- def load_model (name ):
139+ def load_model (name , model_paths , client ):
148140 if name == " squeezenet" :
149141 model = models.squeezenet1_1()
150142 models.densenet121()
151- weights = torch.load(CLIENT .file(MODEL_PATHS [ ' squeezenet' ]).getFile().name)
143+ weights = torch.load(client .file(model_paths[ " squeezenet" ]).getFile().name)
152144 else :
153145 model = models.alexnet()
154- weights = torch.load(CLIENT .file(MODEL_PATHS [ ' alexnet' ]).getFile().name)
146+ weights = torch.load(client .file(model_paths[ " alexnet" ]).getFile().name)
155147 model.load_state_dict(weights)
156148 return model.float().eval()
157149
158150
159- def get_image (image_url ):
160- input = {" image" : image_url, " resize" : {' width' : 224 , ' height' : 224 }}
161- result = CLIENT .algo(SMID_ALGO ).pipe(input ).result[" savePath" ][0 ]
162- local_path = CLIENT .file(result).getFile().name
151+ def get_image (image_url , smid_algo , client ):
152+ input = {" image" : image_url, " resize" : {" width" : 224 , " height" : 224 }}
153+ result = client .algo(smid_algo ).pipe(input ).result[" savePath" ][0 ]
154+ local_path = client .file(result).getFile().name
163155 img_data = Image.open(local_path)
164156 return img_data
165157
166158
167159def infer_image (image_url , n , globals ):
168- model = globals [' model' ]
169- labels = globals [' labels' ]
170- image_data = get_image(image_url)
160+ model = globals [" model" ]
161+ labels = globals [" labels" ]
162+ image_data = get_image(image_url, globals [ " SMID_ALGO " ], globals [ " CLIENT " ] )
171163 transformed = transforms.Compose([
172164 transforms.ToTensor(),
173165 transforms.Normalize(mean = [0.485 , 0.456 , 0.406 ],
@@ -186,30 +178,39 @@ def infer_image(image_url, n, globals):
186178
187179
188180def load ():
189- globals = {' model' : load_model(" squeezenet" ), ' labels' : load_labels()}
181+ globals = {}
182+ globals [" MODEL_PATHS" ] = {
183+ " squeezenet" : " data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth" ,
184+ " alexnet" : " data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth" ,
185+ }
186+ globals [" LABEL_PATHS" ] = " data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
187+ globals [" CLIENT" ] = client()
188+ globals [" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
189+ globals [" model" ] = load_model(" squeezenet" , globals [" MODEL_PATHS" ], globals [" CLIENT" ])
190+ globals [" labels" ] = load_labels(globals [" LABEL_PATHS" ], globals [" CLIENT" ])
190191 return globals
191192
192193
193194def apply (input , globals ):
194195 if isinstance (input , dict ):
195196 if " n" in input :
196- n = input [' n ' ]
197+ n = input [" n " ]
197198 else :
198199 n = 3
199200 if " data" in input :
200- if isinstance (input [' data' ], str ):
201- output = infer_image(input [' data' ], n, globals )
202- elif isinstance (input [' data' ], list ):
203- for row in input [' data' ]:
204- row[' predictions' ] = infer_image(row[' image_url' ], n, globals )
205- output = input [' data' ]
201+ if isinstance (input [" data" ], str ):
202+ output = infer_image(input [" data" ], n, globals )
203+ elif isinstance (input [" data" ], list ):
204+ for row in input [" data" ]:
205+ row[" predictions" ] = infer_image(row[" image_url" ], n, globals )
206+ output = input [" data" ]
206207 else :
207- raise Exception (" ' data' must be a image url or a list of image urls (with labels)" )
208+ raise Exception (" " data" must be a image url or a list of image urls (with labels)" )
208209 return output
209210 else :
210- raise Exception (" ' data' must be defined" )
211+ raise Exception (" " data" must be defined" )
211212 else :
212- raise Exception (' input must be a json object' )
213+ raise Exception (" input must be a json object" )
213214
214215
215216algorithm = ADK(apply_func = apply, load_func = load)
0 commit comments