Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions brainpy/algorithms/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,11 @@ def call(self, targets, inputs, outputs=None) -> ArrayType:
raise ValueError(f'Target must be a scalar, but got multiple variables: {targets.shape}. ')
targets = targets.flatten()

# initialize parameters
param = self.init_weights(inputs.shape[1], targets.shape[1])
# Initialize a 1-D parameter vector to match the flattened 1-D
# ``targets`` used below. (``init_weights`` returns a 2-D ``(n_features,
# n_out)`` array; reading ``targets.shape[1]`` after the flatten raised
# IndexError, so request a single output and squeeze to 1-D.)
param = self.init_weights(inputs.shape[1], 1).flatten()

def cond_fun(a):
i, par_old, par_new = a
Expand Down Expand Up @@ -538,8 +541,9 @@ def call(self, targets, inputs, outputs=None):
# checking
inputs = _check_data_2d_atls(bm.as_jax(inputs))
targets = _check_data_2d_atls(bm.as_jax(targets))
# solving
inputs = normalize(polynomial_features(inputs, degree=self.degree))
# solving. ``add_bias`` must match ``predict`` so the fitted weight
# length equals the feature width seen at prediction time.
inputs = normalize(polynomial_features(inputs, degree=self.degree, add_bias=self.add_bias))
return super(ElasticNetRegression, self).gradient_descent_solve(targets, inputs)

def predict(self, W, X):
Expand Down
52 changes: 34 additions & 18 deletions brainpy/algorithms/offline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,24 @@ def test_fit_and_predict(self):


class TestLogisticRegression:
def test_call_is_currently_broken(self):
# NOTE (defect): LogisticRegression.call flattens ``targets`` to 1-D
# (offline.py line ~386) and then immediately reads ``targets.shape[1]``
# at line ~389, which raises IndexError. The closed-form (non gradient
# descent) branch is also unreachable because of this. The fit path is
# therefore broken for both gradient_descent=True and =False.
def test_call_runs_and_separates(self):
# P16-C2 (was test_call_is_currently_broken): ``LogisticRegression.call``
# used to crash with IndexError because it flattened ``targets`` to 1-D
# and then read ``targets.shape[1]``. After the fix it must run and learn
# a usable separator on a trivially separable problem.
rng = np.random.RandomState(1)
x = rng.uniform(-1, 1, size=(30, 2)).astype(np.float32)
y = (x[:, :1] > 0).astype(np.float32)
algo = offline.LogisticRegression(max_iter=50, learning_rate=0.1)
with pytest.raises(IndexError):
algo(bm.asarray(y), bm.asarray(x))
algo = offline.LogisticRegression(max_iter=200, learning_rate=0.5)
w = algo(bm.asarray(y), bm.asarray(x))
w = np.asarray(bm.as_jax(w))
# one weight per input feature (1-D parameter vector)
assert w.reshape(-1).shape == (2,)
assert np.all(np.isfinite(w))
# predictions should mostly agree with the (separable) labels
pred = np.asarray(bm.as_jax(algo.predict(bm.asarray(w), bm.asarray(x))))
acc = np.mean((pred.reshape(-1) > 0.5) == y.reshape(-1))
assert acc >= 0.8

def test_predict_applies_sigmoid(self):
# ``predict`` itself works in isolation (it does not hit the broken call).
Expand Down Expand Up @@ -209,19 +215,29 @@ def test_elastic_net_regression_fit(self):
w = np.asarray(bm.as_jax(algo(y, x)))
assert np.all(np.isfinite(w))

