Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/CI-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ jobs:
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install -e .
- name: Test with pytest
env:
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
run: |
pytest tests/

Expand All @@ -58,6 +60,8 @@ jobs:
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install -e .
- name: Test with pytest
env:
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
run: |
pytest tests/

Expand All @@ -80,5 +84,7 @@ jobs:
python -m pip install -r requirements-dev.txt
pip install -e .
- name: Test with pytest
env:
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
run: |
python -m pytest tests/
8 changes: 4 additions & 4 deletions brainpy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,13 +595,13 @@ def _err_jit_false_branch(x):


def _cond(err_fun, pred, err_arg):
from brainpy.math.remove_vmap import remove_vmap
from brainstate.transform import unvmap

@wraps(err_fun)
def true_err_fun(arg, transforms):
err_fun(arg)

cond(remove_vmap(pred),
cond(unvmap(pred),
partial(_err_jit_true_branch, true_err_fun),
_err_jit_false_branch,
err_arg)
Expand Down Expand Up @@ -636,14 +636,14 @@ def jit_error_checking_no_args(pred: bool, err: Exception):
err: Exception
The error.
"""
from brainpy.math.remove_vmap import remove_vmap
from brainstate.transform import unvmap
from brainpy.math.interoperability import as_jax

assert isinstance(err, Exception), 'Must be instance of Exception.'

def true_err_fun(arg, transforms):
raise err

cond(remove_vmap(as_jax(pred)),
cond(unvmap(as_jax(pred)),
lambda: jax.pure_callback(true_err_fun, None),
lambda: None)
12 changes: 7 additions & 5 deletions brainpy/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,11 +597,13 @@ def while_loop(
def body(x):
r = body_fun(*x)
if r is None:
raise ValueError(
'`body_fun` of `while_loop` must return the updated operands, '
'but got `None`. Returning `None` would leave the operands unchanged '
'and the loop condition would never become False, causing an infinite loop.'
)
# Classic brainpy idiom: ``body_fun`` mutates ``Variable`` state in place
# and returns ``None`` (often with empty ``operands``). brainstate's
# ``while_loop`` tracks that state automatically and the loop condition is
# driven by the mutated state, so the operands are threaded through
# unchanged. Returning ``x`` preserves this behaviour while still allowing
# a functional ``body_fun`` to return the updated operands explicitly.
return x
return r

return brainstate.transform.while_loop(
Expand Down
5 changes: 4 additions & 1 deletion brainpy/math/object_transform/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,8 +1026,11 @@ def test_debug1(self):
a = bm.random.RandomState()

def f(b):
# ``a.value`` is a typed JAX PRNG key (``key<fry>``) under modern JAX and
# cannot be used in arithmetic; read it (print) but differentiate through
# the random draw and the input ``b`` only.
print(a.value)
return a.value + b + a.random()
return b + a.random()

f = bm.vector_grad(f, argnums=0)
f(1.)
Expand Down
109 changes: 33 additions & 76 deletions brainpy/math/remove_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import jax
import jax.numpy as jnp
from brainstate.transform import unvmap

from brainstate._compatible_import import Primitive
from jax.core import ShapedArray
from jax.interpreters import batching, mlir, xla
from .ndarray import Array

__all__ = [
Expand All @@ -29,16 +25,19 @@
def remove_vmap(x, op='any'):
"""Reduce ``x`` with ``any``/``all`` *across the vmap batch axis as well*.

This is a custom primitive whose batching rule deliberately collapses the
batch axis into a single **global** scalar. That is, when called under
:func:`jax.vmap`, ``remove_vmap(x, 'any')`` returns one ``bool`` summarising
*all* batch elements together (``True`` if any element of any batch is
truthy), rather than a per-batch vector of results.
This is a thin backward-compatible alias for
:func:`brainstate.transform.unvmap`, which is the actively maintained
implementation. ``unvmap`` collapses the batch axis into a single
**global** scalar: when called under :func:`jax.vmap`,
``remove_vmap(x, 'any')`` returns one ``bool`` summarising *all* batch
elements together (``True`` if any element of any batch is truthy), rather
than a per-batch vector of results.

This is intentional: the primitive is used for global convergence / NaN-style
checks where the batch dimension must not survive the reduction. The batching
rule returns :data:`jax.interpreters.batching.not_mapped`, so the output is a
genuine unbatched scalar (it is *not* broadcast back across the batch axis).
checks where the batch dimension must not survive the reduction. Delegating to
:func:`brainstate.transform.unvmap` keeps BrainPy compatible across JAX
releases (jax ``>= 0.10`` removed ``jax.interpreters.batching.not_mapped``,
which the previous in-tree primitive relied on).

Parameters
----------
Expand All @@ -52,69 +51,27 @@ def remove_vmap(x, op='any'):
jax.Array
A scalar boolean. Under :func:`jax.vmap` it is a single global scalar,
not a per-batch result.

Raises
------
ValueError
If ``op`` is not supported by :func:`brainstate.transform.unvmap`.

See Also
--------
brainstate.transform.unvmap

Examples
--------
.. code-block:: python

>>> import jax.numpy as jnp
>>> from brainpy.math.remove_vmap import remove_vmap
>>> bool(remove_vmap(jnp.array([False, True])))
True
>>> bool(remove_vmap(jnp.array([True, False]), 'all'))
False
"""
if isinstance(x, Array):
x = x.value
if op == 'any':
return _any_without_vmap(x)
elif op == 'all':
return _all_without_vmap(x)
else:
raise ValueError(f'Do not support type: {op}')


_any_no_vmap_prim = Primitive('any_no_vmap')


def _any_without_vmap(x):
return _any_no_vmap_prim.bind(x)


def _any_without_vmap_imp(x):
return jnp.any(x)


def _any_without_vmap_abs(x):
return ShapedArray(shape=(), dtype=jnp.bool_)


def _any_without_vmap_batch(x, batch_axes):
(x,) = x
return _any_without_vmap(x), batching.not_mapped


_any_no_vmap_prim.def_impl(_any_without_vmap_imp)
_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
if hasattr(xla, "lower_fun"):
xla.register_translation(_any_no_vmap_prim,
xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True))
mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))

_all_no_vmap_prim = Primitive('all_no_vmap')


def _all_without_vmap(x):
return _all_no_vmap_prim.bind(x)


def _all_without_vmap_imp(x):
return jnp.all(x)


def _all_without_vmap_abs(x):
return ShapedArray(shape=(), dtype=jnp.bool_)


def _all_without_vmap_batch(x, batch_axes):
(x,) = x
return _all_without_vmap(x), batching.not_mapped


_all_no_vmap_prim.def_impl(_all_without_vmap_imp)
_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
if hasattr(xla, "lower_fun"):
xla.register_translation(_all_no_vmap_prim,
xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))
return unvmap(x, op)
9 changes: 6 additions & 3 deletions brainpy/math/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,10 @@ def test_chisquare1(self):
self.assertTrue(a.dtype, float)

def test_chisquare2(self):
# Array-valued ``df`` is now supported (broadcasts over the batch axis).
br.seed()
with self.assertRaises(NotImplementedError):
a = bm.random.chisquare(df=[2, 3, 4])
a = bm.random.chisquare(df=[2, 3, 4])
self.assertTupleEqual(a.shape, (3,))

def test_chisquare3(self):
br.seed()
Expand Down Expand Up @@ -451,8 +452,10 @@ def test_rayleigh(self):
self.assertTupleEqual(a.shape, (4, 2))

def test_triangular(self):
# ``triangular`` signature is ``(left, mode, right, size, ...)``; pass the
# output shape via the ``size`` keyword.
br.seed()
a = bm.random.triangular((2, 2))
a = bm.random.triangular(size=(2, 2))
self.assertTupleEqual(a.shape, (2, 2))

def test_vonmises(self):
Expand Down
13 changes: 13 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
"""Pytest configuration shared by both test roots (``tests/`` and ``brainpy/``).

Force matplotlib onto the non-interactive ``Agg`` backend so that tests which
exercise the analysis/plotting code paths (e.g. phase-plane and bifurcation
analyses that call ``pyplot.show()``) never try to open a GUI window. This keeps
the suite headless and non-blocking locally and in CI regardless of the
``MPLBACKEND`` environment variable.
"""

import matplotlib

matplotlib.use('Agg', force=True)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"numpy>=1.15",
"jax",
"tqdm",
"brainstate>=0.2.7",
"brainstate>=0.5.1",
"brainunit>=0.2.0",
"brainevent>=0.0.7",
"braintools>=0.0.9",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy>=1.15
brainunit
brainevent>=0.0.7
braintools>=0.0.9
brainstate>=0.2.7
brainstate>=0.5.1
brainpy_state>=0.0.3
jax
tqdm
21 changes: 12 additions & 9 deletions tests/audit/test_object_transform_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
accept a ``Variable``/``Array`` in ``operands``), H-03 (``for_loop(jit=False)``
with a zero-length pytree operand returns ``[]`` instead of crashing),
M-03 (``scan`` returns ``(carry, ys)``), M-05 (``ifelse`` builds mutually
exclusive conditions), M-06 (``while_loop`` body returning ``None`` raises).
exclusive conditions), M-06 (``while_loop`` body returning ``None`` threads the
operands through unchanged so the canonical state-mutation idiom keeps working).
* ``function.py`` — ``Partial``/``to_object`` behaviour and L-04 (``function``
emits a ``DeprecationWarning``).
* ``_utils.py`` — ``warp_to_no_state_input_output`` strips/restores states.
Expand Down Expand Up @@ -352,16 +353,18 @@ def body_f(x, y):
assert len(res) == 2


def test_while_loop_body_returning_none_raises():
"""M-06: a ``while_loop`` body that returns ``None`` would freeze the carry
and loop forever -- it must raise a clear ``ValueError`` instead."""
def test_while_loop_body_returning_none_threads_operands():
"""M-06: a ``while_loop`` body that returns ``None`` mutates ``Variable`` state
in place (the canonical brainpy idiom, e.g. ``SpikeTimeGroup.update`` with empty
``operands``) and threads the operands through unchanged. brainstate tracks the
mutated state, which drives the loop condition, so it must NOT raise."""
a = bm.Variable(bm.zeros(1))

def body(x):
# returns None -> illegal
pass
def body():
a.value += 1.

with pytest.raises(ValueError):
bm.while_loop(body, lambda x: x < 3., 0.)
bm.while_loop(body, lambda: bm.all(a.value < 3.), ())
assert float(np.asarray(a.value[0])) == 3.0


def test_ifelse_callable_branches_mutually_exclusive():
Expand Down
Loading