Skip to content

Multi-res grid refinement + Neon backend support#159

Open
hsalehipour wants to merge 265 commits intoAutodesk:mainfrom
hsalehipour:dev
Open

Multi-res grid refinement + Neon backend support#159
hsalehipour wants to merge 265 commits intoAutodesk:mainfrom
hsalehipour:dev

Conversation

@hsalehipour
Copy link
Copy Markdown
Collaborator

Contributing Guidelines

Description

Grid refinement capability is now supported in XLB through the Neon backend. The Neon backend provides full support for dense grids on multi-GPU systems, as well as multi-resolution grids on single GPUs. All newly introduced functionalities have been carefully tested and optimized. This represents a major enhancement to the library and involves substantial additions and improvements to the codebase.

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

How Has This Been Tested?

  • All pytest tests pass
============================================= test session starts ==============================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /home/max/repo/test/XLB
collected 93 items                                                                                             

tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_jax.py ....                                 [  4%]
tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py ....                                [  8%]
tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_jax.py ......               [ 15%]
tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py ......              [ 21%]
tests/boundary_conditions/mask/test_bc_indices_masker_jax.py .......                                     [ 29%]
tests/boundary_conditions/mask/test_bc_indices_masker_warp.py ......                                     [ 35%]
tests/grids/test_grid_jax.py .......                                                                     [ 43%]
tests/grids/test_grid_warp.py ....                                                                       [ 47%]
tests/kernels/collision/test_bgk_collision_jax.py ......                                                 [ 53%]
tests/kernels/collision/test_bgk_collision_warp.py ......                                                [ 60%]
tests/kernels/equilibrium/test_equilibrium_jax.py ......                                                 [ 66%]
tests/kernels/equilibrium/test_equilibrium_warp.py ......                                                [ 73%]
tests/kernels/macroscopic/test_macroscopic_jax.py ......                                                 [ 79%]
tests/kernels/macroscopic/test_macroscopic_warp.py .......                                               [ 87%]
tests/kernels/stream/test_stream_jax.py ......                                                           [ 93%]
tests/kernels/stream/test_stream_warp.py ......                                                          [100%]

======================================== 93 passed in 248.34s (0:04:08) ========================================

Linting and Code Formatting

Make sure the code follows the project's linting and formatting standards. This project uses Ruff for linting.

To run Ruff, execute the following command from the root of the repository:

ruff check .
  • Ruff passes

massimim and others added 27 commits November 13, 2025 08:09
Simplifies the `add_to_app` method in the multiresolution stepper.
It now leverages keyword arguments and introspection for more flexible and maintainable operator calls.
This change enhances code readability and reduces the risk of errors when adding new operators.
(perf) Introduce two new scheduling strategies for the MRES algorithm
Merged and resolved conflicts of the latest XLB/main into dev
…So no need to further multiply by rho in KBC.
…mplifying the function signatures across multiple classes.
* Fixed some runtime bugs

* fixed some naming/spelling errors

* removed some debugging comments.

* Introduced a new file `cell_type.py` containing boundary-mask constants for fluid voxelss to replace hardcoded values with the new constants.

* Applied renaming of 254 to SFV to function names
- Unified multi-resolution recursion builder in `simulation_manager.py` to streamline the construction of simulation steps.
- Refactored nse_multires_stepper for improved clarity
- Updated performance optimization handling in `multires_momentum_transfer.py` to support multiple fusion strategies.
…ine and clarify the implementation of multi-resolution streaming steps.
(refactoring) Cleaning up multi-res stepper.
… multi-res by ensuring consistent use of `store_dtype` and `compute_dtype`.
Fixed mixed precision handling of the Neon backend
@hsalehipour hsalehipour requested a review from mehdiataei March 13, 2026 22:13
* (build) Introducing Neon backend as an optional installation parameter.

* (install) new installation mode for neon backend.

* (build) Add ARM support for Neon wheel resolution

* (documentation) Fixes to README and AUTHORS

* (ruff) fixes to the style

* (documentation) fix list of supported python versions
Copy link
Copy Markdown
Contributor

@mehdiataei mehdiataei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @massimim, this looks like a strong contribution. I left code-level comments in this round, but I also wanted to share a few higher-level suggestions.

  1. I think framing Neon as a new compute "backend" may be slightly misleading. Conceptually, Neon here does not seem fully parallel to JAX. The implementation largely reuses Warp functionals and then executes them through Neon handles, containers, and skeletons. In that sense, Neon feels more like an execution/runtime layer on top of Warp code generation than a standalone compute backend.

I think a different framing could make the design clearer:

  • If Neon is fundamentally "Warp math + Neon execution", I would be hesitant to model it as a third peer backend throughout the operator hierarchy.
  • Instead, I would consider splitting the abstraction into:
    • kernel / math backend: JAX vs Warp
    • execution runtime: direct Warp launch vs Neon container/skeleton launch

