Skip to content

fix(running): jit-dict mutation, dict string monitors, multi-device parallel concat#852

Merged
chaoming0625 merged 1 commit into
masterfrom
fix/audit-20260619-train-running
Jun 18, 2026
Merged

fix(running): jit-dict mutation, dict string monitors, multi-device parallel concat#852
chaoming0625 merged 1 commit into
masterfrom
fix/audit-20260619-train-running

Conversation

@chaoming0625

@chaoming0625 chaoming0625 commented Jun 18, 2026

Copy link
Copy Markdown
Member

Fresh review of brainpy/train + brainpy/running.

  • HighRunner.__init__ popped 'predict' from the caller's jit dict, mutating it and corrupting subclass jit config (predict JIT-on even when disabled); now copies.
  • High — dict-form string monitors (monitors={'a': 'V'}) were never resolved (stored the literal string) and invalid names weren't validated; fixed resolution + validation + result key.
  • Highjax_parallelize_map crashed concatenating a trailing partial chunk sharded on a device subset; now gathers to host before concat.
  • Mediumprocess_pool_lock mutated caller param dicts via .update(); now submits a copy.

(Recorded for follow-up: BPTT i0 pinning under reset_state=False continuation — cross-cutting, left unfixed. RLS/ridge/GD math lives in algorithms/, handled by P16.) In-scope: 108 passed. Findings: docs/issues-found-20260619-train-running.md.

Summary by Sourcery

Fix multiple correctness issues in training/running utilities around JIT configuration, monitor resolution, and JAX multi-device parallelism, and document the audit findings.

Bug Fixes:

  • Prevent Runner from mutating caller-provided JIT configuration dicts and ensure the original predict-phase setting remains visible to subclasses.
  • Resolve dict-form string monitors to their target variables (including dotted names and indexed forms), validate invalid names, and keep monitor results keyed by the user-specified name.
  • Avoid device-placement errors in jax_parallelize_map by gathering per-chunk outputs to host before concatenation, including for trailing partial chunks on a subset of devices.
  • Stop process_pool_lock from mutating caller parameter dicts when injecting the multiprocessing lock.

Enhancements:

  • Refine dict-monitor target resolution so name-based tuples are correctly resolved while already-resolved values pass through unchanged.
  • Standardize jax_parallelize_map outputs to concatenate on NumPy arrays and convert back to BrainPy arrays when buffers are not cleared.
  • Add regression and coverage tests for Runner JIT dict handling, dict-form monitors, JAX vmap/pmap chunking (including multi-device partial chunks), and native multiprocessing lock handling.

Documentation:

  • Add an audit report documenting identified issues and their status for brainpy/train and brainpy/running modules.

…ice parallel concat

- Runner.__init__ popped 'predict' from the caller's jit dict, mutating it and
  corrupting subclass jit config (predict JIT-on when user set it off); copy the dict (High)
- dict-form string monitors (monitors={'a': 'V'}) were never resolved and invalid
  names not validated; fix the resolution branch + validation + result key (High)
- jax_parallelize_map crashed concatenating a trailing partial chunk sharded on a
  device subset ("incompatible devices"); gather chunks to host before concat (High)
- process_pool_lock injected the lock into the caller's param dicts via update();
  submit a copy (Medium)

Findings recorded in docs/issues-found-20260619-train-running.md
@chaoming0625 chaoming0625 merged commit 4ec07f4 into master Jun 18, 2026
2 of 5 checks passed
@sourcery-ai

sourcery-ai Bot commented Jun 18, 2026

Copy link
Copy Markdown

Reviewer's Guide

Fixes several correctness issues in brainpy running infrastructure: prevents unintended mutation of caller dicts in Runner.jit and process_pool_lock, correctly resolves dict-form string monitors to underlying Variables with validation, and makes jax_parallelize_map robust for trailing partial chunks across multiple devices, adding focused regression tests and audit documentation.

File-Level Changes

Change Details Files
Stop Runner.init from mutating caller-provided jit dictionaries and ensure subclasses can still read the original predict setting.
  • In Runner.init, when jit is a dict, create a shallow copy before processing to avoid mutating the caller’s dict.
  • Populate self.jit from the copied dict, then derive the predict phase entry without popping from the original.
  • Preserve the caller’s original jit dict in self._origin_jit so subclasses reading self._origin_jit.get('predict') see the user-specified value.
  • Add coverage tests ensuring the caller jit dict remains unchanged and _origin_jit preserves the explicit predict setting.
brainpy/running/runner.py
brainpy/running/runner_coverage_test.py
Fix dict-form string monitors so they resolve to Variables, validate names, and keep the user’s monitor key.
  • Change _find_dict_monitor_targets to treat (name_str, index) pairs as unresolved string monitors that must be looked up on the target, mirroring sequence-form resolution.
  • Resolve dotted names (including nested node names) to the correct Variable and preserve the index, raising RunningError/MonitorError for invalid names.
  • Store resolved monitors under the original dict key (_key) rather than under the variable name so runner.mon[key] works as documented.
  • Update and extend tests to assert proper resolution of simple, nested, and indexed string monitors, and to require errors for unknown variable names.
