Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 51 additions & 40 deletions brainpy/algorithms/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,42 +60,46 @@ def __init__(self, name=None):
def __call__(self, targets, inputs, outputs=None):
"""The training procedure.

Parameters::
Parameters
----------

targets: ArrayType
targets : ArrayType
The 2d target data with the shape of `(num_batch, num_output)`.
inputs: ArrayType
inputs : ArrayType
The 2d input data with the shape of `(num_batch, num_input)`.
outputs: ArrayType
outputs : ArrayType
The 2d output data with the shape of `(num_batch, num_output)`.

Returns::
Returns
-------

weight: ArrayType
weight : ArrayType
The weights after fit.
"""
return self.call(targets, inputs, outputs)

def call(self, targets, inputs, outputs=None) -> ArrayType:
"""The training procedure.

Parameters::
Parameters
----------

inputs: ArrayType
inputs : ArrayType
The 3d input data with the shape of `(num_batch, num_time, num_input)`,
or, the 2d input data with the shape of `(num_time, num_input)`.

targets: ArrayType
targets : ArrayType
The 3d target data with the shape of `(num_batch, num_time, num_output)`,
or the 2d target data with the shape of `(num_time, num_output)`.

outputs: ArrayType
outputs : ArrayType
The 3d output data with the shape of `(num_batch, num_time, num_output)`,
or the 2d output data with the shape of `(num_time, num_output)`.

Returns::
Returns
-------

weight: ArrayType
weight : ArrayType
The weights after fit.
"""
raise NotImplementedError('Must implement the __call__ function by the subclass itself.')
Expand All @@ -117,11 +121,12 @@ class RegressionAlgorithm(OfflineAlgorithm):
""" Base regression model. Models the relationship between a scalar dependent variable y and the independent
variables X.

Parameters::
Parameters
----------

max_iter: int
max_iter : int
The number of training iterations the algorithm will tune the weights for.
learning_rate: float
learning_rate : float
The step length that will be used when updating the weights.
"""

Expand Down Expand Up @@ -178,9 +183,10 @@ def predict(self, W, X):
class LinearRegression(RegressionAlgorithm):
"""Training algorithm of least-square regression.

Parameters::
Parameters
----------

name: str
name : str
The name of the algorithm.
"""

Expand Down Expand Up @@ -221,20 +227,21 @@ def call(self, targets, inputs, outputs=None):
class RidgeRegression(RegressionAlgorithm):
"""Training algorithm of ridge regression.

Parameters::
Parameters
----------

alpha: float
alpha : float
The regularization coefficient.

.. versionadded:: 2.2.0

beta: float
beta : float
The regularization coefficient.

.. deprecated:: 2.2.0
Please use `alpha` to set regularization factor.

name: str
name : str
The name of the algorithm.
"""

Expand Down Expand Up @@ -295,16 +302,17 @@ def __repr__(self):
class LassoRegression(RegressionAlgorithm):
"""Lasso regression method for offline training.

Parameters::
Parameters
----------

alpha: float
alpha : float
Constant that multiplies the L1 term. Defaults to 1.0.
`alpha = 0` is equivalent to an ordinary least square.
max_iter: int
max_iter : int
The maximum number of iterations.
degree: int
degree : int
The degree of the polynomial that the independent variable X will be transformed to.
name: str
name : str
The name of the algorithm.
"""

Expand Down Expand Up @@ -350,17 +358,18 @@ def predict(self, W, X):
class LogisticRegression(RegressionAlgorithm):
"""Logistic regression method for offline training.

Parameters::
Parameters
----------

learning_rate: float
learning_rate : float
The step length that will be taken when following the negative gradient during
training.
gradient_descent: boolean
gradient_descent : boolean
True or false depending on if gradient descent should be used when training. If
false then we use batch optimization by least squares.
max_iter: int
max_iter : int
The number of iteration to optimize the parameters.
name: str
name : str
The name of the algorithm.
"""

Expand Down Expand Up @@ -498,18 +507,19 @@ def predict(self, W, X):
class ElasticNetRegression(RegressionAlgorithm):
"""

Parameters:
Parameters
----------
-----------
degree: int
degree : int
The degree of the polynomial that the independent variable X will be transformed to.
reg_factor: float
reg_factor : float
The factor that will determine the amount of regularization and feature
shrinkage.
l1_ration: float
l1_ration : float
Weighs the contribution of l1 and l2 regularization.
n_iterations: float
n_iterations : float
The number of training iterations the algorithm will tune the weights for.
learning_rate: float
learning_rate : float
The step length that will be used when updating the weights.
"""

Expand Down Expand Up @@ -563,11 +573,12 @@ def get_supported_offline_methods():
def register_offline_method(name: str, method: OfflineAlgorithm):
"""Register a new offline learning method.

Parameters::
Parameters
----------

name: str
name : str
The method name.
method: OfflineAlgorithm
method : OfflineAlgorithm
The function method.
"""
if name in name2func:
Expand Down
56 changes: 32 additions & 24 deletions brainpy/algorithms/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,22 @@ def __init__(self, name=None):
def __call__(self, *args, **kwargs):
"""The training procedure.

Parameters::
Parameters
----------

identifier: str
identifier : str
The variable name.
target: ArrayType
target : ArrayType
The 2d target data with the shape of `(num_batch, num_output)`.
input: ArrayType
input : ArrayType
The 2d input data with the shape of `(num_batch, num_input)`.
output: ArrayType
output : ArrayType
The 2d output data with the shape of `(num_batch, num_output)`.

Returns::
Returns
-------

weight: ArrayType
weight : ArrayType
The weights after fit.
"""
return self.call(*args, **kwargs)
Expand All @@ -69,20 +71,22 @@ def register_target(self, *args, **kwargs):
def call(self, target, input, output, identifier: str = ''):
"""The training procedure.

Parameters::
Parameters
----------

identifier: str
identifier : str
The variable name.
target: ArrayType
target : ArrayType
The 2d target data with the shape of `(num_batch, num_output)`.
input: ArrayType
input : ArrayType
The 2d input data with the shape of `(num_batch, num_input)`.
output: ArrayType
output : ArrayType
The 2d output data with the shape of `(num_batch, num_output)`.

Returns::
Returns
-------

weight: ArrayType
weight : ArrayType
The weights after fit.
"""
raise NotImplementedError('Must implement the call() function by the subclass itself.')
Expand All @@ -100,15 +104,17 @@ class RLS(OnlineAlgorithm):
contrast to other algorithms such as the least mean squares
(LMS) that aim to reduce the mean square error.

See Also::
See Also
--------

LMS, ForceLearning

Parameters::
Parameters
----------

alpha: float
alpha : float
The learning rate.
name: str
name : str
The algorithm name.

"""
Expand Down Expand Up @@ -176,11 +182,12 @@ class LMS(OnlineAlgorithm):
based on the error at the current time. It was invented in 1960 by
Stanford University professor Bernard Widrow and his first Ph.D. student, Ted Hoff.

Parameters::
Parameters
----------

alpha: float
alpha : float
The learning rate.
name: str
name : str
The target name.
"""

Expand Down Expand Up @@ -211,11 +218,12 @@ def get_supported_online_methods():
def register_online_method(name: str, method: OnlineAlgorithm):
"""Register a new oneline learning method.

Parameters::
Parameters
----------

name: str
name : str
The method name.
method: callable
method : callable
The function method.
"""
if name in name2func:
Expand Down
Loading
Loading