Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions brainpy/running/jax_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ def jax_parallelize_map(

res_tree = None
results = None
# Build the pmapped function once and reuse it across all chunks. Re-applying
# ``jax.pmap`` inside the loop forces a recompilation on every chunk, which is
# both slow and unnecessary since the traced function does not change.
# Build the pmapped function once and reuse it across all chunks. ``jax.pmap``
# automatically re-traces for a chunk whose leading-axis size differs from the
# device count (e.g. a trailing partial chunk), so a single cached function is
# both correct and avoids redundant re-tracing of full chunks.
pmap_func = pmap(func)
for i in range(0, num_pars[0], num_parallel):
if isinstance(arguments, dict):
Expand All @@ -145,16 +146,21 @@ def jax_parallelize_map(
else:
raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}')
res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.Array))
# Gather each chunk's output to the host. A trailing partial chunk is
# sharded on only a *subset* of devices, so leaving the outputs on-device
# makes the final concatenation fail with "Received incompatible devices
# for jitted computation". Pulling to numpy first sidesteps the placement
# conflict (it also serves the ``clear_buffer`` path for free).
Comment on lines +149 to +153

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider whether forcing all chunk outputs through NumPy is acceptable from a performance/placement perspective.

The new logic always converts chunk outputs to host NumPy arrays, concatenates on CPU, and only then converts back to bm.asarray when clear_buffer=False. Previously, we stayed on device and used bm.concatenate in that case. For large outputs or many chunks, this host round‑trip can significantly hurt GPU/TPU throughput.

If on‑device performance matters here, consider only host-converting the problematic partial chunk(s) or using a concat that can handle subsets of devices. Otherwise, it may be worth explicitly documenting that this path now always routes via host memory.

Suggested change
# Gather each chunk's output to the host. A trailing partial chunk is
# sharded on only a *subset* of devices, so leaving the outputs on-device
# makes the final concatenation fail with "Received incompatible devices
# for jitted computation". Pulling to numpy first sidesteps the placement
# conflict (it also serves the ``clear_buffer`` path for free).
# Gather each chunk's output to the host.
#
# A trailing partial chunk is sharded on only a *subset* of devices, so
# leaving the outputs on-device makes the final concatenation fail with
# "Received incompatible devices for jitted computation". Pulling to
# NumPy first sidesteps the placement conflict (it also serves the
# ``clear_buffer`` path for free).
#
# NOTE: This path now *always* routes chunk outputs through host memory:
# - every per-chunk result is converted to a host NumPy array here
# - we concatenate on CPU via ``np.concatenate`` below
# - when ``clear_buffer=False``, the concatenated results are then
# transferred back to device via ``bm.asarray``
#
# This trades off device-local concatenation (previous behavior used
# ``bm.concatenate`` on device arrays when ``clear_buffer=False``) for
# robustness to heterogeneous sharding, at the cost of an extra
# host/device round-trip for large outputs or many chunks.

if results is None:
results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
results = tuple([np.asarray(val)] for val in res_values)
else:
for j, val in enumerate(res_values):
results[j].append(np.asarray(val) if clear_buffer else val)
results[j].append(np.asarray(val))
if clear_buffer:
bm.clear_buffer_memory()
if res_tree is None:
return None
results = ([np.concatenate(res, axis=0) for res in results]
if clear_buffer else
[bm.concatenate(res, axis=0) for res in results])
results = [np.concatenate(res, axis=0) for res in results]
if not clear_buffer:
results = [bm.asarray(res) for res in results]
return tree_unflatten(res_tree, results)
126 changes: 126 additions & 0 deletions brainpy/running/jax_multiprocessing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ``brainpy/running/jax_multiprocessing.py``.

Covers :func:`jax_vectorize_map` (``jax.vmap`` chunking) and
:func:`jax_parallelize_map` (``jax.pmap`` chunking), including:

