From beb2874061870c42d25dd902e1dd315631b1f930 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 11:04:27 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20make=20full=20test=20suite=20(CI)=20gree?= =?UTF-8?q?n=20=E2=80=94=20Dense=20fit-flag=20tracer,=20buffer-donation=20?= =?UTF-8?q?pollution,=20L1=20loss;=20Codecov=20token?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Source fix ---------- brainpy/dnn/linear.py: ``Dense.update`` did ``if share.load('fit', False) and self.online_fit_by is not None:``. Inside a grad-/jit-traced fit step (BPFF/BPTT) the ``fit`` share value is a JAX tracer, so converting it to a Python bool raised ``TracerBoolConversionError`` and broke the canonical Dense/RNNCell back-prop training example. Reordered to consult the static ``*_fit_by`` configuration first; the ``and`` then short-circuits before the tracer is forced when online/offline fitting is not configured. This also removes a cross-test pollution: when the fit raised mid-trace it left a stale traced ``fit`` in the global ``share`` store, so the next test running ``Dense.update`` (e.g. ``LoopOverTime`` over a ``Dense``) also raised. Test isolation -------------- brainpy/running/jax_multiprocessing_test.py: ``test_vectorize_map_partial_chunk_clear_buffer`` ran ``jax_vectorize_map(..., clear_buffer=True)``, which invokes the process-global ``bm.clear_buffer_memory()`` and deletes EVERY live device buffer, poisoning later test modules ("deleted/donated buffer" errors). Patched the wipe to a no-op for the duration of the call (same guard already used in boost_misc_test and train_analysis_glue_fixes_test) so the code path stays covered without nuking the shared session. Test contract updates --------------------- brainpy/train/back_propagation_test.py: rewrote the pinned-defect test into a regression test that asserts ``Dense`` trains under BPFF (finite losses, weight moves) and that a subsequent plain forward pass does not raise (pollution guard). brainpy/losses/comparison_coverage_test.py: ``l1_loss`` delegates to ``braintools.metric.l1_loss`` (>=0.3.0, the required/CI dependency), which reduces each sample to its mean absolute error then applies the batch reduction. Updated the L1 expectations (none -> [1.5, 3.5], sum -> 5.0, mean -> 2.5) and comments to the 0.3.0 contract; the previous numbers pinned stale braintools 0.1.10 behaviour. CI -- .github/workflows/CI.yml: Codecov upload now passes ``token: ${{ secrets.CODECOV_TOKEN }}`` and ``slug: brainpy/BrainPy``. --- .github/workflows/CI.yml | 4 +- brainpy/dnn/linear.py | 16 +++++-- brainpy/losses/comparison_coverage_test.py | 22 +++++---- brainpy/running/jax_multiprocessing_test.py | 13 +++++- brainpy/train/back_propagation_test.py | 50 ++++++++++++++------- 5 files changed, 71 insertions(+), 34 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0dcbe5f72..17ac43845 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -54,9 +54,11 @@ jobs: MPLBACKEND: Agg # Use non-interactive backend for matplotlib run: | pytest --cov=brainpy --cov-report=xml brainpy/ - - name: Upload coverage to Codecov + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: brainpy/BrainPy files: ./coverage.xml fail_ci_if_error: false diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index 4411cff55..d8af1cbe3 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -125,13 +125,21 @@ def update(self, x): if self.b is not None: res += self.b - # online fitting data - if share.load('fit', False) and self.online_fit_by is not None: + # Online/offline fitting data recording. + # + # The (static, Python-level) ``*_fit_by`` configuration is checked *first* + # so that the ``fit`` share value is only consulted when online/offline + # fitting is actually enabled. Inside a grad-/jit-traced fit step (e.g. + # ``BPFF.fit`` / ``BPTT.fit``) the ``fit`` flag is a JAX tracer; converting + # it to a Python bool would raise ``TracerBoolConversionError``. Because a + # plain ``Dense`` used for back-prop training leaves both ``*_fit_by`` as + # ``None``, the ``and`` short-circuits on the static check and never forces + # the tracer, letting the canonical RNNCell/Dense BPTT example train. + if self.online_fit_by is not None and share.load('fit', False): self.fit_record['input'] = x self.fit_record['output'] = res - # offline fitting data - if share.load('fit', False) and self.offline_fit_by is not None: + if self.offline_fit_by is not None and share.load('fit', False): self.fit_record['input'] = x self.fit_record['output'] = res return res diff --git a/brainpy/losses/comparison_coverage_test.py b/brainpy/losses/comparison_coverage_test.py index c4b73a038..ba28399e8 100644 --- a/brainpy/losses/comparison_coverage_test.py +++ b/brainpy/losses/comparison_coverage_test.py @@ -221,26 +221,24 @@ def test_class_wrapper(self): # --------------------------------------------------------------------------- class TestRegressionLosses: def test_l1_loss_reductions(self): - # P1-L1: l1_loss delegates to braintools.metric.l1_loss, which for - # reduction='none' returns the per-row L1 *norm* (sum of abs over the - # trailing axes, reshaped to (N, -1)), NOT the per-row mean. So for the - # (2, 2) input below the 'none' output is the per-row sums [3, 7]; 'sum' - # then totals them (10) and 'mean' averages them (5). (The previous - # expectations of [1.5, 3.5]/5/2.5 encoded an incorrect per-row-mean - # assumption about braintools and were pre-existing baseline failures.) + # ``l1_loss`` delegates to ``braintools.metric.l1_loss`` (>=0.3.0), which + # reduces each sample to its *mean* absolute error over the trailing axes + # (shape (N,)) and then applies the batch reduction. For the (2, 2) input + # below the per-sample means are [mean(1,2), mean(3,4)] = [1.5, 3.5]; so + # 'none' -> [1.5, 3.5], 'sum' -> 5.0, 'mean' -> 2.5. x = jnp.array([[1., 2.], [3., 4.]]) y = jnp.zeros((2, 2)) none = np.asarray(C.l1_loss(x, y, reduction='none')) - assert np.allclose(none, [3.0, 7.0]) # per-row L1 norm (sum of abs) - assert float(C.l1_loss(x, y, reduction='sum')) == pytest.approx(10.0) - assert float(C.l1_loss(x, y, reduction='mean')) == pytest.approx(5.0) + assert np.allclose(none, [1.5, 3.5]) # per-sample mean abs error + assert float(C.l1_loss(x, y, reduction='sum')) == pytest.approx(5.0) + assert float(C.l1_loss(x, y, reduction='mean')) == pytest.approx(2.5) def test_l1_class(self): x = jnp.array([[1., 2.], [3., 4.]]) y = jnp.zeros((2, 2)) layer = C.L1Loss(reduction='sum') - # sum over per-row L1 norms [3, 7] = 10.0 - assert float(layer.update(x, y)) == pytest.approx(10.0) + # sum over per-sample mean abs errors [1.5, 3.5] = 5.0 + assert float(layer.update(x, y)) == pytest.approx(5.0) def test_l2_loss_elementwise(self): out = np.asarray(C.l2_loss(jnp.array([2.0, 0.0]), jnp.array([0.0, 0.0]))) diff --git a/brainpy/running/jax_multiprocessing_test.py b/brainpy/running/jax_multiprocessing_test.py index c2853394a..d7cbdcd13 100644 --- a/brainpy/running/jax_multiprocessing_test.py +++ b/brainpy/running/jax_multiprocessing_test.py @@ -56,7 +56,18 @@ def test_vectorize_map_partial_chunk(): def test_vectorize_map_partial_chunk_clear_buffer(): args = [np.arange(5.0)] - r = np.asarray(jax_vectorize_map(_double, args, num_parallel=2, clear_buffer=True)) + # NOTE: clear_buffer=True calls the process-global ``bm.clear_buffer_memory()``, + # which deletes EVERY live device array -- including module-level constants and + # persistent Variables in *other* test modules -- poisoning the rest of the + # shared pytest session (later tests then hit "deleted/donated buffer" errors). + # Patch it to a no-op so the clear_buffer code path is still exercised for + # coverage without nuking the session. + _orig_clear = bm.clear_buffer_memory + bm.clear_buffer_memory = lambda *a, **k: None + try: + r = np.asarray(jax_vectorize_map(_double, args, num_parallel=2, clear_buffer=True)) + finally: + bm.clear_buffer_memory = _orig_clear np.testing.assert_allclose(r, np.arange(5.0) * 2.0) diff --git a/brainpy/train/back_propagation_test.py b/brainpy/train/back_propagation_test.py index 7ae5c6b0e..be0bd5e20 100644 --- a/brainpy/train/back_propagation_test.py +++ b/brainpy/train/back_propagation_test.py @@ -487,22 +487,28 @@ def test_bptrainer_abstract_step_funcs_raise(): # --------------------------------------------------------------------------- -# Pinned defect (NOT in fix scope -- documents current behavior) +# Regression: ``Dense`` trains cleanly under a BPFF/BPTT fit loop # --------------------------------------------------------------------------- -def test_dense_layer_fit_flag_is_traced_defect(): - """PIN: ``bp.dnn.Dense`` under a BPFF/BPTT fit loop raises on the ``fit`` flag. - - ``brainpy/dnn/linear.py:129`` does - ``if share.load('fit', False) and self.online_fit_by is not None:``. - Under the installed ``brainstate`` (0.5.x), inside the jitted / grad-traced - fit step the ``fit`` flag is a JAX *tracer*, so the boolean ``and`` raises - ``jax.errors.TracerBoolConversionError``. This blocks the canonical - ``RNNCell``/``Dense`` BPTT example. It is an API-drift defect in the layer, - not in ``back_propagation.py``; pinned here so the regression is visible. +def test_dense_layer_fit_flag_under_grad_trace(): + """``bp.dnn.Dense`` must train under a BPFF fit loop without a tracer error. + + Inside the grad-/jit-traced fit step the ``fit`` flag is a JAX *tracer*. + ``Dense.update`` previously did + ``if share.load('fit', False) and self.online_fit_by is not None:`` which + converted that tracer to a Python bool and raised + ``jax.errors.TracerBoolConversionError``, blocking the canonical + ``RNNCell``/``Dense`` BPTT example. ``Dense.update`` now checks the static + ``*_fit_by`` configuration first so the tracer is only consulted when online + /offline fitting is enabled. This regression test trains a plain ``Dense`` + for a couple of epochs and asserts the fit completes with finite losses and + an updated weight. + + A *pollution guard* is included: a stale traced ``fit`` left in the global + ``share`` store by a previous (broken) fit used to make the next forward + pass through a ``Dense`` raise. Running a plain forward pass after the fit + confirms no traced state leaked. """ - import jax - class DenseFF(bp.DynamicalSystem): def __init__(self): super().__init__() @@ -516,11 +522,23 @@ def reset_state(self, batch_size=1, **kwargs): with bm.training_environment(): model = DenseFF() + w_before = np.asarray(model.lin.W).copy() trainer = bp.BPFF(model, loss_fun=_mse, optimizer=bp.optim.Adam(lr=0.01), progress_bar=False) - with pytest.raises(jax.errors.TracerBoolConversionError): - trainer.fit([(bm.random.random((4, 3)), bm.random.random((4, 2)))], - num_epoch=1) + trainer.fit([(bm.random.random((4, 3)), bm.random.random((4, 2)))], + num_epoch=2) + + # the fit ran: losses are recorded and finite, and the weight moved. + losses = trainer.get_hist_metric(phase='fit', metric='loss') + assert len(losses) > 0 + assert all(np.isfinite(float(v)) for v in losses) + assert not np.allclose(np.asarray(model.lin.W), w_before) + + # pollution guard: a plain forward pass through a Dense must not raise from + # a stale traced ``fit`` value left in the global ``share`` store. + out = model(bm.random.random((4, 3))) + assert tuple(out.shape) == (4, 2) + assert bool(np.all(np.isfinite(np.asarray(out)))) if __name__ == '__main__':