1- import base64
21import inspect
32import json
43import os
54import sys
6- import traceback
7- import six
5+ import Algorithmia
6+ from adk .io import create_exception , format_data , format_response
7+ from adk .manifest .modeldata import ModelData
88
99
1010class ADK (object ):
11- def __init__ (self , apply_func , load_func = None ):
11+ def __init__ (self , apply_func , load_func = None , client = None ):
1212 """
1313 Creates the adk object
1414 :param apply_func: A required function that can have an arity of 1-2, depending on if loading occurs
15- :param load_func: An optional supplier function used if load time events are required, has an arity of 0.
15+ :param load_func: An optional supplier function used if load time events are required, if a model manifest is provided;
16+ the function may have a single `manifest` parameter to interact with the model manifest, otherwise must have no parameters.
17+ :param client: A Algorithmia Client instance that might be user defined,
18+ and is used for interacting with a model manifest file; if defined.
1619 """
1720 self .FIFO_PATH = "/tmp/algoout"
21+
22+ if client :
23+ self .client = client
24+ else :
25+ self .client = Algorithmia .client ()
26+
1827 apply_args , _ , _ , _ , _ , _ , _ = inspect .getfullargspec (apply_func )
28+ self .apply_arity = len (apply_args )
1929 if load_func :
2030 load_args , _ , _ , _ , _ , _ , _ = inspect .getfullargspec (load_func )
21- if len (load_args ) > 0 :
22- raise Exception ("load function must not have parameters" )
31+ self .load_arity = len (load_args )
32+ if self .load_arity != 1 :
33+ raise Exception ("load function expects 1 parameter to be used to store algorithm state" )
2334 self .load_func = load_func
2435 else :
2536 self .load_func = None
26- if len (apply_args ) > 2 or len (apply_args ) == 0 :
27- raise Exception ("apply function may have between 1 and 2 parameters, not {}" .format (len (apply_args )))
2837 self .apply_func = apply_func
2938 self .is_local = not os .path .exists (self .FIFO_PATH )
3039 self .load_result = None
3140 self .loading_exception = None
41+ self .manifest_path = "model_manifest.json.freeze"
42+ self .model_data = self .init_manifest (self .manifest_path )
43+
44+ def init_manifest (self , path ):
45+ return ModelData (self .client , path )
3246
3347 def load (self ):
3448 try :
49+ if self .model_data .available ():
50+ self .model_data .initialize ()
3551 if self .load_func :
36- self .load_result = self .load_func ()
52+ self .load_result = self .load_func (self . model_data )
3753 except Exception as e :
3854 self .loading_exception = e
3955 finally :
@@ -45,55 +61,16 @@ def load(self):
4561
4662 def apply (self , payload ):
4763 try :
48- if self .load_result :
64+ if self .load_result and self . apply_arity == 2 :
4965 apply_result = self .apply_func (payload , self .load_result )
5066 else :
5167 apply_result = self .apply_func (payload )
52- response_obj = self . format_response (apply_result )
68+ response_obj = format_response (apply_result )
5369 return response_obj
5470 except Exception as e :
55- response_obj = self . create_exception (e )
71+ response_obj = create_exception (e )
5672 return response_obj
5773
58- def format_data (self , request ):
59- if request ["content_type" ] in ["text" , "json" ]:
60- data = request ["data" ]
61- elif request ["content_type" ] == "binary" :
62- data = self .wrap_binary_data (base64 .b64decode (request ["data" ]))
63- else :
64- raise Exception ("Invalid content_type: {}" .format (request ["content_type" ]))
65- return data
66-
67- def is_binary (self , arg ):
68- if six .PY3 :
69- return isinstance (arg , base64 .bytes_types )
70-
71- return isinstance (arg , bytearray )
72-
73- def wrap_binary_data (self , data ):
74- if six .PY3 :
75- return bytes (data )
76- else :
77- return bytearray (data )
78-
79- def format_response (self , response ):
80- if self .is_binary (response ):
81- content_type = "binary"
82- response = str (base64 .b64encode (response ), "utf-8" )
83- elif isinstance (response , six .string_types ) or isinstance (response , six .text_type ):
84- content_type = "text"
85- else :
86- content_type = "json"
87- response_string = json .dumps (
88- {
89- "result" : response ,
90- "metadata" : {
91- "content_type" : content_type
92- }
93- }
94- )
95- return response_string
96-
9774 def write_to_pipe (self , payload , pprint = print ):
9875 if self .is_local :
9976 if isinstance (payload , dict ):
@@ -109,40 +86,24 @@ def write_to_pipe(self, payload, pprint=print):
10986 if os .name == "nt" :
11087 sys .stdin = payload
11188
112- def create_exception (self , exception , loading_exception = False ):
113- if hasattr (exception , "error_type" ):
114- error_type = exception .error_type
115- elif loading_exception :
116- error_type = "LoadingError"
117- else :
118- error_type = "AlgorithmError"
119- response = json .dumps ({
120- "error" : {
121- "message" : str (exception ),
122- "stacktrace" : traceback .format_exc (),
123- "error_type" : error_type ,
124- }
125- })
126- return response
127-
12889 def process_local (self , local_payload , pprint ):
12990 result = self .apply (local_payload )
13091 self .write_to_pipe (result , pprint = pprint )
13192
13293 def init (self , local_payload = None , pprint = print ):
133- self .load ()
134- if self .is_local and local_payload :
94+ self .load ()
95+ if self .is_local and local_payload :
96+ if self .loading_exception :
97+ load_error = create_exception (self .loading_exception , loading_exception = True )
98+ self .write_to_pipe (load_error , pprint = pprint )
99+ self .process_local (local_payload , pprint )
100+ else :
101+ for line in sys .stdin :
102+ request = json .loads (line )
103+ formatted_input = format_data (request )
135104 if self .loading_exception :
136- load_error = self . create_exception (self .loading_exception , loading_exception = True )
105+ load_error = create_exception (self .loading_exception , loading_exception = True )
137106 self .write_to_pipe (load_error , pprint = pprint )
138- self .process_local (local_payload , pprint )
139- else :
140- for line in sys .stdin :
141- request = json .loads (line )
142- formatted_input = self .format_data (request )
143- if self .loading_exception :
144- load_error = self .create_exception (self .loading_exception , loading_exception = True )
145- self .write_to_pipe (load_error , pprint = pprint )
146- else :
147- result = self .apply (formatted_input )
148- self .write_to_pipe (result )
107+ else :
108+ result = self .apply (formatted_input )
109+ self .write_to_pipe (result )
0 commit comments