diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml index f1755260d..6bedcafaa 100644 --- a/.github/workflows/CI-models.yml +++ b/.github/workflows/CI-models.yml @@ -36,6 +36,8 @@ jobs: if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip install -e . - name: Test with pytest + env: + MPLBACKEND: Agg # Use non-interactive backend for matplotlib run: | pytest tests/ @@ -58,6 +60,8 @@ jobs: if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip install -e . - name: Test with pytest + env: + MPLBACKEND: Agg # Use non-interactive backend for matplotlib run: | pytest tests/ @@ -80,5 +84,7 @@ jobs: python -m pip install -r requirements-dev.txt pip install -e . - name: Test with pytest + env: + MPLBACKEND: Agg # Use non-interactive backend for matplotlib run: | python -m pytest tests/ diff --git a/brainpy/check.py b/brainpy/check.py index 30dbc3174..3af3d74fa 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -595,13 +595,13 @@ def _err_jit_false_branch(x): def _cond(err_fun, pred, err_arg): - from brainpy.math.remove_vmap import remove_vmap + from brainstate.transform import unvmap @wraps(err_fun) def true_err_fun(arg, transforms): err_fun(arg) - cond(remove_vmap(pred), + cond(unvmap(pred), partial(_err_jit_true_branch, true_err_fun), _err_jit_false_branch, err_arg) @@ -636,7 +636,7 @@ def jit_error_checking_no_args(pred: bool, err: Exception): err: Exception The error. """ - from brainpy.math.remove_vmap import remove_vmap + from brainstate.transform import unvmap from brainpy.math.interoperability import as_jax assert isinstance(err, Exception), 'Must be instance of Exception.' @@ -644,6 +644,6 @@ def jit_error_checking_no_args(pred: bool, err: Exception): def true_err_fun(arg, transforms): raise err - cond(remove_vmap(as_jax(pred)), + cond(unvmap(as_jax(pred)), lambda: jax.pure_callback(true_err_fun, None), lambda: None) diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py index 4e8942b4e..4253dc48b 100644 --- a/brainpy/math/object_transform/controls.py +++ b/brainpy/math/object_transform/controls.py @@ -597,11 +597,13 @@ def while_loop( def body(x): r = body_fun(*x) if r is None: - raise ValueError( - '`body_fun` of `while_loop` must return the updated operands, ' - 'but got `None`. Returning `None` would leave the operands unchanged ' - 'and the loop condition would never become False, causing an infinite loop.' - ) + # Classic brainpy idiom: ``body_fun`` mutates ``Variable`` state in place + # and returns ``None`` (often with empty ``operands``). brainstate's + # ``while_loop`` tracks that state automatically and the loop condition is + # driven by the mutated state, so the operands are threaded through + # unchanged. Returning ``x`` preserves this behaviour while still allowing + # a functional ``body_fun`` to return the updated operands explicitly. + return x return r return brainstate.transform.while_loop( diff --git a/brainpy/math/object_transform/tests/test_autograd.py b/brainpy/math/object_transform/tests/test_autograd.py index feb7cd190..4fbfe9645 100644 --- a/brainpy/math/object_transform/tests/test_autograd.py +++ b/brainpy/math/object_transform/tests/test_autograd.py @@ -1026,8 +1026,11 @@ def test_debug1(self): a = bm.random.RandomState() def f(b): + # ``a.value`` is a typed JAX PRNG key (``key``) under modern JAX and + # cannot be used in arithmetic; read it (print) but differentiate through + # the random draw and the input ``b`` only. print(a.value) - return a.value + b + a.random() + return b + a.random() f = bm.vector_grad(f, argnums=0) f(1.) diff --git a/brainpy/math/remove_vmap.py b/brainpy/math/remove_vmap.py index 103ef1925..b06126d15 100644 --- a/brainpy/math/remove_vmap.py +++ b/brainpy/math/remove_vmap.py @@ -13,12 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import jax -import jax.numpy as jnp +from brainstate.transform import unvmap -from brainstate._compatible_import import Primitive -from jax.core import ShapedArray -from jax.interpreters import batching, mlir, xla from .ndarray import Array __all__ = [ @@ -29,16 +25,19 @@ def remove_vmap(x, op='any'): """Reduce ``x`` with ``any``/``all`` *across the vmap batch axis as well*. - This is a custom primitive whose batching rule deliberately collapses the - batch axis into a single **global** scalar. That is, when called under - :func:`jax.vmap`, ``remove_vmap(x, 'any')`` returns one ``bool`` summarising - *all* batch elements together (``True`` if any element of any batch is - truthy), rather than a per-batch vector of results. + This is a thin backward-compatible alias for + :func:`brainstate.transform.unvmap`, which is the actively maintained + implementation. ``unvmap`` collapses the batch axis into a single + **global** scalar: when called under :func:`jax.vmap`, + ``remove_vmap(x, 'any')`` returns one ``bool`` summarising *all* batch + elements together (``True`` if any element of any batch is truthy), rather + than a per-batch vector of results. This is intentional: the primitive is used for global convergence / NaN-style - checks where the batch dimension must not survive the reduction. The batching - rule returns :data:`jax.interpreters.batching.not_mapped`, so the output is a - genuine unbatched scalar (it is *not* broadcast back across the batch axis). + checks where the batch dimension must not survive the reduction. Delegating to + :func:`brainstate.transform.unvmap` keeps BrainPy compatible across JAX + releases (jax ``>= 0.10`` removed ``jax.interpreters.batching.not_mapped``, + which the previous in-tree primitive relied on). Parameters ---------- @@ -52,69 +51,27 @@ def remove_vmap(x, op='any'): jax.Array A scalar boolean. Under :func:`jax.vmap` it is a single global scalar, not a per-batch result. + + Raises + ------ + ValueError + If ``op`` is not supported by :func:`brainstate.transform.unvmap`. + + See Also + -------- + brainstate.transform.unvmap + + Examples + -------- + .. code-block:: python + + >>> import jax.numpy as jnp + >>> from brainpy.math.remove_vmap import remove_vmap + >>> bool(remove_vmap(jnp.array([False, True]))) + True + >>> bool(remove_vmap(jnp.array([True, False]), 'all')) + False """ if isinstance(x, Array): x = x.value - if op == 'any': - return _any_without_vmap(x) - elif op == 'all': - return _all_without_vmap(x) - else: - raise ValueError(f'Do not support type: {op}') - - -_any_no_vmap_prim = Primitive('any_no_vmap') - - -def _any_without_vmap(x): - return _any_no_vmap_prim.bind(x) - - -def _any_without_vmap_imp(x): - return jnp.any(x) - - -def _any_without_vmap_abs(x): - return ShapedArray(shape=(), dtype=jnp.bool_) - - -def _any_without_vmap_batch(x, batch_axes): - (x,) = x - return _any_without_vmap(x), batching.not_mapped - - -_any_no_vmap_prim.def_impl(_any_without_vmap_imp) -_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs) -batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch -if hasattr(xla, "lower_fun"): - xla.register_translation(_any_no_vmap_prim, - xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True)) -mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False)) - -_all_no_vmap_prim = Primitive('all_no_vmap') - - -def _all_without_vmap(x): - return _all_no_vmap_prim.bind(x) - - -def _all_without_vmap_imp(x): - return jnp.all(x) - - -def _all_without_vmap_abs(x): - return ShapedArray(shape=(), dtype=jnp.bool_) - - -def _all_without_vmap_batch(x, batch_axes): - (x,) = x - return _all_without_vmap(x), batching.not_mapped - - -_all_no_vmap_prim.def_impl(_all_without_vmap_imp) -_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs) -batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch -if hasattr(xla, "lower_fun"): - xla.register_translation(_all_no_vmap_prim, - xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True)) -mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False)) + return unvmap(x, op) diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index a7e09d498..3d039a7b0 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -312,9 +312,10 @@ def test_chisquare1(self): self.assertTrue(a.dtype, float) def test_chisquare2(self): + # Array-valued ``df`` is now supported (broadcasts over the batch axis). br.seed() - with self.assertRaises(NotImplementedError): - a = bm.random.chisquare(df=[2, 3, 4]) + a = bm.random.chisquare(df=[2, 3, 4]) + self.assertTupleEqual(a.shape, (3,)) def test_chisquare3(self): br.seed() @@ -451,8 +452,10 @@ def test_rayleigh(self): self.assertTupleEqual(a.shape, (4, 2)) def test_triangular(self): + # ``triangular`` signature is ``(left, mode, right, size, ...)``; pass the + # output shape via the ``size`` keyword. br.seed() - a = bm.random.triangular((2, 2)) + a = bm.random.triangular(size=(2, 2)) self.assertTupleEqual(a.shape, (2, 2)) def test_vonmises(self): diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000..e3ca3d3f4 --- /dev/null +++ b/conftest.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +"""Pytest configuration shared by both test roots (``tests/`` and ``brainpy/``). + +Force matplotlib onto the non-interactive ``Agg`` backend so that tests which +exercise the analysis/plotting code paths (e.g. phase-plane and bifurcation +analyses that call ``pyplot.show()``) never try to open a GUI window. This keeps +the suite headless and non-blocking locally and in CI regardless of the +``MPLBACKEND`` environment variable. +""" + +import matplotlib + +matplotlib.use('Agg', force=True) diff --git a/pyproject.toml b/pyproject.toml index 3449c0264..0edf82774 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "numpy>=1.15", "jax", "tqdm", - "brainstate>=0.2.7", + "brainstate>=0.5.1", "brainunit>=0.2.0", "brainevent>=0.0.7", "braintools>=0.0.9", diff --git a/requirements.txt b/requirements.txt index 09fbf6cd4..e02aa230d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy>=1.15 brainunit brainevent>=0.0.7 braintools>=0.0.9 -brainstate>=0.2.7 +brainstate>=0.5.1 brainpy_state>=0.0.3 jax tqdm diff --git a/tests/audit/test_object_transform_fixes.py b/tests/audit/test_object_transform_fixes.py index 9e690216d..5e239dbee 100644 --- a/tests/audit/test_object_transform_fixes.py +++ b/tests/audit/test_object_transform_fixes.py @@ -15,7 +15,8 @@ accept a ``Variable``/``Array`` in ``operands``), H-03 (``for_loop(jit=False)`` with a zero-length pytree operand returns ``[]`` instead of crashing), M-03 (``scan`` returns ``(carry, ys)``), M-05 (``ifelse`` builds mutually - exclusive conditions), M-06 (``while_loop`` body returning ``None`` raises). + exclusive conditions), M-06 (``while_loop`` body returning ``None`` threads the + operands through unchanged so the canonical state-mutation idiom keeps working). * ``function.py`` — ``Partial``/``to_object`` behaviour and L-04 (``function`` emits a ``DeprecationWarning``). * ``_utils.py`` — ``warp_to_no_state_input_output`` strips/restores states. @@ -352,16 +353,18 @@ def body_f(x, y): assert len(res) == 2 -def test_while_loop_body_returning_none_raises(): - """M-06: a ``while_loop`` body that returns ``None`` would freeze the carry - and loop forever -- it must raise a clear ``ValueError`` instead.""" +def test_while_loop_body_returning_none_threads_operands(): + """M-06: a ``while_loop`` body that returns ``None`` mutates ``Variable`` state + in place (the canonical brainpy idiom, e.g. ``SpikeTimeGroup.update`` with empty + ``operands``) and threads the operands through unchanged. brainstate tracks the + mutated state, which drives the loop condition, so it must NOT raise.""" + a = bm.Variable(bm.zeros(1)) - def body(x): - # returns None -> illegal - pass + def body(): + a.value += 1. - with pytest.raises(ValueError): - bm.while_loop(body, lambda x: x < 3., 0.) + bm.while_loop(body, lambda: bm.all(a.value < 3.), ()) + assert float(np.asarray(a.value[0])) == 3.0 def test_ifelse_callable_branches_mutually_exclusive():