Skip to content

Commit 59cfc89

Browse files
lukebaumanncopybara-github
authored andcommitted
Converting timeout for read and write requests to int to match the function signature types.
PiperOrigin-RevId: 676552009
1 parent 297d59f commit 59cfc89

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

pathwaysutils/persistence/helper.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get_write_request(
8181
location_path: str,
8282
name: str,
8383
jax_array: jax.Array,
84-
timeout: int,
84+
timeout: datetime.timedelta,
8585
) -> str:
8686
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
8787
sharding = jax_array.sharding
@@ -102,7 +102,10 @@ def get_write_request(
102102
# pylint:enable=protected-access
103103
],
104104
},
105-
"timeout": {"seconds": timeout},
105+
"timeout": {
106+
"seconds": timeout.seconds,
107+
"nano": timeout.microseconds * 1000,
108+
},
106109
}
107110
})
108111

@@ -114,7 +117,7 @@ def get_read_request(
114117
shape: Sequence[int],
115118
sharding: jax.sharding.Sharding,
116119
devices: Sequence[jax.Device],
117-
timeout_seconds: int,
120+
timeout: datetime.timedelta,
118121
) -> str:
119122
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
120123
if not isinstance(devices, np.ndarray):
@@ -130,7 +133,10 @@ def get_read_request(
130133
"devices": {
131134
"device_ids": [device.id for device in devices.flatten()]
132135
},
133-
"timeout": {"seconds": timeout_seconds},
136+
"timeout": {
137+
"seconds": timeout.seconds,
138+
"nano": timeout.microseconds * 1000,
139+
},
134140
}
135141
})
136142

@@ -142,9 +148,7 @@ def write_one_array(
142148
timeout: datetime.timedelta,
143149
):
144150
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
145-
write_request = get_write_request(
146-
location, name, value, timeout.total_seconds()
147-
)
151+
write_request = get_write_request(location, name, value, timeout)
148152
write_executable = plugin_executable.PluginExecutable(write_request)
149153
_, write_future = write_executable.call([value])
150154
return write_future
@@ -167,7 +171,7 @@ def read_one_array(
167171
shape,
168172
shardings,
169173
devices,
170-
timeout.total_seconds(),
174+
timeout,
171175
)
172176
read_executable = plugin_executable.PluginExecutable(read_request)
173177
out_aval = core.ShapedArray(shape, dtype)

0 commit comments

Comments
 (0)