11from enum import Enum
22from redis import StrictRedis
33from ._util import to_string
4+ import six
45
56try :
67 import numpy as np
78except ImportError :
89 np = None
910
1011try :
11- from typing import Union , Any , AnyStr , ByteString , Collection
12+ from typing import Union , Any , AnyStr , ByteString , Collection , Type
1213except ImportError :
1314 pass
1415
@@ -41,6 +42,12 @@ class DType(Enum):
4142 float64 = 'double'
4243
4344
45+ def _str_or_strlist (v ):
46+ if isinstance (v , six .string_types ):
47+ return [v ]
48+ return v
49+
50+
4451class Tensor (object ):
4552 ARGNAME = 'VALUES'
4653
@@ -158,13 +165,13 @@ def modelset(self,
158165 name , # type: AnyStr
159166 backend , # type: Backend
160167 device , # type: Device
161- inputs , # type: Collection[AnyStr]
162- outputs , # type: Collection[AnyStr]
168+ input , # type: Union[AnyStr| Collection[AnyStr] ]
169+ output , # type: Union[AnyStr| Collection[AnyStr] ]
163170 data # type: ByteString
164171 ):
165172 args = ['AI.MODELSET' , name , backend .value , device .value , 'INPUTS' ]
166- args += inputs
167- args += ['OUTPUTS' ] + outputs
173+ args += _str_or_strlist ( input )
174+ args += ['OUTPUTS' ] + _str_or_strlist ( output )
168175 args += [data ]
169176 return self .execute_command (* args )
170177
@@ -176,9 +183,14 @@ def modelget(self, name):
176183 'data' : rv [2 ]
177184 }
178185
179- def modelrun (self , name , inputs , outputs ):
186+ def modelrun (self ,
187+ name ,
188+ input , # type: Union[AnyStr|Collection[AnyStr]]
189+ output # type: Union[AnyStr|Collection[AnyStr]]
190+ ):
180191 args = ['AI.MODELRUN' , name ]
181- args += ['INPUTS' ] + inputs + ['OUTPUTS' ] + outputs
192+ args += ['INPUTS' ] + _str_or_strlist (input )
193+ args += ['OUTPUTS' ] + _str_or_strlist (output )
182194 return self .execute_command (* args )
183195
184196 def tensorset (self , key , tensor ):
@@ -196,22 +208,23 @@ def tensorset(self, key, tensor):
196208 args += tensor .value
197209 return self .execute_command (* args )
198210
199- def tensorget (self , key , astype = Tensor , meta_only = False ):
211+ def tensorget (self , key , as_type = Tensor , meta_only = False ):
212+ # type: (AnyStr, Type[Tensor], bool) -> Tensor
200213 """
201214 Retrieve the value of a tensor from the server
202215 :param key: the name of the tensor
203- :param astype : the resultant tensor type
216+ :param as_type : the resultant tensor type
204217 :param meta_only: if true, then the value is not retrieved,
205218 only the shape and the type
206- :return: an instance of astype
219+ :return: an instance of as_type
207220 """
208- argname = 'META' if meta_only else astype .ARGNAME
221+ argname = 'META' if meta_only else as_type .ARGNAME
209222 res = self .execute_command ('AI.TENSORGET' , key , argname )
210223 dtype , shape = to_string (res [0 ]), res [1 ]
211224 if meta_only :
212- return astype (dtype , shape , [])
225+ return as_type (dtype , shape , [])
213226 else :
214- return astype (dtype , shape , res [2 ])
227+ return as_type (dtype , shape , res [2 ])
215228
216229 def scriptset (self , name , device , script ):
217230 return self .execute_command ('AI.SCRIPTSET' , name , device .value , script )
@@ -223,9 +236,14 @@ def scriptget(self, name):
223236 'script' : to_string (r [1 ])
224237 }
225238
226- def scriptrun (self , name , function , inputs , outputs ):
239+ def scriptrun (self ,
240+ name ,
241+ function , # type: AnyStr
242+ input , # type: Union[AnyStr|Collection[AnyStr]]
243+ output # type: Union[AnyStr|Collection[AnyStr]]
244+ ):
227245 args = ['AI.SCRIPTRUN' , name , function , 'INPUTS' ]
228- args += inputs
246+ args += _str_or_strlist ( input )
229247 args += ['OUTPUTS' ]
230- args += outputs
248+ args += _str_or_strlist ( output )
231249 return self .execute_command (* args )
0 commit comments