Skip to content

Commit 3a09bad

Browse files
lukebaumanncopybara-github
authored andcommitted
Undoes the incorrect fix to the argument and instead fixes the type hints of the function for write_executable.call
PiperOrigin-RevId: 676055225
1 parent 6e773bb commit 3a09bad

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pathwaysutils/persistence/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def write_one_array(
146146
location, name, value, timeout.total_seconds()
147147
)
148148
write_executable = plugin_executable.PluginExecutable(write_request)
149-
_, write_future = write_executable.call([[value]])
149+
_, write_future = write_executable.call([value])
150150
return write_future
151151

152152

pathwaysutils/plugin_executable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import concurrent.futures
1717
import threading
18-
from typing import List, Sequence, Tuple
18+
from typing import List, Sequence, Tuple, Union
1919

2020
import jax
2121
from jax._src.interpreters import pxla
@@ -33,7 +33,7 @@ def __init__(self, prog_str: str):
3333

3434
def call(
3535
self,
36-
in_arr: Sequence[List[jax.Array]] = (),
36+
in_arr: Sequence[Union[jax.Array, List[jax.Array]]] = (),
3737
out_shardings: Sequence[jax.sharding.XLACompatibleSharding] = (),
3838
out_avals: Sequence[jax.core.ShapedArray] = (),
3939
out_committed: bool = True,

0 commit comments

Comments
 (0)