Skip to content

Fix CI: replace remove_vmap with brainstate.transform.unvmap#831

Merged
chaoming0625 merged 3 commits into
masterfrom
worktree-fix-remove-vmap-unvmap
Jun 18, 2026
Merged

Fix CI: replace remove_vmap with brainstate.transform.unvmap#831
chaoming0625 merged 3 commits into
masterfrom
worktree-fix-remove-vmap-unvmap

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Jun 18, 2026

Copy link
Copy Markdown
Member

Summary

Fixes both failing CI workflows on jax 0.10.2 + brainstate 0.5.1:

  • Continuous Integration with Models (pytest tests/) — broken by the remove_vmap batching.not_mapped removal.
  • Continuous Integration (pytest brainpy/) — 9 additional pre-existing failures from brainstate 0.5.1 API drift.

1. remove_vmap → brainstate.transform.unvmap

brainpy/math/remove_vmap.py defined a custom batching primitive whose batching rule returned jax.interpreters.batching.not_mapped, which jax >= 0.10.2 removed:

AttributeError: module 'jax.interpreters.batching' has no attribute 'not_mapped'

brainstate.transform.unvmap (fixed in brainstate 0.5.1, resolves the sentinel via getattr(batching, 'not_mapped', None)) is the maintained equivalent with identical global-reduction semantics.

  • check.py: both call sites use brainstate.transform.unvmap directly.
  • math/remove_vmap.py: thin backward-compatible alias delegating to unvmap (unwraps Array); removes the broken in-tree primitive. Existing test_remove_vmap_* regression suite stays green.
  • requirements.txt / pyproject.toml: require brainstate>=0.5.1.

Fixes test_math_core_fixes.py::test_remove_vmap_under_vmap_is_global and test_net_rate_FHN.py::TestFHN.test1.

2. brainstate 0.5.1 API-drift failures in pytest brainpy/ (9 tests)

  • while_loop (controls.py) — brainstate's while_loop is driven by tracked State, so the canonical brainpy idiom (body mutates Variable state in place and returns None, often with empty operands, e.g. SpikeTimeGroup.update) is valid. When body_fun returns None, thread operands through unchanged instead of raising. Fixes test_SpikeTimeGroup (input/input_groups), TestAlpha::test_v1, TestDualExpon::test_dual_expon[_v2], TestWhile::test3.
  • test_autograd.py::TestDebug::test_debug1RandomState.value is now a typed PRNG key (key<fry>) that rejects arithmetic; differentiate through the random draw + input only.
  • test_random.pychisquare now supports array-valued df (assert shape instead of expecting NotImplementedError); triangular signature is (left, mode, right, size, ...) so pass shape via size=.

Verification

  • All 11 originally-failing tests pass.
  • pytest tests/ (audit + simulation): 1026 passed.
  • Affected brainpy/ files: 167 passed. Full pytest brainpy/ sweep in progress.
  • No batching.not_mapped usage remains in brainpy/.

brainpy/math/remove_vmap.py defined a custom batching primitive whose
batching rule returned jax.interpreters.batching.not_mapped. jax >= 0.10.2
removed that attribute, raising

    AttributeError: module 'jax.interpreters.batching' has no attribute 'not_mapped'

which broke both CI workflows (Continuous Integration and Continuous
Integration with Models). The failures surfaced in
tests/audit/test_math_core_fixes.py::test_remove_vmap_under_vmap_is_global
and tests/simulation/test_net_rate_FHN.py::TestFHN.test1 (via
jit_error -> _cond -> remove_vmap under the vmap in delay_couplings).

brainstate.transform.unvmap (fixed in brainstate 0.5.1) is the maintained
equivalent; it resolves the sentinel compatibly via
getattr(batching, 'not_mapped', None) and keeps identical global-reduction
semantics and signature.

- check.py: both call sites now use brainstate.transform.unvmap directly.
- math/remove_vmap.py: reduced to a thin backward-compatible alias that
  unwraps brainpy.math.Array then delegates to unvmap; removes the broken
  in-tree primitive. Existing test_remove_vmap_* regression suite stays green.
- requirements.txt / pyproject.toml: require brainstate>=0.5.1.
@sourcery-ai

sourcery-ai Bot commented Jun 18, 2026

Copy link
Copy Markdown

Reviewer's Guide

Replaces the custom in-tree remove_vmap batching primitive with the maintained brainstate.transform.unvmap helper, updates internal call sites to use unvmap directly, and bumps the brainstate dependency to a version that supports this behavior, fixing CI failures on newer JAX versions.

Sequence diagram for jit_error_checking_no_args using brainstate.transform.unvmap

