1+ from enum import Enum
12from redis import StrictRedis
23from ._util import to_string
34
1213 pass
1314
1415
15- DEVICE_CPU = 'cpu'
16- DEVICE_GPU = 'gpu'
16+ class Device (Enum ):
17+ cpu = 'cpu'
18+ gpu = 'gpu'
1719
18- BACKEND_TF = 'tf'
19- BACKEND_TORCH = 'torch'
20- BACKEND_ONNX = 'ort'
2120
21+ class Backend (Enum ):
22+ tf = 'tf'
23+ torch = 'torch'
24+ onnx = 'ort'
25+
26+
27+ class DType (Enum ):
28+ float = 'float'
29+ double = 'double'
30+ int8 = 'int8'
31+ int16 = 'int16'
32+ int32 = 'int32'
33+ int64 = 'int64'
34+ uint8 = 'uint8'
35+ uint16 = 'uint16'
36+ uint32 = 'uint32'
37+ uint64 = 'uint64'
38+
39+ # aliases
40+ float32 = 'float'
41+ float64 = 'double'
2242
23- class Tensor (object ):
24- FLOAT = 'float'
25- DOUBLE = 'double'
26- INT8 = 'int8'
27- INT16 = 'int16'
28- INT32 = 'int32'
29- INT64 = 'int64'
30- UINT8 = 'uint8'
31- UINT16 = 'uint16'
32- UINT32 = 'uint32'
33- UINT64 = 'uint64'
3443
44+ class Tensor (object ):
3545 ARGNAME = 'VALUES'
3646
3747 def __init__ (self ,
38- ttype , # type: AnyStr
48+ dtype , # type: DType
3949 shape , # type: Collection[int]
4050 value ):
4151 """
4252 Declare a tensor suitable for passing to tensorset
43- :param ttype : The type the values should be stored as.
53+ :param dtype : The type the values should be stored as.
4454 This can be one of Tensor.FLOAT, tensor.DOUBLE, etc.
4555 :param shape: An array describing the shape of the tensor. For an
4656 image 250x250 with three channels, this would be [250, 250, 3]
@@ -51,7 +61,7 @@ def __init__(self,
5161 is correct. Your application must ensure that the ordering
5262 is always consistent.
5363 """
54- self .type = ttype
64+ self .type = dtype
5565 self .shape = shape
5666 self .value = value
5767 self ._size = 1
@@ -72,7 +82,7 @@ def __repr__(self):
7282
7383class ScalarTensor (Tensor ):
7484 def __init__ (self , dtype , * values ):
75- # type: (ScalarTensor, AnyStr , Any) -> None
85+ # type: (ScalarTensor, DType , Any) -> None
7686 """
7787 Declare a tensor with a bunch of scalar values. This can be used
7888 to 'batch-load' several tensors.
@@ -88,13 +98,13 @@ class BlobTensor(Tensor):
8898 ARGNAME = 'BLOB'
8999
90100 def __init__ (self ,
91- ttype ,
101+ dtype ,
92102 shape , # type: Collection[int]
93103 * blobs # type: Union[BlobTensor, ByteString]
94104 ):
95105 """
96106 Create a tensor from a binary blob
97- :param ttype : The datatype, one of Tensor.FLOAT, Tensor.DOUBLE, etc.
107+ :param dtype : The datatype, one of Tensor.FLOAT, Tensor.DOUBLE, etc.
98108 :param shape: An array
99109 :param blobs: One or more blobs to assign to the tensor.
100110 """
@@ -110,7 +120,7 @@ def __init__(self,
110120 blobs = bytes (blobs [0 ])
111121 size = 1
112122
113- super (BlobTensor , self ).__init__ (ttype , shape , blobs )
123+ super (BlobTensor , self ).__init__ (dtype , shape , blobs )
114124 self ._size = size
115125
116126 @classmethod
@@ -119,9 +129,8 @@ def from_numpy(cls, *nparrs):
119129 blobs = []
120130 for arr in nparrs :
121131 blobs .append (arr .data )
122- return cls (
123- BlobTensor ._from_numpy_type (nparrs [0 ].dtype ),
124- nparrs [0 ].shape , * blobs )
132+ dt = DType .__members__ [str (nparrs [0 ].dtype )]
133+ return cls (dt , nparrs [0 ].shape , * blobs )
125134
126135 @property
127136 def blob (self ):
@@ -143,22 +152,17 @@ def _to_numpy_type(t):
143152 return mm [t ]
144153 return t
145154
146- @staticmethod
147- def _from_numpy_type (t ):
148- t = str (t ).lower ()
149- mm = {
150- 'float32' : 'float' ,
151- 'float64' : 'double' ,
152- 'float_' : 'double'
153- }
154- if t in mm :
155- return mm [t ]
156- return t
157-
158155
159156class Client (StrictRedis ):
160- def modelset (self , name , backend , device , inputs , outputs , data ):
161- args = ['AI.MODELSET' , name , backend , device , 'INPUTS' ]
157+ def modelset (self ,
158+ name , # type: AnyStr
159+ backend , # type: Backend
160+ device , # type: Device
161+ inputs , # type: Collection[AnyStr]
162+ outputs , # type: Collection[AnyStr]
163+ data # type: ByteString
164+ ):
165+ args = ['AI.MODELSET' , name , backend .value , device .value , 'INPUTS' ]
162166 args += inputs
163167 args += ['OUTPUTS' ] + outputs
164168 args += [data ]
@@ -167,8 +171,8 @@ def modelset(self, name, backend, device, inputs, outputs, data):
167171 def modelget (self , name ):
168172 rv = self .execute_command ('AI.MODELGET' , name )
169173 return {
170- 'backend' : rv [0 ],
171- 'device' : rv [1 ],
174+ 'backend' : Backend ( rv [0 ]) ,
175+ 'device' : Device ( rv [1 ]) ,
172176 'data' : rv [2 ]
173177 }
174178
@@ -186,7 +190,7 @@ def tensorset(self, key, tensor):
186190 """
187191 if np and isinstance (tensor , np .ndarray ):
188192 tensor = BlobTensor .from_numpy (tensor )
189- args = ['AI.TENSORSET' , key , tensor .type , tensor .size ]
193+ args = ['AI.TENSORSET' , key , tensor .type . value , tensor .size ]
190194 args += tensor .shape
191195 args += [tensor .ARGNAME ]
192196 args += tensor .value
@@ -210,7 +214,7 @@ def tensorget(self, key, astype=Tensor, meta_only=False):
210214 return astype (dtype , shape , res [2 ])
211215
212216 def scriptset (self , name , device , script ):
213- return self .execute_command ('AI.SCRIPTSET' , name , device , script )
217+ return self .execute_command ('AI.SCRIPTSET' , name , device . value , script )
214218
215219 def scriptget (self , name ):
216220 r = self .execute_command ('AI.SCRIPTGET' , name )
0 commit comments