Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions brainpy/dyn/rates/rnncells.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def __init__(
self.state[:] = self.state2train

def reset_state(self, batch_or_mode=None, **kwargs):
# Accept ``batch_size`` as an alias for ``batch_or_mode`` so the canonical
# ``model.reset(batch_size=...)`` convention (used by ``brainpy.dyn``
# neurons and the training tutorials) works on recurrent cells too.
if batch_or_mode is None:
batch_or_mode = kwargs.get('batch_size', None)
self.state.value = variable(self._state_initializer, batch_or_mode, self.num_out)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
Expand Down Expand Up @@ -236,6 +241,11 @@ def __init__(
self.state[:] = self.state2train

def reset_state(self, batch_or_mode=None, **kwargs):
# Accept ``batch_size`` as an alias for ``batch_or_mode`` so the canonical
# ``model.reset(batch_size=...)`` convention (used by ``brainpy.dyn``
# neurons and the training tutorials) works on recurrent cells too.
if batch_or_mode is None:
batch_or_mode = kwargs.get('batch_size', None)
self.state.value = variable(self._state_initializer, batch_or_mode, self.num_out)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
Expand Down Expand Up @@ -371,6 +381,8 @@ def __init__(
self.state[:] = self.state2train

def reset_state(self, batch_or_mode=None, **kwargs):
if batch_or_mode is None:
batch_or_mode = kwargs.get('batch_size', None)
self.state.value = variable(self._state_initializer, batch_or_mode, self.num_out * 2)
if self.train_state:
self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False)
Expand Down Expand Up @@ -522,6 +534,8 @@ def __init__(
self.reset_state()

def reset_state(self, batch_or_mode: int = 1, **kwargs):
if 'batch_size' in kwargs and kwargs['batch_size'] is not None:
batch_or_mode = kwargs['batch_size']
if self.mode.is_a(bm.NonBatchingMode):
shape = self.input_shape + (self.out_channels,)
self.h = variable_(self._state_initializer, shape)
Expand Down
22 changes: 22 additions & 0 deletions brainpy/dyn/rates/rnncells_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,28 @@ def test_rnn_cells_reset_state_int_batch(self, cls):
expected = (2, 8) if cls == 'LSTMCell' else (2, 4)
self.assertTupleEqual(tuple(cell.state.shape), expected)

@parameterized.product(cls=['RNNCell', 'GRUCell', 'LSTMCell'])
def test_rnn_cells_reset_batch_size_kwarg(self, cls):
# Regression (build_training_models tutorial): ``reset(batch_size=...)``
# — the canonical convention shared with ``brainpy.dyn`` neurons — must
# add a leading batch axis on recurrent cells. Previously the cells only
# understood ``batch_or_mode`` and silently dropped ``batch_size``,
# resetting the state to an unbatched shape and raising a MathError.
bm.random.seed()
cell = getattr(bp.dyn, cls)(num_in=3, num_out=4, mode=bm.batching_mode)
cell.reset(batch_size=5)
expected = (5, 8) if cls == 'LSTMCell' else (5, 4)
self.assertTupleEqual(tuple(cell.state.shape), expected)

@parameterized.product(cls=['RNNCell', 'GRUCell', 'LSTMCell'])
def test_rnn_cells_reset_batch_or_mode_kwarg_still_works(self, cls):
# Back-compat: the original ``batch_or_mode`` keyword must keep working.
bm.random.seed()
cell = getattr(bp.dyn, cls)(num_in=3, num_out=4, mode=bm.batching_mode)
cell.reset_state(batch_or_mode=3)
expected = (3, 8) if cls == 'LSTMCell' else (3, 4)
self.assertTupleEqual(tuple(cell.state.shape), expected)


if __name__ == '__main__':
absltest.main()
38 changes: 34 additions & 4 deletions brainpy/inputs/currents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import brainstate
import braintools
import brainunit as u

import brainpy.math

Expand All @@ -35,6 +36,31 @@
]


def _as_hz(frequency):
"""Attach ``Hz`` units to a bare numeric frequency.

``braintools.input.sinusoidal`` / ``braintools.input.square`` require the
frequency to carry frequency (``Hz``) units; a plain number is treated as
dimensionless and rejected. A value that is already a
:class:`brainunit.Quantity` is returned unchanged.
"""
return frequency if isinstance(frequency, u.Quantity) else frequency * u.Hz


def _as_ms(value):
"""Attach ``ms`` units to a bare numeric time value.

Once ``dt`` carries time units the waveform helpers convert the oscillation
frequency against that time unit, so the remaining time arguments
(``dt``/``duration``/``t_start``/``t_end``) must be unit-carrying as well.
``None`` passes through and existing :class:`brainunit.Quantity` values are
left untouched.
"""
if value is None:
return None
return value if isinstance(value, u.Quantity) else value * u.ms


def section_input(values, durations, dt=None, return_length=False):
"""Format an input current with different sections.

Expand Down Expand Up @@ -279,8 +305,10 @@ def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end=
Whether the sinusoid oscillates around 0 (False), or
has a positive DC bias, thus non-negative (True).
"""
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.sinusoidal(amplitude, frequency, duration, t_start=t_start, t_end=t_end, bias=bias)
dt = brainpy.math.get_dt() if dt is None else dt
with brainstate.environ.context(dt=_as_ms(dt)):
return braintools.input.sinusoidal(amplitude, _as_hz(frequency), _as_ms(duration),
t_start=_as_ms(t_start), t_end=_as_ms(t_end), bias=bias)


def square_input(amplitude, frequency, duration, dt=None, bias=False, t_start=0., t_end=None):
Expand All @@ -305,6 +333,8 @@ def square_input(amplitude, frequency, duration, dt=None, bias=False, t_start=0.
Whether the sinusoid oscillates around 0 (False), or
has a positive DC bias, thus non-negative (True).
"""
with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt):
return braintools.input.square(amplitude, frequency, duration, t_start=t_start, t_end=t_end, duty_cycle=0.5,
dt = brainpy.math.get_dt() if dt is None else dt
with brainstate.environ.context(dt=_as_ms(dt)):
return braintools.input.square(amplitude, _as_hz(frequency), _as_ms(duration),
t_start=_as_ms(t_start), t_end=_as_ms(t_end), duty_cycle=0.5,
bias=bias)
48 changes: 37 additions & 11 deletions brainpy/inputs/currents_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# ==============================================================================
from unittest import TestCase

import brainunit as u
import numpy as np

import brainpy as bp
Expand Down Expand Up @@ -81,17 +82,42 @@ def test_ou_process(self):
current7 = bp.inputs.ou_process(mean=1., sigma=0.1, tau=10., duration=duration, n=2, t_start=10., t_end=180.)
show(current7, duration, 'Ornstein-Uhlenbeck Process')

# def test_sinusoidal_input(self):
# duration = 2000 * u.ms
# current8 = bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0 * u.Hz,
# duration=duration, t_start=100. * u.ms, dt=0.1 * u.ms)
# show(current8, duration, 'Sinusoidal Input')
#
# def test_square_input(self):
# duration = 2000 * u.ms
# current9 = bp.inputs.square_input(amplitude=1., frequency=2.0 * u.Hz,
# duration=duration, t_start=100 * u.ms, dt=0.1 * u.ms)
# show(current9, duration, 'Square Input')
def test_sinusoidal_input_bare_frequency(self):
# Regression: a bare numeric ``frequency`` (in Hz) must be accepted.
# ``braintools`` started requiring frequency/time arguments to carry
# units; the wrapper now attaches ``Hz``/``ms`` so the documented plain
# ``frequency=2.0`` call keeps working instead of raising
# ``AssertionError: Frequency must be in Hz``.
duration = 2000
current8 = bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0,
duration=duration, t_start=100., dt=0.1)
current8 = np.asarray(current8)
self.assertEqual(current8.shape[0], int(duration / 0.1))
# amplitude 1 -> values bounded in [-1, 1]; current is zero before t_start
self.assertLessEqual(float(np.max(np.abs(current8))), 1.0 + 1e-5)
self.assertTrue(np.allclose(current8[:int(100 / 0.1)], 0.))
show(current8, duration, 'Sinusoidal Input')

def test_square_input_bare_frequency(self):
# Regression: same contract for ``square_input``.
duration = 2000
current9 = bp.inputs.square_input(amplitude=1., frequency=2.0,
duration=duration, t_start=100., dt=0.1)
current9 = np.asarray(current9)
self.assertEqual(current9.shape[0], int(duration / 0.1))
self.assertLessEqual(float(np.max(np.abs(current9))), 1.0 + 1e-5)
show(current9, duration, 'Square Input')

def test_sinusoidal_input_quantity_frequency(self):
# The unit-carrying form (``frequency=2 * u.Hz``, ``duration=... * u.ms``)
# must produce the same waveform as the bare-number form.
bare = np.asarray(bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0,
duration=2000, t_start=100., dt=0.1))
quant = np.asarray(bp.inputs.sinusoidal_input(amplitude=1., frequency=2.0 * u.Hz,
duration=2000 * u.ms, t_start=100. * u.ms,
dt=0.1 * u.ms))
self.assertEqual(bare.shape, quant.shape)
self.assertTrue(np.allclose(bare, quant))

def test_general1(self):
I1 = bp.inputs.section_input(values=[0, 1, 2], durations=[10, 20, 30], dt=0.1)
Expand Down
Loading
Loading