def test_elastic_net_predict_bias_mismatch_is_broken(self):
# NOTE (defect): ElasticNetRegression.call builds features with
# ``polynomial_features(inputs, degree=self.degree)`` which defaults to
# add_bias=True, while ``predict`` calls it with add_bias=self.add_bias
# (default False). The resulting feature width differs from the fitted
# weight length, so predicting on freshly-built features raises a
# shape-mismatch TypeError from jnp.dot.
def test_elastic_net_train_predict_consistent(self):
# P16-H1 (was test_elastic_net_predict_bias_mismatch_is_broken):
# ``call`` used to build features with the default add_bias=True while
# ``predict`` used add_bias=self.add_bias (default False), giving a
# train/predict feature-count mismatch that crashed jnp.dot. After the
# fix, training and prediction must use identical feature construction.
x, y = _xy(slope=2.0)
algo = offline.ElasticNetRegression(alpha=0.01, degree=2, l1_ratio=0.5,
max_iter=50, learning_rate=0.001)
w = bm.asarray(np.asarray(bm.as_jax(algo(y, x))))
with pytest.raises(TypeError):
algo.predict(w, x)
pred = np.asarray(bm.as_jax(algo.predict(w, x)))
assert pred.shape[0] == np.asarray(bm.as_jax(x)).shape[0]
assert np.all(np.isfinite(pred))

def test_elastic_net_add_bias_true_consistent(self):
# The fix must also hold when add_bias=True is requested explicitly.
x, y = _xy(slope=2.0)
algo = offline.ElasticNetRegression(alpha=0.01, degree=2, l1_ratio=0.5,
add_bias=True, max_iter=50,
learning_rate=0.001)
w = bm.asarray(np.asarray(bm.as_jax(algo(y, x))))
pred = np.asarray(bm.as_jax(algo.predict(w, x)))
assert np.all(np.isfinite(pred))


class TestRegistry:
Expand Down
6 changes: 5 additions & 1 deletion brainpy/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ def polynomial_features(X, degree: int, add_bias: bool = True):
return bm.insert(X, 0, 1, axis=1) if add_bias else X
if add_bias:
n_features += 1
X_new = bm.zeros((n_samples, 1 + n_features + len(combinations)))
# ``n_features`` already accounts for the bias slot (when ``add_bias``); the
# design matrix is exactly the (bias +) original features plus the degree>=2
# interaction terms. The previous extra leading ``1 +`` left a dead all-zero
# trailing column and over-counted the feature dimension by one.
X_new = bm.zeros((n_samples, n_features + len(combinations)))
if add_bias:
X_new[:, 0] = 1
X_new[:, 1:n_features] = X
Expand Down
18 changes: 9 additions & 9 deletions brainpy/algorithms/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,25 +97,25 @@ def test_degree1_shortcircuit_without_bias(self):
def test_degree2_with_bias(self):
X = bm.asarray([[2.0, 3.0]])
out = np.asarray(bm.as_jax(utils.polynomial_features(X, degree=2, add_bias=True)))
# X_new has shape (n_samples, 1 + n_features + len(combinations)).
# With add_bias, n_features is bumped to 3 (2 + 1), combos == 3, so the
# allocated width is 1 + 3 + 3 == 7. NOTE: the leading "1 +" plus the
# bumped n_features leaves one all-zero trailing column unused, i.e. the
# output is wider than the mathematically expected 6 columns.
assert out.shape == (1, 7)
# P16-M2: width is exactly 1 bias + 2 linear + 3 interaction == 6
# (previously a dead all-zero trailing column made it 7).
assert out.shape == (1, 6)
assert out[0, 0] == 1.0
# the linear features should appear
assert 2.0 in out[0] and 3.0 in out[0]
# interactions: 4 (=2^2), 6 (=2*3), 9 (=3^2)
for v in (4.0, 6.0, 9.0):
assert np.any(np.isclose(out[0], v))
# no dead all-zero column anymore
assert not np.any(np.all(out == 0, axis=0))

def test_degree2_without_bias(self):
X = bm.asarray([[2.0, 3.0]])
out = np.asarray(bm.as_jax(utils.polynomial_features(X, degree=2, add_bias=False)))
# 1 + 2 linear + 3 interaction -> 6 cols (again one extra leading column;
# see NOTE in test_degree2_with_bias).
assert out.shape == (1, 6)
# P16-M2: 2 linear + 3 interaction -> 5 cols (previously 6 with a dead
# leading allocation slot).
assert out.shape == (1, 5)
assert not np.any(np.all(out == 0, axis=0))


