Skip to content

Commit 9389097

Browse files
committed
Add nicer client API
1 parent be64f6c commit 9389097

File tree

3 files changed

+152
-32
lines changed

3 files changed

+152
-32
lines changed

redisai/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
from .client import Client, Type
2-
3-
1+
from .client import (
2+
Client, Tensor, ScalarTensor, BlobTensor,
3+
DEVICE_GPU, DEVICE_CPU,
4+
BACKEND_ONNX, BACKEND_TF, BACKEND_TORCH
5+
)

redisai/_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import six
22

3+
34
def to_string(s):
45
if isinstance(s, six.string_types):
56
return s

redisai/client.py

Lines changed: 146 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,157 @@
11
from redis import StrictRedis
2-
from redis._compat import (long, nativestr)
3-
from enum import Enum
4-
import six
5-
6-
class Type(Enum):
7-
FLOAT=1
8-
DOUBLE=2
9-
INT8=3
10-
INT16=4
11-
INT32=5
12-
INT64=6
13-
UINT8=7
14-
UINT16=8
15-
162

17-
class Client(StrictRedis):
3+
DEVICE_CPU = 'cpu'
4+
DEVICE_GPU = 'gpu'
5+
6+
BACKEND_TF = 'tf'
7+
BACKEND_TORCH = 'torch'
8+
BACKEND_ONNX = 'ort'
9+
10+
11+
class Tensor(object):
12+
FLOAT = 'float'
13+
DOUBLE = 'double'
14+
INT8 = 'int8'
15+
INT16 = 'int16'
16+
INT32 = 'int32'
17+
INT64 = 'int64'
18+
UINT8 = 'uint8'
19+
UINT16 = 'uint16'
20+
UINT32 = 'uint32'
21+
UINT64 = 'uint64'
22+
23+
ARGNAME = 'VALUES'
24+
25+
def __init__(self, ttype, shape, value):
26+
"""
27+
Declare a tensor suitable for passing to tensorset
28+
:param ttype: The type the values should be stored as.
29+
This can be one of Tensor.FLOAT, tensor.DOUBLE, etc.
30+
:param shape: An array describing the shape of the tensor. For an
31+
image 250x250 with three channels, this would be [250, 250, 3]
32+
:param value: The value for the tensor. Can be an array.
33+
The contents must coordinate with the shape, meaning that the
34+
overall length needs to be the product of all figures in the
35+
shape. There is no verification to ensure that each dimension
36+
is correct. Your application must ensure that the ordering
37+
is always consistent.
38+
"""
39+
self.type = ttype
40+
self.shape = shape
41+
self.value = value
42+
self._size = 1
43+
if not isinstance(value, (list, tuple)):
44+
self.value = [value]
1845

19-
def __init__(self, *args, **kwargs):
46+
@property
47+
def size(self):
48+
return self._size
49+
50+
def __repr__(self):
51+
return '<{c.__class__.__name__}(shape={s} type={t}) at 0x{id:x}>'.format(
52+
c=self,
53+
s=self.shape,
54+
t=self.type,
55+
id=id(self))
56+
57+
58+
class ScalarTensor(Tensor):
59+
def __init__(self, dtype, *values):
60+
"""
61+
Declare a tensor with a bunch of scalar values. This can be used
62+
to 'batch-load' several tensors.
63+
64+
:param dtype: The datatype to store the tensor as
65+
:param values: List of values
2066
"""
21-
Create a new Client optional host and port
67+
super(ScalarTensor, self).__init__(dtype, [1], values)
68+
self._size = len(values)
69+
70+
71+
class BlobTensor(Tensor):
72+
73+
ARGNAME = 'BLOB'
2274

23-
If conn is not None, we employ an already existing redis connection
75+
def __init__(self, ttype, shape, *blobs):
2476
"""
25-
StrictRedis.__init__(self, *args, **kwargs)
26-
27-
# Set the module commands' callbacks
28-
MODULE_CALLBACKS = {
29-
'AI.TENSORSET': lambda r: r and nativestr(r) == 'OK',
77+
Create a tensor from a binary blob
78+
:param ttype: The datatype, one of Tensor.FLOAT, Tensor.DOUBLE, etc.
79+
:param shape: An array
80+
:param blobs: One or more blobs to assign to the tensor.
81+
"""
82+
if len(blobs) > 1:
83+
blobarr = bytearray()
84+
for b in blobs:
85+
if isinstance(b, BlobTensor):
86+
b = b.value[0]
87+
blobarr += b
88+
size = len(blobs)
89+
blobs = blobarr
90+
else:
91+
blobs = blobs[0]
92+
size = 1
93+
94+
super(BlobTensor, self).__init__(ttype, shape, blobs)
95+
self._size = size
96+
97+
98+
class Client(StrictRedis):
99+
def modelset(self, name, backend, device, inputs, outputs, data):
100+
args = ['AI.MODELSET', name, backend, device, 'INPUTS']
101+
args += inputs
102+
args += ['OUTPUTS'] + outputs
103+
args += [data]
104+
return self.execute_command(*args)
105+
106+
def modelget(self, name):
107+
rv = self.execute_command('AI.MODELGET', name)
108+
return {
109+
'backend': rv[0],
110+
'device': rv[1],
111+
'data': rv[2]
30112
}
31-
for k, v in six.iteritems(MODULE_CALLBACKS):
32-
self.set_response_callback(k, v)
33113

114+
def modelrun(self, name, inputs, outputs):
115+
args = ['AI.MODELRUN', name]
116+
args += ['INPUTS'] + inputs + ['OUTPUTS'] + outputs
117+
return self.execute_command(*args)
34118

35-
def tensorset(self, key, type, dimensions, tensor):
36-
args = ['AI.TENSORSET', key, type.name] + dimensions + ['VALUES'] + tensor
37-
119+
def tensorset(self, key, tensor):
120+
"""
121+
Set the values of the tensor on the server using the provided Tensor object
122+
:param key: The name of the tensor
123+
:param tensor: a `Tensor` object
124+
:return:
125+
"""
126+
args = ['AI.TENSORSET', key, tensor.type, tensor.size]
127+
args += tensor.shape
128+
args += [tensor.ARGNAME]
129+
args += tensor.value
130+
print args
38131
return self.execute_command(*args)
39132

40-
133+
def tensorget(self, key, astype=Tensor, meta_only=False):
134+
argname = 'META' if meta_only else astype.ARGNAME
135+
res = self.execute_command('AI.TENSORGET', key, argname)
136+
if meta_only:
137+
return astype(res[0], res[1], [])
138+
else:
139+
dtype, shape, value = res
140+
return astype(dtype, shape, value)
141+
142+
def scriptset(self, name, device, script):
143+
return self.execute_command('AI.SCRIPTSET', name, device, script)
144+
145+
def scriptget(self, name):
146+
r = self.execute_command('AI.SCRIPTGET', name)
147+
return {
148+
'device': r[0],
149+
'script': r[1]
150+
}
151+
152+
def scriptrun(self, name, function, inputs, outputs):
153+
args = ['AI.SCRIPTRUN', name, function, 'INPUTS']
154+
args += inputs
155+
args += ['OUTPUTS']
156+
args += outputs
157+
return self.execute_command(*args)

0 commit comments

Comments
 (0)