44import json
55from torchvision import models , transforms
66
7- CLIENT = client ()
8- SMID_ALGO = "algo://util/SmartImageDownloader/0.2.x"
9- LABEL_PATH = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
10- MODEL_PATHS = {
11- "squeezenet" : 'data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth' ,
12- 'alexnet' : 'data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth' ,
13- }
14-
15-
16- def load_labels ():
17- local_path = CLIENT .file (LABEL_PATH ).getFile ().name
7+ def load_labels (label_path , client ):
8+ local_path = client .file (label_path ).getFile ().name
189 with open (local_path ) as f :
1910 labels = json .load (f )
2011 labels = [labels [str (k )][1 ] for k in range (len (labels ))]
2112 return labels
2213
2314
24- def load_model (name ):
15+ def load_model (name , model_paths , client ):
2516 if name == "squeezenet" :
2617 model = models .squeezenet1_1 ()
2718 models .densenet121 ()
28- weights = torch .load (CLIENT .file (MODEL_PATHS [ ' squeezenet' ]).getFile ().name )
19+ weights = torch .load (client .file (model_paths [ " squeezenet" ]).getFile ().name )
2920 else :
3021 model = models .alexnet ()
31- weights = torch .load (CLIENT .file (MODEL_PATHS [ ' alexnet' ]).getFile ().name )
22+ weights = torch .load (client .file (model_paths [ " alexnet" ]).getFile ().name )
3223 model .load_state_dict (weights )
3324 return model .float ().eval ()
3425
3526
36- def get_image (image_url ):
37- input = {"image" : image_url , "resize" : {' width' : 224 , ' height' : 224 }}
38- result = CLIENT .algo (SMID_ALGO ).pipe (input ).result ["savePath" ][0 ]
39- local_path = CLIENT .file (result ).getFile ().name
27+ def get_image (image_url , smid_algo , client ):
28+ input = {"image" : image_url , "resize" : {" width" : 224 , " height" : 224 }}
29+ result = client .algo (smid_algo ).pipe (input ).result ["savePath" ][0 ]
30+ local_path = client .file (result ).getFile ().name
4031 img_data = Image .open (local_path )
4132 return img_data
4233
4334
4435def infer_image (image_url , n , globals ):
45- model = globals [' model' ]
46- labels = globals [' labels' ]
47- image_data = get_image (image_url )
36+ model = globals [" model" ]
37+ labels = globals [" labels" ]
38+ image_data = get_image (image_url , globals [ "SMID_ALGO" ], globals [ "CLIENT" ] )
4839 transformed = transforms .Compose ([
4940 transforms .ToTensor (),
5041 transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
@@ -63,30 +54,39 @@ def infer_image(image_url, n, globals):
6354
6455
6556def load ():
66- globals = {'model' : load_model ("squeezenet" ), 'labels' : load_labels ()}
57+ globals = {}
58+ globals ["MODEL_PATHS" ] = {
59+ "squeezenet" : "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth" ,
60+ "alexnet" : "data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth" ,
61+ }
62+ globals ["LABEL_PATHS" ] = "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
63+ globals ["CLIENT" ] = client ()
64+ globals ["SMID_ALGO" ] = "algo://util/SmartImageDownloader/0.2.x"
65+ globals ["model" ] = load_model ("squeezenet" , globals ["MODEL_PATHS" ], globals ["CLIENT" ])
66+ globals ["labels" ] = load_labels (globals ["LABEL_PATHS" ], globals ["CLIENT" ])
6767 return globals
6868
6969
7070def apply (input , globals ):
7171 if isinstance (input , dict ):
7272 if "n" in input :
73- n = input ['n' ]
73+ n = input ["n" ]
7474 else :
7575 n = 3
7676 if "data" in input :
77- if isinstance (input [' data' ], str ):
78- output = infer_image (input [' data' ], n , globals )
79- elif isinstance (input [' data' ], list ):
80- for row in input [' data' ]:
81- row [' predictions' ] = infer_image (row [' image_url' ], n , globals )
82- output = input [' data' ]
77+ if isinstance (input [" data" ], str ):
78+ output = infer_image (input [" data" ], n , globals )
79+ elif isinstance (input [" data" ], list ):
80+ for row in input [" data" ]:
81+ row [" predictions" ] = infer_image (row [" image_url" ], n , globals )
82+ output = input [" data" ]
8383 else :
84- raise Exception ("' data' must be a image url or a list of image urls (with labels)" )
84+ raise Exception ("" data " must be a image url or a list of image urls (with labels)" )
8585 return output
8686 else :
87- raise Exception ("' data' must be defined" )
87+ raise Exception ("" data " must be defined" )
8888 else :
89- raise Exception (' input must be a json object' )
89+ raise Exception (" input must be a json object" )
9090
9191
9292algorithm = ADK (apply_func = apply , load_func = load )
0 commit comments