class TestNormalize:
Expand Down
9 changes: 7 additions & 2 deletions brainpy/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,15 +689,20 @@ def coo2csr(coo, num_pre):
post_ids = post_ids[sort_ids]
indices = post_ids
unique_pre_ids, pre_count = onp.unique(pre_ids, return_counts=True)
final_pre_count = onp.zeros(num_pre, dtype=jnp.uint32)
# Use the connection index dtype (not uint32) so the assignment below does
# not trigger an int->uint scatter dtype mismatch.
final_pre_count = onp.zeros(num_pre, dtype=get_idx_type())
final_pre_count[unique_pre_ids] = pre_count
else:
sort_ids = onp.argsort(bm.as_jax(pre_ids))
post_ids = bm.as_jax(post_ids)
post_ids = post_ids[sort_ids]
indices = post_ids
unique_pre_ids, pre_count = jnp.unique(pre_ids, return_counts=True)
final_pre_count = bm.zeros(num_pre, dtype=jnp.uint32)
# Use the connection index dtype (not uint32) so the in-place update below
# does not trigger an int->uint scatter dtype mismatch (FutureWarning that
# becomes an error in future JAX releases).
final_pre_count = bm.zeros(num_pre, dtype=get_idx_type())
final_pre_count[unique_pre_ids] = pre_count
final_pre_count = bm.as_jax(final_pre_count)
indptr = final_pre_count.cumsum()
Expand Down
2 changes: 1 addition & 1 deletion brainpy/connect/custom_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, indices, inptr, **kwargs):
self.max_post = self.indices.max()

def build_csr(self):
if self.pre_num != self.pre_num:
if self.pre_num != self.inptr.size - 1:
raise ConnectorError(f'(pre_size, post_size) is inconsistent with '
f'the shape of the sparse matrix.')
if self.post_num <= self.max_post:
Expand Down
39 changes: 39 additions & 0 deletions brainpy/connect/custom_conn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,45 @@ def test_MatConn2(self):
conn(pre_size=5, post_size=1)


class TestCSRConn(TestCase):
def _csr(self):
# 3 pre-synaptic neurons (indptr has 4 entries), max post id == 2
indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
indptr = np.array([0, 2, 3, 5], dtype=np.int32)
return indices, indptr

def test_csrconn_consistent_ok(self):
# P16-H2: a CSRConn whose declared pre_size matches the indptr length
# must build without error.
indices, indptr = self._csr()
conn = bp.conn.CSRConn(indices, indptr)
ind, ip = conn.require(3, 3, 'csr')
assert np.array_equal(np.asarray(ind), indices)
assert np.array_equal(np.asarray(ip), indptr)

def test_csrconn_inconsistent_pre_num_raises(self):
# P16-H2: previously the guard ``self.pre_num != self.pre_num`` was a
# tautology (always False), so an inconsistent pre_size silently produced
# a malformed CSR. It must now raise.
indices, indptr = self._csr() # indptr implies 3 pre
conn = bp.conn.CSRConn(indices, indptr)
with pytest.raises(bp.errors.ConnectorError):
conn.require(5, 3, 'csr') # pre=5 inconsistent with indptr (3)

def test_coo2csr_no_dtype_warning(self):
# P16-M1: coo2csr must not emit the int32->uint32 scatter FutureWarning
# (which is slated to become an error in future JAX).
import warnings
import jax.numpy as jnp
from brainpy.connect.base import coo2csr
pre = jnp.array([0, 0, 1, 2, 2, 2])
post = jnp.array([1, 2, 0, 0, 1, 2])
with warnings.catch_warnings():
warnings.simplefilter('error', FutureWarning)
ind, indptr = coo2csr((pre, post), 3)
assert np.array_equal(np.asarray(indptr), np.array([0, 2, 3, 6]))


