|
21 | 21 |
|
22 | 22 | import jax |
23 | 23 | from jax import core |
24 | | -from jax.lib import xla_client as xc |
25 | 24 | import numpy as np |
26 | 25 | from pathwaysutils import plugin_executable |
27 | 26 |
|
28 | 27 |
|
29 | | -def dtype_to_etype(dtype: np.dtype) -> xc.PrimitiveType: |
| 28 | +def dtype_to_xla_primitive_type_str(dtype: np.dtype) -> str: |
30 | 29 | """Converts a numpy dtype to an xla PrimitiveType.""" |
31 | 30 | if dtype == np.dtype("bfloat16"): |
32 | | - return xc.PrimitiveType.BF16 |
| 31 | + return "BF16" |
33 | 32 | elif dtype == np.dtype("float32"): |
34 | | - return xc.PrimitiveType.F32 |
| 33 | + return "F32" |
35 | 34 | elif dtype == np.dtype("float64"): |
36 | | - return xc.PrimitiveType.F64 |
| 35 | + return "F64" |
37 | 36 | elif dtype == np.dtype("int8"): |
38 | | - return xc.PrimitiveType.S8 |
| 37 | + return "S8" |
39 | 38 | elif dtype == np.dtype("int16"): |
40 | | - return xc.PrimitiveType.S16 |
| 39 | + return "S16" |
41 | 40 | elif dtype == np.dtype("int32"): |
42 | | - return xc.PrimitiveType.S32 |
| 41 | + return "S32" |
43 | 42 | elif dtype == np.dtype("int64"): |
44 | | - return xc.PrimitiveType.S64 |
| 43 | + return "S64" |
45 | 44 | elif dtype == np.dtype("uint8"): |
46 | | - return xc.PrimitiveType.U8 |
| 45 | + return "U8" |
47 | 46 | elif dtype == np.dtype("uint16"): |
48 | | - return xc.PrimitiveType.U16 |
| 47 | + return "U16" |
49 | 48 | elif dtype == np.dtype("uint32"): |
50 | | - return xc.PrimitiveType.U32 |
| 49 | + return "U32" |
51 | 50 | elif dtype == np.dtype("uint64"): |
52 | | - return xc.PrimitiveType.U64 |
| 51 | + return "U64" |
53 | 52 | else: |
54 | 53 | raise ValueError(f"Unsupported dtype: {dtype}") |
55 | 54 |
|
@@ -91,19 +90,15 @@ def get_hlo_sharding_string( |
91 | 90 | ) |
92 | 91 |
|
93 | 92 |
|
94 | | -def get_shape_string( |
| 93 | +def get_shape_info( |
95 | 94 | dtype: np.dtype, |
96 | | - shape: Sequence[int], |
97 | | -) -> str: |
98 | | - """Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string.""" |
99 | | - return base64_utf8_stringify( |
100 | | - xc.Shape.array_shape( |
101 | | - xc.PrimitiveType(dtype_to_etype(dtype)), |
102 | | - shape, |
103 | | - ) |
104 | | - .with_major_to_minor_layout_if_absent() |
105 | | - .to_serialized_proto() |
106 | | - ) |
| 95 | + dimensions: Sequence[int], |
| 96 | +) -> dict[str, Union[Sequence[int], str]]: |
| 97 | + """Returns shape info in the format expected by read requests.""" |
| 98 | + return { |
| 99 | + "xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype), |
| 100 | + "dimensions": dimensions, |
| 101 | + } |
107 | 102 |
|
108 | 103 |
|
109 | 104 | def get_write_request( |
@@ -188,7 +183,7 @@ def get_read_request( |
188 | 183 | d = { |
189 | 184 | "persistenceReadRequest": { |
190 | 185 | "b64_location": string_to_base64(location_path), |
191 | | - "b64_shape_proto_string": get_shape_string(dtype, shape), |
| 186 | + "shape": get_shape_info(dtype, shape), |
192 | 187 | "b64_name": string_to_base64(name), |
193 | 188 | "b64_hlo_sharding_string": get_hlo_sharding_string( |
194 | 189 | sharding, len(shape) |
|
0 commit comments