11from redis import StrictRedis
2+ from ._util import to_string
3+
4+ try :
5+ import numpy as np
6+ except ImportError :
7+ np = None
8+
9+ try :
10+ from typing import Union , Any , AnyStr , ByteString , Collection
11+ except ImportError :
12+ pass
13+
214
315DEVICE_CPU = 'cpu'
416DEVICE_GPU = 'gpu'
@@ -22,7 +34,10 @@ class Tensor(object):
2234
2335 ARGNAME = 'VALUES'
2436
25- def __init__ (self , ttype , shape , value ):
37+ def __init__ (self ,
38+ ttype , # type: AnyStr
39+ shape , # type: Collection[int]
40+ value ):
2641 """
2742 Declare a tensor suitable for passing to tensorset
2843 :param ttype: The type the values should be stored as.
@@ -57,6 +72,7 @@ def __repr__(self):
5772
5873class ScalarTensor (Tensor ):
5974 def __init__ (self , dtype , * values ):
75+ # type: (ScalarTensor, AnyStr, Any) -> None
6076 """
6177 Declare a tensor with a bunch of scalar values. This can be used
6278 to 'batch-load' several tensors.
@@ -69,10 +85,13 @@ def __init__(self, dtype, *values):
6985
7086
7187class BlobTensor (Tensor ):
72-
7388 ARGNAME = 'BLOB'
7489
75- def __init__ (self , ttype , shape , * blobs ):
90+ def __init__ (self ,
91+ ttype ,
92+ shape , # type: Collection[int]
93+ * blobs # type: Union[BlobTensor, ByteString]
94+ ):
7695 """
7796 Create a tensor from a binary blob
7897 :param ttype: The datatype, one of Tensor.FLOAT, Tensor.DOUBLE, etc.
@@ -86,14 +105,56 @@ def __init__(self, ttype, shape, *blobs):
86105 b = b .value [0 ]
87106 blobarr += b
88107 size = len (blobs )
89- blobs = blobarr
108+ blobs = bytes ( blobarr )
90109 else :
91- blobs = blobs [0 ]
110+ blobs = bytes ( blobs [0 ])
92111 size = 1
93112
94113 super (BlobTensor , self ).__init__ (ttype , shape , blobs )
95114 self ._size = size
96115
116+ @classmethod
117+ def from_numpy (cls , * nparrs ):
118+ # type: (type, np.array) -> BlobTensor
119+ blobs = []
120+ for arr in nparrs :
121+ blobs .append (arr .data )
122+ return cls (
123+ BlobTensor ._from_numpy_type (nparrs [0 ].dtype ),
124+ nparrs [0 ].shape , * blobs )
125+
126+ @property
127+ def blob (self ):
128+ return self .value [0 ]
129+
130+ def to_numpy (self ):
131+ # type: () -> np.array
132+ a = np .frombuffer (self .value [0 ], dtype = self ._to_numpy_type (self .type ))
133+ return a .reshape (self .shape )
134+
135+ @staticmethod
136+ def _to_numpy_type (t ):
137+ t = t .lower ()
138+ mm = {
139+ 'float' : 'float32' ,
140+ 'double' : 'float64'
141+ }
142+ if t in mm :
143+ return mm [t ]
144+ return t
145+
146+ @staticmethod
147+ def _from_numpy_type (t ):
148+ t = str (t ).lower ()
149+ mm = {
150+ 'float32' : 'float' ,
151+ 'float64' : 'double' ,
152+ 'float_' : 'double'
153+ }
154+ if t in mm :
155+ return mm [t ]
156+ return t
157+
97158
98159class Client (StrictRedis ):
99160 def modelset (self , name , backend , device , inputs , outputs , data ):
@@ -117,36 +178,45 @@ def modelrun(self, name, inputs, outputs):
117178 return self .execute_command (* args )
118179
119180 def tensorset (self , key , tensor ):
181+ # type: (Client, AnyStr, Union[Tensor, np.ndarray]) -> Any
120182 """
121183 Set the values of the tensor on the server using the provided Tensor object
122184 :param key: The name of the tensor
123185 :param tensor: a `Tensor` object
124- :return:
125186 """
187+ if np and isinstance (tensor , np .ndarray ):
188+ tensor = BlobTensor .from_numpy (tensor )
126189 args = ['AI.TENSORSET' , key , tensor .type , tensor .size ]
127190 args += tensor .shape
128191 args += [tensor .ARGNAME ]
129192 args += tensor .value
130- print args
131193 return self .execute_command (* args )
132194
133195 def tensorget (self , key , astype = Tensor , meta_only = False ):
196+ """
197+ Retrieve the value of a tensor from the server
198+ :param key: the name of the tensor
199+ :param astype: the resultant tensor type
200+ :param meta_only: if true, then the value is not retrieved,
201+ only the shape and the type
202+ :return: an instance of astype
203+ """
134204 argname = 'META' if meta_only else astype .ARGNAME
135205 res = self .execute_command ('AI.TENSORGET' , key , argname )
206+ dtype , shape = to_string (res [0 ]), res [1 ]
136207 if meta_only :
137- return astype (res [ 0 ], res [ 1 ] , [])
208+ return astype (dtype , shape , [])
138209 else :
139- dtype , shape , value = res
140- return astype (dtype , shape , value )
210+ return astype (dtype , shape , res [2 ])
141211
142212 def scriptset (self , name , device , script ):
143213 return self .execute_command ('AI.SCRIPTSET' , name , device , script )
144214
145215 def scriptget (self , name ):
146216 r = self .execute_command ('AI.SCRIPTGET' , name )
147217 return {
148- 'device' : r [0 ],
149- 'script' : r [1 ]
218+ 'device' : to_string ( r [0 ]) ,
219+ 'script' : to_string ( r [1 ])
150220 }
151221
152222 def scriptrun (self , name , function , inputs , outputs ):
0 commit comments