Skip to content

Commit 297d59f

Browse files
lukebaumanncopybara-github
authored andcommitted
Change from using jaxlib.xla_client.ifrt_programs to jax.extend.ifrt_programs.ifrt_programs.
PiperOrigin-RevId: 676551752
1 parent ea02e1f commit 297d59f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pathwaysutils/plugin_executable.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919

2020
import jax
2121
from jax._src.interpreters import pxla
22-
from jaxlib import xla_client
22+
from jax.extend.ifrt_programs import ifrt_programs
2323

2424

2525
class PluginExecutable:
2626
"""Class for running compiled IFRT program over the IFRT Proxy."""
2727

2828
def __init__(self, prog_str: str):
2929
ifrt_client = jax.local_devices()[0].client
30-
program = xla_client.ifrt_programs.make_plugin_program(prog_str)
31-
options = xla_client.ifrt_programs.make_plugin_compile_options()
30+
program = ifrt_programs.make_plugin_program(prog_str)
31+
options = ifrt_programs.make_plugin_compile_options()
3232
self.compiled = ifrt_client.compile_ifrt_program(program, options)
3333

3434
def call(

0 commit comments

Comments
 (0)