Skip to content

Commit 6e773bb

Browse files
committed
Use jax.extend.backend instead of jax._src.xla_bridge for registering the proxy backend.
PiperOrigin-RevId: 675731069
1 parent af2aaf0 commit 6e773bb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pathwaysutils/proxy_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
"""Register the IFRT Proxy as a backend for JAX."""
1515

1616
import jax
17-
from jax._src import xla_bridge
18-
from jaxlib.xla_extension import ifrt_proxy
17+
from jax.extend import backend
18+
from jax.lib.xla_extension import ifrt_proxy
1919

2020

2121
def register_backend_factory():
22-
xla_bridge.register_backend_factory(
22+
backend.register_backend_factory(
2323
"proxy",
2424
lambda: ifrt_proxy.get_client(
2525
jax.config.read("jax_backend_target"),

0 commit comments

Comments
 (0)