brainpy/running/runner.py
brainpy/running/runner_coverage_test.py
Make jax_parallelize_map handle trailing partial chunks on a subset of devices by gathering chunk outputs to host before concatenation.
  • Retain a single cached pmap(func) but clarify via comments that pmap will re-trace as needed for partial chunks with smaller leading dimensions.
  • After each chunk run, flatten tree leaves and convert them to NumPy arrays immediately, building per-leaf lists of host arrays regardless of clear_buffer.
  • Always concatenate results using np.concatenate on host arrays, then convert back to bm.asarray when clear_buffer is False to preserve API expectations.
  • Add a dedicated jax_multiprocessing_test module that exercises vectorized and parallel paths, including partial chunks, dict arguments, bad types, and a multi-device regression case in a subprocess configured with multiple host devices.
brainpy/running/jax_multiprocessing.py
brainpy/running/jax_multiprocessing_test.py
Prevent process_pool_lock from mutating caller-owned parameter dicts when injecting the lock argument.
  • Change the dict-handling branch of process_pool_lock to construct a new kwds dict with the lock added, instead of calling update on the original dict.
  • Ensure apply_async is called with the copied dict so caller data structures remain unchanged after the pool run.
  • Add a regression test verifying that the original params list and its dict elements are not modified and that the lock key is absent from caller data.
brainpy/running/native_multiprocessing.py
brainpy/running/native_multiprocessing_coverage_test.py
Record the audit results for the train/running review, including fixed issues and remaining follow-ups.
  • Add an audit markdown file summarizing all identified issues in brainpy/train and brainpy/running, with IDs, severities, locations, and fix status.
  • Document the now-fixed jit dict mutation, dict monitor resolution, jax_parallelize_map partial-chunk crash, and process_pool_lock mutation.
  • Record additional medium/low issues that are left for follow-up, such as BPTT i0 pinning semantics and minor performance/style/documentation items.
docs/issues-found-20260619-train-running.md

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 2 issues, and left some high level feedback:

  • In _find_dict_monitor_targets, the (tuple, list) branch assumes _mon[0] and _mon[1] exist; consider explicitly validating the value length (e.g., exactly 2) and raising a clear error if a user passes a 1‑element tuple/list to avoid confusing index errors.
  • In jax_parallelize_map, you currently rely on np.asarray(val) to pull chunk results to host; using jax.device_get (on the full r pytree) would make the intent clearer and be more robust if non‑array leaves or custom array types are ever introduced.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In `_find_dict_monitor_targets`, the `(tuple, list)` branch assumes `_mon[0]` and `_mon[1]` exist; consider explicitly validating the value length (e.g., exactly 2) and raising a clear error if a user passes a 1‑element tuple/list to avoid confusing index errors.
- In `jax_parallelize_map`, you currently rely on `np.asarray(val)` to pull chunk results to host; using `jax.device_get` (on the full `r` pytree) would make the intent clearer and be more robust if non‑array leaves or custom array types are ever introduced.

## Individual Comments

### Comment 1
<location path="brainpy/running/jax_multiprocessing.py" line_range="149-153" />
<code_context>
+        # conflict (it also serves the ``clear_buffer`` path for free).
         if results is None:
-            results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
+            results = tuple([np.asarray(val)] for val in res_values)
         else:
             for j, val in enumerate(res_values):
-                results[j].append(np.asarray(val) if clear_buffer else val)
+                results[j].append(np.asarray(val))
         if clear_buffer:
             bm.clear_buffer_memory()
</code_context>
<issue_to_address>
**suggestion (performance):** Consider whether forcing all chunk outputs through NumPy is acceptable from a performance/placement perspective.

The new logic always converts chunk outputs to host NumPy arrays, concatenates on CPU, and only then converts back to `bm.asarray` when `clear_buffer=False`. Previously, we stayed on device and used `bm.concatenate` in that case. For large outputs or many chunks, this host round‑trip can significantly hurt GPU/TPU throughput.

If on‑device performance matters here, consider only host-converting the problematic partial chunk(s) or using a concat that can handle subsets of devices. Otherwise, it may be worth explicitly documenting that this path now always routes via host memory.

```suggestion
        # Gather each chunk's output to the host.
        #
        # A trailing partial chunk is sharded on only a *subset* of devices, so
        # leaving the outputs on-device makes the final concatenation fail with
        # "Received incompatible devices for jitted computation". Pulling to
        # NumPy first sidesteps the placement conflict (it also serves the
        # ``clear_buffer`` path for free).
        #
        # NOTE: This path now *always* routes chunk outputs through host memory:
        #   - every per-chunk result is converted to a host NumPy array here
        #   - we concatenate on CPU via ``np.concatenate`` below
        #   - when ``clear_buffer=False``, the concatenated results are then
        #     transferred back to device via ``bm.asarray``
        #
        # This trades off device-local concatenation (previous behavior used
        # ``bm.concatenate`` on device arrays when ``clear_buffer=False``) for
        # robustness to heterogeneous sharding, at the cost of an extra
        # host/device round-trip for large outputs or many chunks.
```
</issue_to_address>

