Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="xlb",
version="0.3.1",
version="0.3.2",
description="XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
Expand All @@ -19,11 +19,11 @@
"numpy-stl>=3.1.2",
"pydantic>=2.9.1",
"ruff>=0.14.1",
"jax>=0.8.0", # Base JAX CPU-only requirement
"jax>=0.8.2", # Base JAX CPU-only requirement
],
extras_require={
"cuda": ["jax[cuda13]>=0.8.0"], # For CUDA installations
"tpu": ["jax[tpu]>=0.8.0"], # For TPU installations
"cuda": ["jax[cuda13]>=0.8.2"], # For CUDA installations (pip install -U "jax[cuda13]")
"tpu": ["jax[tpu]>=0.8.2"], # For TPU installations
"test": ["pytest>=8.0.0"],
},
python_requires=">=3.11",
Expand Down
4 changes: 2 additions & 2 deletions xlb/distribute/distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
from jax import lax
from jax.experimental.shard_map import shard_map
from jax import shard_map
from jax import jit


Expand Down Expand Up @@ -72,7 +72,7 @@ def _wrapped_operator(*args):
mesh=grid.global_mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
check_vma=False,
)
return distributed_operator(*args)

Expand Down
4 changes: 2 additions & 2 deletions xlb/operator/parallel_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jax.experimental.shard_map import shard_map
from jax import shard_map
from jax.sharding import PartitionSpec as P
from jax import lax

Expand Down Expand Up @@ -47,7 +47,7 @@ def __call__(self, f):
mesh=self.grid.global_mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
check_vma=False,
)(f)
return f

Expand Down
2 changes: 1 addition & 1 deletion xlb/operator/stepper/ibm_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from xlb.operator import Operator
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper
from warp.utils import ScopedTimer
from warp import ScopedTimer


class IBMStepper(IncompressibleNavierStokesStepper):
Expand Down