We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
jaxlib.xla_client.ifrt_programs
jax.extend.ifrt_programs.ifrt_programs
1 parent ea02e1f commit 297d59fCopy full SHA for 297d59f
pathwaysutils/plugin_executable.py
@@ -19,16 +19,16 @@
19
20
import jax
21
from jax._src.interpreters import pxla
22
-from jaxlib import xla_client
+from jax.extend.ifrt_programs import ifrt_programs
23
24
25
class PluginExecutable:
26
"""Class for running compiled IFRT program over the IFRT Proxy."""
27
28
def __init__(self, prog_str: str):
29
ifrt_client = jax.local_devices()[0].client
30
- program = xla_client.ifrt_programs.make_plugin_program(prog_str)
31
- options = xla_client.ifrt_programs.make_plugin_compile_options()
+ program = ifrt_programs.make_plugin_program(prog_str)
+ options = ifrt_programs.make_plugin_compile_options()
32
self.compiled = ifrt_client.compile_ifrt_program(program, options)
33
34
def call(
0 commit comments