class TestSparseMatConn(TestCase):
def test_sparseMatConn(self):
conn_mat = np.random.randint(2, size=(5, 3), dtype=bp.math.bool_)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/inputs/currents.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def spike_current(*args, **kwargs):
warnings.warn('Please use "brainpy.inputs.spike_input()" instead. '
'"brainpy.inputs.spike_current()" is deprecated since version 2.1.13.',
DeprecationWarning)
return constant_input(*args, **kwargs)
return spike_input(*args, **kwargs)


def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None):
Expand Down Expand Up @@ -189,7 +189,7 @@ def ramp_current(*args, **kwargs):
warnings.warn('Please use "brainpy.inputs.ramp_input()" instead. '
'"brainpy.inputs.ramp_current()" is deprecated since version 2.1.13.',
DeprecationWarning)
return constant_input(*args, **kwargs)
return ramp_input(*args, **kwargs)


def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None):
Expand Down
20 changes: 11 additions & 9 deletions brainpy/inputs/currents_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,27 @@ def test_constant_current_warns_and_delegates(self):
self.assertEqual(duration, 200)

def test_spike_current_warns(self):
# NOTE: ``spike_current`` is documented as a spike-input shim but its
# body actually delegates to ``constant_input`` (copy/paste defect in
# the deprecation wrapper), so it must be fed constant_input-style
# ``[(value, duration), ...]`` pairs rather than spike arguments.
# P16-C1: ``spike_current`` now correctly delegates to ``spike_input``
# (it previously delegated to ``constant_input`` and crashed on spike
# arguments). It must warn AND accept spike-style arguments.
kwargs = dict(sp_times=[10, 20, 30], sp_lens=1., sp_sizes=0.5, duration=40.)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
out = bp.inputs.spike_current([(0, 50), (1, 50)])
out = bp.inputs.spike_current(**kwargs)
self.assertTrue(any(issubclass(x.category, DeprecationWarning) for x in w))
self.assertIsNotNone(out)
self.assertTrue(np.array_equal(np.asarray(out), np.asarray(bp.inputs.spike_input(**kwargs))))

def test_ramp_current_warns(self):
# NOTE: like ``spike_current``, ``ramp_current`` also delegates to
# ``constant_input`` rather than ``ramp_input`` (same wrapper defect),
# so it expects ``[(value, duration), ...]`` pairs.
# P16-C1: ``ramp_current`` now correctly delegates to ``ramp_input``
# (it previously delegated to ``constant_input``). It must warn AND
# accept ramp-style arguments.
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
out = bp.inputs.ramp_current([(0, 50), (1, 50)])
out = bp.inputs.ramp_current(0., 1., 100.)
self.assertTrue(any(issubclass(x.category, DeprecationWarning) for x in w))
self.assertIsNotNone(out)
self.assertTrue(np.array_equal(np.asarray(out), np.asarray(bp.inputs.ramp_input(0., 1., 100.))))


@unittest.skipUnless(HAS_BRAINUNIT, 'brainunit required for unit-aware inputs')
Expand Down
21 changes: 21 additions & 0 deletions brainpy/inputs/currents_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,24 @@ def test_general2(self):
bp.math.random.random((3, 10))],
durations=[100, 300, 100])
self.assertTrue(current.shape == (5000, 3, 10))

def test_spike_current_alias(self):
# P16-C1: the deprecated ``spike_current`` alias must forward to
# ``spike_input`` (it used to forward to ``constant_input`` and crash).
import warnings
kwargs = dict(sp_times=[10, 20, 30], sp_lens=1., sp_sizes=0.5, duration=40.)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
aliased = bp.inputs.spike_current(**kwargs)
direct = bp.inputs.spike_input(**kwargs)
self.assertTrue(np.array_equal(np.asarray(aliased), np.asarray(direct)))

def test_ramp_current_alias(self):
# P16-C1: the deprecated ``ramp_current`` alias must forward to
# ``ramp_input`` (it used to forward to ``constant_input`` and crash).
import warnings
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
aliased = bp.inputs.ramp_current(0., 1., 100.)
direct = bp.inputs.ramp_input(0., 1., 100.)
self.assertTrue(np.array_equal(np.asarray(aliased), np.asarray(direct)))
Loading
Loading