I think this would make Neon more generic and would better highlight its real strength: the execution model and skeleton abstraction, rather than presenting it as a bespoke backend. It may also make the integration easier to extend and adopt. Several of the current issues feel like symptoms of the abstraction boundary being one layer off. This likely needs some careful design thought, but I would strongly encourage it. To me, the more compelling framing is that Neon provides a skeleton/runtime that Warp kernels can target.

  1. The multires implementation also feels too monolithic. Kernels, schedule planning, state ownership, and runtime graph compilation all live in roughly the same layer, which makes the system harder to reason about, test, and extend.

One possible improvement would be to introduce a typed MultiresPlan / Schedule layer that represents the recursive timestep as explicit operations, then have a separate Neon graph builder that lowers that plan into containers/skeletons. I would also keep simulation state in a manager and keep kernels separate from schedule construction.

  1. The topology and coordinate model feels too implicit at the moment, which makes it harder to debug and reuse. An explicit MultiresTopology or LevelInfo abstraction could help a lot, with methods such as:
  • level_shape(level)
  • global_bounds(level)
  • to_global(level, coords)
  • face_indices(level, side)
  • active_indices(level)

I think making those concepts explicit would improve both clarity and correctness.

  1. For the new BC, I suggest clearly clarifying the Re that it has been validated for.

"jax>=0.8.2", # Base JAX CPU-only requirement
],
extras_require={
"warp": ["warp-lang>=1.10.0"], # Warp backend (single-GPU); included by default for full backend support
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new default install path is broken: pip install xlb no longer guarantees import xlb

Repro:

uv venv .venv --python 3.12
uv pip install --python .venv/bin/python .
 source .venv/bin/activate && python -c "import xlb"

I suggest adding warp-lang to base install or fully decouple top-level imports from Warp so a minimal CPU/JAX install can import successfully

},
python_requires=">=3.11",
dependency_links=["https://storage.googleapis.com/jax-releases/libtpu_releases.html"],
cmdclass={"install": InstallWithNeonHooks},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will bake in python 312. It may break if the user have a different python version.

I think we need to either move to a version-specific XLB wheels per interpreter or move the Neon installation to a runtime script instead.

"warp": ["warp-lang>=1.10.0"], # Warp backend (single-GPU); included by default for full backend support
"cuda": ["jax[cuda13]>=0.8.2"], # For CUDA installations (pip install -U "jax[cuda13]")
"tpu": ["jax[tpu]>=0.8.2"], # For TPU installations
"neon": [_neon_wheel_requirement()],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another issue I found:

if you have warp-lang installed, using [neon] will not remove/replace warp so there will be a conflict.

Check:

uv venv .venv --python 3.12
uv pip install --python .venv/bin/python warp-lang
uv pip install --python .venv/bin/python '.[neon]

nvtx.pop_range()

@Operator.register_backend(ComputeBackend.NEON)
def neon_launch(self, f_0, f_1, bc_mask, missing_mask, omega, timestep):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiresIncompressibleNavierStokesStepper.neon_launch()

will hit TypeError: 'dict' object is not callable using the registered NEON neon_launch

I think you can just remove neon_launch() it is not used in the execution paths.



class MultiresSimulationManager(MultiresIncompressibleNavierStokesStepper):
"""Orchestrates multi-resolution LBM simulations on the Neon backend.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiresSimulationManager(..., force_vector=...) will fail. This should be handled explicitly. Either implement the forcedCollision (which should be easy) or maybe create a check for this.

"0.238" : { "x-velocity" : [24.405,24.168,22.782,20.196,16.970,13.937,12.137,11.757,12.851,14.649,16.780,18.995,21.070,23.335,25.280,27.468,29.262,30.832,32.133,33.102,33.856,34.473,34.922,35.340,35.698,36.039,36.336,36.629,36.906,37.193,37.454,37.691,37.929,38.329,38.611,38.875,39.126,39.414,39.677,39.917,40.097,40.259,40.380,40.478,40.568], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
"0.288" : { "x-velocity" : [21.489,22.225,22.127,21.456,20.404,19.743,19.541,19.909,21.002,22.381,24.018,25.670,27.421,28.998,30.371,31.523,32.406,33.111,33.670,34.155,34.532,34.893,35.240,35.567,35.875,36.158,36.437,36.708,36.974,37.230,37.473,37.709,37.932,38.266,38.515,38.773,39.008,39.270,39.562,39.782,39.962,40.148,40.266,40.369,40.475], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}
}
} No newline at end of file
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls fix the trailing whitespace

velocity_set = velocity_set or DefaultConfig.velocity_set
if compute_backend == ComputeBackend.WARP:
from xlb.grid.warp_grid import WarpGrid

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can delete WarpGrid import

self.bk = neon.Backend(runtime=neon.Backend.Runtime.stream, dev_idx_list=dev_idx_list)
self.bk.info_print()
self.grid = neon.dense.dGrid(backend=self.bk, dim=self.dim, sparsity=None, stencil=self.neon_stencil)
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove these stray pass statements? same for initialize_backend()

self.neon_stencil.append([xval, yval, zval])

self.bk = neon.Backend(runtime=neon.Backend.Runtime.stream, dev_idx_list=dev_idx_list)
self.bk.info_print()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you hide these print statements behind a debug flag? There are some more in grid.

return self.velocity_set

def _initialize_backend(self):
# FIXME@max: for now we hardcode the number of devices to 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants