@@ -104,6 +104,7 @@ def load():
104104 # The return object from this function can be passed directly as input to your apply function.
105105 # A great example would be any model files that need to be available to this algorithm
106106 # during runtime.
107+
107108 # Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
108109 globals = {}
109110 globals [' payload' ] = " Loading has been completed."
@@ -121,53 +122,41 @@ algorithm.init("Algorithmia")
121122## [ pytorch based image classification] ( examples/pytorch_image_classification )
122123<!-- embedme examples/pytorch_image_classification/src/Algorithm.py -->
123124``` python
124- from Algorithmia import client, ADK
125+ from Algorithmia import ADK
126+ import Algorithmia
125127import torch
126128from PIL import Image
127129import json
128130from torchvision import models, transforms
129131
130- CLIENT = client()
131- SMID_ALGO = " algo://util/SmartImageDownloader/0.2.x"
132- LABEL_PATH = " data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json"
133- MODEL_PATHS = {
134- " squeezenet" : ' data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth' ,
135- ' alexnet' : ' data://AlgorithmiaSE/image_cassification_demo/alexnet-owt-4df8aa71.pth' ,
136- }
137-
138-
139- def load_labels ():
140- local_path = CLIENT .file(LABEL_PATH ).getFile().name
132+ def load_labels (label_path , client ):
133+ local_path = client.file(label_path).getFile().name
141134 with open (local_path) as f:
142135 labels = json.load(f)
143136 labels = [labels[str (k)][1 ] for k in range (len (labels))]
144137 return labels
145138
146139
147- def load_model (name ):
148- if name == " squeezenet" :
149- model = models.squeezenet1_1()
150- models.densenet121()
151- weights = torch.load(CLIENT .file(MODEL_PATHS [' squeezenet' ]).getFile().name)
152- else :
153- model = models.alexnet()
154- weights = torch.load(CLIENT .file(MODEL_PATHS [' alexnet' ]).getFile().name)
140+ def load_model (model_paths , client ):
141+ model = models.squeezenet1_1()
142+ local_file = client.file(model_paths[" filepath" ]).getFile().name
143+ weights = torch.load(local_file)
155144 model.load_state_dict(weights)
156145 return model.float().eval()
157146
158147
159- def get_image (image_url ):
160- input = {" image" : image_url, " resize" : {' width' : 224 , ' height' : 224 }}
161- result = CLIENT .algo(SMID_ALGO ).pipe(input ).result[" savePath" ][0 ]
162- local_path = CLIENT .file(result).getFile().name
148+ def get_image (image_url , smid_algo , client ):
149+ input = {" image" : image_url, " resize" : {" width" : 224 , " height" : 224 }}
150+ result = client .algo(smid_algo ).pipe(input ).result[" savePath" ][0 ]
151+ local_path = client .file(result).getFile().name
163152 img_data = Image.open(local_path)
164153 return img_data
165154
166155
167156def infer_image (image_url , n , globals ):
168- model = globals [' model' ]
169- labels = globals [' labels' ]
170- image_data = get_image(image_url)
157+ model = globals [" model" ]
158+ labels = globals [" labels" ]
159+ image_data = get_image(image_url, globals [ " SMID_ALGO " ], globals [ " CLIENT " ] )
171160 transformed = transforms.Compose([
172161 transforms.ToTensor(),
173162 transforms.Normalize(mean = [0.485 , 0.456 , 0.406 ],
@@ -185,31 +174,36 @@ def infer_image(image_url, n, globals):
185174 return result
186175
187176
188- def load ():
189- globals = {' model' : load_model(" squeezenet" ), ' labels' : load_labels()}
177+ def load (manifest ):
178+
179+ globals = {}
180+ client = Algorithmia.client()
181+ globals [" SMID_ALGO" ] = " algo://util/SmartImageDownloader/0.2.x"
182+ globals [" model" ] = load_model(manifest[" squeezenet" ], client)
183+ globals [" labels" ] = load_labels(manifest[" label_file" ], client)
190184 return globals
191185
192186
193187def apply (input , globals ):
194188 if isinstance (input , dict ):
195189 if " n" in input :
196- n = input [' n ' ]
190+ n = input [" n " ]
197191 else :
198192 n = 3
199193 if " data" in input :
200- if isinstance (input [' data' ], str ):
201- output = infer_image(input [' data' ], n, globals )
202- elif isinstance (input [' data' ], list ):
203- for row in input [' data' ]:
204- row[' predictions' ] = infer_image(row[' image_url' ], n, globals )
205- output = input [' data' ]
194+ if isinstance (input [" data" ], str ):
195+ output = infer_image(input [" data" ], n, globals )
196+ elif isinstance (input [" data" ], list ):
197+ for row in input [" data" ]:
198+ row[" predictions" ] = infer_image(row[" image_url" ], n, globals )
199+ output = input [" data" ]
206200 else :
207- raise Exception (" ' data' must be a image url or a list of image urls (with labels)" )
201+ raise Exception (" \" data\" must be a image url or a list of image urls (with labels)" )
208202 return output
209203 else :
210- raise Exception (" ' data' must be defined" )
204+ raise Exception (" \" data\" must be defined" )
211205 else :
212- raise Exception (' input must be a json object' )
206+ raise Exception (" input must be a json object" )
213207
214208
215209algorithm = ADK(apply_func = apply, load_func = load)
0 commit comments