Skip to content

Commit 1b68ef8

Browse files
committed
numpy support, py3 support
1 parent 9389097 commit 1b68ef8

File tree

1 file changed

+82
-12
lines changed

1 file changed

+82
-12
lines changed

redisai/client.py

Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
11
from redis import StrictRedis
2+
from ._util import to_string
3+
4+
try:
5+
import numpy as np
6+
except ImportError:
7+
np = None
8+
9+
try:
10+
from typing import Union, Any, AnyStr, ByteString, Collection
11+
except ImportError:
12+
pass
13+
214

315
DEVICE_CPU = 'cpu'
416
DEVICE_GPU = 'gpu'
@@ -22,7 +34,10 @@ class Tensor(object):
2234

2335
ARGNAME = 'VALUES'
2436

25-
def __init__(self, ttype, shape, value):
37+
def __init__(self,
38+
ttype, # type: AnyStr
39+
shape, # type: Collection[int]
40+
value):
2641
"""
2742
Declare a tensor suitable for passing to tensorset
2843
:param ttype: The type the values should be stored as.
@@ -57,6 +72,7 @@ def __repr__(self):
5772

5873
class ScalarTensor(Tensor):
5974
def __init__(self, dtype, *values):
75+
# type: (ScalarTensor, AnyStr, Any) -> None
6076
"""
6177
Declare a tensor with a bunch of scalar values. This can be used
6278
to 'batch-load' several tensors.
@@ -69,10 +85,13 @@ def __init__(self, dtype, *values):
6985

7086

7187
class BlobTensor(Tensor):
72-
7388
ARGNAME = 'BLOB'
7489

75-
def __init__(self, ttype, shape, *blobs):
90+
def __init__(self,
91+
ttype,
92+
shape, # type: Collection[int]
93+
*blobs # type: Union[BlobTensor, ByteString]
94+
):
7695
"""
7796
Create a tensor from a binary blob
7897
:param ttype: The datatype, one of Tensor.FLOAT, Tensor.DOUBLE, etc.
@@ -86,14 +105,56 @@ def __init__(self, ttype, shape, *blobs):
86105
b = b.value[0]
87106
blobarr += b
88107
size = len(blobs)
89-
blobs = blobarr
108+
blobs = bytes(blobarr)
90109
else:
91-
blobs = blobs[0]
110+
blobs = bytes(blobs[0])
92111
size = 1
93112

94113
super(BlobTensor, self).__init__(ttype, shape, blobs)
95114
self._size = size
96115

116+
@classmethod
117+
def from_numpy(cls, *nparrs):
118+
# type: (type, np.array) -> BlobTensor
119+
blobs = []
120+
for arr in nparrs:
121+
blobs.append(arr.data)
122+
return cls(
123+
BlobTensor._from_numpy_type(nparrs[0].dtype),
124+
nparrs[0].shape, *blobs)
125+
126+
@property
127+
def blob(self):
128+
return self.value[0]
129+
130+
def to_numpy(self):
131+
# type: () -> np.array
132+
a = np.frombuffer(self.value[0], dtype=self._to_numpy_type(self.type))
133+
return a.reshape(self.shape)
134+
135+
@staticmethod
136+
def _to_numpy_type(t):
137+
t = t.lower()
138+
mm = {
139+
'float': 'float32',
140+
'double': 'float64'
141+
}
142+
if t in mm:
143+
return mm[t]
144+
return t
145+
146+
@staticmethod
147+
def _from_numpy_type(t):
148+
t = str(t).lower()
149+
mm = {
150+
'float32': 'float',
151+
'float64': 'double',
152+
'float_': 'double'
153+
}
154+
if t in mm:
155+
return mm[t]
156+
return t
157+
97158

98159
class Client(StrictRedis):
99160
def modelset(self, name, backend, device, inputs, outputs, data):
@@ -117,36 +178,45 @@ def modelrun(self, name, inputs, outputs):
117178
return self.execute_command(*args)
118179

119180
def tensorset(self, key, tensor):
181+
# type: (Client, AnyStr, Union[Tensor, np.ndarray]) -> Any
120182
"""
121183
Set the values of the tensor on the server using the provided Tensor object
122184
:param key: The name of the tensor
123185
:param tensor: a `Tensor` object
124-
:return:
125186
"""
187+
if np and isinstance(tensor, np.ndarray):
188+
tensor = BlobTensor.from_numpy(tensor)
126189
args = ['AI.TENSORSET', key, tensor.type, tensor.size]
127190
args += tensor.shape
128191
args += [tensor.ARGNAME]
129192
args += tensor.value
130-
print args
131193
return self.execute_command(*args)
132194

133195
def tensorget(self, key, astype=Tensor, meta_only=False):
196+
"""
197+
Retrieve the value of a tensor from the server
198+
:param key: the name of the tensor
199+
:param astype: the resultant tensor type
200+
:param meta_only: if true, then the value is not retrieved,
201+
only the shape and the type
202+
:return: an instance of astype
203+
"""
134204
argname = 'META' if meta_only else astype.ARGNAME
135205
res = self.execute_command('AI.TENSORGET', key, argname)
206+
dtype, shape = to_string(res[0]), res[1]
136207
if meta_only:
137-
return astype(res[0], res[1], [])
208+
return astype(dtype, shape, [])
138209
else:
139-
dtype, shape, value = res
140-
return astype(dtype, shape, value)
210+
return astype(dtype, shape, res[2])
141211

142212
def scriptset(self, name, device, script):
143213
return self.execute_command('AI.SCRIPTSET', name, device, script)
144214

145215
def scriptget(self, name):
146216
r = self.execute_command('AI.SCRIPTGET', name)
147217
return {
148-
'device': r[0],
149-
'script': r[1]
218+
'device': to_string(r[0]),
219+
'script': to_string(r[1])
150220
}
151221

152222
def scriptrun(self, name, function, inputs, outputs):

0 commit comments

Comments
 (0)