@@ -81,7 +81,7 @@ def get_write_request(
8181 location_path : str ,
8282 name : str ,
8383 jax_array : jax .Array ,
84- timeout : int ,
84+ timeout : datetime . timedelta ,
8585) -> str :
8686 """Returns a string representation of the plugin program which writes the given jax_array to the given location."""
8787 sharding = jax_array .sharding
@@ -102,7 +102,10 @@ def get_write_request(
102102 # pylint:enable=protected-access
103103 ],
104104 },
105- "timeout" : {"seconds" : timeout },
105+ "timeout" : {
106+ "seconds" : timeout .seconds ,
107+ "nano" : timeout .microseconds * 1000 ,
108+ },
106109 }
107110 })
108111
@@ -114,7 +117,7 @@ def get_read_request(
114117 shape : Sequence [int ],
115118 sharding : jax .sharding .Sharding ,
116119 devices : Sequence [jax .Device ],
117- timeout_seconds : int ,
120+ timeout : datetime . timedelta ,
118121) -> str :
119122 """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
120123 if not isinstance (devices , np .ndarray ):
@@ -130,7 +133,10 @@ def get_read_request(
130133 "devices" : {
131134 "device_ids" : [device .id for device in devices .flatten ()]
132135 },
133- "timeout" : {"seconds" : timeout_seconds },
136+ "timeout" : {
137+ "seconds" : timeout .seconds ,
138+ "nano" : timeout .microseconds * 1000 ,
139+ },
134140 }
135141 })
136142
@@ -142,9 +148,7 @@ def write_one_array(
142148 timeout : datetime .timedelta ,
143149):
144150 """Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
145- write_request = get_write_request (
146- location , name , value , timeout .total_seconds ()
147- )
151+ write_request = get_write_request (location , name , value , timeout )
148152 write_executable = plugin_executable .PluginExecutable (write_request )
149153 _ , write_future = write_executable .call ([value ])
150154 return write_future
@@ -167,7 +171,7 @@ def read_one_array(
167171 shape ,
168172 shardings ,
169173 devices ,
170- timeout . total_seconds () ,
174+ timeout ,
171175 )
172176 read_executable = plugin_executable .PluginExecutable (read_request )
173177 out_aval = core .ShapedArray (shape , dtype )
0 commit comments