Skip to content

Commit ea02e1f

Browse files
lukebaumanncopybara-github
authored andcommitted
Changing from deprecated jax.sharding.XLACompatibleSharding to jax.sharding.Sharding per deprecation warning: jax.sharding.XLACompatibleSharding is deprecated. Use jax.sharding.Sharding instead.
PiperOrigin-RevId: 676460570
1 parent b3344a3 commit ea02e1f

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

pathwaysutils/persistence/helper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def string_to_base64(text: str) -> str:
5050

5151

5252
def get_hlo_sharding_string(
53-
sharding: jax.sharding.XLACompatibleSharding,
53+
sharding: jax.sharding.Sharding,
5454
num_dimensions: int,
5555
) -> str:
5656
"""Serializes the sharding to an hlo-sharding, encodes it to base64 and returns the base-64 as an utf-8 string."""
@@ -85,7 +85,7 @@ def get_write_request(
8585
) -> str:
8686
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
8787
sharding = jax_array.sharding
88-
assert isinstance(sharding, jax.sharding.XLACompatibleSharding), sharding
88+
assert isinstance(sharding, jax.sharding.Sharding), sharding
8989
return json.dumps({
9090
"persistenceWriteRequest": {
9191
"b64_location": string_to_base64(location_path),
@@ -112,7 +112,7 @@ def get_read_request(
112112
name: str,
113113
dtype: np.dtype,
114114
shape: Sequence[int],
115-
sharding: jax.sharding.XLACompatibleSharding,
115+
sharding: jax.sharding.Sharding,
116116
devices: Sequence[jax.Device],
117117
timeout_seconds: int,
118118
) -> str:
@@ -155,7 +155,7 @@ def read_one_array(
155155
name: str,
156156
dtype: np.dtype,
157157
shape: Sequence[int],
158-
shardings: jax.sharding.XLACompatibleSharding,
158+
shardings: jax.sharding.Sharding,
159159
devices: Union[Sequence[jax.Device], np.ndarray],
160160
timeout: datetime.timedelta,
161161
):

pathwaysutils/plugin_executable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(self, prog_str: str):
3434
def call(
3535
self,
3636
in_arr: Sequence[Union[jax.Array, List[jax.Array]]] = (),
37-
out_shardings: Sequence[jax.sharding.XLACompatibleSharding] = (),
37+
out_shardings: Sequence[jax.sharding.Sharding] = (),
3838
out_avals: Sequence[jax.core.ShapedArray] = (),
3939
out_committed: bool = True,
4040
) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:

0 commit comments

Comments
 (0)