- the chunked vmap path with a trailing partial chunk and with dict-form args;
- the length-mismatch guard;
- (regression for P15-H3) ``jax_parallelize_map`` with a task count that is *not*
a multiple of the device count, which produces a trailing partial chunk sharded
on a *subset* of devices. The faulty version cached one ``pmap`` and then crashed
in the closing ``bm.concatenate`` with "Received incompatible devices for jitted
computation". The multi-device sub-test runs in a subprocess (devices must be
configured before JAX initialises) and is skipped if extra host devices cannot be
spun up.
"""

import os
import subprocess
import sys

import numpy as np
import pytest

import brainpy.math as bm
from brainpy.running.jax_multiprocessing import jax_vectorize_map, jax_parallelize_map


def _double(x):
return x * 2.0


# --------------------------------------------------------------------------- #
# jax_vectorize_map (vmap)
# --------------------------------------------------------------------------- #

def test_vectorize_map_partial_chunk():
# 5 tasks, chunk size 2 -> chunks of 2, 2, 1 (trailing partial chunk).
args = [np.arange(5.0)]
r = np.asarray(jax_vectorize_map(_double, args, num_parallel=2))
np.testing.assert_allclose(r, np.arange(5.0) * 2.0)


def test_vectorize_map_partial_chunk_clear_buffer():
args = [np.arange(5.0)]
r = np.asarray(jax_vectorize_map(_double, args, num_parallel=2, clear_buffer=True))
np.testing.assert_allclose(r, np.arange(5.0) * 2.0)


def test_vectorize_map_dict_args():
def add(x, y):
return x + y

args = {'x': np.arange(4.0), 'y': np.arange(4.0) * 10}
r = np.asarray(jax_vectorize_map(add, args, num_parallel=3))
np.testing.assert_allclose(r, np.arange(4.0) * 11)


def test_vectorize_map_length_mismatch_raises():
with pytest.raises(ValueError):
jax_vectorize_map(_double, [np.arange(4.0), np.arange(3.0)], num_parallel=2)


def test_vectorize_map_bad_arguments_type_raises():
with pytest.raises(TypeError):
jax_vectorize_map(_double, 42, num_parallel=2)


# --------------------------------------------------------------------------- #
# jax_parallelize_map (pmap)
# --------------------------------------------------------------------------- #

def test_parallelize_map_single_device():
# On a single device num_parallel must be 1; chunks of size 1 each.
args = [np.arange(3.0)]
r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
Comment on lines +86 to +90

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Parallel pmap path lacks coverage for the clear_buffer branch.

jax_parallelize_map always gathers to NumPy and only wraps back into bm.asarray when clear_buffer=False. Current tests only exercise this clear_buffer=False path. Please add a test variant with clear_buffer=True, e.g.:

def test_parallelize_map_single_device_clear_buffer():
    args = [np.arange(3.0)]
    r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
    r = np.asarray(r)
    np.testing.assert_allclose(r, np.arange(3.0) * 2.0)

This will cover the clear-buffer path and verify the returned array type/shape for the pmap case.

Suggested change
def test_parallelize_map_single_device():
# On a single device num_parallel must be 1; chunks of size 1 each.
args = [np.arange(3.0)]
r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
def test_parallelize_map_single_device():
# On a single device num_parallel must be 1; chunks of size 1 each.
args = [np.arange(3.0)]
r = np.asarray(jax_parallelize_map(_double, args, num_parallel=1))
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)
def test_parallelize_map_single_device_clear_buffer():
# clear_buffer=True should return a backend array; convert to NumPy for comparison.
args = [np.arange(3.0)]
r = jax_parallelize_map(_double, args, num_parallel=1, clear_buffer=True)
r = np.asarray(r)
np.testing.assert_allclose(r, np.arange(3.0) * 2.0)



def test_parallelize_map_length_mismatch_raises():
with pytest.raises(ValueError):
jax_parallelize_map(_double, [np.arange(2.0), np.arange(1.0)], num_parallel=1)


# Regression for P15-H3: trailing partial chunk across multiple devices.
_MULTI_DEVICE_SNIPPET = r"""
import numpy as np
import jax
assert jax.local_device_count() == 4, jax.local_device_count()
from brainpy.running.jax_multiprocessing import jax_parallelize_map
# 6 tasks, num_parallel == 4 devices -> chunks of 4 then 2 (partial, subset of devices).
r = jax_parallelize_map(lambda x: x * 2.0, [np.arange(6.0)], num_parallel=4)
r = np.asarray(r)
expected = np.arange(6.0) * 2.0
assert np.allclose(r, expected), (r, expected)
print('OK')
"""


def test_parallelize_map_partial_chunk_multi_device():
env = dict(os.environ)
env['XLA_FLAGS'] = (env.get('XLA_FLAGS', '') + ' --xla_force_host_platform_device_count=4').strip()
env.setdefault('JAX_PLATFORMS', 'cpu')
proc = subprocess.run(
[sys.executable, '-c', _MULTI_DEVICE_SNIPPET],
env=env, capture_output=True, text=True, timeout=300,
)
if proc.returncode != 0:
# Could not spin up 4 host devices in this environment -> skip rather than fail.
if 'AssertionError' in proc.stderr and 'local_device_count' in proc.stderr:
pytest.skip('Could not configure 4 host devices for the pmap test.')
pytest.fail(f'multi-device pmap run failed:\nSTDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}')
assert 'OK' in proc.stdout
4 changes: 2 additions & 2 deletions brainpy/running/native_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def some_func(..., lock, ...):
if isinstance(net_params, (list, tuple)):
results.append(pool.apply_async(func, args=tuple(net_params) + (lock,)))
elif isinstance(net_params, dict):
net_params.update(lock=lock)
results.append(pool.apply_async(func, kwds=net_params))
# Do not mutate the caller-owned dict; submit a copy with the lock added.
results.append(pool.apply_async(func, kwds={**net_params, 'lock': lock}))
else:
raise ValueError('Unknown parameter type: ', type(net_params))
pool.close()
Expand Down
9 changes: 9 additions & 0 deletions brainpy/running/native_multiprocessing_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,12 @@ def test_process_pool_lock_with_dict_params():
def test_process_pool_lock_unknown_param_type_raises():
with pytest.raises(ValueError):
process_pool_lock(_add_lock, [42], num_process=1)


def test_process_pool_lock_does_not_mutate_caller_dict():
# P15-M1: the lock must not be injected into the caller-owned param dicts.
params = [{'x': 1, 'y': 2}]
results = process_pool_lock(_add_lock_kw, params, num_process=1)
assert results == [3]
assert params == [{'x': 1, 'y': 2}] # unchanged
assert 'lock' not in params[0]
21 changes: 17 additions & 4 deletions brainpy/running/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def __init__(
if isinstance(jit, bool):
self.jit = {C.PREDICT_PHASE: jit}
elif isinstance(jit, dict):
# Operate on a shallow copy: never mutate the caller-owned dict. The
# original ``jit`` is also kept as ``self._origin_jit`` and read by
# subclasses (e.g. ``DSTrainer``/``BPTrainer``) which expect the
# explicit ``predict`` setting to still be present.
jit = dict(jit)
for k, v in jit.items():
self.jit[k] = v
self.jit[C.PREDICT_PHASE] = jit.pop(C.PREDICT_PHASE, True)
Expand Down Expand Up @@ -238,32 +243,40 @@ def _find_dict_monitor_targets(self, _monitors):
monitors = {}
name2node = None
for _key, _mon in _monitors.items():
if isinstance(_mon, str):
# A ``(name_str, index)`` value (produced by ``_format_dict_monitors``
# from a string monitor such as ``{'a': 'V'}`` or ``{'a': ('sub.V', 2)}``)
# must be resolved to its target Variable, exactly like the
# sequence-form resolver. ``(Variable, index)`` / callable values are
# already resolved and fall through to the ``else`` branch unchanged.
if isinstance(_mon, (tuple, list)) and isinstance(_mon[0], str):
if name2node is None:
name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())}

# ``_key`` is the user-chosen monitor name (e.g. 'a' for
# ``{'a': 'V'}``); the resolved (Variable, index) must be stored
# under ``_key`` so that ``runner.mon[_key]`` works.
key, index = _mon[0], _mon[1]
splits = key.split('.')
if len(splits) == 1:
if not hasattr(self.target, splits[0]):
raise RunningError(f'{self.target} does not has variable {key}.')
monitors[key] = (getattr(self.target, splits[-1]), index)
monitors[_key] = (getattr(self.target, splits[-1]), index)
else:
if not hasattr(self.target, splits[0]):
if splits[0] not in name2node:
raise MonitorError(f'Cannot find target {key} in monitor of {self.target}, please check.')
else:
master = name2node[splits[0]]
assert len(splits) == 2
monitors[key] = (getattr(master, splits[-1]), index)
monitors[_key] = (getattr(master, splits[-1]), index)
else:
master = self.target
for s in splits[:-1]:
try:
master = getattr(master, s)
except KeyError:
raise MonitorError(f'Cannot find {key} in {master}, please check.')
monitors[key] = (getattr(master, splits[-1]), index)
monitors[_key] = (getattr(master, splits[-1]), index)
else:
monitors[_key] = _mon
return monitors
Expand Down
85 changes: 48 additions & 37 deletions brainpy/running/runner_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,23 @@ def test_jit_invalid_type_raises(target):
Runner(target, monitors=None, jit='not-a-jit', progress_bar=False)


def test_jit_dict_does_not_mutate_caller(target):
# P15-H1: the caller's jit dict must not be mutated (no key popped).
d = {'train': False, C.PREDICT_PHASE: False}
Runner(target, monitors=None, jit=d, progress_bar=False)
assert d == {'train': False, C.PREDICT_PHASE: False}
assert C.PREDICT_PHASE in d


def test_jit_dict_origin_jit_preserves_predict_key(target):
# P15-H1: ``_origin_jit`` keeps the explicit predict setting so that
# subclasses reading ``self._origin_jit.get('predict')`` see the user value.
d = {C.PREDICT_PHASE: False, 'fit': True}
r = Runner(target, monitors=None, jit=d, progress_bar=False)
assert r._origin_jit.get(C.PREDICT_PHASE) is False
assert r.jit[C.PREDICT_PHASE] is False


# --------------------------------------------------------------------------- #
# default / None monitors
# --------------------------------------------------------------------------- #
Expand Down Expand Up @@ -224,32 +241,34 @@ def test_dict_monitor_variable_value(target):
assert var is target.V and idx is None


def test_dict_monitor_str_value_not_resolved():
# NOTE: DEFECT - dict-form *string* monitors are NOT resolved to their
# target Variable. ``_format_dict_monitors`` wraps a string value 'V' into
# the tuple ('V', None); by the time it reaches
# ``_find_dict_monitor_targets`` the value is a tuple, so the
# ``isinstance(_mon, str)`` resolution branch (runner.py lines ~241-266) is
# never taken and the value falls through to the ``else`` branch which
# stores it verbatim. The recorded "variable" is therefore the literal
# string 'V', not ``target.V``. (Sequence-form monitors resolve correctly.)
def test_dict_monitor_str_value_resolved():
# P15-H2 (was a documented DEFECT): dict-form *string* monitors must resolve
# to their target Variable, exactly like sequence-form monitors.
target = _Target()
r = Runner(target, monitors={'a': 'V'}, progress_bar=False, jit=False)
var, idx = r._monitors['a']
assert var == 'V' # the bug: a string, not the Variable
assert var is target.V # now resolved to the Variable
assert idx is None


def test_dict_monitor_str_value_nested_not_resolved():
# NOTE: DEFECT (same root cause as above) - a dotted string value is also
# stored verbatim and never resolved to ``target.sub.V``.
def test_dict_monitor_str_value_nested_resolved():
# P15-H2: a dotted string value resolves to the nested ``target.sub.V``.
target = _Target()
r = Runner(target, monitors={'a': 'sub.V'}, progress_bar=False, jit=False)
var, idx = r._monitors['a']
assert var == 'sub.V'
assert var is target.sub.V
assert idx is None


def test_dict_monitor_str_with_index_resolved():
# P15-H2: ``(name, index)`` dict values resolve the name and keep the index.
target = _Target()
r = Runner(target, monitors={'a': ('V', 2)}, progress_bar=False, jit=False)
var, idx = r._monitors['a']
assert var is target.V
assert np.asarray(bm.as_jax(idx)).tolist() == [2]


def test_dict_monitor_var_index_tuple(target):
r = Runner(target, monitors={'a': (target.spike, 0)}, progress_bar=False, jit=False)
_, idx = r._monitors['a']
Expand All @@ -274,15 +293,12 @@ def test_dict_monitor_callable_value(target):
assert r._monitors['a'] is fn


def test_dict_monitor_str_missing_var_not_validated():
# NOTE: DEFECT (same root cause) - because dict string monitors are never
# resolved, an *invalid* variable name like 'nope' is silently accepted and
# stored verbatim instead of raising RunningError (contrast with the
# sequence-form which validates and raises). See
# ``test_dict_monitor_str_value_not_resolved``.
def test_dict_monitor_str_missing_var_raises():
# P15-H2 (was a documented DEFECT): an invalid variable name in a dict-form
# string monitor must now be validated and raise, like the sequence form.
target = _Target()
r = Runner(target, monitors={'a': 'nope'}, progress_bar=False, jit=False)
assert r._monitors['a'] == ('nope', None)
with pytest.raises(RunningError):
Runner(target, monitors={'a': 'nope'}, progress_bar=False, jit=False)


def test_dict_monitor_nonstr_key_raises(target):
Expand Down Expand Up @@ -349,29 +365,24 @@ def test_find_dict_monitor_targets_type_guard(target):
r._find_dict_monitor_targets(['not', 'a', 'dict'])


def test_find_dict_monitor_targets_bare_string_branch():
# NOTE: DEFECT - the ``isinstance(_mon, str)`` branch in
# ``_find_dict_monitor_targets`` is dead under normal use (see
# ``test_dict_monitor_str_value_not_resolved``) AND broken: it does
# ``key, index = _mon[0], _mon[1]`` which takes the first *two characters*
# of the string rather than treating the whole string as the variable name.
# We call the helper directly to exercise this branch. With the bare string
# 'Vx', key='V' (a real single-char variable) and index='x'.
def test_find_dict_monitor_targets_name_tuple_branch():
# P15-H2: the resolution branch fires for a ``(name_str, index)`` tuple
# (the form produced by ``_format_dict_monitors`` for string monitors) and
# resolves the whole name to the Variable, preserving the key and index.
target = _Target()
r = Runner(target, monitors=None, progress_bar=False, jit=False)
out = r._find_dict_monitor_targets({'k': 'Vx'})
var, index = out['V'] # NOTE: keyed by 'V' (first char), not 'k'
out = r._find_dict_monitor_targets({'k': ('V', None)})
var, index = out['k'] # keyed by 'k', not by the first char of 'V'
assert var is target.V
assert index == 'x' # NOTE: second char taken as the "index"
assert index is None


def test_find_dict_monitor_targets_bare_string_missing_var_raises():
# NOTE: DEFECT - same broken branch: a string whose first character is not
# an attribute of the target raises RunningError.
def test_find_dict_monitor_targets_name_tuple_missing_var_raises():
# P15-H2: a ``(name, index)`` tuple whose name is unknown raises RunningError.
target = _Target()
r = Runner(target, monitors=None, progress_bar=False, jit=False)
with pytest.raises(RunningError):
r._find_dict_monitor_targets({'k': 'zz'})
r._find_dict_monitor_targets({'k': ('zz', None)})


# --------------------------------------------------------------------------- #
Expand Down
Loading
Loading