sequenceDiagram
  participant Caller
  participant jit_error_checking_no_args
  participant as_jax
  participant unvmap
  participant cond
  participant jax_pure_callback

  Caller->>jit_error_checking_no_args: jit_error_checking_no_args(pred, err)
  jit_error_checking_no_args->>as_jax: as_jax(pred)
  as_jax-->>jit_error_checking_no_args: pred_jax
  jit_error_checking_no_args->>unvmap: unvmap(pred_jax)
  unvmap-->>jit_error_checking_no_args: global_pred
  jit_error_checking_no_args->>cond: cond(global_pred, true_branch, false_branch)
  alt [global_pred is True]
    cond->>jax_pure_callback: jax.pure_callback(true_err_fun, None)
  else [global_pred is False]
    cond->>Caller: return None
  end
Loading

File-Level Changes

Change Details Files
Replace in-tree remove_vmap primitive with a thin wrapper around brainstate.transform.unvmap for JAX compatibility.
  • Remove custom any/all primitives and their batching, abstract eval, and lowering implementations that depended on jax.interpreters.batching.not_mapped.
  • Introduce brainstate.transform.unvmap import and delegate remove_vmap(x, op) to unvmap after unwrapping brainpy.math.Array.
  • Update remove_vmap docstring to describe the delegation to unvmap, its semantics, supported ops, and examples.
brainpy/math/remove_vmap.py
Update internal error-checking paths to use unvmap directly instead of remove_vmap.
  • Change _cond helper to import unvmap from brainstate.transform and call cond(unvmap(pred), ...) instead of cond(remove_vmap(pred), ...).
  • Change jit_error_checking_no_args to import unvmap and call cond(unvmap(as_jax(pred)), ...) instead of using remove_vmap.
brainpy/check.py
Bump brainstate dependency to a version that provides the fixed unvmap implementation.
  • Increase brainstate minimum version requirement from >=0.2.7 to >=0.5.1 in pyproject configuration.
  • Mirror the brainstate version bump in requirements.txt to keep development and runtime dependencies aligned.
pyproject.toml
requirements.txt

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@github-actions github-actions Bot added dependencies Pull requests that update a dependency file build labels Jun 18, 2026

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've left some high level feedback:

  • In brainpy.check, consider consistently using the brainpy.math.remove_vmap alias instead of importing brainstate.transform.unvmap directly so the implementation can be swapped centrally without touching multiple call sites.
  • The new unvmap imports inside _cond and jit_error_checking_no_args are done within the functions; moving these imports to module scope would simplify the code path and avoid repeated imports at runtime.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In `brainpy.check`, consider consistently using the `brainpy.math.remove_vmap` alias instead of importing `brainstate.transform.unvmap` directly so the implementation can be swapped centrally without touching multiple call sites.
- The new `unvmap` imports inside `_cond` and `jit_error_checking_no_args` are done within the functions; moving these imports to module scope would simplify the code path and avoid repeated imports at runtime.

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

These 9 in-package test failures pre-existed on master once brainstate 0.5.1
is installed; they are unrelated to remove_vmap.

while_loop (controls.py): brainstate 0.5.1's while_loop is driven by tracked
State, so the canonical brainpy idiom -- body mutates Variable state in place
and returns None (often with empty operands, e.g. SpikeTimeGroup.update) -- is
valid again. Restore it: when body_fun returns None, thread the operands
through unchanged instead of raising. Fixes test_SpikeTimeGroup (input/
input_groups), TestAlpha::test_v1, TestDualExpon::test_dual_expon[_v2],
TestWhile::test3.

test_autograd.py::TestDebug::test_debug1: RandomState.value is now a typed JAX
PRNG key (key<fry>) which rejects arithmetic; differentiate through the random
draw and input only (drop 'a.value +').

test_random.py: chisquare now supports array-valued df (assert shape instead of
expecting NotImplementedError); triangular signature is (left, mode, right,
size, ...) so pass the shape via size=.
@github-actions github-actions Bot added the tests label Jun 18, 2026
- tests/audit/test_object_transform_fixes.py: #830 added M-06 asserting a
  None-returning while_loop body raises, but that broke the canonical brainpy
  idiom (body mutates Variable state, returns None) used by real models like
  SpikeTimeGroup. Now that the wrapper threads operands through unchanged,
  rewrite the test to assert that idiom works (state-driven termination) and
  update the module docstring.
- conftest.py (new, repo root): force matplotlib onto the non-interactive Agg
  backend for both test roots (tests/ and brainpy/) so analysis tests that call
  pyplot.show() never open GUI windows, locally or in CI.
- CI-models.yml: set MPLBACKEND=Agg on the pytest steps to match CI.yml.
@github-actions github-actions Bot added the ci-cd label Jun 18, 2026
@chaoming0625 chaoming0625 merged commit bf077f5 into master Jun 18, 2026
18 checks passed
@chaoming0625 chaoming0625 deleted the worktree-fix-remove-vmap-unvmap branch June 18, 2026 09:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

build ci-cd dependencies Pull requests that update a dependency file tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant