From 3f270e34fcdaa87d67b6886be8baad93e8d1aa54 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 18 Jun 2026 15:55:37 +0800 Subject: [PATCH 1/3] Fix CI: replace remove_vmap with brainstate.transform.unvmap brainpy/math/remove_vmap.py defined a custom batching primitive whose batching rule returned jax.interpreters.batching.not_mapped. jax >= 0.10.2 removed that attribute, raising AttributeError: module 'jax.interpreters.batching' has no attribute 'not_mapped' which broke both CI workflows (Continuous Integration and Continuous Integration with Models). The failures surfaced in tests/audit/test_math_core_fixes.py::test_remove_vmap_under_vmap_is_global and tests/simulation/test_net_rate_FHN.py::TestFHN.test1 (via jit_error -> _cond -> remove_vmap under the vmap in delay_couplings). brainstate.transform.unvmap (fixed in brainstate 0.5.1) is the maintained equivalent; it resolves the sentinel compatibly via getattr(batching, 'not_mapped', None) and keeps identical global-reduction semantics and signature. - check.py: both call sites now use brainstate.transform.unvmap directly. - math/remove_vmap.py: reduced to a thin backward-compatible alias that unwraps brainpy.math.Array then delegates to unvmap; removes the broken in-tree primitive. Existing test_remove_vmap_* regression suite stays green. - requirements.txt / pyproject.toml: require brainstate>=0.5.1. --- brainpy/check.py | 8 +-- brainpy/math/remove_vmap.py | 109 +++++++++++------------------------- pyproject.toml | 2 +- requirements.txt | 2 +- 4 files changed, 39 insertions(+), 82 deletions(-) diff --git a/brainpy/check.py b/brainpy/check.py index 30dbc3174..3af3d74fa 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -595,13 +595,13 @@ def _err_jit_false_branch(x): def _cond(err_fun, pred, err_arg): - from brainpy.math.remove_vmap import remove_vmap + from brainstate.transform import unvmap @wraps(err_fun) def true_err_fun(arg, transforms): err_fun(arg) - cond(remove_vmap(pred), + cond(unvmap(pred), partial(_err_jit_true_branch, true_err_fun), _err_jit_false_branch, err_arg) @@ -636,7 +636,7 @@ def jit_error_checking_no_args(pred: bool, err: Exception): err: Exception The error. """ - from brainpy.math.remove_vmap import remove_vmap + from brainstate.transform import unvmap from brainpy.math.interoperability import as_jax assert isinstance(err, Exception), 'Must be instance of Exception.' @@ -644,6 +644,6 @@ def jit_error_checking_no_args(pred: bool, err: Exception): def true_err_fun(arg, transforms): raise err - cond(remove_vmap(as_jax(pred)), + cond(unvmap(as_jax(pred)), lambda: jax.pure_callback(true_err_fun, None), lambda: None) diff --git a/brainpy/math/remove_vmap.py b/brainpy/math/remove_vmap.py index 103ef1925..b06126d15 100644 --- a/brainpy/math/remove_vmap.py +++ b/brainpy/math/remove_vmap.py @@ -13,12 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import jax -import jax.numpy as jnp +from brainstate.transform import unvmap -from brainstate._compatible_import import Primitive -from jax.core import ShapedArray -from jax.interpreters import batching, mlir, xla from .ndarray import Array __all__ = [ @@ -29,16 +25,19 @@ def remove_vmap(x, op='any'): """Reduce ``x`` with ``any``/``all`` *across the vmap batch axis as well*. - This is a custom primitive whose batching rule deliberately collapses the - batch axis into a single **global** scalar. That is, when called under - :func:`jax.vmap`, ``remove_vmap(x, 'any')`` returns one ``bool`` summarising - *all* batch elements together (``True`` if any element of any batch is - truthy), rather than a per-batch vector of results. + This is a thin backward-compatible alias for + :func:`brainstate.transform.unvmap`, which is the actively maintained + implementation. ``unvmap`` collapses the batch axis into a single + **global** scalar: when called under :func:`jax.vmap`, + ``remove_vmap(x, 'any')`` returns one ``bool`` summarising *all* batch + elements together (``True`` if any element of any batch is truthy), rather + than a per-batch vector of results. This is intentional: the primitive is used for global convergence / NaN-style - checks where the batch dimension must not survive the reduction. The batching - rule returns :data:`jax.interpreters.batching.not_mapped`, so the output is a - genuine unbatched scalar (it is *not* broadcast back across the batch axis). + checks where the batch dimension must not survive the reduction. Delegating to + :func:`brainstate.transform.unvmap` keeps BrainPy compatible across JAX + releases (jax ``>= 0.10`` removed ``jax.interpreters.batching.not_mapped``, + which the previous in-tree primitive relied on). Parameters ---------- @@ -52,69 +51,27 @@ def remove_vmap(x, op='any'): jax.Array A scalar boolean. Under :func:`jax.vmap` it is a single global scalar, not a per-batch result. + + Raises + ------ + ValueError + If ``op`` is not supported by :func:`brainstate.transform.unvmap`. + + See Also + -------- + brainstate.transform.unvmap + + Examples + -------- + .. code-block:: python + + >>> import jax.numpy as jnp + >>> from brainpy.math.remove_vmap import remove_vmap + >>> bool(remove_vmap(jnp.array([False, True]))) + True + >>> bool(remove_vmap(jnp.array([True, False]), 'all')) + False """ if isinstance(x, Array): x = x.value - if op == 'any': - return _any_without_vmap(x) - elif op == 'all': - return _all_without_vmap(x) - else: - raise ValueError(f'Do not support type: {op}') - - -_any_no_vmap_prim = Primitive('any_no_vmap') - - -def _any_without_vmap(x): - return _any_no_vmap_prim.bind(x) - - -def _any_without_vmap_imp(x): - return jnp.any(x) - - -def _any_without_vmap_abs(x): - return ShapedArray(shape=(), dtype=jnp.bool_) - - -def _any_without_vmap_batch(x, batch_axes): - (x,) = x - return _any_without_vmap(x), batching.not_mapped - - -_any_no_vmap_prim.def_impl(_any_without_vmap_imp) -_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs) -batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch -if hasattr(xla, "lower_fun"): - xla.register_translation(_any_no_vmap_prim, - xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True)) -mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False)) - -_all_no_vmap_prim = Primitive('all_no_vmap') - - -def _all_without_vmap(x): - return _all_no_vmap_prim.bind(x) - - -def _all_without_vmap_imp(x): - return jnp.all(x) - - -def _all_without_vmap_abs(x): - return ShapedArray(shape=(), dtype=jnp.bool_) - - -def _all_without_vmap_batch(x, batch_axes): - (x,) = x - return _all_without_vmap(x), batching.not_mapped - - -_all_no_vmap_prim.def_impl(_all_without_vmap_imp) -_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs) -batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch -if hasattr(xla, "lower_fun"): - xla.register_translation(_all_no_vmap_prim, - xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True)) -mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False)) + return unvmap(x, op) diff --git a/pyproject.toml b/pyproject.toml index 3449c0264..0edf82774 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "numpy>=1.15", "jax", "tqdm", - "brainstate>=0.2.7", + "brainstate>=0.5.1", "brainunit>=0.2.0", "brainevent>=0.0.7", "braintools>=0.0.9", diff --git a/requirements.txt b/requirements.txt index 09fbf6cd4..e02aa230d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy>=1.15 brainunit brainevent>=0.0.7 braintools>=0.0.9 -brainstate>=0.2.7 +brainstate>=0.5.1 brainpy_state>=0.0.3 jax tqdm From e656dcc3a6d30d9c9d0601407f53631740e5e3b2 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 18 Jun 2026 17:19:33 +0800 Subject: [PATCH 2/3] Fix pytest brainpy/ failures from brainstate 0.5.1 API drift These 9 in-package test failures pre-existed on master once brainstate 0.5.1 is installed; they are unrelated to remove_vmap. while_loop (controls.py): brainstate 0.5.1's while_loop is driven by tracked State, so the canonical brainpy idiom -- body mutates Variable state in place and returns None (often with empty operands, e.g. SpikeTimeGroup.update) -- is valid again. Restore it: when body_fun returns None, thread the operands through unchanged instead of raising. Fixes test_SpikeTimeGroup (input/ input_groups), TestAlpha::test_v1, TestDualExpon::test_dual_expon[_v2], TestWhile::test3. test_autograd.py::TestDebug::test_debug1: RandomState.value is now a typed JAX PRNG key (key) which rejects arithmetic; differentiate through the random draw and input only (drop 'a.value +'). test_random.py: chisquare now supports array-valued df (assert shape instead of expecting NotImplementedError); triangular signature is (left, mode, right, size, ...) so pass the shape via size=. --- brainpy/math/object_transform/controls.py | 12 +++++++----- brainpy/math/object_transform/tests/test_autograd.py | 5 ++++- brainpy/math/tests/test_random.py | 9 ++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py index 4e8942b4e..4253dc48b 100644 --- a/brainpy/math/object_transform/controls.py +++ b/brainpy/math/object_transform/controls.py @@ -597,11 +597,13 @@ def while_loop( def body(x): r = body_fun(*x) if r is None: - raise ValueError( - '`body_fun` of `while_loop` must return the updated operands, ' - 'but got `None`. Returning `None` would leave the operands unchanged ' - 'and the loop condition would never become False, causing an infinite loop.' - ) + # Classic brainpy idiom: ``body_fun`` mutates ``Variable`` state in place + # and returns ``None`` (often with empty ``operands``). brainstate's + # ``while_loop`` tracks that state automatically and the loop condition is + # driven by the mutated state, so the operands are threaded through + # unchanged. Returning ``x`` preserves this behaviour while still allowing + # a functional ``body_fun`` to return the updated operands explicitly. + return x return r return brainstate.transform.while_loop( diff --git a/brainpy/math/object_transform/tests/test_autograd.py b/brainpy/math/object_transform/tests/test_autograd.py index feb7cd190..4fbfe9645 100644 --- a/brainpy/math/object_transform/tests/test_autograd.py +++ b/brainpy/math/object_transform/tests/test_autograd.py @@ -1026,8 +1026,11 @@ def test_debug1(self): a = bm.random.RandomState() def f(b): + # ``a.value`` is a typed JAX PRNG key (``key``) under modern JAX and + # cannot be used in arithmetic; read it (print) but differentiate through + # the random draw and the input ``b`` only. print(a.value) - return a.value + b + a.random() + return b + a.random() f = bm.vector_grad(f, argnums=0) f(1.) diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index a7e09d498..3d039a7b0 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -312,9 +312,10 @@ def test_chisquare1(self): self.assertTrue(a.dtype, float) def test_chisquare2(self): + # Array-valued ``df`` is now supported (broadcasts over the batch axis). br.seed() - with self.assertRaises(NotImplementedError): - a = bm.random.chisquare(df=[2, 3, 4]) + a = bm.random.chisquare(df=[2, 3, 4]) + self.assertTupleEqual(a.shape, (3,)) def test_chisquare3(self): br.seed() @@ -451,8 +452,10 @@ def test_rayleigh(self): self.assertTupleEqual(a.shape, (4, 2)) def test_triangular(self): + # ``triangular`` signature is ``(left, mode, right, size, ...)``; pass the + # output shape via the ``size`` keyword. br.seed() - a = bm.random.triangular((2, 2)) + a = bm.random.triangular(size=(2, 2)) self.assertTupleEqual(a.shape, (2, 2)) def test_vonmises(self): From 092d2f10b7556b6fe3d47da303099cabfe8f6662 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 18 Jun 2026 17:34:14 +0800 Subject: [PATCH 3/3] Update while_loop audit test for restored idiom; force matplotlib Agg - tests/audit/test_object_transform_fixes.py: #830 added M-06 asserting a None-returning while_loop body raises, but that broke the canonical brainpy idiom (body mutates Variable state, returns None) used by real models like SpikeTimeGroup. Now that the wrapper threads operands through unchanged, rewrite the test to assert that idiom works (state-driven termination) and update the module docstring. - conftest.py (new, repo root): force matplotlib onto the non-interactive Agg backend for both test roots (tests/ and brainpy/) so analysis tests that call pyplot.show() never open GUI windows, locally or in CI. - CI-models.yml: set MPLBACKEND=Agg on the pytest steps to match CI.yml. --- .github/workflows/CI-models.yml | 6 ++++++ conftest.py | 13 +++++++++++++ tests/audit/test_object_transform_fixes.py | 21 ++++++++++++--------- 3 files changed, 31 insertions(+), 9 deletions(-) create mode 100644 conftest.py diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml index f1755260d..6bedcafaa 100644 --- a/.github/workflows/CI-models.yml +++ b/.github/workflows/CI-models.yml @@ -36,6 +36,8 @@ jobs: if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip install -e . - name: Test with pytest + env: + MPLBACKEND: Agg # Use non-interactive backend for matplotlib run: | pytest tests/ @@ -58,6 +60,8 @@ jobs: if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip install -e . - name: Test with pytest + env: + MPLBACKEND: Agg # Use non-interactive backend for matplotlib run: | pytest tests/ @@ -80,5 +84,7 @@ jobs: python -m pip install -r requirements-dev.txt pip install -e . - name: Test with pytest + env: + MPLBACKEND: Agg # Use non-interactive backend for matplotlib run: | python -m pytest tests/ diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..e3ca3d3f4 --- /dev/null +++ b/conftest.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +"""Pytest configuration shared by both test roots (``tests/`` and ``brainpy/``). + +Force matplotlib onto the non-interactive ``Agg`` backend so that tests which +exercise the analysis/plotting code paths (e.g. phase-plane and bifurcation +analyses that call ``pyplot.show()``) never try to open a GUI window. This keeps +the suite headless and non-blocking locally and in CI regardless of the +``MPLBACKEND`` environment variable. +""" + +import matplotlib + +matplotlib.use('Agg', force=True) diff --git a/tests/audit/test_object_transform_fixes.py b/tests/audit/test_object_transform_fixes.py index 9e690216d..5e239dbee 100644 --- a/tests/audit/test_object_transform_fixes.py +++ b/tests/audit/test_object_transform_fixes.py @@ -15,7 +15,8 @@ accept a ``Variable``/``Array`` in ``operands``), H-03 (``for_loop(jit=False)`` with a zero-length pytree operand returns ``[]`` instead of crashing), M-03 (``scan`` returns ``(carry, ys)``), M-05 (``ifelse`` builds mutually - exclusive conditions), M-06 (``while_loop`` body returning ``None`` raises). + exclusive conditions), M-06 (``while_loop`` body returning ``None`` threads the + operands through unchanged so the canonical state-mutation idiom keeps working). * ``function.py`` — ``Partial``/``to_object`` behaviour and L-04 (``function`` emits a ``DeprecationWarning``). * ``_utils.py`` — ``warp_to_no_state_input_output`` strips/restores states. @@ -352,16 +353,18 @@ def body_f(x, y): assert len(res) == 2 -def test_while_loop_body_returning_none_raises(): - """M-06: a ``while_loop`` body that returns ``None`` would freeze the carry - and loop forever -- it must raise a clear ``ValueError`` instead.""" +def test_while_loop_body_returning_none_threads_operands(): + """M-06: a ``while_loop`` body that returns ``None`` mutates ``Variable`` state + in place (the canonical brainpy idiom, e.g. ``SpikeTimeGroup.update`` with empty + ``operands``) and threads the operands through unchanged. brainstate tracks the + mutated state, which drives the loop condition, so it must NOT raise.""" + a = bm.Variable(bm.zeros(1)) - def body(x): - # returns None -> illegal - pass + def body(): + a.value += 1. - with pytest.raises(ValueError): - bm.while_loop(body, lambda x: x < 3., 0.) + bm.while_loop(body, lambda: bm.all(a.value < 3.), ()) + assert float(np.asarray(a.value[0])) == 3.0 def test_ifelse_callable_branches_mutually_exclusive():