Skip to content

Commit dc09c43

Browse files
committed
Change to use enum types
1 parent 1b68ef8 commit dc09c43

File tree

2 files changed

+49
-47
lines changed

2 files changed

+49
-47
lines changed

redisai/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
from .client import (
2-
Client, Tensor, ScalarTensor, BlobTensor,
3-
DEVICE_GPU, DEVICE_CPU,
4-
BACKEND_ONNX, BACKEND_TF, BACKEND_TORCH
2+
Client, Tensor, ScalarTensor, BlobTensor, DType, Device, Backend
53
)

redisai/client.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from redis import StrictRedis
23
from ._util import to_string
34

@@ -12,35 +13,44 @@
1213
pass
1314

1415

15-
DEVICE_CPU = 'cpu'
16-
DEVICE_GPU = 'gpu'
16+
class Device(Enum):
17+
cpu = 'cpu'
18+
gpu = 'gpu'
1719

18-
BACKEND_TF = 'tf'
19-
BACKEND_TORCH = 'torch'
20-
BACKEND_ONNX = 'ort'
2120

21+
class Backend(Enum):
22+
tf = 'tf'
23+
torch = 'torch'
24+
onnx = 'ort'
25+
26+
27+
class DType(Enum):
28+
float = 'float'
29+
double = 'double'
30+
int8 = 'int8'
31+
int16 = 'int16'
32+
int32 = 'int32'
33+
int64 = 'int64'
34+
uint8 = 'uint8'
35+
uint16 = 'uint16'
36+
uint32 = 'uint32'
37+
uint64 = 'uint64'
38+
39+
# aliases
40+
float32 = 'float'
41+
float64 = 'double'
2242

23-
class Tensor(object):
24-
FLOAT = 'float'
25-
DOUBLE = 'double'
26-
INT8 = 'int8'
27-
INT16 = 'int16'
28-
INT32 = 'int32'
29-
INT64 = 'int64'
30-
UINT8 = 'uint8'
31-
UINT16 = 'uint16'
32-
UINT32 = 'uint32'
33-
UINT64 = 'uint64'
3443

44+
class Tensor(object):
3545
ARGNAME = 'VALUES'
3646

3747
def __init__(self,
38-
ttype, # type: AnyStr
48+
dtype, # type: DType
3949
shape, # type: Collection[int]
4050
value):
4151
"""
4252
Declare a tensor suitable for passing to tensorset
43-
:param ttype: The type the values should be stored as.
53+
:param dtype: The type the values should be stored as.
4454
This can be one of Tensor.FLOAT, tensor.DOUBLE, etc.
4555
:param shape: An array describing the shape of the tensor. For an
4656
image 250x250 with three channels, this would be [250, 250, 3]
@@ -51,7 +61,7 @@ def __init__(self,
5161
is correct. Your application must ensure that the ordering
5262
is always consistent.
5363
"""
54-
self.type = ttype
64+
self.type = dtype
5565
self.shape = shape
5666
self.value = value
5767
self._size = 1
@@ -72,7 +82,7 @@ def __repr__(self):
7282

7383
class ScalarTensor(Tensor):
7484
def __init__(self, dtype, *values):
75-
# type: (ScalarTensor, AnyStr, Any) -> None
85+
# type: (ScalarTensor, DType, Any) -> None
7686
"""
7787
Declare a tensor with a bunch of scalar values. This can be used
7888
to 'batch-load' several tensors.
@@ -88,13 +98,13 @@ class BlobTensor(Tensor):
8898
ARGNAME = 'BLOB'
8999

90100
def __init__(self,
91-
ttype,
101+
dtype,
92102
shape, # type: Collection[int]
93103
*blobs # type: Union[BlobTensor, ByteString]
94104
):
95105
"""
96106
Create a tensor from a binary blob
97-
:param ttype: The datatype, one of Tensor.FLOAT, Tensor.DOUBLE, etc.
107+
:param dtype: The datatype, one of Tensor.FLOAT, Tensor.DOUBLE, etc.
98108
:param shape: An array
99109
:param blobs: One or more blobs to assign to the tensor.
100110
"""
@@ -110,7 +120,7 @@ def __init__(self,
110120
blobs = bytes(blobs[0])
111121
size = 1
112122

113-
super(BlobTensor, self).__init__(ttype, shape, blobs)
123+
super(BlobTensor, self).__init__(dtype, shape, blobs)
114124
self._size = size
115125

116126
@classmethod
@@ -119,9 +129,8 @@ def from_numpy(cls, *nparrs):
119129
blobs = []
120130
for arr in nparrs:
121131
blobs.append(arr.data)
122-
return cls(
123-
BlobTensor._from_numpy_type(nparrs[0].dtype),
124-
nparrs[0].shape, *blobs)
132+
dt = DType.__members__[str(nparrs[0].dtype)]
133+
return cls(dt, nparrs[0].shape, *blobs)
125134

126135
@property
127136
def blob(self):
@@ -143,22 +152,17 @@ def _to_numpy_type(t):
143152
return mm[t]
144153
return t
145154

146-
@staticmethod
147-
def _from_numpy_type(t):
148-
t = str(t).lower()
149-
mm = {
150-
'float32': 'float',
151-
'float64': 'double',
152-
'float_': 'double'
153-
}
154-
if t in mm:
155-
return mm[t]
156-
return t
157-
158155

159156
class Client(StrictRedis):
160-
def modelset(self, name, backend, device, inputs, outputs, data):
161-
args = ['AI.MODELSET', name, backend, device, 'INPUTS']
157+
def modelset(self,
158+
name, # type: AnyStr
159+
backend, # type: Backend
160+
device, # type: Device
161+
inputs, # type: Collection[AnyStr]
162+
outputs, # type: Collection[AnyStr]
163+
data # type: ByteString
164+
):
165+
args = ['AI.MODELSET', name, backend.value, device.value, 'INPUTS']
162166
args += inputs
163167
args += ['OUTPUTS'] + outputs
164168
args += [data]
@@ -167,8 +171,8 @@ def modelset(self, name, backend, device, inputs, outputs, data):
167171
def modelget(self, name):
168172
rv = self.execute_command('AI.MODELGET', name)
169173
return {
170-
'backend': rv[0],
171-
'device': rv[1],
174+
'backend': Backend(rv[0]),
175+
'device': Device(rv[1]),
172176
'data': rv[2]
173177
}
174178

@@ -186,7 +190,7 @@ def tensorset(self, key, tensor):
186190
"""
187191
if np and isinstance(tensor, np.ndarray):
188192
tensor = BlobTensor.from_numpy(tensor)
189-
args = ['AI.TENSORSET', key, tensor.type, tensor.size]
193+
args = ['AI.TENSORSET', key, tensor.type.value, tensor.size]
190194
args += tensor.shape
191195
args += [tensor.ARGNAME]
192196
args += tensor.value
@@ -210,7 +214,7 @@ def tensorget(self, key, astype=Tensor, meta_only=False):
210214
return astype(dtype, shape, res[2])
211215

212216
def scriptset(self, name, device, script):
213-
return self.execute_command('AI.SCRIPTSET', name, device, script)
217+
return self.execute_command('AI.SCRIPTSET', name, device.value, script)
214218

215219
def scriptget(self, name):
216220
r = self.execute_command('AI.SCRIPTGET', name)

0 commit comments

Comments
 (0)