### Comment 2
<location path="brainpy/running/jax_multiprocessing_test.py" line_range="86-90" />
<code_context>
+# jax_parallelize_map (pmap)
+# --------------------------------------------------------------------------- #
+
+def test_parallelize_map_single_device():
+    # On a single device num_parallel must be 1; chunks of size 1 each.
+    args = [np.arange(3.0)]
+    r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
+    np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
+
+
</code_context>
<issue_to_address>
**suggestion (testing):** Parallel pmap path lacks coverage for the clear_buffer branch.

`jax_parallelize_map` always gathers to NumPy and only wraps back into `bm.asarray` when `clear_buffer=False`. Current tests only exercise this `clear_buffer=False` path. Please add a test variant with `clear_buffer=True`, e.g.:

```python
def test_parallelize_map_single_device_clear_buffer():
    args = [np.arange(3.0)]
    r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
    r = np.asarray(r)
    np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
```

This will cover the clear-buffer path and verify the returned array type/shape for the pmap case.

```suggestion
def test_parallelize_map_single_device():
    # On a single device num_parallel must be 1; chunks of size 1 each.
    args = [np.arange(3.0)]
    r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
    np.testing.assert_allclose(r, np.arange(3.0) * 2.0)


def test_parallelize_map_single_device_clear_buffer():
    # clear_buffer=True should return a backend array; convert to NumPy for comparison.
    args = [np.arange(3.0)]
    r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
    r = np.asarray(r)
    np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
```
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment on lines +149 to +153
# Gather each chunk's output to the host. A trailing partial chunk is
# sharded on only a *subset* of devices, so leaving the outputs on-device
# makes the final concatenation fail with "Received incompatible devices
# for jitted computation". Pulling to numpy first sidesteps the placement
# conflict (it also serves the ``clear_buffer`` path for free).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider whether forcing all chunk outputs through NumPy is acceptable from a performance/placement perspective.

The new logic always converts chunk outputs to host NumPy arrays, concatenates on CPU, and only then converts back to bm.asarray when clear_buffer=False. Previously, we stayed on device and used bm.concatenate in that case. For large outputs or many chunks, this host round‑trip can significantly hurt GPU/TPU throughput.

If on‑device performance matters here, consider only host-converting the problematic partial chunk(s) or using a concat that can handle subsets of devices. Otherwise, it may be worth explicitly documenting that this path now always routes via host memory.

Suggested change
# Gather each chunk's output to the host. A trailing partial chunk is
# sharded on only a *subset* of devices, so leaving the outputs on-device
# makes the final concatenation fail with "Received incompatible devices
# for jitted computation". Pulling to numpy first sidesteps the placement
# conflict (it also serves the ``clear_buffer`` path for free).
# Gather each chunk's output to the host.
#
# A trailing partial chunk is sharded on only a *subset* of devices, so
# leaving the outputs on-device makes the final concatenation fail with
# "Received incompatible devices for jitted computation". Pulling to
# NumPy first sidesteps the placement conflict (it also serves the
# ``clear_buffer`` path for free).
#
# NOTE: This path now *always* routes chunk outputs through host memory:
# - every per-chunk result is converted to a host NumPy array here
# - we concatenate on CPU via ``np.concatenate`` below
# - when ``clear_buffer=False``, the concatenated results are then
# transferred back to device via ``bm.asarray``
#
# This trades off device-local concatenation (previous behavior used
# ``bm.concatenate`` on device arrays when ``clear_buffer=False``) for
# robustness to heterogeneous sharding, at the cost of an extra
# host/device round-trip for large outputs or many chunks.

Comment on lines +86 to +90
def test_parallelize_map_single_device():
# On a single device num_parallel must be 1; chunks of size 1 each.
args = [np.arange(3.0)]
r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Parallel pmap path lacks coverage for the clear_buffer branch.

jax_parallelize_map always gathers to NumPy and only wraps back into bm.asarray when clear_buffer=False. Current tests only exercise this clear_buffer=False path. Please add a test variant with clear_buffer=True, e.g.:

def test_parallelize_map_single_device_clear_buffer():
    args = [np.arange(3.0)]
    r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
    r = np.asarray(r)
    np.testing.assert_allclose(r, np.arange(3.0) * 2.0)

This will cover the clear-buffer path and verify the returned array type/shape for the pmap case.

Suggested change
def test_parallelize_map_single_device():
# On a single device num_parallel must be 1; chunks of size 1 each.
args = [np.arange(3.0)]
r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
def test_parallelize_map_single_device():
# On a single device num_parallel must be 1; chunks of size 1 each.
args = [np.arange(3.0)]
r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
def test_parallelize_map_single_device_clear_buffer():
# clear_buffer=True should return a backend array; convert to NumPy for comparison.
args = [np.arange(3.0)]
r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
r = np.asarray(r)
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)

@github-actions github-actions Bot added documentation Improvements or additions to documentation tests labels Jun 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant