Fix CI: replace remove_vmap with brainstate.transform.unvmap#831
Merged
Conversation
brainpy/math/remove_vmap.py defined a custom batching primitive whose
batching rule returned jax.interpreters.batching.not_mapped. jax >= 0.10.2
removed that attribute, raising
AttributeError: module 'jax.interpreters.batching' has no attribute 'not_mapped'
which broke both CI workflows (Continuous Integration and Continuous
Integration with Models). The failures surfaced in
tests/audit/test_math_core_fixes.py::test_remove_vmap_under_vmap_is_global
and tests/simulation/test_net_rate_FHN.py::TestFHN.test1 (via
jit_error -> _cond -> remove_vmap under the vmap in delay_couplings).
brainstate.transform.unvmap (fixed in brainstate 0.5.1) is the maintained
equivalent; it resolves the sentinel compatibly via
getattr(batching, 'not_mapped', None) and keeps identical global-reduction
semantics and signature.
- check.py: both call sites now use brainstate.transform.unvmap directly.
- math/remove_vmap.py: reduced to a thin backward-compatible alias that
unwraps brainpy.math.Array then delegates to unvmap; removes the broken
in-tree primitive. Existing test_remove_vmap_* regression suite stays green.
- requirements.txt / pyproject.toml: require brainstate>=0.5.1.
Reviewer's GuideReplaces the custom in-tree remove_vmap batching primitive with the maintained brainstate.transform.unvmap helper, updates internal call sites to use unvmap directly, and bumps the brainstate dependency to a version that supports this behavior, fixing CI failures on newer JAX versions. Sequence diagram for jit_error_checking_no_args using brainstate.transform.unvmapsequenceDiagram
participant Caller
participant jit_error_checking_no_args
participant as_jax
participant unvmap
participant cond
participant jax_pure_callback
Caller->>jit_error_checking_no_args: jit_error_checking_no_args(pred, err)
jit_error_checking_no_args->>as_jax: as_jax(pred)
as_jax-->>jit_error_checking_no_args: pred_jax
jit_error_checking_no_args->>unvmap: unvmap(pred_jax)
unvmap-->>jit_error_checking_no_args: global_pred
jit_error_checking_no_args->>cond: cond(global_pred, true_branch, false_branch)
alt [global_pred is True]
cond->>jax_pure_callback: jax.pure_callback(true_err_fun, None)
else [global_pred is False]
cond->>Caller: return None
end
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've left some high level feedback:
- In
brainpy.check, consider consistently using thebrainpy.math.remove_vmapalias instead of importingbrainstate.transform.unvmapdirectly so the implementation can be swapped centrally without touching multiple call sites. - The new
unvmapimports inside_condandjit_error_checking_no_argsare done within the functions; moving these imports to module scope would simplify the code path and avoid repeated imports at runtime.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `brainpy.check`, consider consistently using the `brainpy.math.remove_vmap` alias instead of importing `brainstate.transform.unvmap` directly so the implementation can be swapped centrally without touching multiple call sites.
- The new `unvmap` imports inside `_cond` and `jit_error_checking_no_args` are done within the functions; moving these imports to module scope would simplify the code path and avoid repeated imports at runtime.Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
These 9 in-package test failures pre-existed on master once brainstate 0.5.1 is installed; they are unrelated to remove_vmap. while_loop (controls.py): brainstate 0.5.1's while_loop is driven by tracked State, so the canonical brainpy idiom -- body mutates Variable state in place and returns None (often with empty operands, e.g. SpikeTimeGroup.update) -- is valid again. Restore it: when body_fun returns None, thread the operands through unchanged instead of raising. Fixes test_SpikeTimeGroup (input/ input_groups), TestAlpha::test_v1, TestDualExpon::test_dual_expon[_v2], TestWhile::test3. test_autograd.py::TestDebug::test_debug1: RandomState.value is now a typed JAX PRNG key (key<fry>) which rejects arithmetic; differentiate through the random draw and input only (drop 'a.value +'). test_random.py: chisquare now supports array-valued df (assert shape instead of expecting NotImplementedError); triangular signature is (left, mode, right, size, ...) so pass the shape via size=.
- tests/audit/test_object_transform_fixes.py: #830 added M-06 asserting a None-returning while_loop body raises, but that broke the canonical brainpy idiom (body mutates Variable state, returns None) used by real models like SpikeTimeGroup. Now that the wrapper threads operands through unchanged, rewrite the test to assert that idiom works (state-driven termination) and update the module docstring. - conftest.py (new, repo root): force matplotlib onto the non-interactive Agg backend for both test roots (tests/ and brainpy/) so analysis tests that call pyplot.show() never open GUI windows, locally or in CI. - CI-models.yml: set MPLBACKEND=Agg on the pytest steps to match CI.yml.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes both failing CI workflows on jax
0.10.2+ brainstate0.5.1:pytest tests/) — broken by theremove_vmapbatching.not_mappedremoval.pytest brainpy/) — 9 additional pre-existing failures from brainstate 0.5.1 API drift.1. remove_vmap → brainstate.transform.unvmap
brainpy/math/remove_vmap.pydefined a custom batching primitive whose batching rule returnedjax.interpreters.batching.not_mapped, which jax>= 0.10.2removed:brainstate.transform.unvmap(fixed in brainstate 0.5.1, resolves the sentinel viagetattr(batching, 'not_mapped', None)) is the maintained equivalent with identical global-reduction semantics.check.py: both call sites usebrainstate.transform.unvmapdirectly.math/remove_vmap.py: thin backward-compatible alias delegating tounvmap(unwrapsArray); removes the broken in-tree primitive. Existingtest_remove_vmap_*regression suite stays green.requirements.txt/pyproject.toml: requirebrainstate>=0.5.1.Fixes
test_math_core_fixes.py::test_remove_vmap_under_vmap_is_globalandtest_net_rate_FHN.py::TestFHN.test1.2. brainstate 0.5.1 API-drift failures in
pytest brainpy/(9 tests)while_loop(controls.py) — brainstate'swhile_loopis driven by tracked State, so the canonical brainpy idiom (body mutatesVariablestate in place and returnsNone, often with empty operands, e.g.SpikeTimeGroup.update) is valid. Whenbody_funreturnsNone, thread operands through unchanged instead of raising. Fixestest_SpikeTimeGroup(input/input_groups),TestAlpha::test_v1,TestDualExpon::test_dual_expon[_v2],TestWhile::test3.test_autograd.py::TestDebug::test_debug1—RandomState.valueis now a typed PRNG key (key<fry>) that rejects arithmetic; differentiate through the random draw + input only.test_random.py—chisquarenow supports array-valueddf(assert shape instead of expectingNotImplementedError);triangularsignature is(left, mode, right, size, ...)so pass shape viasize=.Verification
pytest tests/(audit + simulation): 1026 passed.brainpy/files: 167 passed. Fullpytest brainpy/sweep in progress.batching.not_mappedusage remains inbrainpy/.