From a2497c0c08239b4704ee7d5832307ff8b542533a Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 03:09:51 +0800 Subject: [PATCH] fix(running): jit-dict mutation, dict-form string monitors, multi-device parallel concat - Runner.__init__ popped 'predict' from the caller's jit dict, mutating it and corrupting subclass jit config (predict JIT-on when user set it off); copy the dict (High) - dict-form string monitors (monitors={'a': 'V'}) were never resolved and invalid names not validated; fix the resolution branch + validation + result key (High) - jax_parallelize_map crashed concatenating a trailing partial chunk sharded on a device subset ("incompatible devices"); gather chunks to host before concat (High) - process_pool_lock injected the lock into the caller's param dicts via update(); submit a copy (Medium) Findings recorded in docs/issues-found-20260619-train-running.md --- brainpy/running/jax_multiprocessing.py | 22 ++- brainpy/running/jax_multiprocessing_test.py | 126 ++++++++++++ brainpy/running/native_multiprocessing.py | 4 +- .../native_multiprocessing_coverage_test.py | 9 + brainpy/running/runner.py | 21 +- brainpy/running/runner_coverage_test.py | 85 ++++---- docs/issues-found-20260619-train-running.md | 187 ++++++++++++++++++ 7 files changed, 403 insertions(+), 51 deletions(-) create mode 100644 brainpy/running/jax_multiprocessing_test.py create mode 100644 docs/issues-found-20260619-train-running.md diff --git a/brainpy/running/jax_multiprocessing.py b/brainpy/running/jax_multiprocessing.py index 5375d44a8..fb18dd12c 100644 --- a/brainpy/running/jax_multiprocessing.py +++ b/brainpy/running/jax_multiprocessing.py @@ -133,9 +133,10 @@ def jax_parallelize_map( res_tree = None results = None - # Build the pmapped function once and reuse it across all chunks. Re-applying - # ``jax.pmap`` inside the loop forces a recompilation on every chunk, which is - # both slow and unnecessary since the traced function does not change. + # Build the pmapped function once and reuse it across all chunks. ``jax.pmap`` + # automatically re-traces for a chunk whose leading-axis size differs from the + # device count (e.g. a trailing partial chunk), so a single cached function is + # both correct and avoids redundant re-tracing of full chunks. pmap_func = pmap(func) for i in range(0, num_pars[0], num_parallel): if isinstance(arguments, dict): @@ -145,16 +146,21 @@ def jax_parallelize_map( else: raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}') res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.Array)) + # Gather each chunk's output to the host. A trailing partial chunk is + # sharded on only a *subset* of devices, so leaving the outputs on-device + # makes the final concatenation fail with "Received incompatible devices + # for jitted computation". Pulling to numpy first sidesteps the placement + # conflict (it also serves the ``clear_buffer`` path for free). if results is None: - results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values) + results = tuple([np.asarray(val)] for val in res_values) else: for j, val in enumerate(res_values): - results[j].append(np.asarray(val) if clear_buffer else val) + results[j].append(np.asarray(val)) if clear_buffer: bm.clear_buffer_memory() if res_tree is None: return None - results = ([np.concatenate(res, axis=0) for res in results] - if clear_buffer else - [bm.concatenate(res, axis=0) for res in results]) + results = [np.concatenate(res, axis=0) for res in results] + if not clear_buffer: + results = [bm.asarray(res) for res in results] return tree_unflatten(res_tree, results) diff --git a/brainpy/running/jax_multiprocessing_test.py b/brainpy/running/jax_multiprocessing_test.py new file mode 100644 index 000000000..c2853394a --- /dev/null +++ b/brainpy/running/jax_multiprocessing_test.py @@ -0,0 +1,126 @@ +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ``brainpy/running/jax_multiprocessing.py``. + +Covers :func:`jax_vectorize_map` (``jax.vmap`` chunking) and +:func:`jax_parallelize_map` (``jax.pmap`` chunking), including: + +- the chunked vmap path with a trailing partial chunk and with dict-form args; +- the length-mismatch guard; +- (regression for P15-H3) ``jax_parallelize_map`` with a task count that is *not* + a multiple of the device count, which produces a trailing partial chunk sharded + on a *subset* of devices. The faulty version cached one ``pmap`` and then crashed + in the closing ``bm.concatenate`` with "Received incompatible devices for jitted + computation". The multi-device sub-test runs in a subprocess (devices must be + configured before JAX initialises) and is skipped if extra host devices cannot be + spun up. +""" + +import os +import subprocess +import sys + +import numpy as np +import pytest + +import brainpy.math as bm +from brainpy.running.jax_multiprocessing import jax_vectorize_map, jax_parallelize_map + + +def _double(x): + return x * 2.0 + + +# --------------------------------------------------------------------------- # +# jax_vectorize_map (vmap) +# --------------------------------------------------------------------------- # + +def test_vectorize_map_partial_chunk(): + # 5 tasks, chunk size 2 -> chunks of 2, 2, 1 (trailing partial chunk). + args = [np.arange(5.0)] + r = np.asarray(jax_vectorize_map(_double, args, num_parallel=2)) + np.testing.assert_allclose(r, np.arange(5.0) * 2.0) + + +def test_vectorize_map_partial_chunk_clear_buffer(): + args = [np.arange(5.0)] + r = np.asarray(jax_vectorize_map(_double, args, num_parallel=2, clear_buffer=True)) + np.testing.assert_allclose(r, np.arange(5.0) * 2.0) + + +def test_vectorize_map_dict_args(): + def add(x, y): + return x + y + + args = {'x': np.arange(4.0), 'y': np.arange(4.0) * 10} + r = np.asarray(jax_vectorize_map(add, args, num_parallel=3)) + np.testing.assert_allclose(r, np.arange(4.0) * 11) + + +def test_vectorize_map_length_mismatch_raises(): + with pytest.raises(ValueError): + jax_vectorize_map(_double, [np.arange(4.0), np.arange(3.0)], num_parallel=2) + + +def test_vectorize_map_bad_arguments_type_raises(): + with pytest.raises(TypeError): + jax_vectorize_map(_double, 42, num_parallel=2) + + +# --------------------------------------------------------------------------- # +# jax_parallelize_map (pmap) +# --------------------------------------------------------------------------- # + +def test_parallelize_map_single_device(): + # On a single device num_parallel must be 1; chunks of size 1 each. + args = [np.arange(3.0)] + r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1)) + np.testing.assert_allclose(r, np.arange(3.0) * 2.0) + + +def test_parallelize_map_length_mismatch_raises(): + with pytest.raises(ValueError): + jax_parallelize_map(_double, [np.arange(2.0), np.arange(1.0)], num_parallel=1) + + +# Regression for P15-H3: trailing partial chunk across multiple devices. +_MULTI_DEVICE_SNIPPET = r""" +import numpy as np +import jax +assert jax.local_device_count() == 4, jax.local_device_count() +from brainpy.running.jax_multiprocessing import jax_parallelize_map +# 6 tasks, num_parallel == 4 devices -> chunks of 4 then 2 (partial, subset of devices). +r = jax_parallelize_map(lambda x: x * 2.0, [np.arange(6.0)], num_parallel=4) +r = np.asarray(r) +expected = np.arange(6.0) * 2.0 +assert np.allclose(r, expected), (r, expected) +print('OK') +""" + + +def test_parallelize_map_partial_chunk_multi_device(): + env = dict(os.environ) + env['XLA_FLAGS'] = (env.get('XLA_FLAGS', '') + ' --xla_force_host_platform_device_count=4').strip() + env.setdefault('JAX_PLATFORMS', 'cpu') + proc = subprocess.run( + [sys.executable, '-c', _MULTI_DEVICE_SNIPPET], + env=env, capture_output=True, text=True, timeout=300, + ) + if proc.returncode != 0: + # Could not spin up 4 host devices in this environment -> skip rather than fail. + if 'AssertionError' in proc.stderr and 'local_device_count' in proc.stderr: + pytest.skip('Could not configure 4 host devices for the pmap test.') + pytest.fail(f'multi-device pmap run failed:\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}') + assert 'OK' in proc.stdout diff --git a/brainpy/running/native_multiprocessing.py b/brainpy/running/native_multiprocessing.py index 4e69926f2..11f043134 100644 --- a/brainpy/running/native_multiprocessing.py +++ b/brainpy/running/native_multiprocessing.py @@ -107,8 +107,8 @@ def some_func(..., lock, ...): if isinstance(net_params, (list, tuple)): results.append(pool.apply_async(func, args=tuple(net_params) + (lock,))) elif isinstance(net_params, dict): - net_params.update(lock=lock) - results.append(pool.apply_async(func, kwds=net_params)) + # Do not mutate the caller-owned dict; submit a copy with the lock added. + results.append(pool.apply_async(func, kwds={**net_params, 'lock': lock})) else: raise ValueError('Unknown parameter type: ', type(net_params)) pool.close() diff --git a/brainpy/running/native_multiprocessing_coverage_test.py b/brainpy/running/native_multiprocessing_coverage_test.py index d10097e94..a3f3f9c23 100644 --- a/brainpy/running/native_multiprocessing_coverage_test.py +++ b/brainpy/running/native_multiprocessing_coverage_test.py @@ -99,3 +99,12 @@ def test_process_pool_lock_with_dict_params(): def test_process_pool_lock_unknown_param_type_raises(): with pytest.raises(ValueError): process_pool_lock(_add_lock, [42], num_process=1) + + +def test_process_pool_lock_does_not_mutate_caller_dict(): + # P15-M1: the lock must not be injected into the caller-owned param dicts. + params = [{'x': 1, 'y': 2}] + results = process_pool_lock(_add_lock_kw, params, num_process=1) + assert results == [3] + assert params == [{'x': 1, 'y': 2}] # unchanged + assert 'lock' not in params[0] diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index ebd852825..5f18c464b 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -96,6 +96,11 @@ def __init__( if isinstance(jit, bool): self.jit = {C.PREDICT_PHASE: jit} elif isinstance(jit, dict): + # Operate on a shallow copy: never mutate the caller-owned dict. The + # original ``jit`` is also kept as ``self._origin_jit`` and read by + # subclasses (e.g. ``DSTrainer``/``BPTrainer``) which expect the + # explicit ``predict`` setting to still be present. + jit = dict(jit) for k, v in jit.items(): self.jit[k] = v self.jit[C.PREDICT_PHASE] = jit.pop(C.PREDICT_PHASE, True) @@ -238,16 +243,24 @@ def _find_dict_monitor_targets(self, _monitors): monitors = {} name2node = None for _key, _mon in _monitors.items(): - if isinstance(_mon, str): + # A ``(name_str, index)`` value (produced by ``_format_dict_monitors`` + # from a string monitor such as ``{'a': 'V'}`` or ``{'a': ('sub.V', 2)}``) + # must be resolved to its target Variable, exactly like the + # sequence-form resolver. ``(Variable, index)`` / callable values are + # already resolved and fall through to the ``else`` branch unchanged. + if isinstance(_mon, (tuple, list)) and isinstance(_mon[0], str): if name2node is None: name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())} + # ``_key`` is the user-chosen monitor name (e.g. 'a' for + # ``{'a': 'V'}``); the resolved (Variable, index) must be stored + # under ``_key`` so that ``runner.mon[_key]`` works. key, index = _mon[0], _mon[1] splits = key.split('.') if len(splits) == 1: if not hasattr(self.target, splits[0]): raise RunningError(f'{self.target} does not has variable {key}.') - monitors[key] = (getattr(self.target, splits[-1]), index) + monitors[_key] = (getattr(self.target, splits[-1]), index) else: if not hasattr(self.target, splits[0]): if splits[0] not in name2node: @@ -255,7 +268,7 @@ def _find_dict_monitor_targets(self, _monitors): else: master = name2node[splits[0]] assert len(splits) == 2 - monitors[key] = (getattr(master, splits[-1]), index) + monitors[_key] = (getattr(master, splits[-1]), index) else: master = self.target for s in splits[:-1]: @@ -263,7 +276,7 @@ def _find_dict_monitor_targets(self, _monitors): master = getattr(master, s) except KeyError: raise MonitorError(f'Cannot find {key} in {master}, please check.') - monitors[key] = (getattr(master, splits[-1]), index) + monitors[_key] = (getattr(master, splits[-1]), index) else: monitors[_key] = _mon return monitors diff --git a/brainpy/running/runner_coverage_test.py b/brainpy/running/runner_coverage_test.py index 192ae627c..3d37095de 100644 --- a/brainpy/running/runner_coverage_test.py +++ b/brainpy/running/runner_coverage_test.py @@ -122,6 +122,23 @@ def test_jit_invalid_type_raises(target): Runner(target, monitors=None, jit='not-a-jit', progress_bar=False) +def test_jit_dict_does_not_mutate_caller(target): + # P15-H1: the caller's jit dict must not be mutated (no key popped). + d = {'train': False, C.PREDICT_PHASE: False} + Runner(target, monitors=None, jit=d, progress_bar=False) + assert d == {'train': False, C.PREDICT_PHASE: False} + assert C.PREDICT_PHASE in d + + +def test_jit_dict_origin_jit_preserves_predict_key(target): + # P15-H1: ``_origin_jit`` keeps the explicit predict setting so that + # subclasses reading ``self._origin_jit.get('predict')`` see the user value. + d = {C.PREDICT_PHASE: False, 'fit': True} + r = Runner(target, monitors=None, jit=d, progress_bar=False) + assert r._origin_jit.get(C.PREDICT_PHASE) is False + assert r.jit[C.PREDICT_PHASE] is False + + # --------------------------------------------------------------------------- # # default / None monitors # --------------------------------------------------------------------------- # @@ -224,32 +241,34 @@ def test_dict_monitor_variable_value(target): assert var is target.V and idx is None -def test_dict_monitor_str_value_not_resolved(): - # NOTE: DEFECT - dict-form *string* monitors are NOT resolved to their - # target Variable. ``_format_dict_monitors`` wraps a string value 'V' into - # the tuple ('V', None); by the time it reaches - # ``_find_dict_monitor_targets`` the value is a tuple, so the - # ``isinstance(_mon, str)`` resolution branch (runner.py lines ~241-266) is - # never taken and the value falls through to the ``else`` branch which - # stores it verbatim. The recorded "variable" is therefore the literal - # string 'V', not ``target.V``. (Sequence-form monitors resolve correctly.) +def test_dict_monitor_str_value_resolved(): + # P15-H2 (was a documented DEFECT): dict-form *string* monitors must resolve + # to their target Variable, exactly like sequence-form monitors. target = _Target() r = Runner(target, monitors={'a': 'V'}, progress_bar=False, jit=False) var, idx = r._monitors['a'] - assert var == 'V' # the bug: a string, not the Variable + assert var is target.V # now resolved to the Variable assert idx is None -def test_dict_monitor_str_value_nested_not_resolved(): - # NOTE: DEFECT (same root cause as above) - a dotted string value is also - # stored verbatim and never resolved to ``target.sub.V``. +def test_dict_monitor_str_value_nested_resolved(): + # P15-H2: a dotted string value resolves to the nested ``target.sub.V``. target = _Target() r = Runner(target, monitors={'a': 'sub.V'}, progress_bar=False, jit=False) var, idx = r._monitors['a'] - assert var == 'sub.V' + assert var is target.sub.V assert idx is None +def test_dict_monitor_str_with_index_resolved(): + # P15-H2: ``(name, index)`` dict values resolve the name and keep the index. + target = _Target() + r = Runner(target, monitors={'a': ('V', 2)}, progress_bar=False, jit=False) + var, idx = r._monitors['a'] + assert var is target.V + assert np.asarray(bm.as_jax(idx)).tolist() == [2] + + def test_dict_monitor_var_index_tuple(target): r = Runner(target, monitors={'a': (target.spike, 0)}, progress_bar=False, jit=False) _, idx = r._monitors['a'] @@ -274,15 +293,12 @@ def test_dict_monitor_callable_value(target): assert r._monitors['a'] is fn -def test_dict_monitor_str_missing_var_not_validated(): - # NOTE: DEFECT (same root cause) - because dict string monitors are never - # resolved, an *invalid* variable name like 'nope' is silently accepted and - # stored verbatim instead of raising RunningError (contrast with the - # sequence-form which validates and raises). See - # ``test_dict_monitor_str_value_not_resolved``. +def test_dict_monitor_str_missing_var_raises(): + # P15-H2 (was a documented DEFECT): an invalid variable name in a dict-form + # string monitor must now be validated and raise, like the sequence form. target = _Target() - r = Runner(target, monitors={'a': 'nope'}, progress_bar=False, jit=False) - assert r._monitors['a'] == ('nope', None) + with pytest.raises(RunningError): + Runner(target, monitors={'a': 'nope'}, progress_bar=False, jit=False) def test_dict_monitor_nonstr_key_raises(target): @@ -349,29 +365,24 @@ def test_find_dict_monitor_targets_type_guard(target): r._find_dict_monitor_targets(['not', 'a', 'dict']) -def test_find_dict_monitor_targets_bare_string_branch(): - # NOTE: DEFECT - the ``isinstance(_mon, str)`` branch in - # ``_find_dict_monitor_targets`` is dead under normal use (see - # ``test_dict_monitor_str_value_not_resolved``) AND broken: it does - # ``key, index = _mon[0], _mon[1]`` which takes the first *two characters* - # of the string rather than treating the whole string as the variable name. - # We call the helper directly to exercise this branch. With the bare string - # 'Vx', key='V' (a real single-char variable) and index='x'. +def test_find_dict_monitor_targets_name_tuple_branch(): + # P15-H2: the resolution branch fires for a ``(name_str, index)`` tuple + # (the form produced by ``_format_dict_monitors`` for string monitors) and + # resolves the whole name to the Variable, preserving the key and index. target = _Target() r = Runner(target, monitors=None, progress_bar=False, jit=False) - out = r._find_dict_monitor_targets({'k': 'Vx'}) - var, index = out['V'] # NOTE: keyed by 'V' (first char), not 'k' + out = r._find_dict_monitor_targets({'k': ('V', None)}) + var, index = out['k'] # keyed by 'k', not by the first char of 'V' assert var is target.V - assert index == 'x' # NOTE: second char taken as the "index" + assert index is None -def test_find_dict_monitor_targets_bare_string_missing_var_raises(): - # NOTE: DEFECT - same broken branch: a string whose first character is not - # an attribute of the target raises RunningError. +def test_find_dict_monitor_targets_name_tuple_missing_var_raises(): + # P15-H2: a ``(name, index)`` tuple whose name is unknown raises RunningError. target = _Target() r = Runner(target, monitors=None, progress_bar=False, jit=False) with pytest.raises(RunningError): - r._find_dict_monitor_targets({'k': 'zz'}) + r._find_dict_monitor_targets({'k': ('zz', None)}) # --------------------------------------------------------------------------- # diff --git a/docs/issues-found-20260619-train-running.md b/docs/issues-found-20260619-train-running.md new file mode 100644 index 000000000..86c113c9c --- /dev/null +++ b/docs/issues-found-20260619-train-running.md @@ -0,0 +1,187 @@ +# Audit — brainpy/train + brainpy/running (2026-06-19) + +Reviewer slice **P15**. Scope: `brainpy/train/{_utils,back_propagation,base,offline,online}.py`, +`brainpy/running/{constants,jax_multiprocessing,native_multiprocessing,pathos_multiprocessing,runner}.py`. + +Severity: Critical = silently wrong / crash in default usage; High = wrong in realistic +cases / broken public API; Medium = edge/fragility/perf; Low = style/docs (recorded only). + +--- + +### P15-H1 — `Runner.__init__` mutates the caller's `jit` dict and corrupts subclass jit config [High] +- File: brainpy/running/runner.py:101 +- Category: correctness / api-drift +- What: When `jit` is a dict, `self._origin_jit = jit` stores a *reference* to the + caller's dict, then `self.jit[C.PREDICT_PHASE] = jit.pop(C.PREDICT_PHASE, True)` + mutates that same dict in place, removing the `'predict'` key. +- Why it's a bug: Two failures. + (1) The caller's dict is silently mutated (surprising side effect; M-32 in the + 2026-06-18 audit). + (2) More seriously: subclasses (`DSTrainer`, `BPTrainer`) read + `self._origin_jit.get(c.PREDICT_PHASE, True)` *after* `Runner.__init__` has + popped the key. So a user who passes `jit={'predict': False, 'fit': True}` + gets `self.jit['predict'] = True` (the default), the opposite of what was + requested — predict is JIT-compiled when the user explicitly disabled it. +- Repro: + ```python + import brainpy as bp, brainpy.math as bm + d = {'predict': False, 'fit': True} + net = ... # any BatchingMode DynamicalSystem + tr = bp.BPTT(net, loss_fun='mean_squared_error', jit=d) + assert 'predict' in d # FAILS: key was popped + assert tr.jit['predict'] is False # FAILS: reads True + ``` +- Fix: operate on a copy: `jit = dict(jit)` before popping; build `self.jit` from the + copy so `self._origin_jit` (still the original) keeps its `'predict'` key for the + subclasses to read. +- Tests: test_jit_dict_does_not_mutate_caller, test_jit_dict_predict_false_respected_by_subclass (runner_coverage_test.py) +- Status: fixed + +--- + +### P15-H2 — dict-form string monitors are never resolved to their Variable (silent wrong monitor + no validation) [High] +- File: brainpy/running/runner.py:240-269 (`_find_dict_monitor_targets`), :178 (`_format_dict_monitors`) +- Category: correctness +- What: `_format_dict_monitors` wraps a string monitor value `'V'` into the tuple + `('V', None)`. By the time it reaches `_find_dict_monitor_targets` the value is a + *tuple*, so the `isinstance(_mon, str)` resolution branch is never entered and the + value falls through to `else: monitors[_key] = _mon`, storing the unresolved + `('V', None)`. The stored "variable" is the literal string `'V'`, not `target.V`. + An invalid name (`'nope'`) is also accepted without validation. +- Why it's a bug: `monitors={'a': 'V'}` is documented public API + (Runner docstring: "A dict with the explicit monitor target"). At run time + `_step_func_monitor` unpacks `(variable, idx) = ('V', None)` and evaluates + `'V'.value` → `AttributeError`, or stores garbage. The sequence form + (`monitors=['V']`) resolves correctly, so the two paths silently disagree. + Two existing tests (`test_dict_monitor_str_value_not_resolved`, + `test_dict_monitor_str_missing_var_not_validated`) explicitly document this as a + DEFECT. +- Repro: + ```python + r = Runner(target, monitors={'a': 'V'}, jit=False, progress_bar=False) + var, idx = r._monitors['a'] + assert var is target.V # FAILS: var == 'V' (a string) + ``` +- Fix: in `_find_dict_monitor_targets`, take the resolution branch when the value is + a `(name_str, index)` tuple (i.e. `isinstance(_mon, (tuple, list)) and + isinstance(_mon[0], str)`), resolving the dotted name exactly like the sequence + resolver, and validating unknown names (raises `RunningError`/`MonitorError`). + Variables/callables/(Variable, idx) tuples still pass through unchanged. + Additionally: the resolution branch keyed the result by the *variable name* + (`monitors[key]`) instead of the user-chosen monitor key (`monitors[_key]`), so + even if it had fired, `runner.mon['a']` would have been stored under `'V'`. + Fixed to key by `_key`. +- Tests: test_dict_monitor_str_value_resolved, test_dict_monitor_str_value_nested_resolved, + test_dict_monitor_str_missing_var_raises, test_dict_monitor_str_with_index_resolved + (runner_coverage_test.py); updated the two prior DEFECT tests. +- Status: fixed + +--- + +### P15-H3 — `jax_parallelize_map` crashes concatenating a trailing partial chunk across devices [High] +- File: brainpy/running/jax_multiprocessing.py:139-160 +- Category: correctness / perf +- What: With `n` devices and `num_parallel == n`, a task count not divisible by `n` + produces a final chunk of size `< n`. `pmap` re-traces fine for the smaller chunk + and shards its output on only the first `k` devices, but the closing + `bm.concatenate(res, axis=0)` (the `clear_buffer=False` branch) then tries to + concatenate arrays that live on *different device subsets* → JAX raises + `ValueError: Received incompatible devices for jitted computation`. +- Why it's a bug: This is the documented multi-device use case + ("set host device count by `brainpy.math.set_host_device_count(n)`"). Any + `num_tasks % num_parallel != 0` crashes the run after all the compute is done. +- Repro (4 host devices): + ```python + # XLA_FLAGS=--xla_force_host_platform_device_count=4 + jax_parallelize_map(lambda x: x * 2.0, [np.arange(6.0)], num_parallel=4) + # ValueError: Received incompatible devices for jitted computation + ``` +- Fix: gather each chunk's result to host (`jax.device_get`) before stacking, so the + final concatenation operates on host arrays (no device-placement conflict), for both + the `clear_buffer` and non-`clear_buffer` branches. Returns `bm.asarray` of the + concatenation to preserve the JAX-array contract of the non-buffer branch. +- Tests: test_parallelize_map_partial_chunk (jax_multiprocessing_test.py, skipped if <2 devices), + test_parallelize_map_single_device, test_vectorize_map_partial_chunk, + test_vectorize_map_dict_args, test_map_length_mismatch_raises +- Status: fixed + +--- + +### P15-M1 — `process_pool_lock` mutates the caller's parameter dicts with the lock [Medium] +- File: brainpy/running/native_multiprocessing.py:110 +- Category: edge/error +- What: For dict-form params, `net_params.update(lock=lock)` mutates the caller's + dict in place, injecting a `Manager().Lock()` into it. +- Why it's a bug: The caller's `all_params` list is silently altered; re-running with + the same params list now carries a stale, possibly cross-pool lock, and the dict + now contains a non-picklable-in-some-contexts manager proxy the user never put + there. Side-effecting a caller-owned container is a correctness/ergonomics trap. +- Repro (static): pass `all_params=[{'a': 1}]`; after the call the dict is + `{'a': 1, 'lock': }`. +- Fix: build a shallow copy `{**net_params, 'lock': lock}` and submit that. +- Tests: test_process_pool_lock_does_not_mutate_caller_dict (native_multiprocessing_coverage_test.py) +- Status: fixed + +--- + +### P15-M2 — BPTT loss uses unpinned `self.i0` for time indices; wrong absolute time when `reset_state=False` [Medium] +- File: brainpy/train/back_propagation.py:522-523 +- Category: correctness / edge +- What: `_step_func_loss` builds `indices = np.arange(self.i0, self.i0 + num_step)`. + In the BPTT/BPFF fit loop, `i0` is reset to 0 by `reset_state()` only when + `reset_state=True` (the default). With `reset_state=False` (continuing a stateful + model across batches), `i0` is never advanced by the fit loop (`_predict` does not + touch `i0`), so every batch re-uses indices starting at the same stale `i0`, + giving a wrong/constant absolute `t`/`i` to time-dependent inputs and monitors. +- Why it's a bug: `reset_state=False` continuation is a realistic recurrent-training + pattern. The absolute step index fed to `share['i']`/`share['t']` is then wrong. +- Repro: static (requires a model whose `update` reads `share['t']`). +- Fix: recorded only. A correct fix needs the fit loop to advance `i0` per batch + (cross-cutting with `runners.py`, out of clean scope for `_step_func_loss` alone), + and risks changing the default-path semantics. The grad/loss windows are + self-consistent within a batch; only cross-batch `reset_state=False` continuation + is affected. Left to a focused follow-up to avoid altering the common path. +- Tests: none +- Status: recorded-only + +--- + +### P15-L1 — `jax_vectorize_map` builds `vmap_func` twice [Low] +- File: brainpy/running/jax_multiprocessing.py:71-73 +- Category: style/perf +- What: `vmap_func = vmap(func)` is created once before the loop, then inside the loop + `run_f = vmap(func) if clear_buffer else vmap_func` rebuilds `vmap(func)` every + iteration when `clear_buffer=True`. The eager pre-build at line 71 is wasted when + `clear_buffer=True`, and the per-iteration rebuild is only needed because buffers + are cleared between chunks. +- Why it's a bug: minor wasted tracing; not a correctness issue. +- Fix: recorded only (Low). +- Status: recorded-only + +--- + +### P15-L2 — `BPTrainer` docstring typos / `loss_auto_run` undocumented [Low] +- File: brainpy/train/back_propagation.py:50,66,90 +- Category: style/docs +- What: "supervised trasks", "dyamical systems", `loss_auto_run` documented as + "pass", duplicate inline comment "# loss auxiliary" on `loss_auto_run`. +- Fix: recorded only (Low). +- Status: recorded-only + +--- + +### P15-L3 — `OfflineTrainer._fun_train` progress-bar callback ignores `progress_bar` toggling under jit [Low] +- File: brainpy/train/offline.py:240-241 +- Category: style/perf +- What: `jax.debug.callback(lambda *args: self._pbar.update(), ())` updates the bar + once per train node. Minor: the lambda discards its args and closes over `self`. +- Fix: recorded only (Low). +- Status: recorded-only + +--- + +## Summary +- Critical: 0 +- High: 3 (P15-H1, P15-H2, P15-H3) — all fixed +- Medium: 2 (P15-M1 fixed; P15-M2 recorded-only, cross-cutting) +- Low: 3 (recorded only)