diff --git a/setup.py b/setup.py index dc5e4f7f..341f2568 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="xlb", - version="0.3.1", + version="0.3.2", description="XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML", long_description=open("README.md").read(), long_description_content_type="text/markdown", @@ -19,11 +19,11 @@ "numpy-stl>=3.1.2", "pydantic>=2.9.1", "ruff>=0.14.1", - "jax>=0.8.0", # Base JAX CPU-only requirement + "jax>=0.8.2", # Base JAX CPU-only requirement ], extras_require={ - "cuda": ["jax[cuda13]>=0.8.0"], # For CUDA installations - "tpu": ["jax[tpu]>=0.8.0"], # For TPU installations + "cuda": ["jax[cuda13]>=0.8.2"], # For CUDA installations (pip install -U "jax[cuda13]") + "tpu": ["jax[tpu]>=0.8.2"], # For TPU installations "test": ["pytest>=8.0.0"], }, python_requires=">=3.11", diff --git a/xlb/distribute/distribute.py b/xlb/distribute/distribute.py index c62b9153..1fc9138e 100644 --- a/xlb/distribute/distribute.py +++ b/xlb/distribute/distribute.py @@ -3,7 +3,7 @@ from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition.boundary_condition import ImplementationStep from jax import lax -from jax.experimental.shard_map import shard_map +from jax import shard_map from jax import jit @@ -72,7 +72,7 @@ def _wrapped_operator(*args): mesh=grid.global_mesh, in_specs=in_specs, out_specs=out_specs, - check_rep=False, + check_vma=False, ) return distributed_operator(*args) diff --git a/xlb/operator/parallel_operator.py b/xlb/operator/parallel_operator.py index 9f9b5c53..e0505f03 100644 --- a/xlb/operator/parallel_operator.py +++ b/xlb/operator/parallel_operator.py @@ -1,4 +1,4 @@ -from jax.experimental.shard_map import shard_map +from jax import shard_map from jax.sharding import PartitionSpec as P from jax import lax @@ -47,7 +47,7 @@ def __call__(self, f): mesh=self.grid.global_mesh, in_specs=in_specs, out_specs=out_specs, - check_rep=False, + check_vma=False, )(f) return f diff --git a/xlb/operator/stepper/ibm_stepper.py b/xlb/operator/stepper/ibm_stepper.py index d6b22196..8507b7bf 100644 --- a/xlb/operator/stepper/ibm_stepper.py +++ b/xlb/operator/stepper/ibm_stepper.py @@ -8,7 +8,7 @@ from xlb.operator import Operator from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper -from warp.utils import ScopedTimer +from warp import ScopedTimer class IBMStepper(IncompressibleNavierStokesStepper):