fix(running): jit-dict mutation, dict string monitors, multi-device parallel concat#852
Conversation
…ice parallel concat
- Runner.__init__ popped 'predict' from the caller's jit dict, mutating it and
corrupting subclass jit config (predict JIT-on when user set it off); copy the dict (High)
- dict-form string monitors (monitors={'a': 'V'}) were never resolved and invalid
names not validated; fix the resolution branch + validation + result key (High)
- jax_parallelize_map crashed concatenating a trailing partial chunk sharded on a
device subset ("incompatible devices"); gather chunks to host before concat (High)
- process_pool_lock injected the lock into the caller's param dicts via update();
submit a copy (Medium)
Findings recorded in docs/issues-found-20260619-train-running.md
Reviewer's GuideFixes several correctness issues in brainpy running infrastructure: prevents unintended mutation of caller dicts in Runner.jit and process_pool_lock, correctly resolves dict-form string monitors to underlying Variables with validation, and makes jax_parallelize_map robust for trailing partial chunks across multiple devices, adding focused regression tests and audit documentation. File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 2 issues, and left some high level feedback:
- In
_find_dict_monitor_targets, the(tuple, list)branch assumes_mon[0]and_mon[1]exist; consider explicitly validating the value length (e.g., exactly 2) and raising a clear error if a user passes a 1‑element tuple/list to avoid confusing index errors. - In
jax_parallelize_map, you currently rely onnp.asarray(val)to pull chunk results to host; usingjax.device_get(on the fullrpytree) would make the intent clearer and be more robust if non‑array leaves or custom array types are ever introduced.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `_find_dict_monitor_targets`, the `(tuple, list)` branch assumes `_mon[0]` and `_mon[1]` exist; consider explicitly validating the value length (e.g., exactly 2) and raising a clear error if a user passes a 1‑element tuple/list to avoid confusing index errors.
- In `jax_parallelize_map`, you currently rely on `np.asarray(val)` to pull chunk results to host; using `jax.device_get` (on the full `r` pytree) would make the intent clearer and be more robust if non‑array leaves or custom array types are ever introduced.
## Individual Comments
### Comment 1
<location path="brainpy/running/jax_multiprocessing.py" line_range="149-153" />
<code_context>
+ # conflict (it also serves the ``clear_buffer`` path for free).
if results is None:
- results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
+ results = tuple([np.asarray(val)] for val in res_values)
else:
for j, val in enumerate(res_values):
- results[j].append(np.asarray(val) if clear_buffer else val)
+ results[j].append(np.asarray(val))
if clear_buffer:
bm.clear_buffer_memory()
</code_context>
<issue_to_address>
**suggestion (performance):** Consider whether forcing all chunk outputs through NumPy is acceptable from a performance/placement perspective.
The new logic always converts chunk outputs to host NumPy arrays, concatenates on CPU, and only then converts back to `bm.asarray` when `clear_buffer=False`. Previously, we stayed on device and used `bm.concatenate` in that case. For large outputs or many chunks, this host round‑trip can significantly hurt GPU/TPU throughput.
If on‑device performance matters here, consider only host-converting the problematic partial chunk(s) or using a concat that can handle subsets of devices. Otherwise, it may be worth explicitly documenting that this path now always routes via host memory.
```suggestion
# Gather each chunk's output to the host.
#
# A trailing partial chunk is sharded on only a *subset* of devices, so
# leaving the outputs on-device makes the final concatenation fail with
# "Received incompatible devices for jitted computation". Pulling to
# NumPy first sidesteps the placement conflict (it also serves the
# ``clear_buffer`` path for free).
#
# NOTE: This path now *always* routes chunk outputs through host memory:
# - every per-chunk result is converted to a host NumPy array here
# - we concatenate on CPU via ``np.concatenate`` below
# - when ``clear_buffer=False``, the concatenated results are then
# transferred back to device via ``bm.asarray``
#
# This trades off device-local concatenation (previous behavior used
# ``bm.concatenate`` on device arrays when ``clear_buffer=False``) for
# robustness to heterogeneous sharding, at the cost of an extra
# host/device round-trip for large outputs or many chunks.
```
</issue_to_address>
### Comment 2
<location path="brainpy/running/jax_multiprocessing_test.py" line_range="86-90" />
<code_context>
+# jax_parallelize_map (pmap)
+# --------------------------------------------------------------------------- #
+
+def test_parallelize_map_single_device():
+ # On a single device num_parallel must be 1; chunks of size 1 each.
+ args = [np.arange(3.0)]
+ r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
+ np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
+
+
</code_context>
<issue_to_address>
**suggestion (testing):** Parallel pmap path lacks coverage for the clear_buffer branch.
`jax_parallelize_map` always gathers to NumPy and only wraps back into `bm.asarray` when `clear_buffer=False`. Current tests only exercise this `clear_buffer=False` path. Please add a test variant with `clear_buffer=True`, e.g.:
```python
def test_parallelize_map_single_device_clear_buffer():
args = [np.arange(3.0)]
r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
r = np.asarray(r)
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
```
This will cover the clear-buffer path and verify the returned array type/shape for the pmap case.
```suggestion
def test_parallelize_map_single_device():
# On a single device num_parallel must be 1; chunks of size 1 each.
args = [np.arange(3.0)]
r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
def test_parallelize_map_single_device_clear_buffer():
# clear_buffer=True should return a backend array; convert to NumPy for comparison.
args = [np.arange(3.0)]
r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
r = np.asarray(r)
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
```
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| # Gather each chunk's output to the host. A trailing partial chunk is | ||
| # sharded on only a *subset* of devices, so leaving the outputs on-device | ||
| # makes the final concatenation fail with "Received incompatible devices | ||
| # for jitted computation". Pulling to numpy first sidesteps the placement | ||
| # conflict (it also serves the ``clear_buffer`` path for free). |
There was a problem hiding this comment.
suggestion (performance): Consider whether forcing all chunk outputs through NumPy is acceptable from a performance/placement perspective.
The new logic always converts chunk outputs to host NumPy arrays, concatenates on CPU, and only then converts back to bm.asarray when clear_buffer=False. Previously, we stayed on device and used bm.concatenate in that case. For large outputs or many chunks, this host round‑trip can significantly hurt GPU/TPU throughput.
If on‑device performance matters here, consider only host-converting the problematic partial chunk(s) or using a concat that can handle subsets of devices. Otherwise, it may be worth explicitly documenting that this path now always routes via host memory.
| # Gather each chunk's output to the host. A trailing partial chunk is | |
| # sharded on only a *subset* of devices, so leaving the outputs on-device | |
| # makes the final concatenation fail with "Received incompatible devices | |
| # for jitted computation". Pulling to numpy first sidesteps the placement | |
| # conflict (it also serves the ``clear_buffer`` path for free). | |
| # Gather each chunk's output to the host. | |
| # | |
| # A trailing partial chunk is sharded on only a *subset* of devices, so | |
| # leaving the outputs on-device makes the final concatenation fail with | |
| # "Received incompatible devices for jitted computation". Pulling to | |
| # NumPy first sidesteps the placement conflict (it also serves the | |
| # ``clear_buffer`` path for free). | |
| # | |
| # NOTE: This path now *always* routes chunk outputs through host memory: | |
| # - every per-chunk result is converted to a host NumPy array here | |
| # - we concatenate on CPU via ``np.concatenate`` below | |
| # - when ``clear_buffer=False``, the concatenated results are then | |
| # transferred back to device via ``bm.asarray`` | |
| # | |
| # This trades off device-local concatenation (previous behavior used | |
| # ``bm.concatenate`` on device arrays when ``clear_buffer=False``) for | |
| # robustness to heterogeneous sharding, at the cost of an extra | |
| # host/device round-trip for large outputs or many chunks. |
| def test_parallelize_map_single_device(): | ||
| # On a single device num_parallel must be 1; chunks of size 1 each. | ||
| args = [np.arange(3.0)] | ||
| r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1)) | ||
| np.testing.assert_allclose(r, np.arange(3.0) * 2.0) |
There was a problem hiding this comment.
suggestion (testing): Parallel pmap path lacks coverage for the clear_buffer branch.
jax_parallelize_map always gathers to NumPy and only wraps back into bm.asarray when clear_buffer=False. Current tests only exercise this clear_buffer=False path. Please add a test variant with clear_buffer=True, e.g.:
def test_parallelize_map_single_device_clear_buffer():
args = [np.arange(3.0)]
r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
r = np.asarray(r)
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)This will cover the clear-buffer path and verify the returned array type/shape for the pmap case.
| def test_parallelize_map_single_device(): | |
| # On a single device num_parallel must be 1; chunks of size 1 each. | |
| args = [np.arange(3.0)] | |
| r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1)) | |
| np.testing.assert_allclose(r, np.arange(3.0) * 2.0) | |
| def test_parallelize_map_single_device(): | |
| # On a single device num_parallel must be 1; chunks of size 1 each. | |
| args = [np.arange(3.0)] | |
| r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1)) | |
| np.testing.assert_allclose(r, np.arange(3.0) * 2.0) | |
| def test_parallelize_map_single_device_clear_buffer(): | |
| # clear_buffer=True should return a backend array; convert to NumPy for comparison. | |
| args = [np.arange(3.0)] | |
| r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True) | |
| r = np.asarray(r) | |
| np.testing.assert_allclose(r, np.arange(3.0) * 2.0) |
Fresh review of
brainpy/train+brainpy/running.Runner.__init__popped'predict'from the caller'sjitdict, mutating it and corrupting subclass jit config (predictJIT-on even when disabled); now copies.monitors={'a': 'V'}) were never resolved (stored the literal string) and invalid names weren't validated; fixed resolution + validation + result key.jax_parallelize_mapcrashed concatenating a trailing partial chunk sharded on a device subset; now gathers to host before concat.process_pool_lockmutated caller param dicts via.update(); now submits a copy.(Recorded for follow-up: BPTT
i0pinning underreset_state=Falsecontinuation — cross-cutting, left unfixed. RLS/ridge/GD math lives inalgorithms/, handled by P16.) In-scope: 108 passed. Findings:docs/issues-found-20260619-train-running.md.Summary by Sourcery
Fix multiple correctness issues in training/running utilities around JIT configuration, monitor resolution, and JAX multi-device parallelism, and document the audit findings.
Bug Fixes:
Enhancements:
Documentation: