From 9cb865d84cf3a431225dd2182081089e0c95df90 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 19 Jun 2026 12:44:59 +0800 Subject: [PATCH] docs: convert public-API docstrings to NumPy style + enforce convention Convert all Google-style (`Args:`/`Returns:`) and broken-RST (`Parameters::`/`Returns::`/`:param:`) docstrings of public APIs to NumPy-doc style across 142 modules. Reformat the shared parameter fragments in `brainpy/dyn/_docs.py` and align their `%s`/`{}` injection indentation (lif.py, base.py) so napoleon parses every injected `Parameters` section correctly. Config: - docs/conf.py: pin napoleon to NumPy style (google off, numpy on, plus use_param/use_rtype/preprocess_types). - pyproject.toml: record `[tool.pydocstyle] convention = "numpy"` as the single source of truth for the docstring convention. Format-only change: with docstrings stripped, the AST of every one of the 141 touched source modules is byte-identical to before. napoleon parses all 1256 public `Parameters` sections with zero errors and zero empty sections. --- brainpy/algorithms/offline.py | 91 +-- brainpy/algorithms/online.py | 56 +- brainpy/analysis/highdim/slow_points.py | 73 +-- brainpy/analysis/lowdim/lowdim_analyzer.py | 19 +- brainpy/analysis/lowdim/lowdim_bifurcation.py | 28 +- brainpy/analysis/lowdim/lowdim_phase_plane.py | 17 +- brainpy/analysis/stability.py | 9 +- brainpy/analysis/utils/measurement.py | 22 +- brainpy/analysis/utils/optimization.py | 21 +- brainpy/analysis/utils/others.py | 18 +- brainpy/check.py | 68 ++- brainpy/checkpoints.py | 30 +- brainpy/connect/base.py | 35 +- brainpy/connect/random_conn.py | 79 ++- brainpy/connect/regular_conn.py | 13 +- brainpy/context.py | 12 +- brainpy/delay.py | 149 +++-- brainpy/dnn/activations.py | 277 +++++---- brainpy/dnn/conv.py | 398 +++++++------ brainpy/dnn/dropout.py | 16 +- brainpy/dnn/function.py | 46 +- brainpy/dnn/interoperation_flax.py | 30 +- brainpy/dnn/linear.py | 399 ++++++++----- brainpy/dnn/normalization.py | 143 ++--- brainpy/dnn/pooling.py | 531 +++++++++--------- brainpy/dyn/_docs.py | 80 ++- brainpy/dyn/channels/calcium.py | 78 ++- .../channels/hyperpolarization_activated.py | 9 +- brainpy/dyn/channels/leaky.py | 3 +- brainpy/dyn/channels/potassium.py | 297 ++++++---- brainpy/dyn/channels/potassium_calcium.py | 6 +- .../channels/potassium_calcium_compatible.py | 6 +- brainpy/dyn/channels/potassium_compatible.py | 150 +++-- brainpy/dyn/channels/sodium.py | 54 +- brainpy/dyn/channels/sodium_compatible.py | 54 +- brainpy/dyn/ions/base.py | 74 ++- brainpy/dyn/ions/calcium.py | 23 +- brainpy/dyn/neurons/base.py | 7 +- brainpy/dyn/neurons/hh.py | 155 ++--- brainpy/dyn/neurons/lif.py | 199 ++++--- brainpy/dyn/others/common.py | 28 +- brainpy/dyn/others/input.py | 26 +- brainpy/dyn/others/noise.py | 15 +- brainpy/dyn/outs/outputs.py | 39 +- brainpy/dyn/projections/align_post.py | 99 ++-- brainpy/dyn/projections/align_pre.py | 108 ++-- brainpy/dyn/projections/conn.py | 3 +- brainpy/dyn/projections/delta.py | 36 +- brainpy/dyn/projections/inputs.py | 25 +- brainpy/dyn/projections/plasticity.py | 45 +- brainpy/dyn/projections/vanilla.py | 18 +- brainpy/dyn/rates/nvar.py | 19 +- brainpy/dyn/rates/populations.py | 86 +-- brainpy/dyn/rates/reservoir.py | 20 +- brainpy/dyn/rates/rnncells.py | 187 +++--- brainpy/dyn/synapses/abstract_models.py | 79 ++- brainpy/dyn/synapses/bio_models.py | 57 +- brainpy/dyn/synapses/delay_couplings.py | 50 +- .../dynold/experimental/abstract_synapses.py | 58 +- brainpy/dynold/experimental/others.py | 9 +- brainpy/dynold/experimental/syn_outs.py | 31 +- brainpy/dynold/experimental/syn_plasticity.py | 24 +- brainpy/dynold/neurons/biological_models.py | 123 ++-- brainpy/dynold/neurons/fractional_models.py | 21 +- brainpy/dynold/neurons/reduced_models.py | 53 +- brainpy/dynold/synapses/abstract_models.py | 110 ++-- brainpy/dynold/synapses/base.py | 11 +- brainpy/dynold/synapses/biological_models.py | 60 +- brainpy/dynold/synouts/conductances.py | 18 +- brainpy/dynold/synouts/ions.py | 13 +- .../dynold/synplast/short_term_plasticity.py | 24 +- brainpy/dynsys.py | 252 +++++---- brainpy/encoding/stateful_encoding.py | 119 ++-- brainpy/encoding/stateless_encoding.py | 99 ++-- brainpy/helpers.py | 40 +- brainpy/initialize/decay_inits.py | 12 +- brainpy/initialize/generic.py | 83 +-- brainpy/initialize/random_inits.py | 61 +- brainpy/initialize/regular_inits.py | 14 +- brainpy/inputs/currents.py | 94 ++-- brainpy/integrators/fde/Caputo.py | 38 +- brainpy/integrators/fde/GL.py | 19 +- brainpy/integrators/fde/base.py | 9 +- brainpy/integrators/fde/generic.py | 29 +- brainpy/integrators/joint_eq.py | 5 +- brainpy/integrators/ode/adaptive_rk.py | 21 +- brainpy/integrators/ode/base.py | 9 +- brainpy/integrators/ode/explicit_rk.py | 27 +- brainpy/integrators/ode/exponential.py | 6 +- brainpy/integrators/ode/generic.py | 36 +- brainpy/integrators/runner.py | 37 +- brainpy/integrators/sde/generic.py | 19 +- brainpy/integrators/sde/normal.py | 12 +- brainpy/integrators/sde/srk_scalar.py | 6 +- brainpy/integrators/utils.py | 6 +- brainpy/losses/comparison.py | 521 ++++++++++------- brainpy/losses/regularization.py | 59 +- brainpy/math/activations.py | 237 ++++---- brainpy/math/compat_numpy.py | 36 +- brainpy/math/compat_pytorch.py | 69 +-- brainpy/math/compat_tensorflow.py | 330 ++++++----- brainpy/math/delayvars.py | 59 +- brainpy/math/environment.py | 191 ++++--- brainpy/math/event/csr_matmat.py | 20 +- brainpy/math/event/csr_matvec.py | 18 +- brainpy/math/interoperability.py | 30 +- brainpy/math/jitconn/matvec.py | 130 +++-- brainpy/math/ndarray.py | 23 +- brainpy/math/object_transform/_utils.py | 6 +- brainpy/math/object_transform/autograd.py | 71 ++- brainpy/math/object_transform/base.py | 193 ++++--- brainpy/math/object_transform/collectors.py | 23 +- brainpy/math/object_transform/controls.py | 102 ++-- brainpy/math/object_transform/function.py | 32 +- brainpy/math/object_transform/jit.py | 18 +- brainpy/math/object_transform/variables.py | 15 +- brainpy/math/others.py | 64 ++- brainpy/math/pre_syn_post.py | 205 ++++--- brainpy/math/scales.py | 16 +- brainpy/math/sharding.py | 93 +-- brainpy/math/sparse/csr_mm.py | 20 +- brainpy/math/sparse/csr_mv.py | 18 +- brainpy/math/sparse/jax_prim.py | 35 +- brainpy/measure.py | 14 +- brainpy/mixin.py | 74 ++- brainpy/optim/optimizer.py | 122 ++-- brainpy/optim/scheduler.py | 57 +- brainpy/runners.py | 55 +- brainpy/running/jax_multiprocessing.py | 32 +- brainpy/running/native_multiprocessing.py | 14 +- brainpy/running/pathos_multiprocessing.py | 62 +- brainpy/running/runner.py | 15 +- brainpy/tools/codes.py | 24 +- brainpy/tools/dicts.py | 23 +- brainpy/tools/functions.py | 3 +- brainpy/tools/others.py | 6 +- brainpy/tools/progress.py | 172 ++++-- brainpy/train/back_propagation.py | 75 +-- brainpy/train/base.py | 23 +- brainpy/train/offline.py | 41 +- brainpy/train/online.py | 37 +- brainpy/transform.py | 34 +- docs/conf.py | 13 + pyproject.toml | 10 + 144 files changed, 5737 insertions(+), 4126 deletions(-) diff --git a/brainpy/algorithms/offline.py b/brainpy/algorithms/offline.py index 4f31093bb..9e36a3eb3 100644 --- a/brainpy/algorithms/offline.py +++ b/brainpy/algorithms/offline.py @@ -60,18 +60,20 @@ def __init__(self, name=None): def __call__(self, targets, inputs, outputs=None): """The training procedure. - Parameters:: + Parameters + ---------- - targets: ArrayType + targets : ArrayType The 2d target data with the shape of `(num_batch, num_output)`. - inputs: ArrayType + inputs : ArrayType The 2d input data with the shape of `(num_batch, num_input)`. - outputs: ArrayType + outputs : ArrayType The 2d output data with the shape of `(num_batch, num_output)`. - Returns:: + Returns + ------- - weight: ArrayType + weight : ArrayType The weights after fit. """ return self.call(targets, inputs, outputs) @@ -79,23 +81,25 @@ def __call__(self, targets, inputs, outputs=None): def call(self, targets, inputs, outputs=None) -> ArrayType: """The training procedure. - Parameters:: + Parameters + ---------- - inputs: ArrayType + inputs : ArrayType The 3d input data with the shape of `(num_batch, num_time, num_input)`, or, the 2d input data with the shape of `(num_time, num_input)`. - targets: ArrayType + targets : ArrayType The 3d target data with the shape of `(num_batch, num_time, num_output)`, or the 2d target data with the shape of `(num_time, num_output)`. - outputs: ArrayType + outputs : ArrayType The 3d output data with the shape of `(num_batch, num_time, num_output)`, or the 2d output data with the shape of `(num_time, num_output)`. - Returns:: + Returns + ------- - weight: ArrayType + weight : ArrayType The weights after fit. """ raise NotImplementedError('Must implement the __call__ function by the subclass itself.') @@ -117,11 +121,12 @@ class RegressionAlgorithm(OfflineAlgorithm): """ Base regression model. Models the relationship between a scalar dependent variable y and the independent variables X. - Parameters:: + Parameters + ---------- - max_iter: int + max_iter : int The number of training iterations the algorithm will tune the weights for. - learning_rate: float + learning_rate : float The step length that will be used when updating the weights. """ @@ -178,9 +183,10 @@ def predict(self, W, X): class LinearRegression(RegressionAlgorithm): """Training algorithm of least-square regression. - Parameters:: + Parameters + ---------- - name: str + name : str The name of the algorithm. """ @@ -221,20 +227,21 @@ def call(self, targets, inputs, outputs=None): class RidgeRegression(RegressionAlgorithm): """Training algorithm of ridge regression. - Parameters:: + Parameters + ---------- - alpha: float + alpha : float The regularization coefficient. .. versionadded:: 2.2.0 - beta: float + beta : float The regularization coefficient. .. deprecated:: 2.2.0 Please use `alpha` to set regularization factor. - name: str + name : str The name of the algorithm. """ @@ -295,16 +302,17 @@ def __repr__(self): class LassoRegression(RegressionAlgorithm): """Lasso regression method for offline training. - Parameters:: + Parameters + ---------- - alpha: float + alpha : float Constant that multiplies the L1 term. Defaults to 1.0. `alpha = 0` is equivalent to an ordinary least square. - max_iter: int + max_iter : int The maximum number of iterations. - degree: int + degree : int The degree of the polynomial that the independent variable X will be transformed to. - name: str + name : str The name of the algorithm. """ @@ -350,17 +358,18 @@ def predict(self, W, X): class LogisticRegression(RegressionAlgorithm): """Logistic regression method for offline training. - Parameters:: + Parameters + ---------- - learning_rate: float + learning_rate : float The step length that will be taken when following the negative gradient during training. - gradient_descent: boolean + gradient_descent : boolean True or false depending on if gradient descent should be used when training. If false then we use batch optimization by least squares. - max_iter: int + max_iter : int The number of iteration to optimize the parameters. - name: str + name : str The name of the algorithm. """ @@ -498,18 +507,19 @@ def predict(self, W, X): class ElasticNetRegression(RegressionAlgorithm): """ - Parameters: + Parameters + ---------- ----------- - degree: int + degree : int The degree of the polynomial that the independent variable X will be transformed to. - reg_factor: float + reg_factor : float The factor that will determine the amount of regularization and feature shrinkage. - l1_ration: float + l1_ration : float Weighs the contribution of l1 and l2 regularization. - n_iterations: float + n_iterations : float The number of training iterations the algorithm will tune the weights for. - learning_rate: float + learning_rate : float The step length that will be used when updating the weights. """ @@ -563,11 +573,12 @@ def get_supported_offline_methods(): def register_offline_method(name: str, method: OfflineAlgorithm): """Register a new offline learning method. - Parameters:: + Parameters + ---------- - name: str + name : str The method name. - method: OfflineAlgorithm + method : OfflineAlgorithm The function method. """ if name in name2func: diff --git a/brainpy/algorithms/online.py b/brainpy/algorithms/online.py index 9dd54c35d..26bf3b0ca 100644 --- a/brainpy/algorithms/online.py +++ b/brainpy/algorithms/online.py @@ -45,20 +45,22 @@ def __init__(self, name=None): def __call__(self, *args, **kwargs): """The training procedure. - Parameters:: + Parameters + ---------- - identifier: str + identifier : str The variable name. - target: ArrayType + target : ArrayType The 2d target data with the shape of `(num_batch, num_output)`. - input: ArrayType + input : ArrayType The 2d input data with the shape of `(num_batch, num_input)`. - output: ArrayType + output : ArrayType The 2d output data with the shape of `(num_batch, num_output)`. - Returns:: + Returns + ------- - weight: ArrayType + weight : ArrayType The weights after fit. """ return self.call(*args, **kwargs) @@ -69,20 +71,22 @@ def register_target(self, *args, **kwargs): def call(self, target, input, output, identifier: str = ''): """The training procedure. - Parameters:: + Parameters + ---------- - identifier: str + identifier : str The variable name. - target: ArrayType + target : ArrayType The 2d target data with the shape of `(num_batch, num_output)`. - input: ArrayType + input : ArrayType The 2d input data with the shape of `(num_batch, num_input)`. - output: ArrayType + output : ArrayType The 2d output data with the shape of `(num_batch, num_output)`. - Returns:: + Returns + ------- - weight: ArrayType + weight : ArrayType The weights after fit. """ raise NotImplementedError('Must implement the call() function by the subclass itself.') @@ -100,15 +104,17 @@ class RLS(OnlineAlgorithm): contrast to other algorithms such as the least mean squares (LMS) that aim to reduce the mean square error. - See Also:: + See Also + -------- LMS, ForceLearning - Parameters:: + Parameters + ---------- - alpha: float + alpha : float The learning rate. - name: str + name : str The algorithm name. """ @@ -176,11 +182,12 @@ class LMS(OnlineAlgorithm): based on the error at the current time. It was invented in 1960 by Stanford University professor Bernard Widrow and his first Ph.D. student, Ted Hoff. - Parameters:: + Parameters + ---------- - alpha: float + alpha : float The learning rate. - name: str + name : str The target name. """ @@ -211,11 +218,12 @@ def get_supported_online_methods(): def register_online_method(name: str, method: OnlineAlgorithm): """Register a new oneline learning method. - Parameters:: + Parameters + ---------- - name: str + name : str The method name. - method: callable + method : callable The function method. """ if name in name2func: diff --git a/brainpy/analysis/highdim/slow_points.py b/brainpy/analysis/highdim/slow_points.py index b1ec7f217..e398e965d 100644 --- a/brainpy/analysis/highdim/slow_points.py +++ b/brainpy/analysis/highdim/slow_points.py @@ -57,7 +57,8 @@ class SlowPointFinder(base.DSAnalyzer): - exclude any non-unique fixed points according to a tolerance - exclude any far-away "outlier" fixed points - Parameters:: + Parameters + ---------- f_cell : callable, function, DynamicalSystem The target of computing the recurrent units. @@ -71,7 +72,7 @@ class SlowPointFinder(base.DSAnalyzer): verbose : bool Whether output the optimization progress. - f_loss: callable + f_loss : callable The loss function. - If ``f_type`` is `"discrete"`, the loss function must receive three arguments, i.e., ``loss(outputs, targets, axis)``. @@ -79,29 +80,29 @@ class SlowPointFinder(base.DSAnalyzer): arguments, i.e., ``loss(outputs, axis)``. .. versionadded:: 2.2.0 - t: float + t : float Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. The time to evaluate the fixed points. Default is 0. .. versionadded:: 2.2.0 - dt: float + dt : float Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. The numerical integration step, which can be used when . The default is given by `brainpy.math.get_dt()`. .. versionadded:: 2.2.0 - inputs: sequence, callable + inputs : sequence, callable Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. Same as ``inputs`` in :py:class:`~.DSRunner`. .. versionadded:: 2.2.0 - excluded_vars: sequence, dict + excluded_vars : sequence, dict Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. The excluded variables (can be a sequence of `Variable` instances). These variables will not be included for optimization of fixed points. .. versionadded:: 2.2.0 - target_vars: dict + target_vars : dict Parameter for ``f_cell`` is instance of :py:class:`~.DynamicalSystem`. The target variables (can be a dict of `Variable` instances). These variables will be included for optimization of fixed points. @@ -114,7 +115,7 @@ class SlowPointFinder(base.DSAnalyzer): .. deprecated:: 2.2.0 Has been removed. Please use ``f_loss`` to set different loss function. - fun_inputs: callable + fun_inputs : callable .. deprecated:: 2.3.1 Will be removed since version 2.4.0. @@ -318,13 +319,14 @@ def find_fps_with_gd_method( ): """Optimize fixed points with gradient descent methods. - Parameters:: + Parameters + ---------- candidates : ArrayType, dict The array with the shape of (batch size, state dim) of hidden states of RNN to start training for fixed points. - tolerance: float + tolerance : float The loss threshold during optimization num_opt : int @@ -333,7 +335,7 @@ def find_fps_with_gd_method( num_batch : int Print training information during optimization every so often. - optimizer: optim.Optimizer + optimizer : optim.Optimizer The optimizer instance. .. versionadded:: 2.1.2 @@ -423,11 +425,12 @@ def find_fps_with_opt_solver( ): """Optimize fixed points with nonlinear optimization solvers. - Parameters:: + Parameters + ---------- - candidates: ArrayType, dict + candidates : ArrayType, dict The candidate (initial) fixed points. - opt_solver: str + opt_solver : str The solver of the optimization. """ # optimization function @@ -468,9 +471,10 @@ def find_fps_with_opt_solver( def filter_loss(self, tolerance: float = 1e-5): """Filter fixed points whose speed larger than a given tolerance. - Parameters:: + Parameters + ---------- - tolerance: float + tolerance : float Discard fixed points with squared speed larger than this value. """ if self.verbose: @@ -493,9 +497,10 @@ def filter_loss(self, tolerance: float = 1e-5): def keep_unique(self, tolerance: float = 2.5e-2): """Filter unique fixed points by choosing a representative within tolerance. - Parameters:: + Parameters + ---------- - tolerance: float + tolerance : float Tolerance for determination of identical fixed points. """ if self.verbose: @@ -515,9 +520,10 @@ def keep_unique(self, tolerance: float = 2.5e-2): def exclude_outliers(self, tolerance: float = 1e0): """Exclude points whose closest neighbor is further than threshold. - Parameters:: + Parameters + ---------- - tolerance: float + tolerance : float Any point whose closest fixed point is greater than tol is an outlier. """ if self.verbose: @@ -560,19 +566,20 @@ def compute_jacobians( ): """Compute the Jacobian matrices at the points. - Parameters:: + Parameters + ---------- - points: np.ndarray, bm.ArrayType, jax.ndarray + points : np.ndarray, bm.ArrayType, jax.ndarray The fixed points with the shape of (num_point, num_dim). - stack_dict_var: bool + stack_dict_var : bool Stack dictionary variables to calculate Jacobian matrix? - plot: bool + plot : bool Plot the decomposition results of the Jacobian matrix. - num_col: int + num_col : int The number of the figure column. - len_col: int + len_col : int The length of each column. - len_row: int + len_row : int The length of each row. """ # check data @@ -620,16 +627,18 @@ def compute_jacobians( def decompose_eigenvalues(matrices, sort_by='magnitude', do_compute_lefts=False): """Compute the eigenvalues of the matrices. - Parameters:: + Parameters + ---------- - matrices: np.ndarray, bm.ArrayType, jax.ndarray + matrices : np.ndarray, bm.ArrayType, jax.ndarray A 3D array with the shape of (num_matrices, dim, dim). - sort_by: str + sort_by : str The method of sorting. - do_compute_lefts: bool + do_compute_lefts : bool Compute the left eigenvectors? Requires a pseudo-inverse call. - Returns:: + Returns + ------- decompositions : list A list of dictionaries with sorted eigenvalues components: diff --git a/brainpy/analysis/lowdim/lowdim_analyzer.py b/brainpy/analysis/lowdim/lowdim_analyzer.py index 02f8dfca9..01ad05271 100644 --- a/brainpy/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/analysis/lowdim/lowdim_analyzer.py @@ -53,7 +53,8 @@ class LowDimAnalyzer(DSAnalyzer): .. note:: ``LowDimAnalyzer`` cannot analyze dynamical system depends on time :math:`t`. - Parameters:: + Parameters + ---------- model : Any, ODEIntegrator, sequence of ODEIntegrator, DynamicalSystem A model of the population, the integrator function, @@ -78,7 +79,7 @@ class LowDimAnalyzer(DSAnalyzer): - Moreover, you can also set ``resolutions={var1: Array([...]), var2: 0.1}`` to specify the search points need to explore for variable `var1`. This will be useful to set sense search points at some inflection points. - lim_scale: float + lim_scale : float The axis limit scale factor. Default is 1.05. The setting means the axes will be clipped to ``[var_min * (1-lim_scale)/2, var_max * (var_max-1)/2]``. options : optional, dict @@ -360,14 +361,16 @@ def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_ >>> all_par1.append(jnp.ones_like(xs) * p1) >>> all_par2.append(jnp.ones_like(xs) * p2) - Parameters:: + Parameters + ---------- candidates args tol_aux loss_screen - Returns:: + Returns + ------- """ # candidates: xs, a vector with the length of self.resolutions[self.x_var] @@ -927,9 +930,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7, >>> all_par1.append(jnp.ones_like(nullcline_points) * p1) >>> all_par2.append(jnp.ones_like(nullcline_points) * p2) - Parameters:: + Parameters + ---------- - candidates: np.ndarray, jnp.ndarray + candidates : np.ndarray, jnp.ndarray The candidate points (batched) to optimize, like the nullcline points. args : tuple The parameters (batched). @@ -937,7 +941,8 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7, tol_unique : float tol_opt_candidate : float, optional - Returns:: + Returns + ------- res : tuple The fixed point results. diff --git a/brainpy/analysis/lowdim/lowdim_bifurcation.py b/brainpy/analysis/lowdim/lowdim_bifurcation.py index 663c8b62a..d39933c0c 100644 --- a/brainpy/analysis/lowdim/lowdim_bifurcation.py +++ b/brainpy/analysis/lowdim/lowdim_bifurcation.py @@ -187,44 +187,46 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, select_candidates='aux_rank', num_rank=100): r"""Make the bifurcation analysis. - Parameters:: + Parameters + ---------- - with_plot: bool + with_plot : bool Whether plot the bifurcation figure. - show: bool + show : bool Whether show the figure. - with_return: bool + with_return : bool Whether return the computed bifurcation results. - tol_aux: float + tol_aux : float The loss tolerance of auxiliary function :math:`f_{aux}` to confirm the fixed point. Default is 1e-7. Once :math:`f_{aux}(x_1) < \mathrm{tol\_aux}`, :math:`x_1` will be a fixed point. - tol_unique: float + tol_unique : float The tolerance of distance between candidate fixed points to confirm they are the same. Default is 1e-2. If :math:`|x_1 - x_2| > \mathrm{tol\_unique}`, then :math:`x_1` and :math:`x_2` are unique fixed points. Otherwise, :math:`x_1` and :math:`x_2` will be treated as a same fixed point. - tol_opt_candidate: float, optional + tol_opt_candidate : float, optional The tolerance of auxiliary function :math:`f_{aux}` to select candidate initial points for fixed point optimization. - num_par_segments: int, sequence of int + num_par_segments : int, sequence of int How to segment parameters. - num_fp_segment: int + num_fp_segment : int How to segment fixed points. - nullcline_aux_filter: float + nullcline_aux_filter : float The - select_candidates: str + select_candidates : str The method to select candidate fixed points. It can be: - ``fx-nullcline``: use the points of fx-nullcline. - ``fy-nullcline``: use the points of fy-nullcline. - ``nullclines``: use the points in both of fx-nullcline and fy-nullcline. - ``aux_rank``: use the minimal value of points for the auxiliary function. - num_rank: int + num_rank : int The number of candidates to be used to optimize the fixed points. rank to use. - Returns:: + Returns + ------- results : tuple Return a tuple of analyzed results: diff --git a/brainpy/analysis/lowdim/lowdim_phase_plane.py b/brainpy/analysis/lowdim/lowdim_phase_plane.py index f0a1d45b2..3bb6b68d8 100644 --- a/brainpy/analysis/lowdim/lowdim_phase_plane.py +++ b/brainpy/analysis/lowdim/lowdim_phase_plane.py @@ -42,7 +42,8 @@ class PhasePlane1D(Num1DAnalyzer): - Vector fields - Fixed points - Parameters:: + Parameters + ---------- model : Any A model of the population, the integrator function, @@ -138,7 +139,8 @@ def plot_fixed_point(self, show=False, with_plot=True, with_return=False): class PhasePlane2D(Num2DAnalyzer): """Phase plane analyzer for 2D dynamical system. - Parameters:: + Parameters + ---------- model : Any A model of the population, the integrator function, @@ -184,9 +186,10 @@ def plot_vector_field(self, with_plot=True, with_return=False, plot_method='streamplot', plot_style=None, show=False): """Plot the vector field. - Parameters:: + Parameters + ---------- - with_plot: bool + with_plot : bool with_return : bool show : bool plot_method : str @@ -382,7 +385,8 @@ def plot_trajectory(self, initials, duration, plot_durations=None, axes='v-v', dt=None, show=False, with_plot=True, with_return=False, **kwargs): """Plot trajectories according to the settings. - Parameters:: + Parameters + ---------- initials : list, tuple, dict The initial value setting of the targets. It can be a tuple/list of floats to specify @@ -478,7 +482,8 @@ def plot_trajectory(self, initials, duration, plot_durations=None, axes='v-v', def plot_limit_cycle_by_sim(self, initials, duration, tol=0.01, show=False, dt=None): """Plot trajectories according to the settings. - Parameters:: + Parameters + ---------- initials : list, tuple The initial value setting of the targets. diff --git a/brainpy/analysis/stability.py b/brainpy/analysis/stability.py index 7d799fb5c..49c588e5c 100644 --- a/brainpy/analysis/stability.py +++ b/brainpy/analysis/stability.py @@ -92,17 +92,20 @@ def stability_analysis(derivatives): The analysis is referred to [1]_. - Parameters:: + Parameters + ---------- derivatives : float, tuple, list, np.ndarray The derivative of the f. - Returns:: + Returns + ------- fp_type : str The type of the fixed point. - References:: + References + ---------- .. [1] http://www.egwald.ca/nonlineardynamics/twodimensionaldynamics.php diff --git a/brainpy/analysis/utils/measurement.py b/brainpy/analysis/utils/measurement.py index d4bc7c01d..1f67b4908 100644 --- a/brainpy/analysis/utils/measurement.py +++ b/brainpy/analysis/utils/measurement.py @@ -59,14 +59,16 @@ def euclidean_distance(points: np.ndarray, num_point=None): >>> from scipy.spatial.distance import squareform, pdist >>> f = lambda points: squareform(pdist(points, metric="euclidean")) - Parameters:: + Parameters + ---------- - points: ArrayType + points : ArrayType The points. - Returns:: + Returns + ------- - dist_matrix: jnp.ndarray + dist_matrix : jnp.ndarray The distance matrix. """ @@ -106,15 +108,17 @@ def euclidean_distance_jax(points: Union[jnp.ndarray, bm.ndarray], num_point=Non >>> from scipy.spatial.distance import squareform, pdist >>> f = lambda points: squareform(pdist(points, metric="euclidean")) - Parameters:: + Parameters + ---------- - points: ArrayType + points : ArrayType The points. - num_point: int + num_point : int - Returns:: + Returns + ------- - dist_matrix: ArrayType + dist_matrix : ArrayType The distance matrix. """ if isinstance(points, dict): diff --git a/brainpy/analysis/utils/optimization.py b/brainpy/analysis/utils/optimization.py index 6c6755922..7bd1f84b6 100644 --- a/brainpy/analysis/utils/optimization.py +++ b/brainpy/analysis/utils/optimization.py @@ -237,18 +237,19 @@ def scipy_minimize_with_jax(fun, x0, """ A simple wrapper for scipy.optimize.minimize using JAX. - Parameters:: + Parameters + ---------- - fun: function + fun : function The objective function to be minimized, written in JAX code so that it is automatically differentiable. It is of type, ```fun: x, *args -> float``` where `x` is a PyTree and args is a tuple of the fixed parameters needed to completely specify the function. - x0: jnp.ndarray + x0 : jnp.ndarray Initial guess represented as a JAX PyTree. - args: tuple, optional. + args : tuple, optional. Extra arguments passed to the objective function and its derivative. Must consist of valid JAX types; e.g. the leaves of the PyTree must be floats. @@ -334,7 +335,8 @@ def scipy_minimize_with_jax(fun, x0, ```callback(xk)``` where `xk` is the current parameter vector, represented as a PyTree. - Returns:: + Returns + ------- res : The optimization result represented as a ``OptimizeResult`` object. Important attributes are: @@ -431,7 +433,8 @@ def numpy_brentq(f, a, b, args=(), xtol=2e-14, maxiter=200, rtol=4 * np.finfo(fl Uses the classic Brent's method to find a zero of the function `f` on the sign changing interval [a , b]. - Parameters:: + Parameters + ---------- f : callable Python function returning a number. `f` must be continuous. @@ -553,14 +556,16 @@ def numpy_brentq(f, a, b, args=(), xtol=2e-14, maxiter=200, rtol=4 * np.finfo(fl def find_root_of_1d_numpy(f, f_points, args=(), tol=1e-8): """Find the roots of the given function by numerical methods. - Parameters:: + Parameters + ---------- f : callable The function. f_points : np.ndarray, list, tuple The value points. - Returns:: + Returns + ------- roots : list The roots. diff --git a/brainpy/analysis/utils/others.py b/brainpy/analysis/utils/others.py index bf7ff3abc..1c3b03334 100644 --- a/brainpy/analysis/utils/others.py +++ b/brainpy/analysis/utils/others.py @@ -104,14 +104,16 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]], tolerance: float = 2.5e-2): """Filter unique fixed points by choosing a representative within tolerance. - Parameters:: + Parameters + ---------- - candidates: np.ndarray, dict + candidates : np.ndarray, dict The fixed points with the shape of (num_point, num_dim). - tolerance: float + tolerance : float tolerance. - Returns:: + Returns + ------- fps_and_ids : tuple A 2-tuple of (kept fixed points, ids of kept fixed points). @@ -150,12 +152,14 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]], def keep_unique_jax(candidates, tolerance=2.5e-2): """Filter unique fixed points by choosing a representative within tolerance. - Parameters:: + Parameters + ---------- - candidates: Tesnor + candidates : Tesnor The fixed points with the shape of (num_point, num_dim). - Returns:: + Returns + ------- fps_and_ids : tuple A 2-tuple of (kept fixed points, ids of kept fixed points). diff --git a/brainpy/check.py b/brainpy/check.py index b3724d629..151cc187a 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -130,13 +130,15 @@ def is_shape_broadcastable(shapes, free_axes=(), return_format_shapes=False): See https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html for more details. - Parameters:: + Parameters + ---------- shapes free_axes return_format_shapes - Returns:: + Returns + ------- """ max_dim = max([len(shape) for shape in shapes]) @@ -335,17 +337,18 @@ def is_float( ) -> float: """Check float type. - Parameters:: + Parameters + ---------- - value: Any - name: optional, str - min_bound: optional, float + value : Any + name : optional, str + min_bound : optional, float The allowed minimum value. - max_bound: optional, float + max_bound : optional, float The allowed maximum value. - allow_none: bool + allow_none : bool Whether allow the value is None. - allow_int: bool + allow_int : bool Whether allow the value be an integer. """ if name is None: name = '' @@ -375,15 +378,16 @@ def is_float( def is_integer(value: int, name=None, min_bound=None, max_bound=None, allow_none=False): """Check integer type. - Parameters:: + Parameters + ---------- - value: int, optional - name: optional, str - min_bound: optional, int + value : int, optional + name : optional, str + min_bound : optional, int The allowed minimum value. - max_bound: optional, int + max_bound : optional, int The allowed maximum value. - allow_none: bool + allow_none : bool Whether allow the value is None. """ if name is None: name = '' @@ -462,13 +466,14 @@ def is_subclass( - the instance of ``B`` or ``C`` will also success to pass the check. - the instance of ``A`` will success to pass the check too. - Parameters:: + Parameters + ---------- - instance: Any + instance : Any The instance in the inheritance hierarchy tree. - supported_types: type, list of type, tuple of type + supported_types : type, list of type, tuple of type All types that are supported. - name: str + name : str The checking target name. """ mode_type = type(instance) @@ -510,13 +515,14 @@ def is_instance( - the instance of ``A`` or ``C`` or ``G`` will fail to pass the check. - the instance of ``B`` or ``D`` or ``E`` or ``F`` will success to pass the check. - Parameters:: + Parameters + ---------- - instance: Any + instance : Any The instance in the inheritance hierarchy tree. - supported_types: type, list of type, tuple of type + supported_types : type, list of type, tuple of type All types that are supported. - name: str + name : str The checking target name. """ if not name: @@ -610,13 +616,14 @@ def true_err_fun(arg, transforms): def jit_error(pred, err_fun, err_arg=None): """Check errors in a jit function. - Parameters:: + Parameters + ---------- - pred: bool, Array + pred : bool, Array The boolean prediction. - err_fun: callable + err_fun : callable The error function, which raise errors. - err_arg: any + err_arg : any The arguments which passed into `err_f`. """ from brainpy.math.interoperability import as_jax @@ -639,11 +646,12 @@ def jit_error(pred, err_fun, err_arg=None): def jit_error_checking_no_args(pred: bool, err: Exception): """Check errors in a jit function. - Parameters:: + Parameters + ---------- - pred: bool + pred : bool The boolean prediction. - err: Exception + err : Exception The error. """ from brainstate.transform import unvmap diff --git a/brainpy/checkpoints.py b/brainpy/checkpoints.py index 44c262bc6..03f06e091 100644 --- a/brainpy/checkpoints.py +++ b/brainpy/checkpoints.py @@ -60,25 +60,27 @@ def save_pytree( commit will happen inside an async callback, which can be explicitly waited by calling `async_manager.wait_previous_save()`. - Parameters:: + Parameters + ---------- - filename: str + filename : str str or pathlib-like path to store checkpoint files in. - target: Any + target : Any serializable flax object, usually a flax optimizer. - overwrite: bool + overwrite : bool overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False). - async_manager: optional, AsyncManager + async_manager : optional, AsyncManager if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly. - verbose: bool + verbose : bool Whether output the print information. - Returns:: + Returns + ------- - out: str + out : str Filename of saved checkpoint. """ return braintools.file.msgpack_save( @@ -97,16 +99,18 @@ def load_pytree( ) -> PyTree: """Load the checkpoint from the given checkpoint path. - Parameters:: + Parameters + ---------- - filename: str + filename : str checkpoint file or directory of checkpoints to restore from. - parallel: bool + parallel : bool whether to load seekable checkpoints in parallel, for speed. - Returns:: + Returns + ------- - out: Any + out : Any Restored `target` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in `target` unchanged. If a file path is specified and is not found, the passed-in `target` will be diff --git a/brainpy/connect/base.py b/brainpy/connect/base.py index f1a57b6e3..eb8ae21e1 100644 --- a/brainpy/connect/base.py +++ b/brainpy/connect/base.py @@ -101,7 +101,8 @@ def set_default_dtype(mat_dtype=None, idx_dtype=None): [0., 1., 0., 1.], [0., 0., 1., 0.]], dtype=float32) - Parameters:: + Parameters + ---------- mat_dtype : type The default dtype for connection matrix. @@ -186,14 +187,16 @@ def __repr__(self): def __call__(self, pre_size, post_size): """Create the concrete connections between two end objects. - Parameters:: + Parameters + ---------- pre_size : int, tuple of int, list of int The size of the pre-synaptic group. post_size : int, tuple of int, list of int The size of the post-synaptic group. - Returns:: + Returns + ------- conn : TwoEndConnector Return the self. @@ -214,7 +217,8 @@ def __call__(self, pre_size, post_size): def _reset_conn(self, pre_size, post_size): """Reset connection attributes. - Parameters:: + Parameters + ---------- pre_size : int, tuple of int, list of int The size of the pre-synaptic group. @@ -395,7 +399,8 @@ def _make_returns(self, structures, conn_data): def require(self, *structures): """Require all the connection data needed. - Examples:: + Examples + -------- >>> import brainpy as bp >>> conn = bp.connect.FixedProb(0.1) @@ -526,9 +531,10 @@ def build_conn(self): - ``build_coo()``: build a coo sparse connection data. - ``build_conn()``: deprecated. - Returns:: + Returns + ------- - conn: tuple, dict + conn : tuple, dict A tuple with two elements: connection type (str) and connection data. For example: ``return 'csr', (ind, indptr)`` Or a dict with three elements: csr, mat and coo. For example: @@ -549,9 +555,10 @@ def build_mat(self): - ``build_coo()``: build a coo sparse connection data. - ``build_conn()``: deprecated. - Returns:: + Returns + ------- - conn: Array + conn : Array A binary matrix with the shape ``(num_pre, num_post)``. """ pass @@ -560,9 +567,10 @@ def build_mat(self): def build_csr(self): """Build a csr sparse connection data. - Returns:: + Returns + ------- - conn: tuple + conn : tuple A tuple denoting the ``(indices, indptr)``. """ pass @@ -571,9 +579,10 @@ def build_csr(self): def build_coo(self): """Build a coo sparse connection data. - Returns:: + Returns + ------- - conn: tuple + conn : tuple A tuple denoting the ``(pre_ids, post_ids)``. """ pass diff --git a/brainpy/connect/random_conn.py b/brainpy/connect/random_conn.py index 91640a2bd..b096e834a 100644 --- a/brainpy/connect/random_conn.py +++ b/brainpy/connect/random_conn.py @@ -43,15 +43,16 @@ class FixedProb(TwoEndConnector): """Connect the post-synaptic neurons with fixed probability. - Parameters:: + Parameters + ---------- - prob: float + prob : float The conn probability. - pre_ratio: float + pre_ratio : float The ratio of pre-synaptic neurons to connect. include_self : bool Whether create (i, i) conn? - allow_multi_conn: bool + allow_multi_conn : bool Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? .. versionadded:: 2.2.3.2 @@ -158,13 +159,14 @@ def build_mat(self): class FixedTotalNum(TwoEndConnector): """Connect the synaptic neurons with fixed total number. - Parameters:: + Parameters + ---------- num : float,int The conn total number. allow_multi_conn : bool, optional Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. - seed: int, optional + seed : int, optional The random number seed. """ @@ -229,7 +231,8 @@ def __repr__(self): class FixedPreNum(FixedNum): """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. - Parameters:: + Parameters + ---------- num : float, int The conn probability (if "num" is float) or the fixed number of @@ -238,7 +241,7 @@ class FixedPreNum(FixedNum): Whether create (i, i) conn ? seed : None, int Seed the random generator. - allow_multi_conn: bool + allow_multi_conn : bool Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? .. versionadded:: 2.2.3.2 @@ -289,7 +292,8 @@ def single_conn(): class FixedPostNum(FixedNum): """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. - Parameters:: + Parameters + ---------- num : float, int The conn probability (if "num" is float) or the fixed number of @@ -298,7 +302,7 @@ class FixedPostNum(FixedNum): Whether create (i, i) conn ? seed : None, int Seed the random generator. - allow_multi_conn: bool + allow_multi_conn : bool Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? .. versionadded:: 2.2.3.2 @@ -389,7 +393,8 @@ class GaussianProb(OneEndConnector): where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. - Parameters:: + Parameters + ---------- sigma : float Width of the Gaussian function. @@ -528,7 +533,8 @@ def build_mat(self, isOptimized=True): class SmallWorld(TwoEndConnector): r"""Build a Watts–Strogatz small-world graph. - Parameters:: + Parameters + ---------- num_neighbor : int Each node is joined with its `k` nearest neighbors in a ring @@ -540,7 +546,8 @@ class SmallWorld(TwoEndConnector): include_self : bool Whether include the node self. - Notes:: + Notes + ----- First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors @@ -549,7 +556,8 @@ class SmallWorld(TwoEndConnector): :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. - References:: + References + ---------- .. [1] Duncan J. Watts and Steven H. Strogatz, Collective dynamics of small-world networks, @@ -677,19 +685,22 @@ class ScaleFreeBA(TwoEndConnector): :math:`m` edges that are preferentially attached to existing nodes with high degree. - Parameters:: + Parameters + ---------- m : int Number of edges to attach from a new node to existing nodes seed : integer, random_state, or None (default) Indicator of random number generation state. - Raises:: + Raises + ------ ConnectorError If `m` does not satisfy ``1 <= m < n``. - References:: + References + ---------- .. [1] A. L. Barabási and R. Albert "Emergence of scaling in random networks", Science 286, pp 509-512, 1999. @@ -789,7 +800,8 @@ class ScaleFreeBADual(TwoEndConnector): edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that are preferentially attached to existing nodes with high degree. - Parameters:: + Parameters + ---------- m1 : int Number of edges to attach from a new node to existing nodes with probability :math:`p` @@ -800,12 +812,14 @@ class ScaleFreeBADual(TwoEndConnector): seed : integer, random_state, or None (default) Indicator of random number generation state. - Raises:: + Raises + ------ ConnectorError If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. - References:: + References + ---------- .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. """ @@ -915,7 +929,8 @@ class PowerLaw(TwoEndConnector): """Holme and Kim algorithm for growing graphs with powerlaw degree distribution and approximate average clustering. - Parameters:: + Parameters + ---------- m : int the number of random edges to add for each new node @@ -924,7 +939,8 @@ class PowerLaw(TwoEndConnector): seed : integer, random_state, or None (default) Indicator of random number generation state. - Notes:: + Notes + ----- The average clustering has a hard time getting above a certain cutoff that depends on :math:`m`. This cutoff is often quite low. The @@ -942,13 +958,15 @@ class PowerLaw(TwoEndConnector): since the initial :math:`m` nodes may not be all linked to a new node on the first iteration like the BA model. - Raises:: + Raises + ------ ConnectorError If :math:`m` does not satisfy :math:`1 <= m <= n` or :math:`p` does not satisfy :math:`0 <= p <= 1`. - References:: + References + ---------- .. [1] P. Holme and B. J. Kim, "Growing scale-free networks with tunable clustering", @@ -1086,17 +1104,18 @@ class ProbDist(TwoEndConnector): .. versionadded:: 2.1.13 - Parameters:: + Parameters + ---------- - dist: float, int + dist : float, int The maximum distance between two points. - prob: float + prob : float The connection probability, within 0. and 1. - pre_ratio: float + pre_ratio : float The ratio of pre-synaptic neurons to connect. - seed: optional, int + seed : optional, int The random seed. - include_self: bool + include_self : bool Whether include the point at the same position. """ diff --git a/brainpy/connect/regular_conn.py b/brainpy/connect/regular_conn.py index 4cc2a9178..7739f1073 100644 --- a/brainpy/connect/regular_conn.py +++ b/brainpy/connect/regular_conn.py @@ -198,7 +198,8 @@ def f_connect(pre_id): class GridFour(GridConn): """The nearest four neighbors connection method. - Parameters:: + Parameters + ---------- periodic_boundary : bool Whether the neuron encode the value space with the periodic boundary. @@ -236,7 +237,8 @@ def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray: class GridN(GridConn): """The nearest (2*N+1) * (2*N+1) neighbors conn method. - Parameters:: + Parameters + ---------- N : int Extend of the conn scope. For example: @@ -252,7 +254,7 @@ class GridN(GridConn): [x x x x x] include_self : bool Whether create (i, i) conn ? - periodic_boundary: bool + periodic_boundary : bool Whether the neuron encode the value space with the periodic boundary. .. versionadded:: 2.2.3.2 """ @@ -291,11 +293,12 @@ def _select_dist(self, dist: jnp.ndarray) -> jnp.ndarray: class GridEight(GridN): """The nearest eight neighbors conn method. - Parameters:: + Parameters + ---------- include_self : bool Whether create (i, i) conn ? - periodic_boundary: bool + periodic_boundary : bool Whether the neurons encode the value space with the periodic boundary. .. versionadded:: 2.2.3.2 """ diff --git a/brainpy/context.py b/brainpy/context.py index 0ef8c42c4..759d188ef 100644 --- a/brainpy/context.py +++ b/brainpy/context.py @@ -54,10 +54,14 @@ def set_dt(self, dt: Union[int, float]): def load(self, key, value: Any = None, desc: str = None): """Load the shared data by the ``key``. - Args: - key (str): the key to indicate the data. - value (Any): the default value when ``key`` is not defined in the shared. - desc: (str): the description of the key. + Parameters + ---------- + key : str + the key to indicate the data. + value : Any + the default value when ``key`` is not defined in the shared. + desc : str + the description of the key. """ return brainstate.environ.get(key, value, desc, env=env) diff --git a/brainpy/delay.py b/brainpy/delay.py index 04270253c..29e14b2ce 100644 --- a/brainpy/delay.py +++ b/brainpy/delay.py @@ -59,12 +59,18 @@ def _get_delay(delay_time, delay_step): class Delay(DynamicalSystem, ParamDesc): """Base class for delay variables. - Args: - time: The delay time. - init: The initial delay data. - method: The delay method. Can be ``rotation`` and ``concat``. - name: The delay name. - mode: The computing mode. + Parameters + ---------- + time + The delay time. + init + The initial delay data. + method + The delay method. Can be ``rotation`` and ``concat``. + name + The delay name. + mode + The computing mode. """ max_time: float @@ -127,35 +133,44 @@ def register_entry( ) -> 'Delay': """Register an entry to access the data. - Args: - entry: str. The entry to access the delay data. - delay_time: The delay time of the entry (can be a float). - delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``. - - Returns: - Return the self. + Parameters + ---------- + entry : str + The entry to access the delay data. + delay_time + The delay time of the entry (can be a float). + delay_step + The delay step of the entry (must be an int). ``delay_step = delay_time / dt``. + + Returns + ------- + Return the self. """ raise NotImplementedError def at(self, entry: str, *indices) -> bm.Array: """Get the data at the given entry. - Args: - entry: str. The entry to access the data. - *indices: The slicing indices. + Parameters + ---------- + entry : str + The entry to access the data. + *indices + The slicing indices. - Returns: - The data. + Returns + ------- + The data. """ raise NotImplementedError def retrieve(self, delay_step, *indices): """Retrieve the delay data according to the delay length. - Parameters:: - - delay_step: int, ArrayType - The delay length used to retrieve the data. + Parameters + ---------- + delay_step : int, ArrayType + The delay length used to retrieve the data. """ raise NotImplementedError() @@ -185,10 +200,14 @@ class VarDelay(Delay): delay = length-1 data delay = length data ] - Args: - target: Variable. The delay target. - time: int, float. The delay time. - init: Any. The delay data. It can be a Python number, like float, int, boolean values. + Parameters + ---------- + target : Variable + The delay target. + time : int, float + The delay time. + init : Any + The delay data. It can be a Python number, like float, int, boolean values. It can also be arrays. Or a callable function or instance of ``Connector``. Note that ``initial_delay_data`` should be arranged as the following way:: @@ -198,10 +217,14 @@ class VarDelay(Delay): ... .... delay = length-1 data delay = length data ] - entries: optional, dict. The delay access entries. - name: str. The delay name. - method: str. The method used for updating delay. Default None. - mode: Mode. The computing mode. Default None. + entries : optional, dict + The delay access entries. + name : str + The delay name. + method : str + The method used for updating delay. Default None. + mode : Mode + The computing mode. Default None. """ @@ -272,13 +295,18 @@ def register_entry( ) -> 'Delay': """Register an entry to access the data. - Args: - entry: str. The entry to access the delay data. - delay_time: The delay time of the entry (can be a float). - delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``. - - Returns: - Return the self. + Parameters + ---------- + entry : str + The entry to access the delay data. + delay_time + The delay time of the entry (can be a float). + delay_step + The delay step of the entry (must be an int). ``delat_step = delay_time / dt``. + + Returns + ------- + Return the self. """ if entry in self._registered_entries: raise KeyError(f'Entry {entry} has been registered. ' @@ -304,12 +332,16 @@ def register_entry( def at(self, entry: str, *indices) -> bm.Array: """Get the data at the given entry. - Args: - entry: str. The entry to access the data. - *indices: The slicing indices. Not include the slice at the batch dimension. + Parameters + ---------- + entry : str + The entry to access the data. + *indices + The slicing indices. Not include the slice at the batch dimension. - Returns: - The data. + Returns + ------- + The data. """ assert isinstance(entry, str), 'entry should be a string for describing the ' if entry not in self._registered_entries: @@ -345,10 +377,10 @@ def _check_delay(self, delay_len): def retrieve(self, delay_step, *indices): """Retrieve the delay data according to the delay length. - Parameters:: - - delay_step: int, Array - The delay length used to retrieve the data. + Parameters + ---------- + delay_step : int, Array + The delay length used to retrieve the data. """ assert self.data is not None assert delay_step is not None @@ -520,12 +552,16 @@ def reset_state(self, *args, **kwargs): def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_data=None) -> Delay: """Initialize a delay class by the return info (usually is created by ``.return_info()`` function). - Args: - info: the return information. - initial_delay_data: The initial delay data. + Parameters + ---------- + info + the return information. + initial_delay_data + The initial delay data. - Returns: - The decay instance. + Returns + ------- + The decay instance. """ if isinstance(info, bm.Variable): return VarDelay(info, init=initial_delay_data) @@ -568,11 +604,14 @@ def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_dat def register_delay_by_return(target: JointType[DynamicalSystem, SupportAutoDelay]): """Register delay class for the given target. - Args: - target: The target class to register delay. + Parameters + ---------- + target + The target class to register delay. - Returns: - The delay registered for the given target. + Returns + ------- + The delay registered for the given target. """ if not target.has_aft_update(delay_identifier): delay_ins = init_delay_by_return(target.return_info()) diff --git a/brainpy/dnn/activations.py b/brainpy/dnn/activations.py index 6f798fad5..f554ee28c 100644 --- a/brainpy/dnn/activations.py +++ b/brainpy/dnn/activations.py @@ -47,16 +47,21 @@ class Threshold(Layer): \text{value}, &\text{ otherwise } \end{cases} - Args: - threshold: The value to threshold at - value: The value to replace with - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + threshold : float + The value to threshold at + value : float + The value to replace with + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -92,14 +97,17 @@ class ReLU(Layer): :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)` - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -152,16 +160,21 @@ class RReLU(Layer): See: https://arxiv.org/pdf/1505.00853.pdf - Args: - lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` - upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + lower : float + lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` + upper : float + upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -210,10 +223,14 @@ class Hardtanh(Layer): x & \text{ otherwise } \\ \end{cases} - Args: - min_val: minimum value of the linear region range. Default: -1 - max_val: maximum value of the linear region range. Default: 1 - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + min_val : float + minimum value of the linear region range. Default: -1 + max_val : float + maximum value of the linear region range. Default: 1 + inplace : bool + can optionally do the operation in-place. Default: ``False`` Keyword arguments :attr:`min_value` and :attr:`max_value` have been deprecated in favor of :attr:`min_val` and :attr:`max_val`. @@ -222,7 +239,8 @@ class Hardtanh(Layer): - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -265,14 +283,17 @@ class ReLU6(Hardtanh): .. math:: \text{ReLU6}(x) = \min(\max(0,x), 6) - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -300,7 +321,8 @@ class Sigmoid(Layer): - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -325,14 +347,17 @@ class Hardsigmoid(Layer): x / 6 + 1 / 2 & \text{otherwise} \end{cases} - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -365,7 +390,8 @@ class Tanh(Layer): - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -392,14 +418,18 @@ class SiLU(Layer): in Reinforcement Learning `_ and `Swish: a Self-Gated Activation Function `_ where the SiLU was experimented with later. - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + + Parameters + ---------- + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -432,14 +462,17 @@ class Mish(Layer): .. note:: See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -475,14 +508,17 @@ class Hardswish(Layer): x \cdot (x + 3) /6 & \text{otherwise} \end{cases} - Args: - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -515,15 +551,19 @@ class ELU(Layer): \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 \end{cases} - Args: - alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + alpha : float + the :math:`\alpha` value for the ELU formulation. Default: 1.0 + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -556,15 +596,19 @@ class CELU(Layer): More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ . - Args: - alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + alpha : float + the :math:`\alpha` value for the CELU formulation. Default: 1.0 + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -603,14 +647,17 @@ class SELU(Layer): More details can be found in the paper `Self-Normalizing Neural Networks`_ . - Args: - inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + inplace : bool, optional + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -640,15 +687,18 @@ class GLU(Layer): :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half of the input matrices and :math:`b` is the second half. - Args: - dim (int): the dimension on which to split the input. Default: -1 + Parameters + ---------- + dim : int + the dimension on which to split the input. Default: -1 Shape: - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -681,15 +731,18 @@ class GELU(Layer): .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) - Args: - approximate (str, optional): the gelu approximation algorithm to use: - ``'none'`` | ``'tanh'``. Default: ``'none'`` + Parameters + ---------- + approximate : str, optional + the gelu approximation algorithm to use: + ``'none'`` | ``'tanh'``. Default: ``'none'`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -724,14 +777,17 @@ class Hardshrink(Layer): 0, & \text{ otherwise } \end{cases} - Args: - lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 + Parameters + ---------- + lambd : float + the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -769,17 +825,21 @@ class LeakyReLU(Layer): \text{negative\_slope} \times x, & \text{ otherwise } \end{cases} - Args: - negative_slope: Controls the angle of the negative slope (which is used for - negative input values). Default: 1e-2 - inplace: can optionally do the operation in-place. Default: ``False`` + Parameters + ---------- + negative_slope : float + Controls the angle of the negative slope (which is used for + negative input values). Default: 1e-2 + inplace : bool + can optionally do the operation in-place. Default: ``False`` Shape: - Input: :math:`(*)` where `*` means, any number of additional dimensions - Output: :math:`(*)`, same shape as the input - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -814,7 +874,8 @@ class LogSigmoid(Layer): - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -837,15 +898,19 @@ class Softplus(Layer): For numerical stability the implementation reverts to the linear function when :math:`input \times \beta > threshold`. - Args: - beta: the :math:`\beta` value for the Softplus formulation. Default: 1 - threshold: values above this revert to a linear function. Default: 20 + Parameters + ---------- + beta : float + the :math:`\beta` value for the Softplus formulation. Default: 1 + threshold : float + values above this revert to a linear function. Default: 20 Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -880,14 +945,17 @@ class Softshrink(Layer): 0, & \text{ otherwise } \end{cases} - Args: - lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 + Parameters + ---------- + lambd : float + the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -936,21 +1004,26 @@ class PReLU(Layer): Channel dim is the 2nd dim of input. When input has dims < 2, then there is no channel dim and the number of channels = 1. - Args: - num_parameters (int): number of :math:`a` to learn. - Although it takes an int as input, there is only two values are legitimate: - 1, or the number of channels at input. Default: 1 - init (float): the initial value of :math:`a`. Default: 0.25 + Parameters + ---------- + num_parameters : int + number of :math:`a` to learn. + Although it takes an int as input, there is only two values are legitimate: + 1, or the number of channels at input. Default: 1 + init : float + the initial value of :math:`a`. Default: 0.25 Shape: - Input: :math:`( *)` where `*` means, any number of additional dimensions. - Output: :math:`(*)`, same shape as the input. - Attributes: + Attributes + ---------- weight (Tensor): the learnable weights of shape (:attr:`num_parameters`). - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -983,7 +1056,8 @@ class Softsign(Layer): - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -1006,7 +1080,8 @@ class Tanhshrink(Layer): - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -1034,15 +1109,19 @@ class Softmin(Layer): dimensions - Output: :math:`(*)`, same shape as the input - Args: - dim (int): A dimension along which Softmin will be computed (so every slice - along dim will sum to 1). + Parameters + ---------- + dim : int + A dimension along which Softmin will be computed (so every slice + along dim will sum to 1). - Returns: + Returns + ------- a Tensor of the same dimension and shape as the input, with values in the range [0, 1] - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -1082,20 +1161,24 @@ class Softmax(Layer): dimensions - Output: :math:`(*)`, same shape as the input - Returns: + Returns + ------- a Tensor of the same dimension and shape as the input with values in the range [0, 1] - Args: - dim (int): A dimension along which Softmax will be computed (so every slice - along dim will sum to 1). + Parameters + ---------- + dim : int + A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). .. note:: This module doesn't work directly with NLLLoss, which expects the Log to be computed between the Softmax and itself. Use `LogSoftmax` instead (it's faster and has better numerical properties). - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -1128,11 +1211,13 @@ class Softmax2d(Layer): - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input) - Returns: + Returns + ------- a Tensor of the same dimension and shape as the input with values in the range [0, 1] - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -1159,14 +1244,18 @@ class LogSoftmax(Layer): dimensions - Output: :math:`(*)`, same shape as the input - Args: - dim (int): A dimension along which LogSoftmax will be computed. + Parameters + ---------- + dim : int + A dimension along which LogSoftmax will be computed. - Returns: + Returns + ------- a Tensor of the same dimension and shape as the input with values in the range [-inf, 0) - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm diff --git a/brainpy/dnn/conv.py b/brainpy/dnn/conv.py index af4ae41df..a01bd84cb 100644 --- a/brainpy/dnn/conv.py +++ b/brainpy/dnn/conv.py @@ -53,46 +53,46 @@ def to_dimension_numbers(num_spatial_dims: int, class _GeneralConv(Layer): """Apply a convolution to the inputs. - Parameters:: - - num_spatial_dims: int - The number of spatial dimensions of the input. - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - kernel_size: int, sequence of int - The shape of the convolutional kernel. - For 1D convolution, the kernel size can be passed as an integer. - For all other cases, it must be a sequence of integers. - stride: int, sequence of int - An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, int, sequence of int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. - lhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - rhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: int - If specified, divides the input features into groups. default 1. - w_initializer: Callable, ArrayType, Initializer - The initializer for the convolutional kernel. - b_initializer: Optional, Callable, ArrayType, Initializer - The initializer for the bias. - mask: ArrayType, Optional - The optional mask of the weights. - mode: Mode - The computation mode of the current object. Default it is `training`. - name: str, Optional - The name of the object. + Parameters + ---------- + num_spatial_dims : int + The number of spatial dimensions of the input. + in_channels : int + The number of input channels. + out_channels : int + The number of output channels. + kernel_size : int, sequence of int + The shape of the convolutional kernel. + For 1D convolution, the kernel size can be passed as an integer. + For all other cases, it must be a sequence of integers. + stride : int, sequence of int + An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding : str, int, sequence of int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. + lhs_dilation : int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + rhs_dilation : int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups : int + If specified, divides the input features into groups. default 1. + w_initializer : Callable, ArrayType, Initializer + The initializer for the convolutional kernel. + b_initializer : Optional, Callable, ArrayType, Initializer + The initializer for the bias. + mask : ArrayType, Optional + The optional mask of the weights. + mode : Mode + The computation mode of the current object. Default it is `training`. + name : str, Optional + The name of the object. """ supported_modes = (bm.TrainingMode, bm.BatchingMode, bm.NonBatchingMode) @@ -207,44 +207,44 @@ class Conv1d(_GeneralConv): The input should a 2d array with the shape of ``[H, C]``, or a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. - Parameters:: - - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - kernel_size: int, sequence of int - The shape of the convolutional kernel. - For 1D convolution, the kernel size can be passed as an integer. - For all other cases, it must be a sequence of integers. - strides: int, sequence of int - An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, int, sequence of int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. - lhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - rhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: int - If specified, divides the input features into groups. default 1. - w_initializer: Callable, ArrayType, Initializer - The initializer for the convolutional kernel. - b_initializer: Callable, ArrayType, Initializer - The initializer for the bias. - mask: ArrayType, Optional - The optional mask of the weights. - mode: Mode - The computation mode of the current object. Default it is `training`. - name: str, Optional - The name of the object. + Parameters + ---------- + in_channels : int + The number of input channels. + out_channels : int + The number of output channels. + kernel_size : int, sequence of int + The shape of the convolutional kernel. + For 1D convolution, the kernel size can be passed as an integer. + For all other cases, it must be a sequence of integers. + strides : int, sequence of int + An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding : str, int, sequence of int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. + lhs_dilation : int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + rhs_dilation : int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups : int + If specified, divides the input features into groups. default 1. + w_initializer : Callable, ArrayType, Initializer + The initializer for the convolutional kernel. + b_initializer : Callable, ArrayType, Initializer + The initializer for the bias. + mask : ArrayType, Optional + The optional mask of the weights. + mode : Mode + The computation mode of the current object. Default it is `training`. + name : str, Optional + The name of the object. """ def __init__( @@ -302,44 +302,44 @@ class Conv2d(_GeneralConv): The input should a 3d array with the shape of ``[H, W, C]``, or a 4d array with the shape of ``[B, H, W, C]``. - Parameters:: - - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - kernel_size: int, sequence of int - The shape of the convolutional kernel. - For 1D convolution, the kernel size can be passed as an integer. - For all other cases, it must be a sequence of integers. - stride: int, sequence of int - An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, int, sequence of int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. - lhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - rhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: int - If specified, divides the input features into groups. default 1. - w_initializer: Callable, ArrayType, Initializer - The initializer for the convolutional kernel. - b_initializer: Callable, ArrayType, Initializer - The initializer for the bias. - mask: ArrayType, Optional - The optional mask of the weights. - mode: Mode - The computation mode of the current object. Default it is `training`. - name: str, Optional - The name of the object. + Parameters + ---------- + in_channels : int + The number of input channels. + out_channels : int + The number of output channels. + kernel_size : int, sequence of int + The shape of the convolutional kernel. + For 1D convolution, the kernel size can be passed as an integer. + For all other cases, it must be a sequence of integers. + stride : int, sequence of int + An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding : str, int, sequence of int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. + lhs_dilation : int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + rhs_dilation : int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups : int + If specified, divides the input features into groups. default 1. + w_initializer : Callable, ArrayType, Initializer + The initializer for the convolutional kernel. + b_initializer : Callable, ArrayType, Initializer + The initializer for the bias. + mask : ArrayType, Optional + The optional mask of the weights. + mode : Mode + The computation mode of the current object. Default it is `training`. + name : str, Optional + The name of the object. """ @@ -398,44 +398,44 @@ class Conv3d(_GeneralConv): The input should a 3d array with the shape of ``[H, W, D, C]``, or a 4d array with the shape of ``[B, H, W, D, C]``. - Parameters:: - - in_channels: int - The number of input channels. - out_channels: int - The number of output channels. - kernel_size: int, sequence of int - The shape of the convolutional kernel. - For 1D convolution, the kernel size can be passed as an integer. - For all other cases, it must be a sequence of integers. - stride: int, sequence of int - An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). - padding: str, int, sequence of int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, - high)` integer pairs that give the padding to apply before and after each - spatial dimension. - lhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of `inputs` - (default: 1). Convolution with input dilation `d` is equivalent to - transposed convolution with stride `d`. - rhs_dilation: int, sequence of int - An integer or a sequence of `n` integers, giving the - dilation factor to apply in each spatial dimension of the convolution - kernel (default: 1). Convolution with kernel dilation - is also known as 'atrous convolution'. - groups: int - If specified, divides the input features into groups. default 1. - w_initializer: Callable, ArrayType, Initializer - The initializer for the convolutional kernel. - b_initializer: Callable, ArrayType, Initializer - The initializer for the bias. - mask: ArrayType, Optional - The optional mask of the weights. - mode: Mode - The computation mode of the current object. Default it is `training`. - name: str, Optional - The name of the object. + Parameters + ---------- + in_channels : int + The number of input channels. + out_channels : int + The number of output channels. + kernel_size : int, sequence of int + The shape of the convolutional kernel. + For 1D convolution, the kernel size can be passed as an integer. + For all other cases, it must be a sequence of integers. + stride : int, sequence of int + An integer or a sequence of `n` integers, representing the inter-window strides (default: 1). + padding : str, int, sequence of int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low, + high)` integer pairs that give the padding to apply before and after each + spatial dimension. + lhs_dilation : int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of `inputs` + (default: 1). Convolution with input dilation `d` is equivalent to + transposed convolution with stride `d`. + rhs_dilation : int, sequence of int + An integer or a sequence of `n` integers, giving the + dilation factor to apply in each spatial dimension of the convolution + kernel (default: 1). Convolution with kernel dilation + is also known as 'atrous convolution'. + groups : int + If specified, divides the input features into groups. default 1. + w_initializer : Callable, ArrayType, Initializer + The initializer for the convolutional kernel. + b_initializer : Callable, ArrayType, Initializer + The initializer for the bias. + mask : ArrayType, Optional + The optional mask of the weights. + mode : Mode + The computation mode of the current object. Default it is `training`. + name : str, Optional + The name of the object. """ @@ -606,20 +606,30 @@ def __init__( ): """Initializes the module. - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - kernel_size: The shape of the kernel. Either an integer or a sequence of + Parameters + ---------- + in_channels + Number of input channels. + out_channels + Number of output channels. + kernel_size + The shape of the kernel. Either an integer or a sequence of length 1. - stride: Optional stride for the kernel. Either an integer or a sequence of + stride + Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1. - padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. + padding + Optional padding algorithm. Either ``VALID`` or ``SAME``. Defaults to ``SAME``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - w_initializer: Optional weight initialization. By default, truncated normal. - b_initializer: Optional bias initialization. By default, zeros. - mask: Optional mask of the weights. - name: The name of the module. + w_initializer + Optional weight initialization. By default, truncated normal. + b_initializer + Optional bias initialization. By default, zeros. + mask + Optional mask of the weights. + name + The name of the module. """ super().__init__( num_spatial_dims=1, @@ -663,20 +673,30 @@ def __init__( ): """Initializes the module. - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - kernel_size: The shape of the kernel. Either an integer or a sequence of + Parameters + ---------- + in_channels + Number of input channels. + out_channels + Number of output channels. + kernel_size + The shape of the kernel. Either an integer or a sequence of length 2. - stride: Optional stride for the kernel. Either an integer or a sequence of + stride + Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1. - padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. + padding + Optional padding algorithm. Either ``VALID`` or ``SAME``. Defaults to ``SAME``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - w_initializer: Optional weight initialization. By default, truncated normal. - b_initializer: Optional bias initialization. By default, zeros. - mask: Optional mask of the weights. - name: The name of the module. + w_initializer + Optional weight initialization. By default, truncated normal. + b_initializer + Optional bias initialization. By default, zeros. + mask + Optional mask of the weights. + name + The name of the module. """ super().__init__( num_spatial_dims=2, @@ -720,20 +740,30 @@ def __init__( ): """Initializes the module. - Args: - in_channels: Number of input channels. - out_channels: Number of output channels. - kernel_size: The shape of the kernel. Either an integer or a sequence of + Parameters + ---------- + in_channels + Number of input channels. + out_channels + Number of output channels. + kernel_size + The shape of the kernel. Either an integer or a sequence of length 3. - stride: Optional stride for the kernel. Either an integer or a sequence of + stride + Optional stride for the kernel. Either an integer or a sequence of length 3. Defaults to 1. - padding: Optional padding algorithm. Either ``VALID`` or ``SAME``. + padding + Optional padding algorithm. Either ``VALID`` or ``SAME``. Defaults to ``SAME``. See: https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - w_initializer: Optional weight initialization. By default, truncated normal. - b_initializer: Optional bias initialization. By default, zeros. - mask: Optional mask of the weights. - name: The name of the module. + w_initializer + Optional weight initialization. By default, truncated normal. + b_initializer + Optional bias initialization. By default, zeros. + mask + Optional mask of the weights. + name + The name of the module. """ super().__init__( num_spatial_dims=3, diff --git a/brainpy/dnn/dropout.py b/brainpy/dnn/dropout.py index 1b61eae73..a3c06c71d 100644 --- a/brainpy/dnn/dropout.py +++ b/brainpy/dnn/dropout.py @@ -33,15 +33,21 @@ class Dropout(Layer): This layer is active only during training (``mode=brainpy.math.training_mode``). In other circumstances it is a no-op. + Parameters + ---------- + prob : float + Probability to keep element of the tensor. + mode : Mode + The computation mode of the object. + name : str + The name of the dynamic system. + + References + ---------- .. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. - Args: - prob: Probability to keep element of the tensor. - mode: Mode. The computation mode of the object. - name: str. The name of the dynamic system. - """ def __init__( diff --git a/brainpy/dnn/function.py b/brainpy/dnn/function.py index 321a51900..c5c10a57d 100644 --- a/brainpy/dnn/function.py +++ b/brainpy/dnn/function.py @@ -29,14 +29,14 @@ class Activation(Layer): r"""Applies an activation function to the inputs - Parameters: + Parameters ---------- - activate_fun: Callable, function - The function of Activation - name: str, Optional - The name of the object - mode: Mode - Enable training this node or not. (default True). + activate_fun : Callable, function + The function of Activation + name : str, Optional + The name of the object + mode : Mode + Enable training this node or not. (default True). """ update_style = 'x' @@ -65,13 +65,19 @@ class Flatten(Layer): number of dimensions including none. - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. - Args: - start_dim: first dim to flatten (default = 1). - end_dim: last dim to flatten (default = -1). - name: str, Optional. The name of the object. - mode: Mode. Enable training this node or not. (default True). - - Examples:: + Parameters + ---------- + start_dim : int + first dim to flatten (default = 1). + end_dim : int + last dim to flatten (default = -1). + name : str, Optional + The name of the object. + mode : Mode + Enable training this node or not. (default True). + + Examples + -------- >>> import brainpy.math as bm >>> inp = bm.random.randn(32, 1, 5, 5) >>> # With default parameters @@ -126,11 +132,15 @@ class Unflatten(Layer): - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. - Args: - dim: int, Dimension to be unflattened. - sizes: Sequence of int. New shape of the unflattened dimension. + Parameters + ---------- + dim : int + Dimension to be unflattened. + sizes : Sequence of int + New shape of the unflattened dimension. - Examples: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm >>> input = bm.random.randn(2, 50) diff --git a/brainpy/dnn/interoperation_flax.py b/brainpy/dnn/interoperation_flax.py index 6c50a5e14..ce2e986be 100644 --- a/brainpy/dnn/interoperation_flax.py +++ b/brainpy/dnn/interoperation_flax.py @@ -52,14 +52,14 @@ class FromFlax(Layer): """ Transform a Flax module as a BrainPy :py:class:`~.DynamicalSystem`. - Parameters:: - - flax_module: Any - The flax Module. - module_args: Any - The module arguments, used to initialize model parameters. - module_kwargs: Any - The module arguments, used to initialize model parameters. + Parameters + ---------- + flax_module : Any + The flax Module. + module_args : Any + The module arguments, used to initialize model parameters. + module_kwargs : Any + The module arguments, used to initialize model parameters. """ def __init__(self, flax_module, *module_args, **module_kwargs): @@ -110,14 +110,18 @@ def setup(self): def __call__(self, carry, *inputs): """A recurrent cell that transformed from a BrainPy :py:class:`~.DynamicalSystem`. - Args: - carry: the hidden state of the transformed recurrent cell, initialized using + Parameters + ---------- + carry + the hidden state of the transformed recurrent cell, initialized using `.initialize_carry()` function in which the original `.reset_state()` is called. - inputs: an ndarray with the input for the current time step. All + inputs + an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions. - Returns: - A tuple with the new carry and the output. + Returns + ------- + A tuple with the new carry and the output. """ # shared arguments i, t = carry[1], carry[2] diff --git a/brainpy/dnn/linear.py b/brainpy/dnn/linear.py index d8af1cbe3..05645aaf4 100644 --- a/brainpy/dnn/linear.py +++ b/brainpy/dnn/linear.py @@ -58,18 +58,18 @@ class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline): y = x \cdot weight + b - Parameters:: - - num_in: int - The number of the input feature. A positive integer. - num_out: int - The number of the output features. A positive integer. - W_initializer: optional, Initializer - The weight initialization. - b_initializer: optional, Initializer - The bias initialization. - mode: Mode - Enable training this node or not. (default True) + Parameters + ---------- + num_in : int + The number of the input feature. A positive integer. + num_out : int + The number of the output features. A positive integer. + W_initializer : optional, Initializer + The weight initialization. + b_initializer : optional, Initializer + The bias initialization. + mode : Mode + Enable training this node or not. (default True) """ def __init__( @@ -267,14 +267,22 @@ def update(self, x): class AllToAll(Layer, SupportSTDP): """Synaptic matrix multiplication with All2All connections. - Args: - num_pre: int. The number of neurons in the presynaptic neuron group. - num_post: int. The number of neurons in the postsynaptic neuron group. - weight: The synaptic weights. - sharding: The sharding strategy. - include_self: bool. Whether connect the neuron with at the same position. - mode: Mode. The computing mode. - name: str. The object name. + Parameters + ---------- + num_pre : int + The number of neurons in the presynaptic neuron group. + num_post : int + The number of neurons in the postsynaptic neuron group. + weight : Union[float, ArrayType, Callable] + The synaptic weights. + sharding : Optional[Sharding] + The sharding strategy. + include_self : bool + Whether connect the neuron with at the same position. + mode : Mode + The computing mode. + name : str + The object name. """ def __init__( @@ -353,12 +361,18 @@ def stdp_update( class OneToOne(Layer, SupportSTDP): """Synaptic matrix multiplication with One2One connection. - Args: - num: int. The number of neurons. - weight: The synaptic weight. - sharding: The sharding strategy. - mode: The computing mode. - name: The object name. + Parameters + ---------- + num : int + The number of neurons. + weight : Union[float, ArrayType, Callable] + The synaptic weight. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The computing mode. + name : Optional[str] + The object name. """ @@ -423,13 +437,20 @@ class MaskedLinear(Layer, SupportSTDP): >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100), >>> weight=0.1) - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - mask_fun: Masking function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + conn : TwoEndConnector + The connection. + weight : Union[float, ArrayType, Callable] + Synaptic weights. Can be a scalar, array, or callable function. + mask_fun : Callable + Masking function. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -562,12 +583,18 @@ class CSRLinear(_CSRLayer): where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a CSR sparse matrix. - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + conn : TwoEndConnector + The connection. + weight : Union[float, ArrayType, Callable] + Synaptic weights. Can be a scalar, array, or callable function. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -612,12 +639,18 @@ class EventCSRLinear(_CSRLayer): where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, :math:`M` the synaptic weight using a CSR sparse matrix. - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + conn : TwoEndConnector + The connection. + weight : Union[float, ArrayType, Callable] + Synaptic weights. Can be a scalar, array, or callable function. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -662,12 +695,18 @@ class CSCLinear(Layer): where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a CSC sparse matrix. - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + conn : TwoEndConnector + The connection. + weight : Union[float, ArrayType, Callable] + Synaptic weights. Can be a scalar, array, or callable function. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -697,12 +736,18 @@ class BcsrMM(Layer): where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a BCSR sparse matrix. - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + conn : TwoEndConnector + The connection. + weight : Union[float, ArrayType, Callable] + Synaptic weights. Can be a scalar, array, or callable function. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -732,12 +777,18 @@ class BcscMM(Layer): where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, :math:`M` the synaptic weight using a BCSC sparse matrix. - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + conn : TwoEndConnector + The connection. + weight : Union[float, ArrayType, Callable] + Synaptic weights. Can be a scalar, array, or callable function. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -798,18 +849,29 @@ class JitFPHomoLinear(JitFPHomoLayer): Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is the same :math:`weight`. - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - weight: float. The synaptic value at each position. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + num_in : int + The number of the input feature. A positive integer. + num_out : int + The number of the input feature. A positive integer. + prob : float + The connectivity probability. + weight : float + The synaptic value at each position. + seed : int + The random seed used to keep the reproducibility of the connectivity. + transpose : bool + Transpose the JIT matrix or not. Default False. + atomic : bool + Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -877,19 +939,31 @@ class JitFPUniformLinear(JitFPUniformLayer): Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_low: float. The lowest value of the uniform distribution. - w_high: float. The highest value of the uniform distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + num_in : int + The number of the input feature. A positive integer. + num_out : int + The number of the input feature. A positive integer. + prob : float + The connectivity probability. + w_low : float + The lowest value of the uniform distribution. + w_high : float + The highest value of the uniform distribution. + seed : int + The random seed used to keep the reproducibility of the connectivity. + transpose : bool + Transpose the JIT matrix or not. Default False. + atomic : bool + Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -957,19 +1031,31 @@ class JitFPNormalLinear(JitFPNormalLayer): Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_mu: float. The center of the normal distribution. - w_sigma: float. The standard variance of the normal distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + num_in : int + The number of the input feature. A positive integer. + num_out : int + The number of the input feature. A positive integer. + prob : float + The connectivity probability. + w_mu : float + The center of the normal distribution. + w_sigma : float + The standard variance of the normal distribution. + seed : int + The random seed used to keep the reproducibility of the connectivity. + transpose : bool + Transpose the JIT matrix or not. Default False. + atomic : bool + Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -1037,18 +1123,29 @@ class EventJitFPHomoLinear(JitFPHomoLayer): Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is the same :math:`weight`. - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - weight: float. The synaptic value at each position. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + num_in : int + The number of the input feature. A positive integer. + num_out : int + The number of the input feature. A positive integer. + prob : float + The connectivity probability. + weight : float + The synaptic value at each position. + seed : int + The random seed used to keep the reproducibility of the connectivity. + transpose : bool + Transpose the JIT matrix or not. Default False. + atomic : bool + Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -1116,19 +1213,31 @@ class EventJitFPUniformLinear(JitFPUniformLayer): Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_low: float. The lowest value of the uniform distribution. - w_high: float. The highest value of the uniform distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + num_in : int + The number of the input feature. A positive integer. + num_out : int + The number of the input feature. A positive integer. + prob : float + The connectivity probability. + w_low : float + The lowest value of the uniform distribution. + w_high : float + The highest value of the uniform distribution. + seed : int + The random seed used to keep the reproducibility of the connectivity. + transpose : bool + Transpose the JIT matrix or not. Default False. + atomic : bool + Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( @@ -1196,19 +1305,31 @@ class EventJitFPNormalLinear(JitFPNormalLayer): Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_mu: float. The center of the normal distribution. - w_sigma: float. The standard variance of the normal distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. + Parameters + ---------- + num_in : int + The number of the input feature. A positive integer. + num_out : int + The number of the input feature. A positive integer. + prob : float + The connectivity probability. + w_mu : float + The center of the normal distribution. + w_sigma : float + The standard variance of the normal distribution. + seed : int + The random seed used to keep the reproducibility of the connectivity. + transpose : bool + Transpose the JIT matrix or not. Default False. + atomic : bool + Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding : Optional[Sharding] + The sharding strategy. + mode : Optional[bm.Mode] + The synaptic computing mode. + name : Optional[str] + The synapse model name. """ def __init__( diff --git a/brainpy/dnn/normalization.py b/brainpy/dnn/normalization.py index d0c5bd640..6f5b3b1d6 100644 --- a/brainpy/dnn/normalization.py +++ b/brainpy/dnn/normalization.py @@ -65,34 +65,36 @@ class BatchNorm(Layer): where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. - Parameters:: + Parameters + ---------- - num_features: int + num_features : int ``C`` from an expected input of size ``(..., C)``. - axis: int, tuple, list + axis : int, tuple, list Axes where the data will be normalized. The feature (channel) axis should be excluded. - momentum: float + momentum : float The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 - epsilon: float + epsilon : float A value added to the denominator for numerical stability. Default: 1e-5 - affine: bool + affine : bool A boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` - bias_initializer: Initializer, ArrayType, Callable + bias_initializer : Initializer, ArrayType, Callable An initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable + scale_initializer : Initializer, ArrayType, Callable An initializer generating the original scaling matrix - axis_name: optional, str, sequence of str + axis_name : optional, str, sequence of str If not ``None``, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this argument means that batch statistics are calculated across all replicas on the named axes. - axis_index_groups: optional, sequence + axis_index_groups : optional, sequence Specifies how devices are grouped. Valid only within ``jax.pmap`` collectives. - References:: + References + ---------- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. @@ -218,34 +220,36 @@ class BatchNorm1d(BatchNorm): where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. - Parameters:: + Parameters + ---------- - num_features: int + num_features : int ``C`` from an expected input of size ``(B, L, C)``. - axis: int, tuple, list + axis : int, tuple, list axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float + epsilon : float A value added to the denominator for numerical stability. Default: 1e-5 - momentum: float + momentum : float The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 - affine: bool + affine : bool A boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` - bias_initializer: Initializer, ArrayType, Callable + bias_initializer : Initializer, ArrayType, Callable an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable + scale_initializer : Initializer, ArrayType, Callable an initializer generating the original scaling matrix - axis_name: optional, str, sequence of str + axis_name : optional, str, sequence of str If not ``None``, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this argument means that batch statistics are calculated across all replicas on the named axes. - axis_index_groups: optional, sequence + axis_index_groups : optional, sequence Specifies how devices are grouped. Valid only within ``jax.pmap`` collectives. - References:: + References + ---------- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. @@ -301,34 +305,36 @@ class BatchNorm2d(BatchNorm): where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. - Parameters:: + Parameters + ---------- - num_features: int + num_features : int ``C`` from an expected input of size ``(B, H, W, C)``. - axis: int, tuple, list + axis : int, tuple, list axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float + epsilon : float a value added to the denominator for numerical stability. Default: 1e-5 - momentum: float + momentum : float The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 - affine: bool + affine : bool A boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` - bias_initializer: Initializer, ArrayType, Callable + bias_initializer : Initializer, ArrayType, Callable an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable + scale_initializer : Initializer, ArrayType, Callable an initializer generating the original scaling matrix - axis_name: optional, str, sequence of str + axis_name : optional, str, sequence of str If not ``None``, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this argument means that batch statistics are calculated across all replicas on the named axes. - axis_index_groups: optional, sequence + axis_index_groups : optional, sequence Specifies how devices are grouped. Valid only within ``jax.pmap`` collectives. - References:: + References + ---------- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. @@ -384,34 +390,36 @@ class BatchNorm3d(BatchNorm): where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. - Parameters:: + Parameters + ---------- - num_features: int + num_features : int ``C`` from an expected input of size ``(B, H, W, D, C)``. - axis: int, tuple, list + axis : int, tuple, list axes where the data will be normalized. The feature (channel) axis should be excluded. - epsilon: float + epsilon : float a value added to the denominator for numerical stability. Default: 1e-5 - momentum: float + momentum : float The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99 - affine: bool + affine : bool A boolean value that when set to ``True``, this module has learnable affine parameters. Default: ``True`` - bias_initializer: Initializer, ArrayType, Callable + bias_initializer : Initializer, ArrayType, Callable an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable + scale_initializer : Initializer, ArrayType, Callable an initializer generating the original scaling matrix - axis_name: optional, str, sequence of str + axis_name : optional, str, sequence of str If not ``None``, it should be a string (or sequence of strings) representing the axis name(s) over which this module is being run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this argument means that batch statistics are calculated across all replicas on the named axes. - axis_index_groups: optional, sequence + axis_index_groups : optional, sequence Specifies how devices are grouped. Valid only within ``jax.pmap`` collectives. - References:: + References + ---------- .. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag. @@ -463,9 +471,10 @@ class LayerNorm(Layer): scale and bias to a whole example/whole channel, please use GroupNorm/ InstanceNorm. - Parameters:: + Parameters + ---------- - normalized_shape: int, sequence of int + normalized_shape : int, sequence of int The input shape from an expected input of size .. math:: @@ -474,18 +483,19 @@ class LayerNorm(Layer): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. - epsilon: float + epsilon : float a value added to the denominator for numerical stability. Default: 1e-5 - bias_initializer: Initializer, ArrayType, Callable + bias_initializer : Initializer, ArrayType, Callable an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable + scale_initializer : Initializer, ArrayType, Callable an initializer generating the original scaling matrix - elementwise_affine: bool + elementwise_affine : bool A boolean value that when set to ``True``, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -571,24 +581,26 @@ class GroupNorm(Layer): The shape of the data should be (b, d1, d2, ..., c), where `d` denotes the batch size and `c` denotes the feature (channel) size. - Parameters:: + Parameters + ---------- - num_groups: int + num_groups : int The number of groups. It should be a factor of the number of channels. - num_channels: int + num_channels : int The number of channels expected in input. - epsilon: float + epsilon : float a value added to the denominator for numerical stability. Default: 1e-5 - affine: bool + affine : bool A boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. - bias_initializer: Initializer, ArrayType, Callable + bias_initializer : Initializer, ArrayType, Callable An initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable + scale_initializer : Initializer, ArrayType, Callable An initializer generating the original scaling matrix - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -666,19 +678,20 @@ class InstanceNorm(GroupNorm): This layer normalizes the data within each feature. It can be regarded as a group normalization layer in which `group_size` equals to 1. - Parameters:: + Parameters + ---------- - num_channels: int + num_channels : int The number of channels expected in input. - epsilon: float + epsilon : float a value added to the denominator for numerical stability. Default: 1e-5 - affine: bool + affine : bool A boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) and zeros (for biases). Default: ``True``. - bias_initializer: Initializer, ArrayType, Callable + bias_initializer : Initializer, ArrayType, Callable an initializer generating the original translation matrix - scale_initializer: Initializer, ArrayType, Callable + scale_initializer : Initializer, ArrayType, Callable an initializer generating the original scaling matrix """ diff --git a/brainpy/dnn/pooling.py b/brainpy/dnn/pooling.py index e06749cc3..ba09538e5 100644 --- a/brainpy/dnn/pooling.py +++ b/brainpy/dnn/pooling.py @@ -44,23 +44,23 @@ class Pool(Layer): """Pooling functions are implemented using the ReduceWindow XLA op. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped, - used to infer ``kernel_size`` or ``stride`` if they are an integer. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped, + used to infer ``kernel_size`` or ``stride`` if they are an integer. + mode : Mode + The computation mode. + name : optional, str + The object name. """ @@ -148,23 +148,23 @@ def _infer_shape(self, class MaxPool(Pool): """Pools the input by taking the maximum over a window. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped, - used to infer ``kernel_size`` or ``stride`` if they are an integer. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped, + used to infer ``kernel_size`` or ``stride`` if they are an integer. + mode : Mode + The computation mode. + name : optional, str + The object name. """ @@ -190,23 +190,23 @@ def __init__( class MinPool(Pool): """Pools the input by taking the minimum over a window. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped, - used to infer ``kernel_size`` or ``stride`` if they are an integer. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped, + used to infer ``kernel_size`` or ``stride`` if they are an integer. + mode : Mode + The computation mode. + name : optional, str + The object name. """ @@ -232,23 +232,23 @@ def __init__( class AvgPool(Pool): """Pools the input by taking the average over a window. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped, - used to infer ``kernel_size`` or ``stride`` if they are an integer. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped, + used to infer ``kernel_size`` or ``stride`` if they are an integer. + mode : Mode + The computation mode. + name : optional, str + The object name. """ def __init__( @@ -409,23 +409,23 @@ class MaxPool1d(_MaxPoolNd): """Applies a 1D max pooling over an input signal composed of several input planes. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode : Mode + The computation mode. + name : optional, str + The object name. """ @@ -453,23 +453,23 @@ class MaxPool2d(_MaxPoolNd): """Applies a 1D max pooling over an input signal composed of several input planes. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode : Mode + The computation mode. + name : optional, str + The object name. """ @@ -496,23 +496,23 @@ class MaxPool3d(_MaxPoolNd): """Applies a 1D max pooling over an input signal composed of several input planes. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode : Mode + The computation mode. + name : optional, str + The object name. """ @@ -573,23 +573,23 @@ class AvgPool1d(_AvgPoolNd): """Applies a 1D average pooling over an input signal composed of several input planes. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode : Mode + The computation mode. + name : optional, str + The object name. """ @@ -617,23 +617,23 @@ class AvgPool2d(_AvgPoolNd): """Applies a 2D average pooling over an input signal composed of several input planes. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode : Mode + The computation mode. + name : optional, str + The object name. """ def __init__( @@ -660,23 +660,23 @@ class AvgPool3d(_AvgPoolNd): """Applies a 3D average pooling over an input signal composed of several input planes. - Parameters:: - - kernel_size: int, sequence of int - An integer, or a sequence of integers defining the window to reduce over. - stride: int, sequence of int - An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). - padding: str, int, sequence of tuple - Either the string `'SAME'`, the string `'VALID'`, or a sequence - of n `(low, high)` integer pairs that give the padding to apply before - and after each spatial dimension. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - mode: Mode - The computation mode. - name: optional, str - The object name. + Parameters + ---------- + kernel_size : int, sequence of int + An integer, or a sequence of integers defining the window to reduce over. + stride : int, sequence of int + An integer, or a sequence of integers, representing the inter-window stride (default: `(1, ..., 1)`). + padding : str, int, sequence of tuple + Either the string `'SAME'`, the string `'VALID'`, or a sequence + of n `(low, high)` integer pairs that give the padding to apply before + and after each spatial dimension. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + mode : Mode + The computation mode. + name : optional, str + The object name. """ @@ -703,13 +703,18 @@ def __init__( def _adaptive_pool1d(x, target_size: int, operation: Callable): """Adaptive pool 1D. - Args: - x: The input. Should be a JAX array of shape `(dim,)`. - target_size: The shape of the output after the pooling operation `(target_size,)`. - operation: The pooling operation to be performed on the input array. - - Returns: - A JAX array of shape `(target_size, )`. + Parameters + ---------- + x + The input. Should be a JAX array of shape `(dim,)`. + target_size : int + The shape of the output after the pooling operation `(target_size,)`. + operation : Callable + The pooling operation to be performed on the input array. + + Returns + ------- + A JAX array of shape `(target_size, )`. """ x = bm.as_jax(x) size = jnp.size(x) @@ -735,21 +740,21 @@ def _generate_vmap(fun: Callable, map_axes: List[int]): class AdaptivePool(Layer): """General N dimensional adaptive down-sampling to a target shape. - Parameters:: - - target_shape: int, sequence of int - The target output shape. - num_spatial_dims: int - The number of spatial dimensions. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - operation: Callable - The down-sampling operation. - name: str - The class name. - mode: Mode - The computing mode. + Parameters + ---------- + target_shape : int, sequence of int + The target output shape. + num_spatial_dims : int + The number of spatial dimensions. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + operation : Callable + The down-sampling operation. + name : str + The class name. + mode : Mode + The computing mode. """ def __init__( @@ -776,11 +781,11 @@ def __init__( def update(self, x): """Input-output mapping. - Parameters:: - - x: Array - Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)` - or `(..., dim_1, dim_2)`. + Parameters + ---------- + x : Array + Inputs. Should be a JAX array of shape `(..., dim_1, dim_2, channels)` + or `(..., dim_1, dim_2)`. """ x = bm.as_jax(x) @@ -815,17 +820,17 @@ def update(self, x): class AdaptiveAvgPool1d(AdaptivePool): """Adaptive one-dimensional average down-sampling. - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. + Parameters + ---------- + target_shape : int, sequence of int + The target output shape. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name : str + The class name. + mode : Mode + The computing mode. """ def __init__(self, @@ -845,17 +850,17 @@ class AdaptiveAvgPool2d(AdaptivePool): """Adaptive two-dimensional average down-sampling. - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. + Parameters + ---------- + target_shape : int, sequence of int + The target output shape. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name : str + The class name. + mode : Mode + The computing mode. """ def __init__(self, @@ -875,17 +880,17 @@ class AdaptiveAvgPool3d(AdaptivePool): """Adaptive three-dimensional average down-sampling. - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. + Parameters + ---------- + target_shape : int, sequence of int + The target output shape. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name : str + The class name. + mode : Mode + The computing mode. """ def __init__(self, @@ -904,17 +909,17 @@ def __init__(self, class AdaptiveMaxPool1d(AdaptivePool): """Adaptive one-dimensional maximum down-sampling. - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. + Parameters + ---------- + target_shape : int, sequence of int + The target output shape. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name : str + The class name. + mode : Mode + The computing mode. """ def __init__(self, @@ -933,17 +938,17 @@ def __init__(self, class AdaptiveMaxPool2d(AdaptivePool): """Adaptive two-dimensional maximum down-sampling. - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. + Parameters + ---------- + target_shape : int, sequence of int + The target output shape. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name : str + The class name. + mode : Mode + The computing mode. """ def __init__(self, @@ -962,17 +967,17 @@ def __init__(self, class AdaptiveMaxPool3d(AdaptivePool): """Adaptive three-dimensional maximum down-sampling. - Parameters:: - - target_shape: int, sequence of int - The target output shape. - channel_axis: int, optional - Axis of the spatial channels for which pooling is skipped. - If ``None``, there is no channel axis. - name: str - The class name. - mode: Mode - The computing mode. + Parameters + ---------- + target_shape : int, sequence of int + The target output shape. + channel_axis : int, optional + Axis of the spatial channels for which pooling is skipped. + If ``None``, there is no channel axis. + name : str + The class name. + mode : Mode + The computing mode. """ def __init__(self, diff --git a/brainpy/dyn/_docs.py b/brainpy/dyn/_docs.py index 214f46d32..aeb82d941 100644 --- a/brainpy/dyn/_docs.py +++ b/brainpy/dyn/_docs.py @@ -12,43 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -pneu_doc = ''' - size: int, or sequence of int. The neuronal population size. - sharding: The sharding strategy. - keep_size: bool. Keep the neuron group size. - mode: Mode. The computing mode. - name: str. The group name. +pneu_doc = ''' +size : int, or sequence of int + The neuronal population size. +sharding + The sharding strategy. +keep_size : bool + Keep the neuron group size. +mode : Mode + The computing mode. +name : str + The group name. '''.strip() dpneu_doc = ''' - spk_fun: callable. The spike activation function. - detach_spk: bool. - method: str. The numerical integration method. - spk_type: The spike data type. - spk_reset: The way to reset the membrane potential when the neuron generates spikes. - This parameter only works when the computing mode is ``TrainingMode``. - It can be ``soft`` and ``hard``. Default is ``soft``. +spk_fun : callable + The spike activation function. +detach_spk : bool +method : str + The numerical integration method. +spk_type + The spike data type. +spk_reset + The way to reset the membrane potential when the neuron generates spikes. + This parameter only works when the computing mode is ``TrainingMode``. + It can be ``soft`` and ``hard``. Default is ``soft``. '''.strip() ref_doc = ''' - tau_ref: float, ArrayType, callable. Refractory period length (ms). - has_ref_var: bool. Whether has the refractory variable. Default is ``False``. +tau_ref : float, ArrayType, callable + Refractory period length (ms). +has_ref_var : bool + Whether has the refractory variable. Default is ``False``. '''.strip() if_doc = ''' - V_rest: float, ArrayType, callable. Resting membrane potential. - R: float, ArrayType, callable. Membrane resistance. - tau: float, ArrayType, callable. Membrane time constant. - V_initializer: ArrayType, callable. The initializer of membrane potential. +V_rest : float, ArrayType, callable + Resting membrane potential. +R : float, ArrayType, callable + Membrane resistance. +tau : float, ArrayType, callable + Membrane time constant. +V_initializer : ArrayType, callable + The initializer of membrane potential. '''.strip() lif_doc = ''' - V_rest: float, ArrayType, callable. Resting membrane potential. - V_reset: float, ArrayType, callable. Reset potential after spike. - V_th: float, ArrayType, callable. Threshold potential of spike. - R: float, ArrayType, callable. Membrane resistance. - tau: float, ArrayType, callable. Membrane time constant. - V_initializer: ArrayType, callable. The initializer of membrane potential. +V_rest : float, ArrayType, callable + Resting membrane potential. +V_reset : float, ArrayType, callable + Reset potential after spike. +V_th : float, ArrayType, callable + Threshold potential of spike. +R : float, ArrayType, callable + Membrane resistance. +tau : float, ArrayType, callable + Membrane time constant. +V_initializer : ArrayType, callable + The initializer of membrane potential. '''.strip() ltc_doc = 'with liquid time-constant' @@ -94,9 +115,12 @@ ''' dual_exp_args = ''' - tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms] - tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms] - A: float. The normalization factor. Default None. +tau_decay : float, ArrayArray, Callable + The time constant of the synaptic decay phase. [ms] +tau_rise : float, ArrayArray, Callable + The time constant of the synaptic rise phase. [ms] +A : float + The normalization factor. Default None. ''' diff --git a/brainpy/dyn/channels/calcium.py b/brainpy/dyn/channels/calcium.py index b2bf3ca5e..4c95b1cb3 100644 --- a/brainpy/dyn/channels/calcium.py +++ b/brainpy/dyn/channels/calcium.py @@ -89,15 +89,16 @@ class _ICa_p2q_ss(CalciumChannel): where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, :math:`E_{Ca}` is the reversal potential of Calcium channel. - Parameters:: + Parameters + ---------- - size: int, tuple of int + size : int, tuple of int The size of the simulation target. - keep_size: bool + keep_size : bool Keep size or flatten the size? - method: str + method : str The numerical method - name: str + name : str The name of the object. g_max : float, ArrayType, Callable, Initializer The maximum conductance. @@ -182,15 +183,16 @@ class _ICa_p2q_markov(CalciumChannel): where :math:`\phi_p` and :math:`\phi_q` are temperature-dependent factors, :math:`E_{Ca}` is the reversal potential of Calcium channel. - Parameters:: + Parameters + ---------- - size: int, tuple of int + size : int, tuple of int The size of the simulation target. - keep_size: bool + keep_size : bool Keep size or flatten the size? - method: str + method : str The numerical method - name: str + name : str The name of the object. g_max : float, ArrayType, Callable, Initializer The maximum conductance. @@ -281,7 +283,8 @@ class ICaN_IS2008(CalciumChannel): where :math:`\phi` is the temperature factor. - Parameters:: + Parameters + ---------- g_max : float The maximal conductance density (:math:`mS/cm^2`). @@ -290,7 +293,8 @@ class ICaN_IS2008(CalciumChannel): phi : float The temperature factor. - References:: + References + ---------- .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. @@ -368,7 +372,8 @@ class ICaT_HM1992(_ICa_p2q_ss): are temperature-dependent factors (:math:`T` is the temperature in Celsius), :math:`E_{Ca}` is the reversal potential of Calcium channel. - Parameters:: + Parameters + ---------- T : float, ArrayType The temperature. @@ -385,12 +390,14 @@ class ICaT_HM1992(_ICa_p2q_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. - See Also:: + See Also + -------- ICa_p2q_form """ @@ -464,7 +471,8 @@ class ICaT_HP1992(_ICa_p2q_ss): are temperature-dependent factors (:math:`T` is the temperature in Celsius), :math:`E_{Ca}` is the reversal potential of Calcium channel. - Parameters:: + Parameters + ---------- T : float, ArrayType The temperature. @@ -481,13 +489,15 @@ class ICaT_HP1992(_ICa_p2q_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [1] Huguenard JR, Prince DA (1992) A novel T-type current underlies prolonged Ca2+- dependent burst firing in GABAergic neurons of rat thalamic reticular nucleus. J Neurosci 12: 3804–3817. - See Also:: + See Also + -------- ICa_p2q_form """ @@ -564,7 +574,8 @@ class ICaHT_HM1992(_ICa_p2q_ss): are temperature-dependent factors (:math:`T` is the temperature in Celsius), :math:`E_{Ca}` is the reversal potential of Calcium channel. - Parameters:: + Parameters + ---------- T : float, ArrayType The temperature. @@ -577,12 +588,14 @@ class ICaHT_HM1992(_ICa_p2q_ss): V_sh : float, ArrayType, Initializer, Callable The membrane potential shift. - References:: + References + ---------- .. [1] Huguenard JR, McCormick DA (1992) Simulation of the currents involved in rhythmic oscillations in thalamic relay neurons. J Neurophysiol 68:1373–1383. - See Also:: + See Also + -------- ICa_p2q_form """ @@ -656,15 +669,16 @@ class ICaHT_Re1993(_ICa_p2q_markov): \beta_{r} &=\frac{0.0065}{\exp [(-15-V+V_{sh}) / 28]+1}, \end{aligned} - Parameters:: + Parameters + ---------- - size: int, tuple of int + size : int, tuple of int The size of the simulation target. - keep_size: bool + keep_size : bool Keep size or flatten the size? - method: str + method : str The numerical method - name: str + name : str The name of the object. g_max : float, ArrayType, Callable, Initializer The maximum conductance. @@ -683,7 +697,8 @@ class ICaHT_Re1993(_ICa_p2q_markov): The temperature factor for channel :math:`q`. If `None`, :math:`\phi_q = \mathrm{T_base_q}^{\frac{T-23}{10}}`. - References:: + References + ---------- .. [1] Reuveni, I., et al. "Stepwise repolarization from Ca2+ plateaus in neocortical pyramidal cells: evidence for nonhomogeneous @@ -757,7 +772,8 @@ class ICaL_IS2008(_ICa_p2q_ss): are temperature-dependent factors (:math:`T` is the temperature in Celsius), :math:`E_{Ca}` is the reversal potential of Calcium channel. - Parameters:: + Parameters + ---------- T : float The temperature. @@ -770,13 +786,15 @@ class ICaL_IS2008(_ICa_p2q_ss): V_sh : float The membrane potential shift. - References:: + References + ---------- .. [1] Inoue, Tsuyoshi, and Ben W. Strowbridge. "Transient activity induces a long-lasting increase in the excitability of olfactory bulb interneurons." Journal of neurophysiology 99, no. 1 (2008): 187-199. - See Also:: + See Also + -------- ICa_p2q_form """ diff --git a/brainpy/dyn/channels/hyperpolarization_activated.py b/brainpy/dyn/channels/hyperpolarization_activated.py index 6cd852de3..878e59ded 100644 --- a/brainpy/dyn/channels/hyperpolarization_activated.py +++ b/brainpy/dyn/channels/hyperpolarization_activated.py @@ -51,7 +51,8 @@ class Ih_HM1992(IonChannel): where :math:`\phi=1` is a temperature-dependent factor. - Parameters:: + Parameters + ---------- g_max : float The maximal conductance density (:math:`mS/cm^2`). @@ -60,7 +61,8 @@ class Ih_HM1992(IonChannel): phi : float The temperature-dependent factor. - References:: + References + ---------- .. [1] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay neurons." Journal @@ -164,7 +166,8 @@ class Ih_De1996(IonChannel): and the temperature regulating factor :math:`\phi=2^{(T-24)/10}`. - References:: + References + ---------- .. [1] Destexhe, Alain, et al. "Ionic mechanisms underlying synchronized oscillations and propagating waves in a model of ferret thalamic diff --git a/brainpy/dyn/channels/leaky.py b/brainpy/dyn/channels/leaky.py index cca1549aa..4923d76c3 100644 --- a/brainpy/dyn/channels/leaky.py +++ b/brainpy/dyn/channels/leaky.py @@ -44,7 +44,8 @@ def reset_state(self, V, batch_size=None): class IL(LeakyChannel): """The leakage channel current. - Parameters:: + Parameters + ---------- g_max : float The leakage conductance. diff --git a/brainpy/dyn/channels/potassium.py b/brainpy/dyn/channels/potassium.py index 5f2f224e7..4c332eff2 100644 --- a/brainpy/dyn/channels/potassium.py +++ b/brainpy/dyn/channels/potassium.py @@ -92,11 +92,12 @@ class _IK_p4_markov_v2(PotassiumChannel): where :math:`\phi` is a temperature-dependent factor. - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The object size. - keep_size: bool + keep_size : bool Whether we use `size` to initialize the variable. Otherwise, variable shape will be initialized as `num`. g_max : float, ArrayType, Initializer, Callable @@ -105,9 +106,9 @@ class _IK_p4_markov_v2(PotassiumChannel): The reversal potential (mV). phi : float, ArrayType, Initializer, Callable The temperature-dependent factor. - method: str + method : str The numerical integration method. - name: str + name : str The object name. """ @@ -177,11 +178,12 @@ class IKDR_Ba2002v2(_IK_p4_markov_v2): where :math:`\phi` is a temperature-dependent factor, which is given by :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The object size. - keep_size: bool + keep_size : bool Whether we use `size` to initialize the variable. Otherwise, variable shape will be initialized as `num`. g_max : float, ArrayType, Initializer, Callable @@ -194,12 +196,13 @@ class IKDR_Ba2002v2(_IK_p4_markov_v2): The temperature (Celsius, :math:`^{\circ}C`). V_sh : float, ArrayType, Initializer, Callable The shift of the membrane potential to spike. - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. @@ -259,25 +262,28 @@ class IK_TM1991v2(_IK_p4_markov_v2): where :math:`V_{sh}` is the membrane shift (default -63 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Initializer, Callable The reversal potential (mV). - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. Vol. 777. Cambridge University Press, 1991. - See Also:: + See Also + -------- INa_TM1991 """ @@ -328,26 +334,29 @@ class IK_HH1952v2(_IK_p4_markov_v2): where :math:`V_{sh}` is the membrane shift (default -45 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Initializer, Callable The reversal potential (mV). - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500. - See Also:: + See Also + -------- INa_HH1952 """ @@ -396,13 +405,14 @@ class _IKA_p4q_ss_v2(PotassiumChannel): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -413,7 +423,8 @@ class _IKA_p4q_ss_v2(PotassiumChannel): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -502,13 +513,14 @@ class IKA1_HM1992v2(_IKA_p4q_ss_v2): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -521,7 +533,8 @@ class IKA1_HM1992v2(_IKA_p4q_ss_v2): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -530,7 +543,8 @@ class IKA1_HM1992v2(_IKA_p4q_ss_v2): TEA-sensitive K current in acutely isolated rat thalamic relay neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - See Also:: + See Also + -------- IKA2_HM1992 """ @@ -595,13 +609,14 @@ class IKA2_HM1992v2(_IKA_p4q_ss_v2): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -614,7 +629,8 @@ class IKA2_HM1992v2(_IKA_p4q_ss_v2): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -623,7 +639,8 @@ class IKA2_HM1992v2(_IKA_p4q_ss_v2): TEA-sensitive K current in acutely isolated rat thalamic relay neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - See Also:: + See Also + -------- IKA1_HM1992 """ @@ -683,13 +700,14 @@ class _IKK2_pq_ss_v2(PotassiumChannel): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -700,7 +718,8 @@ class _IKK2_pq_ss_v2(PotassiumChannel): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -789,13 +808,14 @@ class IKK2A_HM1992v2(_IKK2_pq_ss_v2): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -808,7 +828,8 @@ class IKK2A_HM1992v2(_IKK2_pq_ss_v2): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -878,13 +899,14 @@ class IKK2B_HM1992v2(_IKK2_pq_ss_v2): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -897,7 +919,8 @@ class IKK2B_HM1992v2(_IKK2_pq_ss_v2): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -966,13 +989,14 @@ class IKNI_Ya1989v2(PotassiumChannel): where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -982,10 +1006,11 @@ class IKNI_Ya1989v2(PotassiumChannel): The membrane potential shift. phi_p : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`p`. - tau_max: float, ArrayType, Callable, Initializer + tau_max : float, ArrayType, Callable, Initializer The :math:`tau_{\max}` parameter. - References:: + References + ---------- .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. @@ -1059,11 +1084,12 @@ class _IK_p4_markov(PotassiumChannel): where :math:`\phi` is a temperature-dependent factor. - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The object size. - keep_size: bool + keep_size : bool Whether we use `size` to initialize the variable. Otherwise, variable shape will be initialized as `num`. g_max : float, ArrayType, Initializer, Callable @@ -1072,9 +1098,9 @@ class _IK_p4_markov(PotassiumChannel): The reversal potential (mV). phi : float, ArrayType, Initializer, Callable The temperature-dependent factor. - method: str + method : str The numerical integration method. - name: str + name : str The object name. """ @@ -1147,11 +1173,12 @@ class IKDR_Ba2002(_IK_p4_markov): where :math:`\phi` is a temperature-dependent factor, which is given by :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The object size. - keep_size: bool + keep_size : bool Whether we use `size` to initialize the variable. Otherwise, variable shape will be initialized as `num`. g_max : float, ArrayType, Initializer, Callable @@ -1164,12 +1191,13 @@ class IKDR_Ba2002(_IK_p4_markov): The temperature (Celsius, :math:`^{\circ}C`). V_sh : float, ArrayType, Initializer, Callable The shift of the membrane potential to spike. - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. @@ -1231,25 +1259,28 @@ class IK_TM1991(_IK_p4_markov): where :math:`V_{sh}` is the membrane shift (default -63 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Initializer, Callable The reversal potential (mV). - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. Vol. 777. Cambridge University Press, 1991. - See Also:: + See Also + -------- INa_TM1991 """ @@ -1302,26 +1333,29 @@ class IK_HH1952(_IK_p4_markov): where :math:`V_{sh}` is the membrane shift (default -45 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Initializer, Callable The reversal potential (mV). - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500. - See Also:: + See Also + -------- INa_HH1952 """ @@ -1372,13 +1406,14 @@ class _IKA_p4q_ss(PotassiumChannel): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -1389,7 +1424,8 @@ class _IKA_p4q_ss(PotassiumChannel): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -1481,13 +1517,14 @@ class IKA1_HM1992(_IKA_p4q_ss): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -1500,7 +1537,8 @@ class IKA1_HM1992(_IKA_p4q_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -1509,7 +1547,8 @@ class IKA1_HM1992(_IKA_p4q_ss): TEA-sensitive K current in acutely isolated rat thalamic relay neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - See Also:: + See Also + -------- IKA2_HM1992 """ @@ -1576,13 +1615,14 @@ class IKA2_HM1992(_IKA_p4q_ss): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -1595,7 +1635,8 @@ class IKA2_HM1992(_IKA_p4q_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -1604,7 +1645,8 @@ class IKA2_HM1992(_IKA_p4q_ss): TEA-sensitive K current in acutely isolated rat thalamic relay neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - See Also:: + See Also + -------- IKA1_HM1992 """ @@ -1666,13 +1708,14 @@ class _IKK2_pq_ss(PotassiumChannel): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -1683,7 +1726,8 @@ class _IKK2_pq_ss(PotassiumChannel): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -1775,13 +1819,14 @@ class IKK2A_HM1992(_IKK2_pq_ss): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -1794,7 +1839,8 @@ class IKK2A_HM1992(_IKK2_pq_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -1866,13 +1912,14 @@ class IKK2B_HM1992(_IKK2_pq_ss): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -1885,7 +1932,8 @@ class IKK2B_HM1992(_IKK2_pq_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -1956,13 +2004,14 @@ class IKNI_Ya1989(PotassiumChannel): where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -1972,10 +2021,11 @@ class IKNI_Ya1989(PotassiumChannel): The membrane potential shift. phi_p : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`p`. - tau_max: float, ArrayType, Callable, Initializer + tau_max : float, ArrayType, Callable, Initializer The :math:`tau_{\max}` parameter. - References:: + References + ---------- .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. @@ -2040,7 +2090,8 @@ def f_p_tau(self, V): class IK_Leak(PotassiumChannel): """The potassium leak channel current. - Parameters:: + Parameters + ---------- g_max : float The potassium leakage conductance which is modulated by both diff --git a/brainpy/dyn/channels/potassium_calcium.py b/brainpy/dyn/channels/potassium_calcium.py index 858c73757..b510833d7 100644 --- a/brainpy/dyn/channels/potassium_calcium.py +++ b/brainpy/dyn/channels/potassium_calcium.py @@ -72,12 +72,14 @@ class IAHP_De1994v2(KCaChannel): :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells recorded in vivo and in vitro. - Parameters:: + Parameters + ---------- g_max : float The maximal conductance density (:math:`mS/cm^2`). - References:: + References + ---------- .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. diff --git a/brainpy/dyn/channels/potassium_calcium_compatible.py b/brainpy/dyn/channels/potassium_calcium_compatible.py index 503f2c705..e41c66222 100644 --- a/brainpy/dyn/channels/potassium_calcium_compatible.py +++ b/brainpy/dyn/channels/potassium_calcium_compatible.py @@ -66,14 +66,16 @@ class IAHP_De1994(IonChannel): :math:`\beta=0.03 \mathrm{~ms}^{-1}` yielded AHPs very similar to those RE cells recorded in vivo and in vitro. - Parameters:: + Parameters + ---------- g_max : float The maximal conductance density (:math:`mS/cm^2`). E : float The reversal potential (mV). - References:: + References + ---------- .. [1] Destexhe, Alain, et al. "A model of spindle rhythmicity in the isolated thalamic reticular nucleus." Journal of neurophysiology 72.2 (1994): 803-818. diff --git a/brainpy/dyn/channels/potassium_compatible.py b/brainpy/dyn/channels/potassium_compatible.py index 1f707bd65..21c75a76c 100644 --- a/brainpy/dyn/channels/potassium_compatible.py +++ b/brainpy/dyn/channels/potassium_compatible.py @@ -69,11 +69,12 @@ class _IK_p4_markov(IonChannel): where :math:`\phi` is a temperature-dependent factor. - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The object size. - keep_size: bool + keep_size : bool Whether we use `size` to initialize the variable. Otherwise, variable shape will be initialized as `num`. g_max : float, ArrayType, Initializer, Callable @@ -82,9 +83,9 @@ class _IK_p4_markov(IonChannel): The reversal potential (mV). phi : float, ArrayType, Initializer, Callable The temperature-dependent factor. - method: str + method : str The numerical integration method. - name: str + name : str The object name. """ @@ -157,11 +158,12 @@ class IKDR_Ba2002(_IK_p4_markov): where :math:`\phi` is a temperature-dependent factor, which is given by :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The object size. - keep_size: bool + keep_size : bool Whether we use `size` to initialize the variable. Otherwise, variable shape will be initialized as `num`. g_max : float, ArrayType, Initializer, Callable @@ -174,12 +176,13 @@ class IKDR_Ba2002(_IK_p4_markov): The temperature (Celsius, :math:`^{\circ}C`). V_sh : float, ArrayType, Initializer, Callable The shift of the membrane potential to spike. - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. @@ -240,25 +243,28 @@ class IK_TM1991(_IK_p4_markov): where :math:`V_{sh}` is the membrane shift (default -63 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Initializer, Callable The reversal potential (mV). - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. Vol. 777. Cambridge University Press, 1991. - See Also:: + See Also + -------- INa_TM1991 """ @@ -310,26 +316,29 @@ class IK_HH1952(_IK_p4_markov): where :math:`V_{sh}` is the membrane shift (default -45 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Initializer, Callable The reversal potential (mV). - method: str + method : str The numerical integration method. - name: str + name : str The object name. - References:: + References + ---------- .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500. - See Also:: + See Also + -------- INa_HH1952 """ @@ -379,13 +388,14 @@ class _IKA_p4q_ss(IonChannel): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -396,7 +406,8 @@ class _IKA_p4q_ss(IonChannel): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -488,13 +499,14 @@ class IKA1_HM1992(_IKA_p4q_ss): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -507,7 +519,8 @@ class IKA1_HM1992(_IKA_p4q_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -516,7 +529,8 @@ class IKA1_HM1992(_IKA_p4q_ss): TEA-sensitive K current in acutely isolated rat thalamic relay neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - See Also:: + See Also + -------- IKA2_HM1992 """ @@ -583,13 +597,14 @@ class IKA2_HM1992(_IKA_p4q_ss): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -602,7 +617,8 @@ class IKA2_HM1992(_IKA_p4q_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -611,7 +627,8 @@ class IKA2_HM1992(_IKA_p4q_ss): TEA-sensitive K current in acutely isolated rat thalamic relay neurons." Journal of neurophysiology 66.4 (1991): 1316-1328. - See Also:: + See Also + -------- IKA1_HM1992 """ @@ -673,13 +690,14 @@ class _IKK2_pq_ss(IonChannel): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -690,7 +708,8 @@ class _IKK2_pq_ss(IonChannel): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -782,13 +801,14 @@ class IKK2A_HM1992(_IKK2_pq_ss): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -801,7 +821,8 @@ class IKK2A_HM1992(_IKK2_pq_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -873,13 +894,14 @@ class IKK2B_HM1992(_IKK2_pq_ss): where :math:`\phi_p` and :math:`\phi_q` are the temperature dependent factors (default 1.). - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -892,7 +914,8 @@ class IKK2B_HM1992(_IKK2_pq_ss): phi_q : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`q`. - References:: + References + ---------- .. [2] Huguenard, John R., and David A. McCormick. "Simulation of the currents involved in rhythmic oscillations in thalamic relay @@ -963,13 +986,14 @@ class IKNI_Ya1989(IonChannel): where :math:`\bar{g}_{M}` was :math:`0.004 \mathrm{mS} / \mathrm{cm}^{2}` and :math:`\tau_{\max }=4 \mathrm{~s}`, unless stated otherwise. - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The geometry size. - method: str + method : str The numerical integration method. - name: str + name : str The object name. g_max : float, ArrayType, Initializer, Callable The maximal conductance density (:math:`mS/cm^2`). @@ -979,10 +1003,11 @@ class IKNI_Ya1989(IonChannel): The membrane potential shift. phi_p : optional, float, ArrayType, Callable, Initializer The temperature factor for channel :math:`p`. - tau_max: float, ArrayType, Callable, Initializer + tau_max : float, ArrayType, Callable, Initializer The :math:`tau_{\max}` parameter. - References:: + References + ---------- .. [1] Yamada, Walter M. "Multiple channels and calcium dynamics." Methods in neuronal modeling (1989): 97-133. @@ -1047,7 +1072,8 @@ def f_p_tau(self, V): class IKL(IonChannel): """The potassium leak channel current. - Parameters:: + Parameters + ---------- g_max : float The potassium leakage conductance which is modulated by both diff --git a/brainpy/dyn/channels/sodium.py b/brainpy/dyn/channels/sodium.py index b1f9700e7..a54fe3e15 100644 --- a/brainpy/dyn/channels/sodium.py +++ b/brainpy/dyn/channels/sodium.py @@ -85,7 +85,8 @@ class _INa_p3q_markov_v2(SodiumChannel): where :math:`\phi` is a temperature-dependent factor. - Parameters:: + Parameters + ---------- g_max : float, ArrayType, Callable, Initializer The maximal conductance density (:math:`mS/cm^2`). @@ -93,9 +94,9 @@ class _INa_p3q_markov_v2(SodiumChannel): The reversal potential (mV). phi : float, ArrayType, Callable, Initializer The temperature-dependent factor. - method: str + method : str The numerical method - name: str + name : str The name of the object. """ @@ -184,7 +185,8 @@ class INa_Ba2002v2(_INa_p3q_markov_v2): where :math:`\phi` is a temperature-dependent factor, which is given by :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - Parameters:: + Parameters + ---------- g_max : float, ArrayType, Callable, Initializer The maximal conductance density (:math:`mS/cm^2`). @@ -195,12 +197,14 @@ class INa_Ba2002v2(_INa_p3q_markov_v2): V_sh : float, ArrayType, Callable, Initializer The shift of the membrane potential to spike. - References:: + References + ---------- .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - See Also:: + See Also + -------- INa_TM1991 """ @@ -265,29 +269,32 @@ class INa_TM1991v2(_INa_p3q_markov_v2): where :math:`V_{sh}` is the membrane shift (default -63 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, tuple of int + size : int, tuple of int The size of the simulation target. - keep_size: bool + keep_size : bool Keep size or flatten the size? - method: str + method : str The numerical method - name: str + name : str The name of the object. g_max : float, ArrayType, Callable, Initializer The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Callable, Initializer The reversal potential (mV). - V_sh: float, ArrayType, Callable, Initializer + V_sh : float, ArrayType, Callable, Initializer The membrane shift. - References:: + References + ---------- .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. Vol. 777. Cambridge University Press, 1991. - See Also:: + See Also + -------- INa_Ba2002 """ @@ -351,30 +358,33 @@ class INa_HH1952v2(_INa_p3q_markov_v2): where :math:`V_{sh}` is the membrane shift (default -45 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, tuple of int + size : int, tuple of int The size of the simulation target. - keep_size: bool + keep_size : bool Keep size or flatten the size? - method: str + method : str The numerical method - name: str + name : str The name of the object. g_max : float, ArrayType, Callable, Initializer The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Callable, Initializer The reversal potential (mV). - V_sh: float, ArrayType, Callable, Initializer + V_sh : float, ArrayType, Callable, Initializer The membrane shift. - References:: + References + ---------- .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500. - See Also:: + See Also + -------- IK_HH1952 """ diff --git a/brainpy/dyn/channels/sodium_compatible.py b/brainpy/dyn/channels/sodium_compatible.py index f18000213..091cade7c 100644 --- a/brainpy/dyn/channels/sodium_compatible.py +++ b/brainpy/dyn/channels/sodium_compatible.py @@ -63,7 +63,8 @@ class _INa_p3q_markov(IonChannel): where :math:`\phi` is a temperature-dependent factor. - Parameters:: + Parameters + ---------- g_max : float, ArrayType, Callable, Initializer The maximal conductance density (:math:`mS/cm^2`). @@ -71,9 +72,9 @@ class _INa_p3q_markov(IonChannel): The reversal potential (mV). phi : float, ArrayType, Callable, Initializer The temperature-dependent factor. - method: str + method : str The numerical method - name: str + name : str The name of the object. """ @@ -165,7 +166,8 @@ class INa_Ba2002(_INa_p3q_markov): where :math:`\phi` is a temperature-dependent factor, which is given by :math:`\phi=3^{\frac{T-36}{10}}` (:math:`T` is the temperature in Celsius). - Parameters:: + Parameters + ---------- g_max : float, ArrayType, Callable, Initializer The maximal conductance density (:math:`mS/cm^2`). @@ -176,12 +178,14 @@ class INa_Ba2002(_INa_p3q_markov): V_sh : float, ArrayType, Callable, Initializer The shift of the membrane potential to spike. - References:: + References + ---------- .. [1] Bazhenov, Maxim, et al. "Model of thalamocortical slow-wave sleep oscillations and transitions to activated states." Journal of neuroscience 22.19 (2002): 8691-8704. - See Also:: + See Also + -------- INa_TM1991 """ @@ -246,29 +250,32 @@ class INa_TM1991(_INa_p3q_markov): where :math:`V_{sh}` is the membrane shift (default -63 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, tuple of int + size : int, tuple of int The size of the simulation target. - keep_size: bool + keep_size : bool Keep size or flatten the size? - method: str + method : str The numerical method - name: str + name : str The name of the object. g_max : float, ArrayType, Callable, Initializer The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Callable, Initializer The reversal potential (mV). - V_sh: float, ArrayType, Callable, Initializer + V_sh : float, ArrayType, Callable, Initializer The membrane shift. - References:: + References + ---------- .. [1] Traub, Roger D., and Richard Miles. Neuronal networks of the hippocampus. Vol. 777. Cambridge University Press, 1991. - See Also:: + See Also + -------- INa_Ba2002 """ @@ -332,30 +339,33 @@ class INa_HH1952(_INa_p3q_markov): where :math:`V_{sh}` is the membrane shift (default -45 mV), and :math:`\phi` is the temperature-dependent factor (default 1.). - Parameters:: + Parameters + ---------- - size: int, tuple of int + size : int, tuple of int The size of the simulation target. - keep_size: bool + keep_size : bool Keep size or flatten the size? - method: str + method : str The numerical method - name: str + name : str The name of the object. g_max : float, ArrayType, Callable, Initializer The maximal conductance density (:math:`mS/cm^2`). E : float, ArrayType, Callable, Initializer The reversal potential (mV). - V_sh: float, ArrayType, Callable, Initializer + V_sh : float, ArrayType, Callable, Initializer The membrane shift. - References:: + References + ---------- .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500. - See Also:: + See Also + -------- IK_HH1952 """ diff --git a/brainpy/dyn/ions/base.py b/brainpy/dyn/ions/base.py index 64526003c..2a8b6a90a 100644 --- a/brainpy/dyn/ions/base.py +++ b/brainpy/dyn/ions/base.py @@ -33,9 +33,12 @@ class MixIons(IonChaDyn, Container, TreeNode): """Mixing Ions. - Args: - ions: Instances of ions. This option defines the master types of all children objects. - channels: Instance of channels. + Parameters + ---------- + ions + Instances of ions. This option defines the master types of all children objects. + channels + Instance of channels. """ master_type = HHTypedNeuron @@ -64,11 +67,14 @@ def update(self, V): def current(self, V): """Generate ion channel current. - Args: - V: The membrane potential. + Parameters + ---------- + V + The membrane potential. - Returns: - Current. + Returns + ------- + Current. """ nodes = tuple(self.nodes(level=1, include_self=False).unique().subset(IonChaDyn).values()) self.check_hierarchies(self._ion_classes, *nodes) @@ -100,8 +106,10 @@ def check_hierarchy(self, roots, leaf): def add_elem(self, *elems, **elements): """Add new elements. - Args: - elements: children objects. + Parameters + ---------- + elements + children objects. """ self.check_hierarchies(self._ion_classes, *elems, **elements) self.children.update(self.format_elements(IonChaDyn, *elems, **elements)) @@ -136,11 +144,14 @@ def _check_master_type(self, leaf): def mix_ions(*ions) -> MixIons: """Create mixed ions. - Args: - ions: Ion instances. + Parameters + ---------- + ions + Ion instances. - Returns: - Instance of MixIons. + Returns + ------- + Instance of MixIons. """ for ion in ions: assert isinstance(ion, Ion), f'Must be instance of {Ion.__name__}. But got {type(ion)}' @@ -151,11 +162,16 @@ def mix_ions(*ions) -> MixIons: class Ion(IonChaDyn, Container, TreeNode): """The brainpy_object calcium dynamics. - Args: - size: The size of the simulation target. - method: The numerical integration method. - name: The name of the object. - channels: The calcium dependent channels. + Parameters + ---------- + size : Shape + The size of the simulation target. + method : str + The numerical integration method. + name : Optional[str] + The name of the object. + channels + The calcium dependent channels. """ '''The type of the master object.''' @@ -189,14 +205,20 @@ def update(self, V): def current(self, V, C=None, E=None, external: bool = False): """Generate ion channel current. - Args: - V: The membrane potential. - C: The given ion concentration. - E: The given reversal potential. - external: Include the external current. - - Returns: - Current. + Parameters + ---------- + V + The membrane potential. + C + The given ion concentration. + E + The given reversal potential. + external : bool + Include the external current. + + Returns + ------- + Current. """ C = self.C if (C is None) else C E = self.E if (E is None) else E diff --git a/brainpy/dyn/ions/calcium.py b/brainpy/dyn/ions/calcium.py index 6c6995897..51457ac6f 100644 --- a/brainpy/dyn/ions/calcium.py +++ b/brainpy/dyn/ions/calcium.py @@ -73,21 +73,22 @@ def reset_state(self, V, C_Ca=None, E_Ca=None, batch_size=None): class CalciumDyna(Calcium): """Calcium ion flow with dynamics. - Parameters:: + Parameters + ---------- - size: int, tuple of int + size : int, tuple of int The ion size. - keep_size: bool + keep_size : bool Keep the geometry size. - C0: float, ArrayType, Initializer, Callable + C0 : float, ArrayType, Initializer, Callable The Calcium concentration outside of membrane. - T: float, ArrayType, Initializer, Callable + T : float, ArrayType, Initializer, Callable The temperature. - C_initializer: Initializer, Callable, ArrayType + C_initializer : Initializer, Callable, ArrayType The initializer for Calcium concentration. - method: str + method : str The numerical method. - name: str + name : str The ion name. """ R = 8.31441 # gas constant, J*mol-1*K-1 @@ -228,7 +229,8 @@ class CalciumDetailed(CalciumDyna): :math:`F=96,489 \mathrm{C} / \mathrm{mol}`, and :math:`\left[\mathrm{Ca}^{2+}\right]_{0}=2 \mathrm{mM}`. - Parameters:: + Parameters + ---------- d : float The thickness of the peri-membrane "shell". @@ -243,7 +245,8 @@ class CalciumDetailed(CalciumDyna): R : float The gas constant. (:math:` J*mol^{-1}*K^{-1}`) - References:: + References + ---------- .. [1] Destexhe, Alain, Agnessa Babloyantz, and Terrence J. Sejnowski. "Ionic mechanisms for intrinsic slow oscillations in thalamic diff --git a/brainpy/dyn/neurons/base.py b/brainpy/dyn/neurons/base.py index fbb25bd6e..5162a4b80 100644 --- a/brainpy/dyn/neurons/base.py +++ b/brainpy/dyn/neurons/base.py @@ -25,9 +25,10 @@ class GradNeuDyn(NeuDyn): """Differentiable and Parallelizable Neuron Group. - Args: - {pneu} - {dpneu} + Parameters + ---------- + {pneu} + {dpneu} """ supported_modes = (bm.TrainingMode, bm.NonBatchingMode) diff --git a/brainpy/dyn/neurons/hh.py b/brainpy/dyn/neurons/hh.py index d369cd3fc..34b0b6f98 100644 --- a/brainpy/dyn/neurons/hh.py +++ b/brainpy/dyn/neurons/hh.py @@ -77,11 +77,12 @@ class CondNeuGroupLTC(HHTypedNeuron, Container, TreeNode): .. versionadded:: 2.1.9 Modeling the conductance-based neuron model. - Parameters:: + Parameters + ---------- size : int, sequence of int The network size of this neuron group. - method: str + method : str The numerical integration method. name : optional, str The neuron group name. @@ -249,7 +250,8 @@ class HHLTC(NeuDyn): such as limit cycles, can be proven to exist. - References:: + References + ---------- .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation @@ -279,37 +281,38 @@ class HHLTC(NeuDyn): - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - ENa: float, ArrayType, Initializer, callable + ENa : float, ArrayType, Initializer, callable The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable + gNa : float, ArrayType, Initializer, callable The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable + EK : float, ArrayType, Initializer, callable The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable + gK : float, ArrayType, Initializer, callable The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable + EL : float, ArrayType, Initializer, callable The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable + gL : float, ArrayType, Initializer, callable The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable + V_th : float, ArrayType, Initializer, callable The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable + C : float, ArrayType, Initializer, callable The membrane capacitance. Default is 1 ufarad. - V_initializer: ArrayType, Initializer, callable + V_initializer : ArrayType, Initializer, callable The initializer of membrane potential. - m_initializer: ArrayType, Initializer, callable + m_initializer : ArrayType, Initializer, callable The initializer of m channel. - h_initializer: ArrayType, Initializer, callable + h_initializer : ArrayType, Initializer, callable The initializer of h channel. - n_initializer: ArrayType, Initializer, callable + n_initializer : ArrayType, Initializer, callable The initializer of n channel. - method: str + method : str The numerical integration method. - name: str + name : str The group name. @@ -481,7 +484,8 @@ class HH(HHLTC): &\beta_n(V) = 0.125 \exp(\frac{-(V + 65)} {80}) - References:: + References + ---------- .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation @@ -517,37 +521,38 @@ class HH(HHLTC): The illustrated example of HH neuron model please see `this notebook <../neurons/HH_model.ipynb>`_. - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - ENa: float, ArrayType, Initializer, callable + ENa : float, ArrayType, Initializer, callable The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable + gNa : float, ArrayType, Initializer, callable The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable + EK : float, ArrayType, Initializer, callable The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable + gK : float, ArrayType, Initializer, callable The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable + EL : float, ArrayType, Initializer, callable The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable + gL : float, ArrayType, Initializer, callable The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable + V_th : float, ArrayType, Initializer, callable The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable + C : float, ArrayType, Initializer, callable The membrane capacitance. Default is 1 ufarad. - V_initializer: ArrayType, Initializer, callable + V_initializer : ArrayType, Initializer, callable The initializer of membrane potential. - m_initializer: ArrayType, Initializer, callable + m_initializer : ArrayType, Initializer, callable The initializer of m channel. - h_initializer: ArrayType, Initializer, callable + h_initializer : ArrayType, Initializer, callable The initializer of h channel. - n_initializer: ArrayType, Initializer, callable + n_initializer : ArrayType, Initializer, callable The initializer of n channel. - method: str + method : str The numerical integration method. - name: str + name : str The group name. """ @@ -619,7 +624,8 @@ class MorrisLecarLTC(NeuDyn): V_th 10 mV The spike threshold. ============= ============== ======== ======================================================= - References:: + References + ---------- .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model @@ -785,7 +791,8 @@ class MorrisLecar(MorrisLecarLTC): V_th 10 mV The spike threshold. ============= ============== ======== ======================================================= - References:: + References + ---------- .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model @@ -850,7 +857,8 @@ class WangBuzsakiHHLTC(NeuDyn): :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. - References:: + References + ---------- .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic inhibition in a hippocampal interneuronal network model. Journal of @@ -875,37 +883,38 @@ class WangBuzsakiHHLTC(NeuDyn): plt.tight_layout() plt.show() - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - ENa: float, ArrayType, Initializer, callable + ENa : float, ArrayType, Initializer, callable The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable + gNa : float, ArrayType, Initializer, callable The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable + EK : float, ArrayType, Initializer, callable The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable + gK : float, ArrayType, Initializer, callable The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable + EL : float, ArrayType, Initializer, callable The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable + gL : float, ArrayType, Initializer, callable The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable + V_th : float, ArrayType, Initializer, callable The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable + C : float, ArrayType, Initializer, callable The membrane capacitance. Default is 1 ufarad. - phi: float, ArrayType, Initializer, callable + phi : float, ArrayType, Initializer, callable The temperature regulator constant. - V_initializer: ArrayType, Initializer, callable + V_initializer : ArrayType, Initializer, callable The initializer of membrane potential. - h_initializer: ArrayType, Initializer, callable + h_initializer : ArrayType, Initializer, callable The initializer of h channel. - n_initializer: ArrayType, Initializer, callable + n_initializer : ArrayType, Initializer, callable The initializer of n channel. - method: str + method : str The numerical integration method. - name: str + name : str The group name. @@ -1072,7 +1081,8 @@ class WangBuzsakiHH(WangBuzsakiHHLTC): :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. - References:: + References + ---------- .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic inhibition in a hippocampal interneuronal network model. Journal of @@ -1097,37 +1107,38 @@ class WangBuzsakiHH(WangBuzsakiHHLTC): plt.tight_layout() plt.show() - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - ENa: float, ArrayType, Initializer, callable + ENa : float, ArrayType, Initializer, callable The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable + gNa : float, ArrayType, Initializer, callable The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable + EK : float, ArrayType, Initializer, callable The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable + gK : float, ArrayType, Initializer, callable The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable + EL : float, ArrayType, Initializer, callable The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable + gL : float, ArrayType, Initializer, callable The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable + V_th : float, ArrayType, Initializer, callable The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable + C : float, ArrayType, Initializer, callable The membrane capacitance. Default is 1 ufarad. - phi: float, ArrayType, Initializer, callable + phi : float, ArrayType, Initializer, callable The temperature regulator constant. - V_initializer: ArrayType, Initializer, callable + V_initializer : ArrayType, Initializer, callable The initializer of membrane potential. - h_initializer: ArrayType, Initializer, callable + h_initializer : ArrayType, Initializer, callable The initializer of h channel. - n_initializer: ArrayType, Initializer, callable + n_initializer : ArrayType, Initializer, callable The initializer of n channel. - method: str + method : str The numerical integration method. - name: str + name : str The group name. """ diff --git a/brainpy/dyn/neurons/lif.py b/brainpy/dyn/neurons/lif.py index d4fe42936..053a652cb 100644 --- a/brainpy/dyn/neurons/lif.py +++ b/brainpy/dyn/neurons/lif.py @@ -77,10 +77,11 @@ class IFLTC(GradNeuDyn): resistance. - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def __init__( @@ -205,10 +206,11 @@ class LifLTC(GradNeuDyn): bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ @@ -348,10 +350,11 @@ class Lif(LifLTC): bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ @@ -407,11 +410,12 @@ class LifRefLTC(LifLTC): bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - Args: - %s - %s - %s - %s + Parameters + ---------- + %s + %s + %s + %s """ @@ -567,11 +571,12 @@ class LifRef(LifRefLTC): bp.visualize.line_plot(runner.mon['ts'], runner.mon['V'], show=True) - Args: - %s - %s - %s - %s + Parameters + ---------- + %s + %s + %s + %s """ @@ -908,9 +913,10 @@ class ExpIF(ExpIFLTC): - Args: - %s - %s + Parameters + ---------- + %s + %s """ def derivative(self, V, t, I): @@ -1032,10 +1038,11 @@ class ExpIFRefLTC(ExpIFLTC): - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ @@ -1272,10 +1279,11 @@ class ExpIFRef(ExpIFRefLTC): - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def derivative(self, V, t, I): @@ -1616,9 +1624,10 @@ class AdExIF(AdExIFLTC): - Args: - %s - %s + Parameters + ---------- + %s + %s """ def dV(self, V, t, w, I): @@ -1728,10 +1737,11 @@ class AdExIFRefLTC(AdExIFLTC): - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def __init__( @@ -1965,10 +1975,11 @@ class AdExIFRef(AdExIFRefLTC): t_last_spike -1e7 Last spike time stamp. ================== ================= ========================================================= - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def dV(self, V, t, w, I): @@ -2239,9 +2250,10 @@ class QuaIF(QuaIFLTC): - Args: - %s - %s + Parameters + ---------- + %s + %s """ def derivative(self, V, t, I): @@ -2326,10 +2338,11 @@ class QuaIFRefLTC(QuaIFLTC): t_last_spike -1e7 Last spike time stamp. ================== ================= ========================================================= - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def __init__( @@ -2529,10 +2542,11 @@ class QuaIFRef(QuaIFRefLTC): t_last_spike -1e7 Last spike time stamp. ================== ================= ========================================================= - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def derivative(self, V, t, I): @@ -2843,9 +2857,10 @@ class AdQuaIF(AdQuaIFLTC): - Args: - %s - %s + Parameters + ---------- + %s + %s """ def dV(self, V, t, w, I): @@ -2941,10 +2956,11 @@ class AdQuaIFRefLTC(AdQuaIFLTC): t_last_spike -1e7 Last spike time stamp. ================== ================= ========================================================== - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def __init__( @@ -3166,10 +3182,11 @@ class AdQuaIFRef(AdQuaIFRefLTC): - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def dV(self, V, t, w, I): @@ -3556,9 +3573,10 @@ class Gif(GifLTC): - Args: - %s - %s + Parameters + ---------- + %s + %s """ def dV(self, V, t, I1, I2, I): @@ -3681,10 +3699,11 @@ class GifRefLTC(GifLTC): - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def __init__( @@ -3958,10 +3977,11 @@ class GifRef(GifRefLTC): - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def dV(self, V, t, I1, I2, I): @@ -4281,9 +4301,10 @@ class Izhikevich(IzhikevichLTC): ================== ================= ========================================================= - Args: - %s - %s + Parameters + ---------- + %s + %s """ @@ -4384,10 +4405,11 @@ class IzhikevichRefLTC(IzhikevichLTC): - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ @@ -4616,10 +4638,11 @@ class IzhikevichRef(IzhikevichRefLTC): - Args: - %s - %s - %s + Parameters + ---------- + %s + %s + %s """ def dV(self, V, t, u, I): diff --git a/brainpy/dyn/others/common.py b/brainpy/dyn/others/common.py index 881c9f080..661edff19 100644 --- a/brainpy/dyn/others/common.py +++ b/brainpy/dyn/others/common.py @@ -42,11 +42,15 @@ class Leaky(NeuDyn): x(t + \Delta t) = \exp{-\Delta t/\tau} x(t) + I - Args: - tau: float, ArrayType, Initializer, callable. Membrane time constant. - method: str. The numerical integration method. Default "exp_auto". - init_var: Initialize the variable or not. - %s + Parameters + ---------- + tau : float, ArrayType, Initializer, callable + Membrane time constant. + method : str + The numerical integration method. Default "exp_auto". + init_var + Initialize the variable or not. + %s """ supported_modes = (bm.TrainingMode, bm.NonBatchingMode) @@ -114,11 +118,15 @@ class Integrator(NeuDyn): where :math:`x` is the integrator value, and :math:`\tau` is the time constant. - Args: - tau: float, ArrayType, Initializer, callable. Membrane time constant. - method: str. The numerical integration method. Default "exp_auto". - x_initializer: ArrayType, Initializer, callable. The initializer of :math:`x`. - %s + Parameters + ---------- + tau : float, ArrayType, Initializer, callable + Membrane time constant. + method : str + The numerical integration method. Default "exp_auto". + x_initializer : ArrayType, Initializer, callable + The initializer of :math:`x`. + %s """ supported_modes = (bm.TrainingMode, bm.NonBatchingMode) diff --git a/brainpy/dyn/others/input.py b/brainpy/dyn/others/input.py index 1272eaa5a..58b9fa770 100644 --- a/brainpy/dyn/others/input.py +++ b/brainpy/dyn/others/input.py @@ -38,11 +38,12 @@ class InputGroup(NeuDyn): """Input neuron group for place holder. - Args: - size: int, tuple of int - keep_size: bool - mode: Mode - name: str + Parameters + ---------- + size : int, tuple of int + keep_size : bool + mode : Mode + name : str """ def __init__( @@ -72,11 +73,12 @@ def reset_state(self, batch_or_mode=None, **kwargs): class OutputGroup(NeuDyn): """Output neuron group for place holder. - Args: - size: int, tuple of int - keep_size: bool - mode: Mode - name: str + Parameters + ---------- + size : int, tuple of int + keep_size : bool + mode : Mode + name : str """ def __init__( @@ -119,8 +121,8 @@ class SpikeTimeGroup(NeuDyn): >>> # at 30 ms, neuron 1 fires. >>> SpikeTimeGroup(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1]) - Parameters:: - + Parameters + ---------- size : int, tuple, list The neuron group geometry. indices : list, tuple, ArrayType diff --git a/brainpy/dyn/others/noise.py b/brainpy/dyn/others/noise.py index cf30f23a8..ae705dc82 100644 --- a/brainpy/dyn/others/noise.py +++ b/brainpy/dyn/others/noise.py @@ -41,19 +41,20 @@ class OUProcess(NeuDyn): where :math:`\theta >0` and :math:`\sigma >0` are parameters and :math:`W_{t}` denotes the Wiener process. - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The model size. - mean: Parameter + mean : Parameter The noise mean value. - sigma: Parameter + sigma : Parameter The noise amplitude. - tau: Parameter + tau : Parameter The decay time constant. - method: str + method : str The numerical integration method for stochastic differential equation. - name: str + name : str The model name. """ diff --git a/brainpy/dyn/outs/outputs.py b/brainpy/dyn/outs/outputs.py index 7457b4dbc..f15bbac3d 100644 --- a/brainpy/dyn/outs/outputs.py +++ b/brainpy/dyn/outs/outputs.py @@ -36,18 +36,20 @@ class COBA(SynOut): I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) - Parameters:: + Parameters + ---------- - E: float, ArrayType, ndarray + E : float, ArrayType, ndarray The reversal potential. - sharding: sequence of str + sharding : sequence of str The axis names for variable for parallelization. - name: str + name : str The model name. - scaling: brainpy.Scaling + scaling : brainpy.Scaling The scaling object. - See Also:: + See Also + -------- CUBA """ @@ -77,14 +79,16 @@ class CUBA(SynOut): I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) - Parameters:: + Parameters + ---------- - name: str + name : str The model name. - scaling: brainpy.Scaling + scaling : brainpy.Scaling The scaling object. - See Also:: + See Also + -------- COBA """ @@ -117,19 +121,20 @@ class MgBlock(SynOut): Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. - Parameters:: + Parameters + ---------- - E: float, ArrayType + E : float, ArrayType The reversal potential for the synaptic current. [mV] - alpha: float, ArrayType + alpha : float, ArrayType Binding constant. Default 0.062 - beta: float, ArrayType + beta : float, ArrayType Unbinding constant. Default 3.57 - cc_Mg: float, ArrayType + cc_Mg : float, ArrayType Concentration of Magnesium ion. Default 1.2 [mM]. - sharding: sequence of str + sharding : sequence of str The axis names for variable for parallelization. - name: str + name : str The model name. """ diff --git a/brainpy/dyn/projections/align_post.py b/brainpy/dyn/projections/align_post.py index 2f0248d3b..45a457eb7 100644 --- a/brainpy/dyn/projections/align_post.py +++ b/brainpy/dyn/projections/align_post.py @@ -112,14 +112,22 @@ def update(self, input): spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) - Args: - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - out_label: str. The prefix of the output function. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + comm : DynamicalSystem + The synaptic communication. + syn : ParamDescriber[JointType[DynamicalSystem, AlignPost]] + The synaptic dynamics. + out : ParamDescriber[JointType[DynamicalSystem, BindCondData]] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + out_label : str + The prefix of the output function. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( @@ -234,15 +242,24 @@ def update(self, inp): spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + pre : JointType[DynamicalSystem, SupportAutoDelay] + The pre-synaptic neuron group. + delay : Union[None, int, float] + The synaptic delay. + comm : DynamicalSystem + The synaptic communication. + syn : ParamDescriber[JointType[DynamicalSystem, AlignPost]] + The synaptic dynamics. + out : ParamDescriber[JointType[DynamicalSystem, BindCondData]] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( @@ -339,13 +356,20 @@ def update(self, input): bp.visualize.raster_plot(indices, spks, show=True) - Args: - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + comm : DynamicalSystem + The synaptic communication. + syn : JointType[DynamicalSystem, AlignPost] + The synaptic dynamics. + out : JointType[DynamicalSystem, BindCondData] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( @@ -457,15 +481,24 @@ def update(self, inp): bp.visualize.raster_plot(indices, spks, show=True) - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + pre : JointType[DynamicalSystem, SupportAutoDelay] + The pre-synaptic neuron group. + delay : Union[None, int, float] + The synaptic delay. + comm : DynamicalSystem + The synaptic communication. + syn : JointType[DynamicalSystem, AlignPost] + The synaptic dynamics. + out : JointType[DynamicalSystem, BindCondData] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( diff --git a/brainpy/dyn/projections/align_pre.py b/brainpy/dyn/projections/align_pre.py index 13a8f7d51..af665b439 100644 --- a/brainpy/dyn/projections/align_pre.py +++ b/brainpy/dyn/projections/align_pre.py @@ -151,15 +151,24 @@ def update(self, inp): bp.visualize.raster_plot(indices, spks, show=True) - Args: - pre: The pre-synaptic neuron group. - syn: The synaptic dynamics. - delay: The synaptic delay. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + pre : DynamicalSystem + The pre-synaptic neuron group. + syn : ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]] + The synaptic dynamics. + delay : Union[None, int, float] + The synaptic delay. + comm : DynamicalSystem + The synaptic communication. + out : JointType[DynamicalSystem, BindCondData] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( @@ -288,15 +297,24 @@ def update(self, inp): bp.visualize.raster_plot(indices, spks, show=True) - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - syn: The synaptic dynamics. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + pre : JointType[DynamicalSystem, SupportAutoDelay] + The pre-synaptic neuron group. + delay : Union[None, int, float] + The synaptic delay. + syn : ParamDescriber[DynamicalSystem] + The synaptic dynamics. + comm : DynamicalSystem + The synaptic communication. + out : JointType[DynamicalSystem, BindCondData] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( @@ -420,15 +438,24 @@ def update(self, inp): bp.visualize.raster_plot(indices, spks, show=True) - Args: - pre: The pre-synaptic neuron group. - syn: The synaptic dynamics. - delay: The synaptic delay. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + pre : DynamicalSystem + The pre-synaptic neuron group. + syn : JointType[DynamicalSystem, SupportAutoDelay] + The synaptic dynamics. + delay : Union[None, int, float] + The synaptic delay. + comm : DynamicalSystem + The synaptic communication. + out : JointType[DynamicalSystem, BindCondData] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( @@ -555,15 +582,24 @@ def update(self, inp): bp.visualize.raster_plot(indices, spks, show=True) - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - syn: The synaptic dynamics. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + pre : JointType[DynamicalSystem, SupportAutoDelay] + The pre-synaptic neuron group. + delay : Union[None, int, float] + The synaptic delay. + syn : DynamicalSystem + The synaptic dynamics. + comm : DynamicalSystem + The synaptic communication. + out : JointType[DynamicalSystem, BindCondData] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( diff --git a/brainpy/dyn/projections/conn.py b/brainpy/dyn/projections/conn.py index 9b26c51ad..ee2b1157c 100644 --- a/brainpy/dyn/projections/conn.py +++ b/brainpy/dyn/projections/conn.py @@ -30,7 +30,8 @@ class SynConn(Projection): """Base class to model two-end synaptic connections. - Parameters:: + Parameters + ---------- pre : NeuGroup Pre-synaptic neuron group. diff --git a/brainpy/dyn/projections/delta.py b/brainpy/dyn/projections/delta.py index fdcb5cde2..7a5d152da 100644 --- a/brainpy/dyn/projections/delta.py +++ b/brainpy/dyn/projections/delta.py @@ -84,11 +84,16 @@ def update(self): vs = bm.for_loop(net.step_run, indices, progress_bar=True) bp.visualize.line_plot(indices, vs, show=True) - Args: - comm: DynamicalSystem. The synaptic communication. - post: DynamicalSystem. The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + comm : DynamicalSystem + The synaptic communication. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( @@ -174,13 +179,20 @@ def update(self): bp.visualize.line_plot(indices, vs, show=True) - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: DynamicalSystem. The synaptic communication. - post: DynamicalSystem. The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + pre : JointType[DynamicalSystem, SupportAutoDelay] + The pre-synaptic neuron group. + delay : Union[None, int, float] + The synaptic delay. + comm : DynamicalSystem + The synaptic communication. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( diff --git a/brainpy/dyn/projections/inputs.py b/brainpy/dyn/projections/inputs.py index 39197c4fb..ebdc1b8ac 100644 --- a/brainpy/dyn/projections/inputs.py +++ b/brainpy/dyn/projections/inputs.py @@ -32,7 +32,9 @@ class InputVar(Dynamic, SupportAutoDelay): """Define an input variable. - Example:: + Examples + -------- + :: import brainpy as bp @@ -127,13 +129,20 @@ class PoissonInput(Projection): All neurons in the target variable receive independent realizations of Poisson spike trains. - Args: - target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. - num_input: The number of inputs. - freq: The frequency of each of the inputs. Must be a scalar. - weight: The synaptic weight. Must be a scalar. - name: The target name. - mode: The computing mode. + Parameters + ---------- + target_var : bm.Variable + The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. + num_input : int + The number of inputs. + freq : Union[int, float] + The frequency of each of the inputs. Must be a scalar. + weight : Union[int, float] + The synaptic weight. Must be a scalar. + name : str + The target name. + mode : Mode + The computing mode. """ def __init__( diff --git a/brainpy/dyn/projections/plasticity.py b/brainpy/dyn/projections/plasticity.py index 03fa3d509..5329a26ab 100644 --- a/brainpy/dyn/projections/plasticity.py +++ b/brainpy/dyn/projections/plasticity.py @@ -119,21 +119,36 @@ def run(i, I_pre, I_post): indices = bm.arange(0, duration, bm.dt) pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) - Args: - tau_s: float. The time constant of :math:`A_{pre}`. - tau_t: float. The time constant of :math:`A_{post}`. - A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. - A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. - W_max: float. The maximum weight. - W_min: float. The minimum weight. - pre: DynamicalSystem. The pre-synaptic neuron group. - delay: int, float. The pre spike delay length. (ms) - syn: DynamicalSystem. The synapse model. - comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. - out: DynamicalSystem. The synaptic current output models. - post: DynamicalSystem. The post-synaptic neuron group. - out_label: str. The output label. - name: str. The model name. + Parameters + ---------- + tau_s : float + The time constant of :math:`A_{pre}`. + tau_t : float + The time constant of :math:`A_{post}`. + A1 : float + The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. + A2 : float + The increment of :math:`A_{post}` produced by a spike. Must be a positive value. + W_max : float + The maximum weight. + W_min : float + The minimum weight. + pre : DynamicalSystem + The pre-synaptic neuron group. + delay : int, float + The pre spike delay length. (ms) + syn : DynamicalSystem + The synapse model. + comm : DynamicalSystem + The communication model, for example, dense or sparse connection layers. + out : DynamicalSystem + The synaptic current output models. + post : DynamicalSystem + The post-synaptic neuron group. + out_label : str + The output label. + name : str + The model name. """ def __init__( diff --git a/brainpy/dyn/projections/vanilla.py b/brainpy/dyn/projections/vanilla.py index f35ac7622..bf7abc62d 100644 --- a/brainpy/dyn/projections/vanilla.py +++ b/brainpy/dyn/projections/vanilla.py @@ -60,12 +60,18 @@ def update(self, input): bp.visualize.raster_plot(indices, spks, show=True) - Args: - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. + Parameters + ---------- + comm : DynamicalSystem + The synaptic communication. + out : JointType[DynamicalSystem, BindCondData] + The synaptic output. + post : DynamicalSystem + The post-synaptic neuron group. + name : str + The projection name. + mode : Mode + The computing mode. """ def __init__( diff --git a/brainpy/dyn/rates/nvar.py b/brainpy/dyn/rates/nvar.py index 14a3ccf89..ad437d30d 100644 --- a/brainpy/dyn/rates/nvar.py +++ b/brainpy/dyn/rates/nvar.py @@ -55,18 +55,20 @@ class NVAR(Layer): - it supports batch size, - it supports multiple orders, - Parameters:: + Parameters + ---------- - delay: int + delay : int The number of delay step. - order: int, sequence of int + order : int, sequence of int The nonlinear order. - stride: int + stride : int The stride to sample linear part vector in the delays. - constant: optional, float + constant : optional, float The constant value. - References:: + References + ---------- .. [1] Gauthier, D.J., Bollt, E., Griffith, A. et al. Next generation reservoir computing. Nat Commun 12, 5564 (2021). @@ -183,9 +185,10 @@ def update(self, x): def get_feature_names(self, for_plot=False) -> List[str]: """Get output feature names for transformation. - Parameters:: + Parameters + ---------- - for_plot: bool + for_plot : bool Use the feature names for plotting or not? (Default False) """ if for_plot: diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py index efa82a92e..06e00b932 100644 --- a/brainpy/dyn/rates/populations.py +++ b/brainpy/dyn/rates/populations.py @@ -54,25 +54,27 @@ class FHN(RateModel): \frac{dx}{dt} = -\alpha V^3 + \beta V^2 + \gamma V - w + I_{ext}\\ \tau \frac{dy}{dt} = (V - \delta - \epsilon w) - Parameters:: + Parameters + ---------- - size: Shape + size : Shape The model size. - x_ou_mean: Parameter + x_ou_mean : Parameter The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter + y_ou_mean : Parameter The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter + x_ou_sigma : Parameter The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter + y_ou_sigma : Parameter The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter + x_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter + y_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - References:: + References + ---------- .. [1] Kostova, T., Ravindran, R., & Schonbek, M. (2004). FitzHugh–Nagumo revisited: Types of bifurcations, periodical forcing and stability @@ -252,22 +254,24 @@ class FeedbackFHN(RateModel): when negative, it is a inhibitory feedback. ============= ============== ======== ======================== - Parameters:: + Parameters + ---------- - x_ou_mean: Parameter + x_ou_mean : Parameter The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter + y_ou_mean : Parameter The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter + x_ou_sigma : Parameter The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter + y_ou_sigma : Parameter The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter + x_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter + y_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - References:: + References + ---------- .. [4] Plant, Richard E. (1981). *A FitzHugh Differential-Difference Equation Modeling Recurrent Neural Feedback. SIAM Journal on @@ -448,23 +452,25 @@ class QIF(RateModel): J 15 \ the strength of the recurrent coupling inside the population ============= ============== ======== ======================== - Parameters:: + Parameters + ---------- - x_ou_mean: Parameter + x_ou_mean : Parameter The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter + y_ou_mean : Parameter The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter + x_ou_sigma : Parameter The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter + y_ou_sigma : Parameter The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter + x_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter + y_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. - References:: + References + ---------- .. [5] E. Montbrió, D. Pazó, A. Roxin (2015) Macroscopic description for networks of spiking neurons. Physical Review X, 5:021028, @@ -614,19 +620,20 @@ class StuartLandauOscillator(RateModel): \frac{dx}{dt} = (a - x^2 - y^2) * x - w*y + I^x_{ext} \\ \frac{dy}{dt} = (a - x^2 - y^2) * y + w*x + I^y_{ext} - Parameters:: + Parameters + ---------- - x_ou_mean: Parameter + x_ou_mean : Parameter The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter + y_ou_mean : Parameter The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter + x_ou_sigma : Parameter The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter + y_ou_sigma : Parameter The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter + x_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter + y_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. """ @@ -764,19 +771,20 @@ class WilsonCowanModel(RateModel): """Wilson-Cowan population model. - Parameters:: + Parameters + ---------- - x_ou_mean: Parameter + x_ou_mean : Parameter The noise mean of the :math:`x` variable, [mV/ms] - y_ou_mean: Parameter + y_ou_mean : Parameter The noise mean of the :math:`y` variable, [mV/ms]. - x_ou_sigma: Parameter + x_ou_sigma : Parameter The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. - y_ou_sigma: Parameter + y_ou_sigma : Parameter The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. - x_ou_tau: Parameter + x_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. - y_ou_tau: Parameter + y_ou_tau : Parameter The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. diff --git a/brainpy/dyn/rates/reservoir.py b/brainpy/dyn/rates/reservoir.py index 57912b0b8..1a7ef14ae 100644 --- a/brainpy/dyn/rates/reservoir.py +++ b/brainpy/dyn/rates/reservoir.py @@ -32,19 +32,20 @@ class Reservoir(Layer): r"""Reservoir node, a pool of leaky-integrator neurons with random recurrent connections [1]_. - Parameters:: + Parameters + ---------- - input_shape: int, tuple of int + input_shape : int, tuple of int The input shape. - num_out: int + num_out : int The number of reservoir nodes. - Win_initializer: Initializer + Win_initializer : Initializer The initialization method for the feedforward connections. - Wrec_initializer: Initializer + Wrec_initializer : Initializer The initialization method for the recurrent connections. - b_initializer: optional, ArrayType, Initializer + b_initializer : optional, ArrayType, Initializer The initialization method for the bias. - leaky_rate: float + leaky_rate : float A float between 0 and 1. activation : str, callable, optional Reservoir activation function. @@ -78,7 +79,7 @@ class Reservoir(Layer): Connectivity of recurrent weights matrix, i.e. ratio of reservoir neurons connected to other reservoir neurons, including themselves. Must be in [0, 1], by default 0.1 - comp_type: str + comp_type : str The connectivity type, can be "dense" or "sparse", "jit". - ``"dense"`` means the connectivity matrix is a dense matrix. @@ -94,7 +95,8 @@ class Reservoir(Layer): distribution (see :py:class:`brainpy.math.random.RandomState`), by default "normal". - References:: + References + ---------- .. [1] Lukoševičius, Mantas. "A practical guide to applying echo state networks." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 659-686. diff --git a/brainpy/dyn/rates/rnncells.py b/brainpy/dyn/rates/rnncells.py index b07c4507d..35ecc8d17 100644 --- a/brainpy/dyn/rates/rnncells.py +++ b/brainpy/dyn/rates/rnncells.py @@ -51,23 +51,23 @@ class RNNCell(Layer): The output is equal to the new state, :math:`h_t`. - Parameters:: - - num_in: int - The dimension of the input vector - num_out: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. + Parameters + ---------- + num_in : int + The dimension of the input vector + num_out : int + The number of hidden unit in the node. + state_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The state initializer. + Wi_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The input weight initializer. + Wh_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The hidden weight initializer. + b_initializer : optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray + The bias weight initializer. + activation : str, callable + The activation function. It can be a string or a callable function. + See ``brainpy.math.activations`` for more details. """ @@ -160,26 +160,26 @@ class GRUCell(Layer): Warning: Backwards compatibility of GRU weights is currently unsupported. - Parameters:: - - num_in: int - The dimension of the input vector - num_out: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. - - References:: - + Parameters + ---------- + num_in : int + The dimension of the input vector + num_out : int + The number of hidden unit in the node. + state_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The state initializer. + Wi_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The input weight initializer. + Wh_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The hidden weight initializer. + b_initializer : optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray + The bias weight initializer. + activation : str, callable + The activation function. It can be a string or a callable function. + See ``brainpy.math.activations`` for more details. + + References + ---------- .. [1] Chung, J., Gulcehre, C., Cho, K. and Bengio, Y., 2014. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555. @@ -285,33 +285,32 @@ class LSTMCell(Layer): The output is equal to the new hidden, :math:`h_t`. - Notes:: - + Parameters + ---------- + num_in : int + The dimension of the input vector + num_out : int + The number of hidden unit in the node. + state_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The state initializer. + Wi_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The input weight initializer. + Wh_initializer : callable, Initializer, bm.ndarray, jax.numpy.ndarray + The hidden weight initializer. + b_initializer : optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray + The bias weight initializer. + activation : str, callable + The activation function. It can be a string or a callable function. + See ``brainpy.math.activations`` for more details. + + Notes + ----- Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0 to :math:`b_f` after initialization in order to reduce the scale of forgetting in the beginning of the training. - - Parameters:: - - num_in: int - The dimension of the input vector - num_out: int - The number of hidden unit in the node. - state_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The state initializer. - Wi_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The input weight initializer. - Wh_initializer: callable, Initializer, bm.ndarray, jax.numpy.ndarray - The hidden weight initializer. - b_initializer: optional, callable, Initializer, bm.ndarray, jax.numpy.ndarray - The bias weight initializer. - activation: str, callable - The activation function. It can be a string or a callable function. - See ``brainpy.math.activations`` for more details. - - References:: - + References + ---------- .. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural network regularization." arXiv preprint arXiv:1409.2329 (2014). .. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical @@ -436,8 +435,9 @@ class _ConvNDLSTMCell(Layer): The output is equal to the new hidden state, :math:`h_t`. - Notes: - Forget gate initialization: + Notes + ----- + Forget gate initialization: Following :cite:`jozefowicz2015empirical` we add 1.0 to :math:`b_f` after initialization in order to reduce the scale of forgetting in the beginning of the training. @@ -470,14 +470,20 @@ def __init__( ): """Constructs a convolutional LSTM. - Args: - num_spatial_dims: Number of spatial dimensions of the input. - input_shape: Shape of the inputs excluding batch size. - out_channels: Number of output channels. - kernel_size: Sequence of kernel sizes (of length ``num_spatial_dims``), + Parameters + ---------- + num_spatial_dims : int + Number of spatial dimensions of the input. + input_shape : Tuple[int, ...] + Shape of the inputs excluding batch size. + out_channels : int + Number of output channels. + kernel_size : Union[int, Sequence[int]] + Sequence of kernel sizes (of length ``num_spatial_dims``), or an int. ``kernel_shape`` will be expanded to define a kernel size in all dimensions. - name: Name of the module. + name : Optional[str] + Name of the module. """ super().__init__(name=name, mode=mode) @@ -577,13 +583,18 @@ def __init__( Output: [Batch_Size, Output_Data_Size, Output_Channel_Size] - Args: - input_shape: Shape of the inputs excluding batch size. - out_channels: Number of output channels. - kernel_size: Sequence of kernel sizes (of length 1), or an int. + Parameters + ---------- + input_shape : Tuple[int, ...] + Shape of the inputs excluding batch size. + out_channels : int + Number of output channels. + kernel_size : Union[int, Sequence[int]] + Sequence of kernel sizes (of length 1), or an int. ``kernel_shape`` will be expanded to define a kernel size in all dimensions. - name: Name of the module. + name : Optional[str] + Name of the module. """ super().__init__( num_spatial_dims=1, @@ -638,13 +649,18 @@ def __init__( Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2 , Output_Channel_Size] - Args: - input_shape: Shape of the inputs excluding batch size. - out_channels: Number of output channels. - kernel_size: Sequence of kernel sizes (of length 2), or an int. + Parameters + ---------- + input_shape : Tuple[int, ...] + Shape of the inputs excluding batch size. + out_channels : int + Number of output channels. + kernel_size : Union[int, Sequence[int]] + Sequence of kernel sizes (of length 2), or an int. ``kernel_shape`` will be expanded to define a kernel size in all dimensions. - name: Name of the module. + name : Optional[str] + Name of the module. """ super().__init__( num_spatial_dims=2, @@ -699,13 +715,18 @@ def __init__( Output: [Batch_Size, Output_Data_Size_Dim1,Output_Data_Size_Dim2,Output_Data_Size_Dim3,Output_Channel_Size] - Args: - input_shape: Shape of the inputs excluding batch size. - out_channels: Number of output channels. - kernel_size: Sequence of kernel sizes (of length 3), or an int. + Parameters + ---------- + input_shape : Tuple[int, ...] + Shape of the inputs excluding batch size. + out_channels : int + Number of output channels. + kernel_size : Union[int, Sequence[int]] + Sequence of kernel sizes (of length 3), or an int. ``kernel_shape`` will be expanded to define a kernel size in all dimensions. - name: Name of the module. + name : Optional[str] + Name of the module. """ super().__init__( num_spatial_dims=3, diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 8f17633f6..ca8b2ffaf 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -100,9 +100,11 @@ def __init__(self, pre, post, delay, prob, g_max, tau, E): ) - Args: - tau: float. The time constant of decay. [ms] - %s + Parameters + ---------- + tau : float + The time constant of decay. [ms] + %s """ def __init__( @@ -272,17 +274,19 @@ def update(self): plt.title('Post V') plt.show() - See Also: - DualExponV2 + See Also + -------- + DualExponV2 .. note:: The implementation of this model can only be used in ``AlignPre`` projections. One the contrary, to seek the ``AlignPost`` projection, please use ``DualExponV2``. - Args: - %s - %s + Parameters + ---------- + %s + %s """ def __init__( @@ -428,12 +432,14 @@ def __init__(self, pre, post, delay, prob, g_max, tau_decay, tau_rise, E): post=post, ) - See Also: - DualExpon + See Also + -------- + DualExpon - Args: - %s - %s + Parameters + ---------- + %s + %s """ def __init__( @@ -555,9 +561,11 @@ def update(self): plt.show() - Args: - %s - tau_decay: float, ArrayType, Callable. The time constant [ms] of the synaptic decay phase. + Parameters + ---------- + %s + tau_decay : float, ArrayType, Callable + The time constant [ms] of the synaptic decay phase. """ def __init__( @@ -735,11 +743,15 @@ def update(self): England journal of medicine, 361(3), p.302. .. [4] https://en.wikipedia.org/wiki/NMDA_receptor - Args: - tau_decay: float, ArrayType, Callable. The time constant of the synaptic decay phase. Default 100 [ms] - tau_rise: float, ArrayType, Callable. The time constant of the synaptic rise phase. Default 2 [ms] - a: float, ArrayType, Callable. Default 0.5 ms^-1. - %s + Parameters + ---------- + tau_decay : float, ArrayType, Callable + The time constant of the synaptic decay phase. Default 100 [ms] + tau_rise : float, ArrayType, Callable + The time constant of the synaptic rise phase. Default 2 [ms] + a : float, ArrayType, Callable + Default 0.5 ms^-1. + %s """ def __init__( @@ -801,10 +813,13 @@ class STD(SynDyn): %s - Args: - tau: float, ArrayType, Callable. The time constant of recovery of the synaptic vesicles. - U: float, ArrayType, Callable. The fraction of resources used per action potential. - %s + Parameters + ---------- + tau : float, ArrayType, Callable + The time constant of recovery of the synaptic vesicles. + U : float, ArrayType, Callable + The fraction of resources used per action potential. + %s """ def __init__( @@ -863,11 +878,15 @@ class STP(SynDyn): %s - Args: - tau_f: float, ArrayType, Callable. The time constant of short-term facilitation. - tau_d: float, ArrayType, Callable. The time constant of short-term depression. - U: float, ArrayType, Callable. The fraction of resources used per action potential. - %s + Parameters + ---------- + tau_f : float, ArrayType, Callable + The time constant of short-term facilitation. + tau_d : float, ArrayType, Callable + The time constant of short-term depression. + U : float, ArrayType, Callable + The fraction of resources used per action potential. + %s """ def __init__( diff --git a/brainpy/dyn/synapses/bio_models.py b/brainpy/dyn/synapses/bio_models.py index 070bbc6e2..e8e83820c 100644 --- a/brainpy/dyn/synapses/bio_models.py +++ b/brainpy/dyn/synapses/bio_models.py @@ -132,13 +132,18 @@ def update(self): and implications for stimulus processing[J]. Proceedings of the National Academy of Sciences, 2012, 109(45): 18553-18558. - Args: - alpha: float, ArrayType, Callable. Binding constant. - beta: float, ArrayType, Callable. Unbinding constant. - T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by + Parameters + ---------- + alpha : float, ArrayType, Callable + Binding constant. + beta : float, ArrayType, Callable + Unbinding constant. + T : float, ArrayType, Callable + Transmitter concentration when synapse is triggered by a pre-synaptic spike.. Default 1 [mM]. - T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms] - %s + T_dur : float, ArrayType, Callable + Transmitter concentration duration time after being triggered. Default 1 [ms] + %s """ supported_modes = (bm.NonBatchingMode, bm.BatchingMode) @@ -282,14 +287,19 @@ def update(self): on the integrative properties of neocortical pyramidal neurons in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547. - Args: - alpha: float, ArrayType, Callable. Binding constant. Default 0.062 - beta: float, ArrayType, Callable. Unbinding constant. Default 3.57 - T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by + Parameters + ---------- + alpha : float, ArrayType, Callable + Binding constant. Default 0.062 + beta : float, ArrayType, Callable + Unbinding constant. Default 3.57 + T : float, ArrayType, Callable + Transmitter concentration when synapse is triggered by a pre-synaptic spike.. Default 1 [mM]. - T_dur: float, ArrayType, Callable. Transmitter concentration duration time + T_dur : float, ArrayType, Callable + Transmitter concentration duration time after being triggered. Default 1 [ms] - %s + %s """ def __init__( @@ -442,15 +452,22 @@ def update(self): .. [4] https://en.wikipedia.org/wiki/NMDA_receptor - Args: - alpha1: float, ArrayType, Callable. The conversion rate of g from inactive to active. Default 2 ms^-1. - beta1: float, ArrayType, Callable. The conversion rate of g from active to inactive. Default 0.01 ms^-1. - alpha2: float, ArrayType, Callable. The conversion rate of x from inactive to active. Default 1 ms^-1. - beta2: float, ArrayType, Callable. The conversion rate of x from active to inactive. Default 0.5 ms^-1. - T: float, ArrayType, Callable. Transmitter concentration when synapse is + Parameters + ---------- + alpha1 : float, ArrayType, Callable + The conversion rate of g from inactive to active. Default 2 ms^-1. + beta1 : float, ArrayType, Callable + The conversion rate of g from active to inactive. Default 0.01 ms^-1. + alpha2 : float, ArrayType, Callable + The conversion rate of x from inactive to active. Default 1 ms^-1. + beta2 : float, ArrayType, Callable + The conversion rate of x from active to inactive. Default 0.5 ms^-1. + T : float, ArrayType, Callable + Transmitter concentration when synapse is triggered by a pre-synaptic spike. Default 1 [mM]. - T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms] - %s + T_dur : float, ArrayType, Callable + Transmitter concentration duration time after being triggered. Default 1 [ms] + %s """ supported_modes = (bm.NonBatchingMode, bm.BatchingMode) diff --git a/brainpy/dyn/synapses/delay_couplings.py b/brainpy/dyn/synapses/delay_couplings.py index f317c1bea..f0c6b0e06 100644 --- a/brainpy/dyn/synapses/delay_couplings.py +++ b/brainpy/dyn/synapses/delay_couplings.py @@ -35,19 +35,20 @@ class DelayCoupling(Projection): """Delay coupling. - Parameters:: + Parameters + ---------- - delay_var: Variable + delay_var : Variable The delay variable. - var_to_output: Variable, sequence of Variable + var_to_output : Variable, sequence of Variable The target variables to output. - conn_mat: ArrayType + conn_mat : ArrayType The connection matrix. - required_shape: sequence of int + required_shape : sequence of int The required shape of `(pre, post)`. - delay_steps: int, ArrayType + delay_steps : int, ArrayType The matrix of delay time steps. Must be int. - initial_delay_data: Initializer, Callable + initial_delay_data : Initializer, Callable The initializer of the initial delay data. """ @@ -134,7 +135,8 @@ class DiffusiveCoupling(DelayCoupling): target_var += coupling - Examples:: + Examples + -------- >>> import brainpy as bp >>> from brainpy import rates @@ -144,21 +146,22 @@ class DiffusiveCoupling(DelayCoupling): >>> initial_delay_data=bp.init.Uniform(0, 0.05)) >>> net = bp.Network(areas, conn) - Parameters:: + Parameters + ---------- - coupling_var1: Variable + coupling_var1 : Variable The first coupling variable, used for delay. - coupling_var2: Variable + coupling_var2 : Variable Another coupling variable. - var_to_output: Variable, sequence of Variable + var_to_output : Variable, sequence of Variable The target variables to output. - conn_mat: ArrayType + conn_mat : ArrayType The connection matrix. - delay_steps: int, ArrayType + delay_steps : int, ArrayType The matrix of delay time steps. Must be int. - initial_delay_data: Initializer, Callable + initial_delay_data : Initializer, Callable The initializer of the initial delay data. - name: str + name : str The name of the model. """ @@ -239,19 +242,20 @@ class AdditiveCoupling(DelayCoupling): coupling = g * delayed_coupling_var target_var += coupling - Parameters:: + Parameters + ---------- - coupling_var: Variable + coupling_var : Variable The coupling variable, used for delay. - var_to_output: Variable, sequence of Variable + var_to_output : Variable, sequence of Variable The target variables to output. - conn_mat: ArrayType + conn_mat : ArrayType The connection matrix. - delay_steps: int, ArrayType + delay_steps : int, ArrayType The matrix of delay time steps. Must be int. - initial_delay_data: Initializer, Callable + initial_delay_data : Initializer, Callable The initializer of the initial delay data. - name: str + name : str The name of the model. """ diff --git a/brainpy/dynold/experimental/abstract_synapses.py b/brainpy/dynold/experimental/abstract_synapses.py index abdc05e9a..769fe7568 100644 --- a/brainpy/dynold/experimental/abstract_synapses.py +++ b/brainpy/dynold/experimental/abstract_synapses.py @@ -56,23 +56,25 @@ class Exponential(SynConnNS): where :math:`\mathrm{STP}` is used to model the short-term plasticity effect. - Parameters:: + Parameters + ---------- - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. - tau: float, ArrayType + tau : float, ArrayType The time constant of decay. [ms] - g_max: float, ArrayType, Initializer, Callable + g_max : float, ArrayType, Initializer, Callable The synaptic strength (the maximum conductance). Default is 1. - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. - References:: + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. @@ -198,25 +200,27 @@ class DualExponential(SynConnNS): where :math:`\mathrm{STP}` is used to model the short-term plasticity effect of synapses. - Parameters:: + Parameters + ---------- - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. - tau_decay: float, ArrayArray, ndarray + tau_decay : float, ArrayArray, ndarray The time constant of the synaptic decay phase. [ms] - tau_rise: float, ArrayArray, ndarray + tau_rise : float, ArrayArray, ndarray The time constant of the synaptic rise phase. [ms] - g_max: float, ArrayType, Initializer, Callable + g_max : float, ArrayType, Initializer, Callable The synaptic strength (the maximum conductance). Default is 1. - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. - References:: + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. @@ -362,25 +366,27 @@ class Alpha(DualExponential): >>> plt.legend() >>> plt.show() - Parameters:: + Parameters + ---------- - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable + delay_step : int, ArrayType, Initializer, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - tau_decay: float, ArrayType + tau_decay : float, ArrayType The time constant of the synaptic decay phase. [ms] - g_max: float, ArrayType, Initializer, Callable + g_max : float, ArrayType, Initializer, Callable The synaptic strength (the maximum conductance). Default is 1. - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. - References:: + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. diff --git a/brainpy/dynold/experimental/others.py b/brainpy/dynold/experimental/others.py index 461a49308..045ca8498 100644 --- a/brainpy/dynold/experimental/others.py +++ b/brainpy/dynold/experimental/others.py @@ -31,13 +31,14 @@ class PoissonInput(DynamicalSystem): All neurons in the target variable receive independent realizations of Poisson spike trains. - Parameters:: + Parameters + ---------- - num_input: int + num_input : int The number of inputs. - freq: float + freq : float The frequency of each of the inputs. Must be a scalar. - weight: float + weight : float The synaptic weight. Must be a scalar. """ diff --git a/brainpy/dynold/experimental/syn_outs.py b/brainpy/dynold/experimental/syn_outs.py index 338771336..531ec1140 100644 --- a/brainpy/dynold/experimental/syn_outs.py +++ b/brainpy/dynold/experimental/syn_outs.py @@ -35,14 +35,16 @@ class COBA(SynOutNS): I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) - Parameters:: + Parameters + ---------- - E: float, ArrayType, ndarray + E : float, ArrayType, ndarray The reversal potential. - name: str + name : str The model name. - See Also:: + See Also + -------- CUBA """ @@ -64,13 +66,15 @@ class CUBA(SynOutNS): I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) - Parameters:: + Parameters + ---------- - name: str + name : str The model name. - See Also:: + See Also + -------- COBA """ @@ -99,17 +103,18 @@ class MgBlock(SynOutNS): Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. - Parameters:: + Parameters + ---------- - E: float, ArrayType + E : float, ArrayType The reversal potential for the synaptic current. [mV] - alpha: float, ArrayType + alpha : float, ArrayType Binding constant. Default 0.062 - beta: float, ArrayType + beta : float, ArrayType Unbinding constant. Default 3.57 - cc_Mg: float, ArrayType + cc_Mg : float, ArrayType Concentration of Magnesium ion. Default 1.2 [mM]. - name: str + name : str The model name. """ diff --git a/brainpy/dynold/experimental/syn_plasticity.py b/brainpy/dynold/experimental/syn_plasticity.py index 1d31b9dbb..abb1abd91 100644 --- a/brainpy/dynold/experimental/syn_plasticity.py +++ b/brainpy/dynold/experimental/syn_plasticity.py @@ -52,14 +52,16 @@ class STD(SynSTPNS): where :math:`U` is the fraction of resources used per action potential, :math:`\tau` is the time constant of recovery of the synaptic vesicles. - Parameters:: + Parameters + ---------- - tau: float + tau : float The time constant of recovery of the synaptic vesicles. - U: float + U : float The fraction of resources used per action potential. - See Also:: + See Also + -------- STP """ @@ -127,18 +129,20 @@ class STP(SynSTPNS): variables just before the arrival of the spike, and :math:`u^+` refers to the moment just after the spike. - Parameters:: + Parameters + ---------- - tau_f: float + tau_f : float The time constant of short-term facilitation. - tau_d: float + tau_d : float The time constant of short-term depression. - U: float + U : float The fraction of resources used per action potential. - method: str + method : str The numerical integral method. - See Also:: + See Also + -------- STD """ diff --git a/brainpy/dynold/neurons/biological_models.py b/brainpy/dynold/neurons/biological_models.py index ddaca03da..cb5f3bdfe 100644 --- a/brainpy/dynold/neurons/biological_models.py +++ b/brainpy/dynold/neurons/biological_models.py @@ -160,40 +160,42 @@ class HH(hh.HH): >>> plt.yticks([]) >>> plt.show() - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - ENa: float, ArrayType, Initializer, callable + ENa : float, ArrayType, Initializer, callable The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable + gNa : float, ArrayType, Initializer, callable The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable + EK : float, ArrayType, Initializer, callable The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable + gK : float, ArrayType, Initializer, callable The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable + EL : float, ArrayType, Initializer, callable The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable + gL : float, ArrayType, Initializer, callable The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable + V_th : float, ArrayType, Initializer, callable The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable + C : float, ArrayType, Initializer, callable The membrane capacitance. Default is 1 ufarad. - V_initializer: ArrayType, Initializer, callable + V_initializer : ArrayType, Initializer, callable The initializer of membrane potential. - m_initializer: ArrayType, Initializer, callable + m_initializer : ArrayType, Initializer, callable The initializer of m channel. - h_initializer: ArrayType, Initializer, callable + h_initializer : ArrayType, Initializer, callable The initializer of h channel. - n_initializer: ArrayType, Initializer, callable + n_initializer : ArrayType, Initializer, callable The initializer of n channel. - method: str + method : str The numerical integration method. - name: str + name : str The group name. - References:: + References + ---------- .. [1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation @@ -298,7 +300,8 @@ class MorrisLecar(hh.MorrisLecar): V_th 10 mV The spike threshold. ============= ============== ======== ======================================================= - References:: + References + ---------- .. [4] Lecar, Harold. "Morris-lecar model." Scholarpedia 2.10 (2007): 1333. .. [5] http://www.scholarpedia.org/article/Morris-Lecar_model @@ -427,52 +430,54 @@ class PinskyRinzelModel(NeuDyn): Values for these parameters, and these function definitions, are taken from Traub et al, 1991. - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - gNa: float, ArrayType, Initializer, callable + gNa : float, ArrayType, Initializer, callable The maximum conductance of sodium channel. - gK: float, ArrayType, Initializer, callable + gK : float, ArrayType, Initializer, callable The maximum conductance of potassium delayed-rectifier channel. - gCa: float, ArrayType, Initializer, callable + gCa : float, ArrayType, Initializer, callable The maximum conductance of calcium channel. - gAHP: float, ArrayType, Initializer, callable + gAHP : float, ArrayType, Initializer, callable The maximum conductance of potassium after-hyper-polarization channel. - gC: float, ArrayType, Initializer, callable + gC : float, ArrayType, Initializer, callable The maximum conductance of calcium activated potassium channel. - gL: float, ArrayType, Initializer, callable + gL : float, ArrayType, Initializer, callable The conductance of leaky channel. - ENa: float, ArrayType, Initializer, callable + ENa : float, ArrayType, Initializer, callable The reversal potential of sodium channel. - EK: float, ArrayType, Initializer, callable + EK : float, ArrayType, Initializer, callable The reversal potential of potassium delayed-rectifier channel. - ECa: float, ArrayType, Initializer, callable + ECa : float, ArrayType, Initializer, callable The reversal potential of calcium channel. - EL: float, ArrayType, Initializer, callable + EL : float, ArrayType, Initializer, callable The reversal potential of leaky channel. - gc: float, ArrayType, Initializer, callable + gc : float, ArrayType, Initializer, callable The coupling strength between the soma and dendrite. - V_th: float, ArrayType, Initializer, callable + V_th : float, ArrayType, Initializer, callable The threshold of the membrane spike. - Cm: float, ArrayType, Initializer, callable + Cm : float, ArrayType, Initializer, callable The threshold of the membrane spike. - A: float, ArrayType, Initializer, callable + A : float, ArrayType, Initializer, callable The total cell membrane area, which is normalized to 1. - p: float, ArrayType, Initializer, callable + p : float, ArrayType, Initializer, callable The proportion of cell area taken up by the soma. - Vs_initializer: ArrayType, Initializer, callable + Vs_initializer : ArrayType, Initializer, callable The initializer of somatic membrane potential. - Vd_initializer: ArrayType, Initializer, callable + Vd_initializer : ArrayType, Initializer, callable The initializer of dendritic membrane potential. - Ca_initializer: ArrayType, Initializer, callable + Ca_initializer : ArrayType, Initializer, callable The initializer of Calcium concentration. - method: str + method : str The numerical integration method. - name: str + name : str The group name. - References:: + References + ---------- .. [7] Pinsky, Paul F., and John Rinzel. "Intrinsic and network rhythmogenesis in a reduced Traub model for CA3 neurons." @@ -766,40 +771,42 @@ class WangBuzsakiModel(hh.WangBuzsakiHH): :math:`E_{\mathrm{K}}=-90 \mathrm{mV}`. - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - ENa: float, ArrayType, Initializer, callable + ENa : float, ArrayType, Initializer, callable The reversal potential of sodium. Default is 50 mV. - gNa: float, ArrayType, Initializer, callable + gNa : float, ArrayType, Initializer, callable The maximum conductance of sodium channel. Default is 120 msiemens. - EK: float, ArrayType, Initializer, callable + EK : float, ArrayType, Initializer, callable The reversal potential of potassium. Default is -77 mV. - gK: float, ArrayType, Initializer, callable + gK : float, ArrayType, Initializer, callable The maximum conductance of potassium channel. Default is 36 msiemens. - EL: float, ArrayType, Initializer, callable + EL : float, ArrayType, Initializer, callable The reversal potential of learky channel. Default is -54.387 mV. - gL: float, ArrayType, Initializer, callable + gL : float, ArrayType, Initializer, callable The conductance of learky channel. Default is 0.03 msiemens. - V_th: float, ArrayType, Initializer, callable + V_th : float, ArrayType, Initializer, callable The threshold of the membrane spike. Default is 20 mV. - C: float, ArrayType, Initializer, callable + C : float, ArrayType, Initializer, callable The membrane capacitance. Default is 1 ufarad. - phi: float, ArrayType, Initializer, callable + phi : float, ArrayType, Initializer, callable The temperature regulator constant. - V_initializer: ArrayType, Initializer, callable + V_initializer : ArrayType, Initializer, callable The initializer of membrane potential. - h_initializer: ArrayType, Initializer, callable + h_initializer : ArrayType, Initializer, callable The initializer of h channel. - n_initializer: ArrayType, Initializer, callable + n_initializer : ArrayType, Initializer, callable The initializer of n channel. - method: str + method : str The numerical integration method. - name: str + name : str The group name. - References:: + References + ---------- .. [9] Wang, X.J. and Buzsaki, G., (1996) Gamma oscillation by synaptic inhibition in a hippocampal interneuronal network model. Journal of diff --git a/brainpy/dynold/neurons/fractional_models.py b/brainpy/dynold/neurons/fractional_models.py index b6a4cef53..1b9a47090 100644 --- a/brainpy/dynold/neurons/fractional_models.py +++ b/brainpy/dynold/neurons/fractional_models.py @@ -73,21 +73,24 @@ class FractionalFHR(FractionalNeuron): relatively fixed time of bursting duration. With the increasing of :math:`a`, the interburst intervals become shorter and periodic bursting changes to tonic spiking. - Examples:: + Examples + -------- - [(Mondal, et, al., 2019): Fractional-order FitzHugh-Rinzel bursting neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2019_Fractional_order_FHR_model.html) - Parameters:: + Parameters + ---------- - size: int, sequence of int + size : int, sequence of int The size of the neuron group. - alpha: float, tensor + alpha : float, tensor The fractional order. - num_memory: int + num_memory : int The total number of the short memory. - References:: + References + ---------- .. [1] Mondal, A., Sharma, S.K., Upadhyay, R.K. *et al.* Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. *Sci Rep* **9,** 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4 """ @@ -235,12 +238,14 @@ class FractionalIzhikevich(FractionalNeuron): in mV. When the spike reaches its peak value, the membrane voltage :math:`v` and the recovery variable :math:`u` are reset according to the above condition. - Examples:: + Examples + -------- - [(Teka, et. al, 2018): Fractional-order Izhikevich neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2018_Fractional_Izhikevich_model.html) - References:: + References + ---------- .. [10] Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. "Spiking and bursting patterns of fractional-order Izhikevich model." Communications diff --git a/brainpy/dynold/neurons/reduced_models.py b/brainpy/dynold/neurons/reduced_models.py index 9ce2783fa..71f229000 100644 --- a/brainpy/dynold/neurons/reduced_models.py +++ b/brainpy/dynold/neurons/reduced_models.py @@ -62,23 +62,24 @@ class LeakyIntegrator(NeuDyn): membrane potential, :math:`\tau` is the time constant, and :math:`R` is the resistance. - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - V_rest: float, ArrayType, Initializer, callable + V_rest : float, ArrayType, Initializer, callable Resting membrane potential. - R: float, ArrayType, Initializer, callable + R : float, ArrayType, Initializer, callable Membrane resistance. - tau: float, ArrayType, Initializer, callable + tau : float, ArrayType, Initializer, callable Membrane time constant. - V_initializer: ArrayType, Initializer, callable + V_initializer : ArrayType, Initializer, callable The initializer of membrane potential. - noise: ArrayType, Initializer, callable + noise : ArrayType, Initializer, callable The noise added onto the membrane potential - method: str + method : str The numerical integration method. - name: str + name : str The group name. """ @@ -176,32 +177,34 @@ class LIF(lif.LifRef): - `(Brette, Romain. 2004) LIF phase locking `_ - Parameters:: + Parameters + ---------- - size: sequence of int, int + size : sequence of int, int The size of the neuron group. - V_rest: float, ArrayType, Initializer, callable + V_rest : float, ArrayType, Initializer, callable Resting membrane potential. - V_reset: float, ArrayType, Initializer, callable + V_reset : float, ArrayType, Initializer, callable Reset potential after spike. - V_th: float, ArrayType, Initializer, callable + V_th : float, ArrayType, Initializer, callable Threshold potential of spike. - R: float, ArrayType, Initializer, callable + R : float, ArrayType, Initializer, callable Membrane resistance. - tau: float, ArrayType, Initializer, callable + tau : float, ArrayType, Initializer, callable Membrane time constant. - tau_ref: float, ArrayType, Initializer, callable + tau_ref : float, ArrayType, Initializer, callable Refractory period length.(ms) - V_initializer: ArrayType, Initializer, callable + V_initializer : ArrayType, Initializer, callable The initializer of membrane potential. - noise: ArrayType, Initializer, callable + noise : ArrayType, Initializer, callable The noise added onto the membrane potential - method: str + method : str The numerical integration method. - name: str + name : str The group name. - References:: + References + ---------- .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. @@ -1282,7 +1285,8 @@ class ALIFBellec2020(NeuDyn): a \gets a + 1 - References:: + References + ---------- .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for recurrent networks of spiking neurons." @@ -1446,7 +1450,8 @@ class LIF_SFA_Bellec2020(NeuDyn): a \gets a + 1 - References:: + References + ---------- .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for recurrent networks of spiking neurons." diff --git a/brainpy/dynold/synapses/abstract_models.py b/brainpy/dynold/synapses/abstract_models.py index 2d50700a8..55eabe776 100644 --- a/brainpy/dynold/synapses/abstract_models.py +++ b/brainpy/dynold/synapses/abstract_models.py @@ -75,22 +75,23 @@ class Delta(TwoEndConn): >>> plt.legend() >>> plt.show() - Parameters:: + Parameters + ---------- - pre: NeuDyn + pre : NeuDyn The pre-synaptic neuron group. - post: NeuDyn + post : NeuDyn The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable + delay_step : int, ArrayType, Initializer, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - g_max: float, ArrayType, Initializer, Callable + g_max : float, ArrayType, Initializer, Callable The synaptic strength. Default is 1. - post_ref_key: str + post_ref_key : str Whether the post-synaptic group has refractory period. """ @@ -218,26 +219,27 @@ class Exponential(TwoEndConn): >>> plt.legend() >>> plt.show() - Parameters:: + Parameters + ---------- - pre: NeuGroup + pre : NeuGroup The pre-synaptic neuron group. - post: NeuGroup + post : NeuGroup The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable + delay_step : int, ArrayType, Initializer, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - tau: float, ArrayType + tau : float, ArrayType The time constant of decay. [ms] - g_max: float, ArrayType, Initializer, Callable + g_max : float, ArrayType, Initializer, Callable The synaptic strength (the maximum conductance). Default is 1. - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. """ @@ -356,28 +358,29 @@ class DualExponential(_TwoEndConnAlignPre): >>> plt.legend() >>> plt.show() - Parameters:: + Parameters + ---------- - pre: NeuDyn + pre : NeuDyn The pre-synaptic neuron group. - post: NeuDyn + post : NeuDyn The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable + delay_step : int, ArrayType, Initializer, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - tau_decay: float, ArrayArray, ndarray + tau_decay : float, ArrayArray, ndarray The time constant of the synaptic decay phase. [ms] - tau_rise: float, ArrayArray, ndarray + tau_rise : float, ArrayArray, ndarray The time constant of the synaptic rise phase. [ms] - g_max: float, ArrayType, Initializer, Callable + g_max : float, ArrayType, Initializer, Callable The synaptic strength (the maximum conductance). Default is 1. - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. """ @@ -477,26 +480,27 @@ class Alpha(_TwoEndConnAlignPre): >>> plt.legend() >>> plt.show() - Parameters:: + Parameters + ---------- - pre: NeuDyn + pre : NeuDyn The pre-synaptic neuron group. - post: NeuDyn + post : NeuDyn The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `sparse`. - delay_step: int, ArrayType, Initializer, Callable + delay_step : int, ArrayType, Initializer, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - tau_decay: float, ArrayType + tau_decay : float, ArrayType The time constant of the synaptic decay phase. [ms] - g_max: float, ArrayType, Initializer, Callable + g_max : float, ArrayType, Initializer, Callable The synaptic strength (the maximum conductance). Default is 1. - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. """ @@ -639,33 +643,35 @@ class NMDA(_TwoEndConnAlignPre): >>> plt.legend() >>> plt.show() - Parameters:: + Parameters + ---------- - pre: NeuDyn + pre : NeuDyn The pre-synaptic neuron group. - post: NeuDyn + post : NeuDyn The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `dense`. - delay_step: int, ArrayType, Initializer, Callable + delay_step : int, ArrayType, Initializer, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - g_max: float, ArrayType, Initializer, Callable + g_max : float, ArrayType, Initializer, Callable The synaptic strength (the maximum conductance). Default is 1. - tau_decay: float, ArrayType + tau_decay : float, ArrayType The time constant of the synaptic decay phase. Default 100 [ms] - tau_rise: float, ArrayType + tau_rise : float, ArrayType The time constant of the synaptic rise phase. Default 2 [ms] - a: float, ArrayType + a : float, ArrayType Default 0.5 ms^-1. - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. - References:: + References + ---------- .. [1] Brunel N, Wang X J. Effects of neuromodulation in a cortical network model of object working memory dominated diff --git a/brainpy/dynold/synapses/base.py b/brainpy/dynold/synapses/base.py index eaf3ad6af..fd7b557ac 100644 --- a/brainpy/dynold/synapses/base.py +++ b/brainpy/dynold/synapses/base.py @@ -138,7 +138,8 @@ def clone(self): class TwoEndConn(SynConn): """Base class to model synaptic connections. - Parameters:: + Parameters + ---------- pre : NeuGroup Pre-synaptic neuron group. @@ -146,25 +147,25 @@ class TwoEndConn(SynConn): Post-synaptic neuron group. conn : optional, ndarray, ArrayType, dict, TwoEndConnector The connection method between pre- and post-synaptic groups. - output: Optional, SynOutput + output : Optional, SynOutput The output for the synaptic current. .. versionadded:: 2.1.13 The output component for a two-end connection model. - stp: Optional, SynSTP + stp : Optional, SynSTP The short-term plasticity model for the synaptic variables. .. versionadded:: 2.1.13 The short-term plasticity component for a two-end connection model. - ltp: Optional, SynLTP + ltp : Optional, SynLTP The long-term plasticity model for the synaptic variables. .. versionadded:: 2.1.13 The long-term plasticity component for a two-end connection model. - name: Optional, str + name : Optional, str The name of the dynamic system. """ diff --git a/brainpy/dynold/synapses/biological_models.py b/brainpy/dynold/synapses/biological_models.py index c6a3ff18d..4b1fdc9ff 100644 --- a/brainpy/dynold/synapses/biological_models.py +++ b/brainpy/dynold/synapses/biological_models.py @@ -115,36 +115,38 @@ class GABAa(AMPA): - `Gamma oscillation network model `_ - Parameters:: + Parameters + ---------- - pre: NeuDyn + pre : NeuDyn The pre-synaptic neuron group. - post: NeuDyn + post : NeuDyn The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `dense`. - delay_step: int, ArrayType, Callable + delay_step : int, ArrayType, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - g_max: float, ArrayType, Callable + g_max : float, ArrayType, Callable The synaptic strength (the maximum conductance). Default is 1. - alpha: float, ArrayType + alpha : float, ArrayType Binding constant. Default 0.062 - beta: float, ArrayType + beta : float, ArrayType Unbinding constant. Default 3.57 - T: float, ArrayType + T : float, ArrayType Transmitter concentration when synapse is triggered by a pre-synaptic spike.. Default 1 [mM]. - T_duration: float, ArrayType + T_duration : float, ArrayType Transmitter concentration duration time after being triggered. Default 1 [ms] - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. - References:: + References + ---------- .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity on the integrative properties of neocortical pyramidal neurons @@ -266,35 +268,37 @@ class BioNMDA(_TwoEndConnAlignPre): >>> plt.legend() >>> plt.show() - Parameters:: + Parameters + ---------- - pre: NeuDyn + pre : NeuDyn The pre-synaptic neuron group. - post: NeuDyn + post : NeuDyn The post-synaptic neuron group. - conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + conn : optional, ArrayType, dict of (str, ndarray), TwoEndConnector The synaptic connections. - comp_method: str + comp_method : str The connection type used for model speed optimization. It can be `sparse` and `dense`. The default is `dense`. - delay_step: int, ArrayType, Callable + delay_step : int, ArrayType, Callable The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. - g_max: float, ArrayType, Callable + g_max : float, ArrayType, Callable The synaptic strength (the maximum conductance). Default is 1. - alpha1: float, ArrayType + alpha1 : float, ArrayType The conversion rate of g from inactive to active. Default 2 ms^-1. - beta1: float, ArrayType + beta1 : float, ArrayType The conversion rate of g from active to inactive. Default 0.01 ms^-1. - alpha2: float, ArrayType + alpha2 : float, ArrayType The conversion rate of x from inactive to active. Default 1 ms^-1. - beta2: float, ArrayType + beta2 : float, ArrayType The conversion rate of x from active to inactive. Default 0.5 ms^-1. - name: str + name : str The name of this synaptic projection. - method: str + method : str The numerical integration methods. - References:: + References + ---------- .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M]. Springer New York, 2010: 162. diff --git a/brainpy/dynold/synouts/conductances.py b/brainpy/dynold/synouts/conductances.py index 84f2b94d8..71685d5df 100644 --- a/brainpy/dynold/synouts/conductances.py +++ b/brainpy/dynold/synouts/conductances.py @@ -35,13 +35,15 @@ class CUBA(_SynOut): I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) - Parameters:: + Parameters + ---------- - name: str + name : str The model name. - See Also:: + See Also + -------- COBA """ @@ -67,14 +69,16 @@ class COBA(_SynOut): I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) - Parameters:: + Parameters + ---------- - E: float, ArrayType, ndarray, callable, Initializer + E : float, ArrayType, ndarray, callable, Initializer The reversal potential. - name: str + name : str The model name. - See Also:: + See Also + -------- CUBA """ diff --git a/brainpy/dynold/synouts/ions.py b/brainpy/dynold/synouts/ions.py index 29b8bdedf..3b83597b2 100644 --- a/brainpy/dynold/synouts/ions.py +++ b/brainpy/dynold/synouts/ions.py @@ -44,17 +44,18 @@ class MgBlock(_SynOut): Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. - Parameters:: + Parameters + ---------- - E: float, ArrayType, callable, Initializer + E : float, ArrayType, callable, Initializer The reversal potential for the synaptic current. [mV] - alpha: float, ArrayType + alpha : float, ArrayType Binding constant. Default 0.062 - beta: float, ArrayType, callable, Initializer + beta : float, ArrayType, callable, Initializer Unbinding constant. Default 3.57 - cc_Mg: float, ArrayType, callable, Initializer + cc_Mg : float, ArrayType, callable, Initializer Concentration of Magnesium ion. Default 1.2 [mM]. - name: str + name : str The model name. """ diff --git a/brainpy/dynold/synplast/short_term_plasticity.py b/brainpy/dynold/synplast/short_term_plasticity.py index 61c0dba70..b61d6fbfc 100644 --- a/brainpy/dynold/synplast/short_term_plasticity.py +++ b/brainpy/dynold/synplast/short_term_plasticity.py @@ -52,14 +52,16 @@ class STD(_SynSTP): where :math:`U` is the fraction of resources used per action potential, :math:`\tau` is the time constant of recovery of the synaptic vesicles. - Parameters:: + Parameters + ---------- - tau: float + tau : float The time constant of recovery of the synaptic vesicles. - U: float + U : float The fraction of resources used per action potential. - See Also:: + See Also + -------- STP """ @@ -137,18 +139,20 @@ class STP(_SynSTP): variables just before the arrival of the spike, and :math:`u^+` refers to the moment just after the spike. - Parameters:: + Parameters + ---------- - tau_f: float + tau_f : float The time constant of short-term facilitation. - tau_d: float + tau_d : float The time constant of short-term depression. - U: float + U : float The fraction of resources used per action potential. - method: str + method : str The numerical integral method. - See Also:: + See Also + -------- STD """ diff --git a/brainpy/dynsys.py b/brainpy/dynsys.py index a27e98922..abe411083 100644 --- a/brainpy/dynsys.py +++ b/brainpy/dynsys.py @@ -60,14 +60,21 @@ def register_delay( ): """Register delay variable. - Args: - identifier: str. The delay access name. - delay_target: The target variable for delay. - delay_step: The delay time step. - initial_delay_data: The initializer for the delay data. - - Returns: - delay_pos: The position of the delay. + Parameters + ---------- + identifier : str + The delay access name. + delay_target : bm.Variable + The target variable for delay. + delay_step : Optional[Union[int, ArrayType, Callable]] + The delay time step. + initial_delay_data : Union[Callable, ArrayType, numbers.Number] + The initializer for the delay data. + + Returns + ------- + delay_pos + The position of the delay. """ _delay_identifier, _init_delay_by_return = _get_delay_tool() assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' @@ -87,19 +94,19 @@ def get_delay_data( ): """Get delay data according to the provided delay steps. - Parameters:: - - identifier: str - The delay variable name. - delay_pos: str - The delay length. - indices: optional, int, slice, ArrayType - The indices of the delay. - - Returns:: - - delay_data: ArrayType - The delay data at the given time. + Parameters + ---------- + identifier : str + The delay variable name. + delay_pos : str + The delay length. + indices : optional, int, slice, ArrayType + The indices of the delay. + + Returns + ------- + delay_data : ArrayType + The delay data at the given time. """ _delay_identifier, _init_delay_by_return = _get_delay_tool() _delay_identifier = _delay_identifier + identifier @@ -113,10 +120,10 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None): For example, in a network model, - Parameters:: - - nodes: sequence, dict - The nodes to update their delay variables. + Parameters + ---------- + nodes : sequence, dict + The nodes to update their delay variables. """ warnings.warn('.update_local_delays() has been removed since brainpy>=2.4.6', DeprecationWarning) @@ -124,10 +131,10 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None): def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): """Reset local delay variables. - Parameters:: - - nodes: sequence, dict - The nodes to Reset their delay variables. + Parameters + ---------- + nodes : sequence, dict + The nodes to Reset their delay variables. """ warnings.warn('.reset_local_delays() has been removed since brainpy>=2.4.6', DeprecationWarning) @@ -167,12 +174,12 @@ class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj): - ``.update_local_delays()`` - ``.reset_local_delays()`` - Parameters:: - + Parameters + ---------- name : optional, str - The name of the dynamical system. - mode: optional, Mode - The model computation mode. It should be an instance of :py:class:`~.Mode`. + The name of the dynamical system. + mode : optional, Mode + The model computation mode. It should be an instance of :py:class:`~.Mode`. """ supported_modes: Optional[Sequence[bm.Mode]] = None @@ -303,13 +310,19 @@ def step_run(self, i, *args, **kwargs): This function can be directly applied to run the dynamical system. Particularly, ``i`` denotes the running index. - Args: - i: The current running index. - *args: The arguments of ``update()`` function. - **kwargs: The arguments of ``update()`` function. - - Returns: - out: The update function returns. + Parameters + ---------- + i + The current running index. + *args + The arguments of ``update()`` function. + **kwargs + The arguments of ``update()`` function. + + Returns + ------- + out + The update function returns. """ global clear_input if clear_input is None: @@ -323,13 +336,19 @@ def step_run(self, i, *args, **kwargs): def jit_step_run(self, i, *args, **kwargs): """The jitted step function for running. - Args: - i: The current running index. - *args: The arguments of ``update()`` function. - **kwargs: The arguments of ``update()`` function. - - Returns: - out: The update function returns. + Parameters + ---------- + i + The current running index. + *args + The arguments of ``update()`` function. + **kwargs + The arguments of ``update()`` function. + + Returns + ------- + out + The update function returns. """ return self.step_run(i, *args, **kwargs) @@ -354,11 +373,16 @@ def register_local_delay( ): """Register local relay at the given delay time. - Args: - var_name: str. The name of the delay target variable. - delay_name: str. The name of the current delay data. - delay_time: The delay time. Float. - delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``. + Parameters + ---------- + var_name : str + The name of the delay target variable. + delay_name : str + The name of the current delay data. + delay_time + The delay time. Float. + delay_step + The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``. """ delay_identifier, init_delay_by_return = _get_delay_tool() delay_identifier = delay_identifier + var_name @@ -379,12 +403,16 @@ def register_local_delay( def get_local_delay(self, var_name, delay_name): """Get the delay at the given identifier (`name`). - Args: - var_name: The name of the target delay variable. - delay_name: The identifier of the delay. + Parameters + ---------- + var_name + The name of the target delay variable. + delay_name + The identifier of the delay. - Returns: - The delayed data at the given delay position. + Returns + ------- + The delayed data at the given delay position. """ delay_identifier, init_delay_by_return = _get_delay_tool() delay_identifier = delay_identifier + var_name @@ -536,12 +564,15 @@ def __call__(self, *args, **kwargs): def __rrshift__(self, other): """Support using right shift operator to call modules. - Examples:: + Examples + -------- - >>> import brainpy as bp - >>> x = bp.math.random.rand((10, 10)) - >>> l = bp.layers.Activation(bm.tanh) - >>> y = x >> l + .. code-block:: python + + >>> import brainpy as bp + >>> x = bp.math.random.rand((10, 10)) + >>> l = bp.layers.Activation(bm.tanh) + >>> y = x >> l """ return self.__call__(other) @@ -549,12 +580,18 @@ def __rrshift__(self, other): class DynSysGroup(DynamicalSystem, Container): """A group of :py:class:`~.DynamicalSystem`s in which the updating order does not matter. - Args: - children_as_tuple: The children objects. - children_as_dict: The children objects. - name: The object name. - mode: The mode which controls the model computation. - child_type: The type of the children object. Default is :py:class:`DynamicalSystem`. + Parameters + ---------- + children_as_tuple + The children objects. + children_as_dict + The children objects. + name + The object name. + mode + The mode which controls the model computation. + child_type + The type of the children object. Default is :py:class:`DynamicalSystem`. """ def __init__( @@ -619,30 +656,37 @@ class Sequential(DynamicalSystem, SupportAutoDelay, Container): On the other hand, the layers in a ``Sequential`` are connected in a cascading way. - Examples:: + Parameters + ---------- + modules_as_tuple + The children modules. + modules_as_dict + The children modules. + name + The object name. + mode + The object computing context/mode. Default is ``None``. - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> - >>> # composing ANN models - >>> l = bp.Sequential(bp.layers.Dense(100, 10), - >>> bm.relu, - >>> bp.layers.Dense(10, 2)) - >>> l(bm.random.random((256, 100))) - >>> - >>> # Using Sequential with Dict. This is functionally the - >>> # same as the above code - >>> l = bp.Sequential(l1=bp.layers.Dense(100, 10), - >>> l2=bm.relu, - >>> l3=bp.layers.Dense(10, 2)) - >>> l(bm.random.random((256, 100))) - - - Args: - modules_as_tuple: The children modules. - modules_as_dict: The children modules. - name: The object name. - mode: The object computing context/mode. Default is ``None``. + Examples + -------- + + .. code-block:: python + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> + >>> # composing ANN models + >>> l = bp.Sequential(bp.layers.Dense(100, 10), + >>> bm.relu, + >>> bp.layers.Dense(10, 2)) + >>> l(bm.random.random((256, 100))) + >>> + >>> # Using Sequential with Dict. This is functionally the + >>> # same as the above code + >>> l = bp.Sequential(l1=bp.layers.Dense(100, 10), + >>> l2=bm.relu, + >>> l3=bp.layers.Dense(10, 2)) + >>> l(bm.random.random((256, 100))) """ def __init__( @@ -696,9 +740,12 @@ def __repr__(self): class Projection(DynamicalSystem): """Base class to model synaptic projections. - Args: - name: The name of the dynamic system. - mode: The computing mode. It should be an instance of :py:class:`~.Mode`. + Parameters + ---------- + name + The name of the dynamic system. + mode + The computing mode. It should be an instance of :py:class:`~.Mode`. """ def update(self, *args, **kwargs): @@ -728,11 +775,16 @@ class Dynamic(DynamicalSystem): - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \ `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`. - Args: - size: The neuron group geometry. - name: The name of the dynamic system. - keep_size: Whether keep the geometry information. - mode: The computing mode. + Parameters + ---------- + size + The neuron group geometry. + name + The name of the dynamic system. + keep_size + Whether keep the geometry information. + mode + The computing mode. """ def __init__( diff --git a/brainpy/encoding/stateful_encoding.py b/brainpy/encoding/stateful_encoding.py index 4d3bebb4a..4cd4b0c14 100644 --- a/brainpy/encoding/stateful_encoding.py +++ b/brainpy/encoding/stateful_encoding.py @@ -36,19 +36,19 @@ class WeightedPhaseEncoder(Encoder): more information into the spikes. This is the major difference from a conventional rate coding scheme that assigns the same weight to every spike [1]_. - Parameters:: - - min_val: float - The minimal value in the given data `x`, used to the data normalization. - max_val: float - The maximum value in the given data `x`, used to the data normalization. - num_phase: int - The number of the encoding period. - weight_fun: Callable - The function to generate weight at the phase :math:`i`. - - References:: - + Parameters + ---------- + min_val : float + The minimal value in the given data `x`, used to the data normalization. + max_val : float + The maximum value in the given data `x`, used to the data normalization. + num_phase : int + The number of the encoding period. + weight_fun : Callable + The function to generate weight at the phase :math:`i`. + + References + ---------- .. [1] Kim, Jaehyun et al. “Deep neural networks with weighted spikes.” Neurocomputing 311 (2018): 373-386. """ @@ -69,17 +69,17 @@ def __init__(self, def __call__(self, x: ArrayType, num_step: int): """Encoding function. - Parameters:: - - x: ArrayType - The input rate value. - num_step: int - The number of time steps. - - Returns:: - - out: ArrayType - The encoded spike train. + Parameters + ---------- + x : ArrayType + The input rate value. + num_step : int + The number of time steps. + + Returns + ------- + out : ArrayType + The encoded spike train. """ # normalize all input signals to fit into the range [1, 1-2^K] x = (x - self.min_val) * self.scale @@ -108,38 +108,48 @@ class LatencyEncoder(Encoder): A larger ``x`` will cause the earlier firing time. - Example:: - - >>> a = bm.array([0.02, 0.5, 1]) - >>> encoder = LatencyEncoder(method='linear', normalize=True) - >>> encoder.multi_steps(a, n_time=5) - Array([[0., 0., 1.], - [0., 0., 0.], - [0., 1., 0.], - [0., 0., 0.], - [1., 0., 0.]]) - - - Args: - min_val: float. The minimal value in the given data `x`, used to the data normalization. - max_val: float. The maximum value in the given data `x`, used to the data normalization. - method: str. How to convert intensity to firing time. Currently, we support `linear` or `log`. + Parameters + ---------- + min_val : float + The minimal value in the given data `x`, used to the data normalization. + max_val : float + The maximum value in the given data `x`, used to the data normalization. + method : str + How to convert intensity to firing time. Currently, we support `linear` or `log`. - If ``method='linear'``, the firing rate is calculated as :math:`t_f(x) = (\text{num_period} - 1)(1 - x)`. - If ``method='log'``, the firing rate is calculated as :math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`, where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`. - threshold: float. Input features below the threhold will fire at the + threshold : float + Input features below the threhold will fire at the final time step unless ``clip=True`` in which case they will not fire at all, defaults to ``0.01``. - clip: bool. Option to remove spikes from features that fall - below the threshold, defaults to ``False``. - tau: float. RC Time constant for LIF model used to calculate + clip : bool + Option to remove spikes from features that fall + below the threshold, defaults to ``False``. + tau : float + RC Time constant for LIF model used to calculate firing time, defaults to ``1``. - normalize: bool. Option to normalize the latency code such that + normalize : bool + Option to normalize the latency code such that the final spike(s) occur within num_steps, defaults to ``False``. - epsilon: float. A tiny positive value to avoid rounding errors when + epsilon : float + A tiny positive value to avoid rounding errors when using torch.arange, defaults to ``1e-7``. + + Examples + -------- + .. code-block:: python + + >>> a = bm.array([0.02, 0.5, 1]) + >>> encoder = LatencyEncoder(method='linear', normalize=True) + >>> encoder.multi_steps(a, n_time=5) + Array([[0., 0., 1.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [1., 0., 0.]]) """ def __init__( @@ -179,12 +189,17 @@ def multi_steps(self, data, n_time: Optional[float] = None): Ensuring x in [0., 1.]. - Args: - data: The rate-based input. - n_time: float. The total time to generate data. If None, use ``tau`` instead. - - Returns: - out: array. The output spiking trains. + Parameters + ---------- + data + The rate-based input. + n_time : float + The total time to generate data. If None, use ``tau`` instead. + + Returns + ------- + out : array + The output spiking trains. """ if n_time is None: n_time = self.tau diff --git a/brainpy/encoding/stateless_encoding.py b/brainpy/encoding/stateless_encoding.py index d358b2c52..54de1b383 100644 --- a/brainpy/encoding/stateless_encoding.py +++ b/brainpy/encoding/stateless_encoding.py @@ -36,7 +36,22 @@ class PoissonEncoder(Encoder): spikes whose firing probability is :math:`x_{\text{normalize}}`. - Examples:: + Parameters + ---------- + min_val : float + The minimal value in the given data `x`, used to the data normalization. + max_val : float + The maximum value in the given data `x`, used to the data normalization. + gain : float + Scale input features by the gain, defaults to ``1``. + offset : float + Shift input features by the offset, defaults to ``0``. + first_spk_time : float + The time to first spike, defaults to ``0``. + + Examples + -------- + .. code-block:: python import brainpy as bp import brainpy.math as bm @@ -51,14 +66,6 @@ class PoissonEncoder(Encoder): # or, encode the image at multiple times once spikes = encoder.multi_steps(img, n_time=10.) - - - Args: - min_val: float. The minimal value in the given data `x`, used to the data normalization. - max_val: float. The maximum value in the given data `x`, used to the data normalization. - gain: float. Scale input features by the gain, defaults to ``1``. - offset: float. Shift input features by the offset, defaults to ``0``. - first_spk_time: float. The time to first spike, defaults to ``0``. """ def __init__( @@ -81,12 +88,17 @@ def __init__( def single_step(self, x, i_step: int = None): """Generate spikes at the single step according to the inputs. - Args: - x: Array. The rate input. - i_step: int. The time step to generate spikes. - - Returns: - out: Array. The encoded spike train. + Parameters + ---------- + x : Array + The rate input. + i_step : int + The time step to generate spikes. + + Returns + ------- + out : Array + The encoded spike train. """ # Draw a single Bernoulli sample for one step. (Delegating to # ``multi_steps`` with ``n_time=None`` would crash on ``int(None / dt)``, @@ -108,17 +120,22 @@ def _normalize(self, x): def multi_steps(self, x, n_time: Optional[float]): """Generate spikes at multiple steps according to the inputs. - Args: - x: Array. The rate input. - n_time: float. Encode rate values as spike trains in the given time length. + Parameters + ---------- + x : Array + The rate input. + n_time : float + Encode rate values as spike trains in the given time length. ``n_time`` is converted into the ``n_step`` according to `n_step = int(n_time / brainpy.math.dt)`. - If ``n_time=None``, encode the rate values at the current time step. Users should repeatedly call it to encode `x` as a spike train. - Else, given the ``x`` with shape ``(S, ...)``, the encoded spike train is the array with shape ``(n_step, S, ...)``. - Returns: - out: Array. The encoded spike train. + Returns + ------- + out : Array + The encoded spike train. """ # ``n_time=None`` means "encode the current single step" (see docstring); # only convert to a step count when an actual duration is given. @@ -144,7 +161,24 @@ class DiffEncoder(Encoder): Optionally include `off_spikes` for negative changes. - Example:: + Parameters + ---------- + threshold : float + Input features with a change greater than the thresold + across one timestep will generate a spike, defaults to ``0.1``. + padding : bool + Used to change how the first time step of spikes are + measured. If ``True``, the first time step will be repeated with itself + resulting in ``0``'s for the output spikes. + If ``False``, the first time step will be padded with ``0``'s, defaults + to ``False``. + off_spike : bool + If ``True``, negative spikes for changes less than + ``-threshold``, defaults to ``False``. + + Examples + -------- + .. code-block:: python >>> a = bm.array([1, 2, 2.9, 3, 3.9]) >>> encoder = DiffEncoder(threshold=1) @@ -163,17 +197,6 @@ class DiffEncoder(Encoder): >>> encoder = DiffEncoder(threshold=1, padding=True, off_spike=True) >>> encoder.multi_steps(b) Array([ 0., 1., -1., 1., 0.]) - - Args: - threshold: float. Input features with a change greater than the thresold - across one timestep will generate a spike, defaults to ``0.1``. - padding: bool. Used to change how the first time step of spikes are - measured. If ``True``, the first time step will be repeated with itself - resulting in ``0``'s for the output spikes. - If ``False``, the first time step will be padded with ``0``'s, defaults - to ``False``. - off_spike: bool. If ``True``, negative spikes for changes less than - ``-threshold``, defaults to ``False``. """ def __init__( @@ -194,11 +217,15 @@ def single_step(self, *args, **kwargs): def multi_steps(self, x): """Encoding multistep inputs with the spiking trains. - Args: - x: Array. The array with the shape of `(num_step, ....)`. + Parameters + ---------- + x : Array + The array with the shape of `(num_step, ....)`. - Returns: - out: Array. The spike train. + Returns + ------- + out : Array + The spike train. """ if self.padding: diff = bm.diff(x, axis=0, prepend=x[:1]) diff --git a/brainpy/helpers.py b/brainpy/helpers.py index 5a7f7c623..7c6a35d50 100644 --- a/brainpy/helpers.py +++ b/brainpy/helpers.py @@ -60,8 +60,10 @@ def reset_state(target: DynamicalSystem, *args, **kwargs): See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. - Args: - target: The target DynamicalSystem. + Parameters + ---------- + target + The target DynamicalSystem. """ dynsys.the_top_layer_reset_state = False @@ -89,8 +91,10 @@ def reset_state(target: DynamicalSystem, *args, **kwargs): def clear_input(target: DynamicalSystem, *args, **kwargs): """Clear all inputs in the given target. - Args: - target:The target DynamicalSystem. + Parameters + ---------- + target + The target DynamicalSystem. """ for node in target.nodes().subset(DynamicalSystem).not_subset(DynView).unique().values(): @@ -101,16 +105,19 @@ def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs): """Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. - Args: - target: DynamicalSystem. The dynamical system to load its states. - state_dict: dict. A dict containing parameters and persistent buffers. + Parameters + ---------- + target : DynamicalSystem + The dynamical system to load its states. + state_dict : dict + A dict containing parameters and persistent buffers. - Returns: + Returns ------- - ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: - * **missing_keys** is a list of str containing the missing keys - * **unexpected_keys** is a list of str containing the unexpected keys + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys """ nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() missing_keys = [] @@ -130,11 +137,14 @@ def load_state(target: DynamicalSystem, state_dict: Dict, **kwargs): def save_state(target: DynamicalSystem, **kwargs) -> Dict: """Save all states in the ``target`` as a dictionary for later disk serialization. - Args: - target: DynamicalSystem. The node to save its states. + Parameters + ---------- + target : DynamicalSystem + The node to save its states. - Returns: - Dict. The state dict for serialization. + Returns + ------- + Dict. The state dict for serialization. """ nodes = target.nodes().subset(DynamicalSystem).not_subset(DynView).unique() # retrieve all nodes return {key: node.save_state(**kwargs) for key, node in nodes.items()} diff --git a/brainpy/initialize/decay_inits.py b/brainpy/initialize/decay_inits.py index 2bd26407d..3762f294e 100644 --- a/brainpy/initialize/decay_inits.py +++ b/brainpy/initialize/decay_inits.py @@ -57,7 +57,8 @@ class GaussianDecay(_IntraLayerInitializer): where :math:`v_k^i` is the $i$-th neuron's encoded value at dimension $k$. - Parameters:: + Parameters + ---------- sigma : float Width of the Gaussian function. @@ -98,7 +99,8 @@ def __init__(self, sigma, max_w, min_w=None, encoding_values=None, def __call__(self, shape, dtype=None): """Build the weights. - Parameters:: + Parameters + ---------- shape : tuple of int, list of int, int The network shape. Note, this is not the weight shape. @@ -230,7 +232,8 @@ class DOGDecay(_IntraLayerInitializer): where weights smaller than :math:`0.005 * max(w_{max}, w_{min})` are not created and self-connections are avoided by default (parameter allow_self_connections). - Parameters:: + Parameters + ---------- sigmas : tuple Widths of the positive and negative Gaussian functions. @@ -271,7 +274,8 @@ def __init__(self, sigmas, max_ws, min_w=None, encoding_values=None, def __call__(self, shape, dtype=None): """Build the weights. - Parameters:: + Parameters + ---------- shape : tuple of int, list of int, int The network shape. Note, this is not the weight shape. diff --git a/brainpy/initialize/generic.py b/brainpy/initialize/generic.py index e73182c19..6c8919e67 100644 --- a/brainpy/initialize/generic.py +++ b/brainpy/initialize/generic.py @@ -56,29 +56,32 @@ def parameter( ): """Initialize parameters. - Parameters:: + Parameters + ---------- - param: callable, Initializer, bm.ndarray, jnp.ndarray, onp.ndarray, float, int, bool + param : callable, Initializer, bm.ndarray, jnp.ndarray, onp.ndarray, float, int, bool The initialization of the parameter. - If it is None, the created parameter will be None. - If it is a callable function :math:`f`, the ``f(size)`` will be returned. - If it is an instance of :py:class:`brainpy.init.Initializer``, the ``f(size)`` will be returned. - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``. - sizes: int, sequence of int + sizes : int, sequence of int The shape of the parameter. - allow_none: bool + allow_none : bool Whether allow the parameter is None. - allow_scalar: bool + allow_scalar : bool Whether allow the parameter is a scalar value. - sharding: Sharding + sharding : Sharding The axes for automatic array sharding. - Returns:: + Returns + ------- - param: ArrayType, float, int, bool, None + param : ArrayType, float, int, bool, None The initialized parameter. - See Also:: + See Also + -------- variable_, noise, delay """ @@ -125,7 +128,8 @@ def variable_( ): """Initialize a :math:`~.Variable` from a callable function or a data. - See Also:: + See Also + -------- variable @@ -148,30 +152,33 @@ def variable( ): """Initialize variables. - Parameters:: + Parameters + ---------- - init: callable, ArrayType + init : callable, ArrayType The data to be initialized as a ``Variable``. - batch_or_mode: int, bool, Mode, optional + batch_or_mode : int, bool, Mode, optional The batch size, mode ``Mode``, boolean state. This is used to specify the batch size of this variable. If it is a boolean or an instance of ``Mode``, the batch size will be 1. If it is None, the variable has no batch axis. - sizes: Shape + sizes : Shape The shape of the variable. - batch_axis: int + batch_axis : int The batch axis. - axis_names: sequence of str + axis_names : sequence of str The name for each axis. These names should match the given ``axes``. - batch_axis_name: str + batch_axis_name : str The name for the batch axis. The name will be used if ``batch_size_or_mode`` is given. - Returns:: + Returns + ------- - variable: bm.Variable + variable : bm.Variable The target ``Variable`` instance. - See Also:: + See Also + -------- variable_, parameter, noise, delay @@ -233,22 +240,25 @@ def noise( ) -> Optional[Callable]: """Initialize a noise function. - Parameters:: + Parameters + ---------- - noises: Any - size: Shape + noises : Any + size : Shape The size of the noise. - num_vars: int + num_vars : int The number of variables. - noise_idx: int + noise_idx : int The index of the current noise among all noise variables. - Returns:: + Returns + ------- - noise_func: function, None + noise_func : function, None The noise function. - See Also:: + See Also + -------- variable_, parameter, delay @@ -273,21 +283,24 @@ def delay( ): """Initialize delay variable. - Parameters:: + Parameters + ---------- - delay_step: int, ndarray, ArrayType + delay_step : int, ndarray, ArrayType The number of delay steps. It can an integer of an array of integers. - delay_target: ndarray, ArrayType + delay_target : ndarray, ArrayType The target variable to delay. - delay_data: optional, ndarray, ArrayType + delay_data : optional, ndarray, ArrayType The initial delay data. - Returns:: + Returns + ------- - info: tuple + info : tuple The triple of delay type, delay steps, and delay variable. - See Also:: + See Also + -------- variable_, parameter, noise """ diff --git a/brainpy/initialize/random_inits.py b/brainpy/initialize/random_inits.py index 19059fee5..a3d0f8060 100644 --- a/brainpy/initialize/random_inits.py +++ b/brainpy/initialize/random_inits.py @@ -62,9 +62,12 @@ def calculate_gain(nonlinearity, param=None): In contrast, the default gain for ``SELU`` sacrifices the normalisation effect for more stable gradient flow in rectangular layers. - Args: - nonlinearity: the non-linear function (`nn.functional` name) - param: optional parameter for the non-linear function + Parameters + ---------- + nonlinearity + the non-linear function (`nn.functional` name) + param + optional parameter for the non-linear function .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html """ @@ -114,10 +117,10 @@ def _compute_fans(shape, in_axis=-2, out_axis=-1): class Normal(_InterLayerInitializer): """Initialize weights with normal distribution. - Parameters:: - + Parameters + ---------- scale : float - The gain of the derivation of the normal distribution. + The gain of the derivation of the normal distribution. """ @@ -139,20 +142,20 @@ def __repr__(self): class TruncatedNormal(_InterLayerInitializer): """Initialize weights with truncated normal distribution. - Parameters:: - + Parameters + ---------- loc : float, ndarray - Mean ("centre") of the distribution before truncating. Note that - the mean of the truncated distribution will not be exactly equal - to ``loc``. + Mean ("centre") of the distribution before truncating. Note that + the mean of the truncated distribution will not be exactly equal + to ``loc``. scale : float - The standard deviation of the normal distribution before truncating. + The standard deviation of the normal distribution before truncating. lower : float, ndarray - A float or array of floats representing the lower bound for + A float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with ``upper``. upper : float, ndarray - A float or array of floats representing the upper bound for - truncation. Must be broadcast-compatible with ``lower``. + A float or array of floats representing the upper bound for + truncation. Must be broadcast-compatible with ``lower``. """ @@ -183,12 +186,12 @@ def __repr__(self): class Gamma(_InterLayerInitializer): """Initialize weights with Gamma distribution. - Parameters:: - - shape: float, Array - Shape parameter. - scale: float, Array - The gain of the derivation of the Gamma distribution. + Parameters + ---------- + shape : float, Array + Shape parameter. + scale : float, Array + The gain of the derivation of the Gamma distribution. """ @@ -208,10 +211,10 @@ def __repr__(self): class Exponential(_InterLayerInitializer): """Initialize weights with Gamma distribution. - Parameters:: - - scale: float, Array - The gain of the derivation of the Exponential distribution. + Parameters + ---------- + scale : float, Array + The gain of the derivation of the Exponential distribution. """ @@ -230,12 +233,12 @@ def __repr__(self): class Uniform(_InterLayerInitializer): """Initialize weights with uniform distribution. - Parameters:: - + Parameters + ---------- min_val : float - The lower limit of the uniform distribution. + The lower limit of the uniform distribution. max_val : float - The upper limit of the uniform distribution. + The upper limit of the uniform distribution. """ def __init__(self, min_val: float = 0., max_val: float = 1., seed=None): diff --git a/brainpy/initialize/regular_inits.py b/brainpy/initialize/regular_inits.py index 57071ab75..86c548919 100644 --- a/brainpy/initialize/regular_inits.py +++ b/brainpy/initialize/regular_inits.py @@ -45,7 +45,8 @@ class Constant(_InterLayerInitializer): Initialize the weights with the given values. - Parameters:: + Parameters + ---------- value : float, int, bm.ndarray The value to specify. @@ -74,17 +75,20 @@ class Identity(_InterLayerInitializer): This initializer was proposed in (Le, et al., 2015) [1]_. - Parameters:: + Parameters + ---------- value : float The optional scaling factor. - Returns:: + Returns + ------- - shape: tuple of int + shape : tuple of int The weight shape/size. - References:: + References + ---------- .. [1] Le, Quoc V., Navdeep Jaitly, and Geoffrey E. Hinton. "A simple way to initialize recurrent networks of rectified linear units." arXiv preprint diff --git a/brainpy/inputs/currents.py b/brainpy/inputs/currents.py index cf708f559..b37156b83 100644 --- a/brainpy/inputs/currents.py +++ b/brainpy/inputs/currents.py @@ -45,7 +45,8 @@ def section_input(values, durations, dt=None, return_length=False): >>> section_input(values=[0, 1], durations=[100, 100]) - Parameters:: + Parameters + ---------- values : list, np.ndarray The current values for each period duration. @@ -56,7 +57,8 @@ def section_input(values, durations, dt=None, return_length=False): return_length : bool Return the final duration length. - Returns:: + Returns + ------- current_and_duration """ @@ -76,7 +78,8 @@ def constant_input(I_and_duration, dt=None): >>> constant_input([(0, 100), (1, 100)]) >>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)]) - Parameters:: + Parameters + ---------- I_and_duration : list This parameter receives the current size and the current @@ -84,7 +87,8 @@ def constant_input(I_and_duration, dt=None): dt : float Default is None. - Returns:: + Returns + ------- current_and_duration : tuple (The formatted current, total duration) @@ -119,7 +123,8 @@ def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None): >>> sp_sizes=0.5, # can be a list to specify the current size at each point >>> duration=400.) - Parameters:: + Parameters + ---------- sp_times : list, tuple The spike time-points. Must be an iterable object. @@ -132,7 +137,8 @@ def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None): dt : float The default is None. - Returns:: + Returns + ------- current : bm.ndarray The formatted input current. @@ -156,7 +162,8 @@ def spike_current(*args, **kwargs): def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None): """Get the gradually changed input current. - Parameters:: + Parameters + ---------- c_start : float The minimum (or maximum) current size. @@ -171,7 +178,8 @@ def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None): dt : float, int, optional The numerical precision. - Returns:: + Returns + ------- current : bm.ndarray The formatted current @@ -196,19 +204,20 @@ def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None): """Stimulus sampled from a Wiener process, i.e. drawn from standard normal distribution N(0, sqrt(dt)). - Parameters:: + Parameters + ---------- - duration: float + duration : float The input duration. - dt: float + dt : float The numerical precision. - n: int + n : int The variable number. - t_start: float + t_start : float The start time. - t_end: float + t_end : float The end time. - seed: int + seed : int The noise seed. """ with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt): @@ -222,25 +231,26 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, dX = (mu - X)/\tau * dt + \sigma*dW - Parameters:: + Parameters + ---------- - mean: float + mean : float Drift of the OU process. - sigma: float + sigma : float Standard deviation of the Wiener process, i.e. strength of the noise. - tau: float + tau : float Timescale of the OU process, in ms. - duration: float + duration : float The input duration. - dt: float + dt : float The numerical precision. - n: int + n : int The variable number. - t_start: float + t_start : float The start time. - t_end: float + t_end : float The end time. - seed: optional, int + seed : optional, int The random seed. """ with brainstate.environ.context(dt=brainpy.math.get_dt() if dt is None else dt): @@ -250,21 +260,22 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end=None, bias=False): """Sinusoidal input. - Parameters:: + Parameters + ---------- - amplitude: float + amplitude : float Amplitude of the sinusoid. - frequency: float + frequency : float Frequency of the sinus oscillation, in Hz - duration: float + duration : float The input duration. - t_start: float + t_start : float The start time. - t_end: float + t_end : float The end time. - dt: float + dt : float The numerical precision. - bias: bool + bias : bool Whether the sinusoid oscillates around 0 (False), or has a positive DC bias, thus non-negative (True). """ @@ -275,21 +286,22 @@ def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end= def square_input(amplitude, frequency, duration, dt=None, bias=False, t_start=0., t_end=None): """Oscillatory square input. - Parameters:: + Parameters + ---------- - amplitude: float + amplitude : float Amplitude of the square oscillation. - frequency: float + frequency : float Frequency of the square oscillation, in Hz. - duration: float + duration : float The input duration. - t_start: float + t_start : float The start time. - t_end: float + t_end : float The end time. - dt: float + dt : float The numerical precision. - bias: bool + bias : bool Whether the sinusoid oscillates around 0 (False), or has a positive DC bias, thus non-negative (True). """ diff --git a/brainpy/integrators/fde/Caputo.py b/brainpy/integrators/fde/Caputo.py index 57d6ca37f..9fb7bb21e 100644 --- a/brainpy/integrators/fde/Caputo.py +++ b/brainpy/integrators/fde/Caputo.py @@ -77,7 +77,8 @@ class CaputoEuler(FDEIntegrator): b_{j, k+1}=\frac{h^{\alpha}}{\alpha}\left((k+1-j)^{\alpha}-(k-j)^{\alpha}\right). - Examples:: + Examples + -------- >>> import brainpy as bp >>> @@ -100,22 +101,24 @@ class CaputoEuler(FDEIntegrator): >>> plt.show() - Parameters:: + Parameters + ---------- f : callable The derivative function. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence + alpha : int, float, jnp.ndarray, bm.ndarray, sequence The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``. - num_memory: int + num_memory : int The total time step of the simulation. - inits: sequence + inits : sequence A sequence of the initial values for variables. - dt: float, int + dt : float, int The numerical precision. - name: str + name : str The integrator name. - References:: + References + ---------- .. [1] Li, Changpin, and Fanhai Zeng. "The finite difference methods for fractional ordinary differential equations." Numerical Functional Analysis and @@ -280,7 +283,8 @@ class CaputoL1Schema(FDEIntegrator): from the first order dynamics. - Examples:: + Examples + -------- >>> import brainpy as bp >>> @@ -303,22 +307,24 @@ class CaputoL1Schema(FDEIntegrator): >>> plt.show() - Parameters:: + Parameters + ---------- f : callable The derivative function. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence + alpha : int, float, jnp.ndarray, bm.ndarray, sequence The fractional-order of the derivative function. Should be in the range of ``(0., 1.]``. - num_memory: int + num_memory : int The total time step of the simulation. - inits: sequence + inits : sequence A sequence of the initial values for variables. - dt: float, int + dt : float, int The numerical precision. - name: str + name : str The integrator name. - References:: + References + ---------- .. [3] Oldham, K., & Spanier, J. (1974). The fractional calculus theory and applications of differentiation and integration to arbitrary diff --git a/brainpy/integrators/fde/GL.py b/brainpy/integrators/fde/GL.py index 76c3fc555..12e110305 100644 --- a/brainpy/integrators/fde/GL.py +++ b/brainpy/integrators/fde/GL.py @@ -80,7 +80,8 @@ class GLShortMemory(FDEIntegrator): is the memory window with a width defined by :math:`M=\frac{L_{m}}{h}`. As was reported in [2]_, the accuracy increases by increaing the width of memory window. - Examples:: + Examples + -------- >>> import brainpy as bp >>> @@ -106,25 +107,27 @@ class GLShortMemory(FDEIntegrator): >>> plt.show() - Parameters:: + Parameters + ---------- f : callable The derivative function. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence + alpha : int, float, jnp.ndarray, bm.ndarray, sequence The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``. - num_memory: int + num_memory : int The length of the short memory. .. versionchanged:: 2.1.11 - inits: sequence + inits : sequence A sequence of the initial values for variables. - dt: float, int + dt : float, int The numerical precision. - name: str + name : str The integrator name. - References:: + References + ---------- .. [1] Clemente-López, D., et al. "Efficient computation of the Grünwald-Letnikov method for arm-based implementations of diff --git a/brainpy/integrators/fde/base.py b/brainpy/integrators/fde/base.py index 53f61d782..b8c05b199 100644 --- a/brainpy/integrators/fde/base.py +++ b/brainpy/integrators/fde/base.py @@ -31,15 +31,16 @@ class FDEIntegrator(Integrator): """Numerical integrator for fractional differential equations (FEDs). - Parameters:: + Parameters + ---------- f : callable The derivative function. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence + alpha : int, float, jnp.ndarray, bm.ndarray, sequence The fractional-order of the derivative function. - dt: float, int + dt : float, int The numerical precision. - name: str + name : str The integrator name. """ diff --git a/brainpy/integrators/fde/generic.py b/brainpy/integrators/fde/generic.py index 2d496c543..f6f52b5bd 100644 --- a/brainpy/integrators/fde/generic.py +++ b/brainpy/integrators/fde/generic.py @@ -38,24 +38,26 @@ def fdeint( ): """Numerical integration for FDEs. - Parameters:: + Parameters + ---------- f : callable, function The derivative function. method : str The shortcut name of the numerical integrator. - alpha: int, float, jnp.ndarray, bm.ndarray, sequence + alpha : int, float, jnp.ndarray, bm.ndarray, sequence The fractional-order of the derivative function. Should be in the range of ``(0., 1.]``. - num_memory: int + num_memory : int The number of the memory length. - inits: sequence + inits : sequence A sequence of the initial values for variables. - dt: float, int + dt : float, int The numerical precision. - name: str + name : str The integrator name. - Returns:: + Returns + ------- integral : FDEIntegrator The numerical solver of `f`. @@ -74,7 +76,8 @@ def fdeint( def set_default_fdeint(method): """Set the default FDE numerical integrator method for fractional differential equations. - Parameters:: + Parameters + ---------- method : str, callable Numerical integrator method. @@ -91,7 +94,8 @@ def set_default_fdeint(method): def get_default_fdeint(): """Get the default FDE numerical integrator method. - Returns:: + Returns + ------- method : str The default numerical integrator method. @@ -102,11 +106,12 @@ def get_default_fdeint(): def register_fde_integrator(name, integrator): """Register a new FDE integrator. - Parameters:: + Parameters + ---------- - name: ste + name : ste The integrator name. - integrator: type + integrator : type The integrator. """ if name in name2method: diff --git a/brainpy/integrators/joint_eq.py b/brainpy/integrators/joint_eq.py index a4954963d..e34618e06 100644 --- a/brainpy/integrators/joint_eq.py +++ b/brainpy/integrators/joint_eq.py @@ -128,9 +128,10 @@ def brainpy_itg_of_ode0_joint_eq(V, u, t, Iext, dt=0.1): >>> eq2 = bp.JointEq(eqs=(eq, dw)) - Parameters:: + Parameters + ---------- - *eqs : + *eqs The elements of derivative function to compose. """ diff --git a/brainpy/integrators/ode/adaptive_rk.py b/brainpy/integrators/ode/adaptive_rk.py index 77f3f2f1a..08bca532a 100644 --- a/brainpy/integrators/ode/adaptive_rk.py +++ b/brainpy/integrators/ode/adaptive_rk.py @@ -119,7 +119,8 @@ class AdaptiveRKIntegrator(ODEIntegrator): & b_1^* & b_2^* & \dots & b_s^*\\ \end{array} - Parameters:: + Parameters + ---------- f : callable The derivative function. @@ -258,7 +259,8 @@ class RKF12(AdaptiveRKIntegrator): & 1 / 256 & 255 / 256 & 0 \end{array} - References:: + References + ---------- .. [1] Fehlberg, E. (1969-07-01). "Low-order classical Runge-Kutta formulas with stepsize control and their application to some heat @@ -302,7 +304,8 @@ class RKF45(AdaptiveRKIntegrator): & 25 / 216 & 0 & 1408 / 2565 & 2197 / 4104 & -1 / 5 & 0 \end{array} - References:: + References + ---------- .. [1] https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta%E2%80%93Fehlberg_method .. [2] Erwin Fehlberg (1969). Low-order classical Runge-Kutta formulas with step @@ -358,7 +361,8 @@ class DormandPrince(AdaptiveRKIntegrator): & 5179 / 57600 & 0 & 7571 / 16695 & 393 / 640 & -92097 / 339200 & 187 / 2100 & 1 / 40 \end{array} - References:: + References + ---------- .. [1] https://en.wikipedia.org/wiki/Dormand%E2%80%93Prince_method .. [2] Dormand, J. R.; Prince, P. J. (1980), "A family of embedded Runge-Kutta formulae", @@ -409,7 +413,8 @@ class CashKarp(AdaptiveRKIntegrator): & 2825 / 27648 & 0 & 18575 / 48384 & 13525 / 55296 & 277 / 14336 & 1 / 4 \end{array} - References:: + References + ---------- .. [1] https://en.wikipedia.org/wiki/Cash%E2%80%93Karp_method .. [2] J. R. Cash, A. H. Karp. "A variable order Runge-Kutta method for initial value @@ -457,7 +462,8 @@ class BogackiShampine(AdaptiveRKIntegrator): & 7 / 24 & 1 / 4 & 1 / 3 & 1 / 8 \end{array} - References:: + References + ---------- .. [1] https://en.wikipedia.org/wiki/Bogacki%E2%80%93Shampine_method .. [2] Bogacki, Przemysław; Shampine, Lawrence F. (1989), "A 3(2) pair of Runge–Kutta @@ -516,7 +522,8 @@ class DOP853(AdaptiveRKIntegrator): DOP853 is an explicit Runge-Kutta method of order 8(5,3) due to Dormand & Prince (with stepsize control and dense output). - References:: + References + ---------- .. [1] E. Hairer, S.P. Norsett and G. Wanner, "Solving ordinary Differential Equations I. Nonstiff Problems", 2nd edition. Springer Series in Computational Mathematics, diff --git a/brainpy/integrators/ode/base.py b/brainpy/integrators/ode/base.py index 22ed7c8d5..cc6c5b091 100644 --- a/brainpy/integrators/ode/base.py +++ b/brainpy/integrators/ode/base.py @@ -37,15 +37,16 @@ def f_names(f): class ODEIntegrator(Integrator): """Numerical Integrator for Ordinary Differential Equations (ODEs). - Parameters:: + Parameters + ---------- f : callable The derivative function. - var_type: str + var_type : str The type for each variable. - dt: float, int + dt : float, int The numerical precision. - name: str + name : str The integrator name. """ diff --git a/brainpy/integrators/ode/explicit_rk.py b/brainpy/integrators/ode/explicit_rk.py index 9b425a911..720202e26 100644 --- a/brainpy/integrators/ode/explicit_rk.py +++ b/brainpy/integrators/ode/explicit_rk.py @@ -132,7 +132,8 @@ class ExplicitRKIntegrator(ODEIntegrator): \hline & b_{1} & b_{2} & \ldots & b_{s} \end{array} - Parameters:: + Parameters + ---------- f : callable The derivative function. @@ -260,7 +261,8 @@ class Euler(ExplicitRKIntegrator): and accuracy limits its popularity mainly to use as a simple introductory example of a numeric solution method. - References:: + References + ---------- .. [1] W. H.; Flannery, B. P.; Teukolsky, S. A.; and Vetterling, W. T. Numerical Recipes in FORTRAN: The Art of Scientific @@ -359,7 +361,8 @@ class MidPoint(ExplicitRKIntegrator): Note that the red chord is not exactly parallel to the green segment (the true tangent), due to the error in estimating the value of :math:`y(t)` at the midpoint. - References:: + References + ---------- .. [1] Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003. .. [2] https://en.wikipedia.org/wiki/Midpoint_method @@ -427,7 +430,8 @@ class Heun2(ExplicitRKIntegrator): {\text{Slope}}_{\text{ideal}}=&{\frac {1}{2}}({\text{Slope}}_{\text{left}}+{\text{Slope}}_{\text{right}}) \end{aligned} - References:: + References + ---------- .. [1] Süli, Endre, and David F. Mayers. An Introduction to Numerical Analysis. no. 1, 2003. """ @@ -581,7 +585,8 @@ class RK2(ExplicitRKIntegrator): \hline & 1 - {1 \over 2 * \beta} & {1 \over 2 * \beta} \end{array} - References:: + References + ---------- .. [1] Chapra, Steven C., and Raymond P. Canale. Numerical methods for engineers. Vol. 1221. New York: Mcgraw-hill, 2011. @@ -697,7 +702,8 @@ class Ralston3(ExplicitRKIntegrator): \hline & 2 / 9 & 1 / 3 & 4 / 9 \end{array} - References:: + References + ---------- .. [1] Ralston, Anthony (1962). "Runge-Kutta Methods with Minimum Error Bounds". Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0 @@ -788,7 +794,8 @@ class RK4(ExplicitRKIntegrator): \hline & 1 / 6 & 1 / 3 & 1 / 3 & 1 / 6 \end{array} - References:: + References + ---------- .. [1] Lambert, J. D. and Lambert, D. Ch. 5 in Numerical Methods for Ordinary Differential Systems: The Initial Value Problem. New York: Wiley, 1991. @@ -826,7 +833,8 @@ class Ralston4(ExplicitRKIntegrator): \hline & .17476028 & -.55148066 & 1.20553560 & .17118478 \end{array} - References:: + References + ---------- .. [1] Ralston, Anthony (1962). "Runge-Kutta Methods with Minimum Error Bounds". Math. Comput. 16 (80): 431–437. doi:10.1090/S0025-5718-1962-0150954-0 @@ -868,7 +876,8 @@ class RK4Rule38(ExplicitRKIntegrator): \end{array} - References:: + References + ---------- .. [1] Hairer, Ernst; Nørsett, Syvert Paul; Wanner, Gerhard (1993), Solving ordinary differential equations I: Nonstiff problems, diff --git a/brainpy/integrators/ode/exponential.py b/brainpy/integrators/ode/exponential.py index 6d4ff9890..ebe1f4e58 100644 --- a/brainpy/integrators/ode/exponential.py +++ b/brainpy/integrators/ode/exponential.py @@ -140,7 +140,8 @@ class ExponentialEuler(ODEIntegrator): to automatically infer the linear part of the given function. Therefore, it has minimal constraints on your derivative function. Arbitrary complex functions can be numerically integrated with this method. - Examples:: + Examples + -------- Here is an example uses ``ExponentialEuler`` to implement HH neuron model. @@ -285,7 +286,8 @@ class ExponentialEuler(ODEIntegrator): >>> run(100) >>> bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) - Parameters:: + Parameters + ---------- f : function, joint_eq.JointEq The derivative function. diff --git a/brainpy/integrators/ode/generic.py b/brainpy/integrators/ode/generic.py index 8ee6c0998..18fa0a6cb 100644 --- a/brainpy/integrators/ode/generic.py +++ b/brainpy/integrators/ode/generic.py @@ -44,7 +44,8 @@ def odeint( ): """Numerical integration for ODEs. - Examples:: + Examples + -------- .. plot:: :include-source: True @@ -69,28 +70,30 @@ def odeint( >>> plt.show() - Parameters:: + Parameters + ---------- f : callable, function The derivative function. method : str The shortcut name of the numerical integrator. - var_type: str + var_type : str The type of the variable defined in the equation. - dt: float + dt : float The numerical integration precision. - name: str + name : str The integrator node. - state_delays: dict + state_delays : dict The state delay variable. - show_code: bool + show_code : bool Show the formated code. - adaptive: bool + adaptive : bool The use adaptive mode. - tol: float + tol : float The tolerence to adapt new step size. - Returns:: + Returns + ------- integral : ODEIntegrator The numerical solver of `f`. @@ -123,7 +126,8 @@ def odeint( def set_default_odeint(method): """Set the default ODE numerical integrator method for differential equations. - Parameters:: + Parameters + ---------- method : str, callable Numerical integrator method. @@ -140,7 +144,8 @@ def set_default_odeint(method): def get_default_odeint(): """Get the default ODE numerical integrator method. - Returns:: + Returns + ------- method : str The default numerical integrator method. @@ -151,10 +156,11 @@ def get_default_odeint(): def register_ode_integrator(name, integrator): """Register a new ODE integrator. - Parameters:: + Parameters + ---------- - name: ste - integrator: type + name : ste + integrator : type """ if name in name2method: raise ValueError(f'"{name}" has been registered in ODE integrators.') diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py index f50f72a8b..28aeb2fba 100644 --- a/brainpy/integrators/runner.py +++ b/brainpy/integrators/runner.py @@ -38,7 +38,8 @@ class IntegratorRunner(Runner): """Structural runner for numerical integrators in brainpy. - Examples:: + Examples + -------- Example to run an ODE integrator, @@ -112,22 +113,23 @@ def __init__( ): """Initialization of structural runner for integrators. - Parameters:: + Parameters + ---------- - target: Integrator + target : Integrator The target to run. - monitors: sequence of str + monitors : sequence of str The variables to monitor. - fun_monitors: dict + fun_monitors : dict The monitors with callable functions. .. deprecated:: 2.3.1 - inits: sequence, dict + inits : sequence, dict The initial value of variables. With this parameter, you can easily control the number of variables to simulate. For example, if one of the variable has the shape of 10, then all variables will be an instance of :py:class:`brainpy.math.Variable` with the shape of :math:`(10,)`. - args: dict + args : dict The equation arguments to update. Note that if one of the arguments are heterogeneous (i.e., a tensor), it means we should run multiple trials. However, you can set the number @@ -137,7 +139,7 @@ def __init__( .. deprecated:: 2.3.1 Will be removed after version 2.4.0. - dyn_args: dict + dyn_args : dict The dynamically changed arguments. This means this argument can control the argument dynamically changed. For example, if you want to inject a time varied currents into the HH neuron model, you can pack the currents @@ -146,11 +148,11 @@ def __init__( .. deprecated:: 2.3.1 Will be removed after version 2.4.0. - dt: float, int - dyn_vars: dict - jit: bool - progress_bar: bool - numpy_mon_after_run: bool + dt : float, int + dyn_vars : dict + jit : bool + progress_bar : bool + numpy_mon_after_run : bool """ if not isinstance(target, Integrator): @@ -286,19 +288,20 @@ def run( ): """The running function. - Parameters:: + Parameters + ---------- duration : float, int, tuple, list The running duration. start_t : float, optional The start time to simulate. - eval_time: bool + eval_time : bool Evaluate the running time or not? - args: dict + args : dict The equation arguments to update. .. versionadded:: 2.3.1 - dyn_args: dict + dyn_args : dict The dynamically changed arguments over time. The size of first dimension should be equal to the running ``duration``. diff --git a/brainpy/integrators/sde/generic.py b/brainpy/integrators/sde/generic.py index 8649ae56a..3d2e28d04 100644 --- a/brainpy/integrators/sde/generic.py +++ b/brainpy/integrators/sde/generic.py @@ -45,14 +45,16 @@ def sdeint( ): """Numerical integration for SDEs. - Parameters:: + Parameters + ---------- f : callable, function The derivative function. method : str The shortcut name of the numerical integrator. - Returns:: + Returns + ------- integral : SDEIntegrator The numerical solver of `f`. @@ -102,7 +104,8 @@ def sdeint( def set_default_sdeint(method): """Set the default SDE numerical integrator method for differential equations. - Parameters:: + Parameters + ---------- method : str, callable Numerical integrator method. @@ -119,7 +122,8 @@ def set_default_sdeint(method): def get_default_sdeint(): """Get the default SDE numerical integrator method. - Returns:: + Returns + ------- method : str The default numerical integrator method. @@ -130,10 +134,11 @@ def get_default_sdeint(): def register_sde_integrator(name, integrator): """Register a new SDE integrator. - Parameters:: + Parameters + ---------- - name: ste - integrator: type + name : ste + integrator : type """ if name in name2method: raise ValueError(f'"{name}" has been registered in SDE integrators.') diff --git a/brainpy/integrators/sde/normal.py b/brainpy/integrators/sde/normal.py index b75d9258a..a4a2b722f 100644 --- a/brainpy/integrators/sde/normal.py +++ b/brainpy/integrators/sde/normal.py @@ -88,7 +88,8 @@ class Euler(SDEIntegrator): \end{aligned} - See Also:: + See Also + -------- Heun @@ -213,7 +214,8 @@ class Heun(Euler): \end{aligned} - See Also:: + See Also + -------- Euler @@ -543,13 +545,15 @@ class ExponentialEuler(SDEIntegrator): where :math:`\varphi(z)=\frac{e^{z}-1}{z}`. - References:: + References + ---------- .. [1] Erdoğan, Utku, and Gabriel J. Lord. "A new class of exponential integrators for stochastic differential equations with multiplicative noise." arXiv preprint arXiv:1608.07096 (2016). - See Also:: + See Also + -------- Euler, Heun, Milstein """ diff --git a/brainpy/integrators/sde/srk_scalar.py b/brainpy/integrators/sde/srk_scalar.py index c21dc1a2a..deb39f22d 100644 --- a/brainpy/integrators/sde/srk_scalar.py +++ b/brainpy/integrators/sde/srk_scalar.py @@ -95,7 +95,8 @@ class SRK1W1(SDEIntegrator): \end{array} - References:: + References + ---------- .. [1] Rößler, Andreas. "Strong and weak approximation methods for stochastic differential equations—some recent developments." Recent developments in applied probability and @@ -214,7 +215,8 @@ class SRK2W1(SDEIntegrator): \end{array} - References:: + References + ---------- .. [1] Rößler, Andreas. "Strong and weak approximation methods for stochastic differential equations—some recent developments." Recent developments in applied probability and diff --git a/brainpy/integrators/utils.py b/brainpy/integrators/utils.py index 7703a0967..c72a47710 100644 --- a/brainpy/integrators/utils.py +++ b/brainpy/integrators/utils.py @@ -64,12 +64,14 @@ def get_args(f): >>> get_args(scope['f5']) (['a', 'b'], ['t', '*args'], ['a', 'b', 't', '*args']) - Parameters:: + Parameters + ---------- f : callable The function. - Returns:: + Returns + ------- args : tuple The variable names, the other arguments, and the original args. diff --git a/brainpy/losses/comparison.py b/brainpy/losses/comparison.py index bfe60cb50..a3aa7223d 100644 --- a/brainpy/losses/comparison.py +++ b/brainpy/losses/comparison.py @@ -126,33 +126,40 @@ class CrossEntropyLoss(WeightedLoss): indices, as this allows for optimized computation. Consider providing `target` as class probabilities only when a single class label per minibatch item is too restrictive. - Args: - weight (Tensor, optional): a manual rescaling weight given to each class. - If given, has to be a Tensor of size `C` - size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, - the losses are averaged over each loss element in the batch. Note that for - some losses, there are multiple elements per sample. If the field :attr:`size_average` - is set to ``False``, the losses are instead summed for each minibatch. Ignored - when :attr:`reduce` is ``False``. Default: ``True`` - ignore_index (int, optional): Specifies a target value that is ignored - and does not contribute to the input gradient. When :attr:`size_average` is - ``True``, the loss is averaged over non-ignored targets. Note that - :attr:`ignore_index` is only applicable when the target contains class indices. - reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the - losses are averaged or summed over observations for each minibatch depending - on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per - batch element instead and ignores :attr:`size_average`. Default: ``True`` - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will - be applied, ``'mean'``: the weighted mean of the output is taken, - ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in - the meantime, specifying either of those two args will override - :attr:`reduction`. Default: ``'mean'`` - label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount - of smoothing when computing the loss, where 0.0 means no smoothing. The targets - become a mixture of the original ground truth and a uniform distribution as described in - `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. + Parameters + ---------- + weight : Tensor, optional + a manual rescaling weight given to each class. + If given, has to be a Tensor of size `C` + size_average : bool, optional + Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + ignore_index : int, optional + Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + reduce : bool, optional + Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction : str, optional + Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + label_smoothing : float, optional + A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Shape: - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` @@ -172,7 +179,9 @@ class probabilities only when a single class label per minibatch item is too res N ={} & \text{batch size} \\ \end{aligned} - Examples:: + Examples + -------- + .. code-block:: python >>> # Example of target with class indices >>> loss = nn.CrossEntropyLoss() @@ -234,31 +243,31 @@ def cross_entropy_loss(predicts, targets, weight=None, reduction='mean', an input of size :math:`(d_1, d_2, ..., d_K, minibatch, C)` with :math:`K \geq 1`, where :math:`K` is the number of dimensions, and a target of appropriate shape. - Parameters:: - + Parameters + ---------- predicts : ArrayType - :math:`(N, C)` where `C = number of classes`, or - :math:`(d_1, d_2, ..., d_K, N, C)` with :math:`K \geq 1` - in the case of `K`-dimensional loss. + :math:`(N, C)` where `C = number of classes`, or + :math:`(d_1, d_2, ..., d_K, N, C)` with :math:`K \geq 1` + in the case of `K`-dimensional loss. targets : ArrayType - :math:`(N, C)` or :math:`(N)` where each value is - :math:`0 \leq \text{targets}[i] \leq C-1`, or - :math:`(d_1, d_2, ..., d_K, N, C)` or :math:`(d_1, d_2, ..., d_K, N)` - with :math:`K \geq 1` in the case of K-dimensional loss. + :math:`(N, C)` or :math:`(N)` where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`, or + :math:`(d_1, d_2, ..., d_K, N, C)` or :math:`(d_1, d_2, ..., d_K, N)` + with :math:`K \geq 1` in the case of K-dimensional loss. weight : ArrayType, optional - A manual rescaling weight given to each class. If given, has to be an array of size `C`. + A manual rescaling weight given to each class. If given, has to be an array of size `C`. reduction : str, optional - Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. - - ``'none'``: no reduction will be applied, - - ``'mean'``: the weighted mean of the output is taken, - - ``'sum'``: the output will be summed. - - Returns:: + Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. + - ``'none'``: no reduction will be applied, + - ``'mean'``: the weighted mean of the output is taken, + - ``'sum'``: the output will be summed. + Returns + ------- output : scalar, ArrayType - If :attr:`reduction` is ``'none'``, then the same size as the target: - :math:`(N)`, or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` - in the case of K-dimensional loss. + If :attr:`reduction` is ``'none'``, then the same size as the target: + :math:`(N)`, or :math:`(d_1, d_2, ..., d_K, N)` with :math:`K \geq 1` + in the case of K-dimensional loss. """ def _cel(_pred, _tar): @@ -311,12 +320,16 @@ def _cel(_pred, _tar): def cross_entropy_sparse(predicts, targets): r"""Computes the softmax cross-entropy loss. - Args: - predicts: (batch, ..., #class) tensor of logits. - targets: (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. + Parameters + ---------- + predicts + (batch, ..., #class) tensor of logits. + targets + (batch, ...) integer tensor of label indexes in {0, ...,#nclass-1} or just a single integer. - Returns: - (batch, ...) tensor of the cross-entropy for each entry. + Returns + ------- + (batch, ...) tensor of the cross-entropy for each entry. """ def crs(_prd, _tar): @@ -333,12 +346,16 @@ def crs(_prd, _tar): def cross_entropy_sigmoid(predicts, targets): """Computes the sigmoid cross-entropy loss. - Args: - predicts: (batch, ..., #class) tensor of logits. - targets: (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) + Parameters + ---------- + predicts + (batch, ..., #class) tensor of logits. + targets + (batch, ..., #class) tensor of label probabilities (e.g. labels.sum(axis=-1) must be 1) - Returns: - (batch, ...) tensor of the cross-entropies for each entry. + Returns + ------- + (batch, ...) tensor of the cross-entropies for each entry. """ r = tree_map( lambda pred, tar: bm.as_jax( @@ -395,14 +412,16 @@ class NLLLoss(Loss): \text{if reduction} = \text{`sum'.} \end{cases} - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will - be applied, ``'mean'``: the weighted mean of the output is taken, - ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in - the meantime, specifying either of those two args will override - :attr:`reduction`. Default: ``'mean'`` + Parameters + ---------- + reduction : str, optional + Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or @@ -470,14 +489,16 @@ def nll_loss(input, target, reduction: str = 'mean'): \text{if reduction} = \text{`sum'.} \end{cases} - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will - be applied, ``'mean'``: the weighted mean of the output is taken, - ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in - the meantime, specifying either of those two args will override - :attr:`reduction`. Default: ``'mean'`` + Parameters + ---------- + reduction : str, optional + Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, or @@ -539,13 +560,15 @@ class L1Loss(Loss): Supports real-valued and complex-valued inputs. - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in the meantime, - specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + Parameters + ---------- + reduction : str, optional + Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. @@ -553,7 +576,9 @@ class L1Loss(Loss): - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- + .. code-block:: python >>> loss = nn.L1Loss() >>> input = bm.random.randn(3, 5) @@ -599,23 +624,23 @@ def l1_loss(logits, targets, reduction='mean'): Supports real-valued and complex-valued inputs. - Parameters:: - + Parameters + ---------- logits : ArrayType - :math:`(N, *)` where :math:`*` means, any number of additional dimensions. + :math:`(N, *)` where :math:`*` means, any number of additional dimensions. targets : ArrayType - :math:`(N, *)`, same shape as the input. + :math:`(N, *)`, same shape as the input. reduction : str - Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. - Default: ``'mean'``. - - ``'none'``: no reduction will be applied, - - ``'mean'``: the sum of the output will be divided by the number of elements in the output, - - ``'sum'``: the output will be summed. Note: :attr:`size_average` - - Returns:: - + Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. + Default: ``'mean'``. + - ``'none'``: no reduction will be applied, + - ``'mean'``: the sum of the output will be divided by the number of elements in the output, + - ``'sum'``: the output will be summed. Note: :attr:`size_average` + + Returns + ------- output : scalar. - If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input. + If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same shape as the input. """ r = tree_map(lambda pred, tar: _bt_metric.l1_loss(pred, tar, reduction=reduction), @@ -629,20 +654,20 @@ def l2_loss(predicts, targets): The 0.5 term is standard in "Pattern Recognition and Machine Learning" by Bishop [1]_, but not "The Elements of Statistical Learning" by Tibshirani. - Parameters:: - - predicts: ArrayType - A vector of arbitrary shape. - targets: ArrayType - A vector of shape compatible with predictions. - - Returns:: + Parameters + ---------- + predicts : ArrayType + A vector of arbitrary shape. + targets : ArrayType + A vector of shape compatible with predictions. + Returns + ------- loss : float - A scalar value containing the l2 loss. - - References:: + A scalar value containing the l2 loss. + References + ---------- .. [1] Bishop, Christopher M. 2006. Pattern Recognition and Machine Learning. """ r = tree_map(lambda pred, tar: _bt_metric.l2_loss(pred, tar), @@ -664,13 +689,18 @@ def update(self, input, target): def mean_absolute_error(x, y, axis=None, reduction: str = 'mean'): r"""Computes the mean absolute error between x and y. - Args: - x: a tensor of shape (d0, .. dN-1). - y: a tensor of shape (d0, .. dN-1). - axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. + Parameters + ---------- + x + a tensor of shape (d0, .. dN-1). + y + a tensor of shape (d0, .. dN-1). + axis + a sequence of the dimensions to keep, use `None` to return a scalar value. + + Returns + ------- + tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. """ r = tree_map(lambda a, b: _bt_metric.absolute_error(a, b, axis=axis, reduction=reduction), x, @@ -706,19 +736,23 @@ class MSELoss(Loss): The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. - Args: - reduction (str, optional): Specifies the reduction to apply to the output: - ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, - ``'mean'``: the sum of the output will be divided by the number of - elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` - and :attr:`reduce` are in the process of being deprecated, and in the meantime, - specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + Parameters + ---------- + reduction : str, optional + Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, + ``'mean'``: the sum of the output will be divided by the number of + elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in the meantime, + specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Target: :math:`(*)`, same shape as the input. - Examples:: + Examples + -------- + .. code-block:: python >>> loss = nn.MSELoss() >>> input = torch.randn(3, 5, requires_grad=True) @@ -738,13 +772,18 @@ def update(self, input: ArrayType, target: ArrayType) -> ArrayType: def mean_squared_error(predicts, targets, axis=None, reduction: str = 'mean'): r"""Computes the mean squared error between x and y. - Args: - predicts: a tensor of shape (d0, .. dN-1). - targets: a tensor of shape (d0, .. dN-1). - axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. + Parameters + ---------- + predicts + a tensor of shape (d0, .. dN-1). + targets + a tensor of shape (d0, .. dN-1). + axis + a sequence of the dimensions to keep, use `None` to return a scalar value. + + Returns + ------- + tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. """ r = tree_map(lambda a, b: _bt_metric.squared_error(a, b, axis=axis, reduction=reduction), predicts, @@ -756,13 +795,18 @@ def mean_squared_error(predicts, targets, axis=None, reduction: str = 'mean'): def mean_squared_log_error(predicts, targets, axis=None, reduction: str = 'mean'): r"""Computes the mean squared logarithmic error between y_true and y_pred. - Args: - targets: a tensor of shape (d0, .. dN-1). - predicts: a tensor of shape (d0, .. dN-1). - keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value. - - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. + Parameters + ---------- + targets + a tensor of shape (d0, .. dN-1). + predicts + a tensor of shape (d0, .. dN-1). + keep_axis + a sequence of the dimensions to keep, use `None` to return a scalar value. + + Returns + ------- + tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error. """ r = tree_map(lambda a, b: _reduce((jnp.log1p(a) - jnp.log1p(b)) ** 2, reduction, axis=axis), predicts, @@ -778,22 +822,22 @@ def huber_loss(predicts, targets, delta: float = 1.0): If gradient descent is applied to the `huber loss`, it is equivalent to clipping gradients of an `l2_loss` to `[-delta, delta]` in the backward pass. - Parameters:: - - predicts: ArrayType - predictions - targets: ArrayType - ground truth - delta: float - radius of quadratic behavior - - Returns:: + Parameters + ---------- + predicts : ArrayType + predictions + targets : ArrayType + ground truth + delta : float + radius of quadratic behavior + Returns + ------- loss : float - The loss value. - - References:: + The loss value. + References + ---------- .. [1] https://en.wikipedia.org/wiki/Huber_loss """ @@ -805,12 +849,16 @@ def huber_loss(predicts, targets, delta: float = 1.0): def binary_logistic_loss(predicts: float, targets: int, ) -> float: """Binary logistic loss. - Args: - targets: ground-truth integer label (0 or 1). - predicts: score produced by the model (float). + Parameters + ---------- + targets + ground-truth integer label (0 or 1). + predicts + score produced by the model (float). - Returns: - loss value + Returns + ------- + loss value """ # Softplus is the Fenchel conjugate of the Fermi-Dirac negentropy on [0, 1]. # softplus = proba * logit - xlogx(proba) - xlogx(1 - proba), @@ -825,12 +873,16 @@ def binary_logistic_loss(predicts: float, targets: int, ) -> float: def multiclass_logistic_loss(label: int, logits: jnp.ndarray) -> float: """Multiclass logistic loss. - Args: - label: ground-truth integer label, between 0 and n_classes - 1. - logits: scores produced by the model, shape = (n_classes, ). + Parameters + ---------- + label : int + ground-truth integer label, between 0 and n_classes - 1. + logits : jnp.ndarray + scores produced by the model, shape = (n_classes, ). - Returns: - loss value + Returns + ------- + loss value """ def loss(pred, tar): @@ -849,15 +901,20 @@ def sigmoid_binary_cross_entropy(logits, labels): not mutually exclusive. This may be used for multilabel image classification for instance a model may predict that an image contains both a cat and a dog. - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) + Parameters + ---------- + logits + unnormalized log probabilities. + labels + the probability for that class. - Args: - logits: unnormalized log probabilities. - labels: the probability for that class. + Returns + ------- + a sigmoid cross entropy loss. - Returns: - a sigmoid cross entropy loss. + References + ---------- + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) """ r = tree_map(lambda pred, tar: _bt_metric.sigmoid_binary_cross_entropy(pred, tar), @@ -872,16 +929,21 @@ def softmax_cross_entropy(logits, labels): For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both. - References: - [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) - - Args: - logits: unnormalized log probabilities. - labels: a valid probability distribution (non-negative, sum to 1), e.g a + Parameters + ---------- + logits + unnormalized log probabilities. + labels + a valid probability distribution (non-negative, sum to 1), e.g a one hot encoding of which class is the correct one for each input. - Returns: - the cross entropy loss. + Returns + ------- + the cross entropy loss. + + References + ---------- + [Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html) """ r = tree_map(lambda pred, tar: _bt_metric.softmax_cross_entropy(pred, tar), logits, @@ -896,16 +958,21 @@ def log_cosh_loss(predicts, targets): log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` for large x. It is a twice differentiable alternative to the Huber loss. - References: - [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) - - Args: - predicts: a vector of arbitrary shape. - targets: a vector of shape compatible with predictions; if not provided + Parameters + ---------- + predicts + a vector of arbitrary shape. + targets + a vector of shape compatible with predictions; if not provided then it is assumed to be zero. - Returns: - the log-cosh loss. + Returns + ------- + the log-cosh loss. + + References + ---------- + [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) """ r = tree_map(lambda pred, tar: _bt_metric.log_cosh(pred, tar), @@ -941,35 +1008,44 @@ def ctc_loss_with_forward_probs( [Graves et al, 2006] that is blank-inserted representations of ``labels``. The return values are the logarithms of the above probabilities. - References: - [Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891) - - Args: - logits: (B, T, K)-array containing logits of each class where B denotes + Parameters + ---------- + logits : ArrayType + (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in ``logits``, and K denotes the number of classes including a class for blanks. - logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each + logit_paddings : ArrayType + (B, T)-array. Padding indicators for ``logits``. Each element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` denotes that ``logits[b, t, :]`` are padded values. - labels: (B, N)-array containing reference integer labels where N denotes + labels : ArrayType + (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence. - label_paddings: (B, N)-array. Padding indicators for ``labels``. Each + label_paddings : ArrayType + (B, N)-array. Padding indicators for ``labels``. Each element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` denotes that ``labels[b, n]`` is a padded label. In the current implementation, ``labels`` must be right-padded, i.e. each row ``labelpaddings[b, :]`` must be repetition of zeroes, followed by repetition of ones. - blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as + blank_id : int + Id for blank token. ``logits[b, :, blank_id]`` are used as probabilities of blank symbols. - log_epsilon: Numerically-stable approximation of log(+0). - - Returns: - A tuple ``(loss_value, logalpha_blank, logalpha_nonblank)``. Here, - ``loss_value`` is a (B,)-array containing the loss values for each sequence - in the batch, ``logalpha_blank`` and ``logalpha_nonblank`` are - (T, B, N+1)-arrays where the (t, b, n)-th element denotes - \log \alpha_B(t, n) and \log \alpha_L(t, n), respectively, for ``b``-th - sequence in the batch. + log_epsilon : float + Numerically-stable approximation of log(+0). + + Returns + ------- + A tuple ``(loss_value, logalpha_blank, logalpha_nonblank)``. Here, + ``loss_value`` is a (B,)-array containing the loss values for each sequence + in the batch, ``logalpha_blank`` and ``logalpha_nonblank`` are + (T, B, N+1)-arrays where the (t, b, n)-th element denotes + \log \alpha_B(t, n) and \log \alpha_L(t, n), respectively, for ``b``-th + sequence in the batch. + + References + ---------- + [Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891) """ return _bt_metric.ctc_loss_with_forward_probs( logits, logit_paddings, labels, label_paddings, @@ -986,27 +1062,35 @@ def ctc_loss(logits: ArrayType, See docstring for ``ctc_loss_with_forward_probs`` for details. - Args: - logits: (B, T, K)-array containing logits of each class where B denotes + Parameters + ---------- + logits : ArrayType + (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in ``logits``, and K denotes the number of classes including a class for blanks. - logit_paddings: (B, T)-array. Padding indicators for ``logits``. Each + logit_paddings : ArrayType + (B, T)-array. Padding indicators for ``logits``. Each element must be either 1.0 or 0.0, and ``logitpaddings[b, t] == 1.0`` denotes that ``logits[b, t, :]`` are padded values. - labels: (B, N)-array containing reference integer labels where N denotes + labels : ArrayType + (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence. - label_paddings: (B, N)-array. Padding indicators for ``labels``. Each + label_paddings : ArrayType + (B, N)-array. Padding indicators for ``labels``. Each element must be either 1.0 or 0.0, and ``labelpaddings[b, n] == 1.0`` denotes that ``labels[b, n]`` is a padded label. In the current implementation, ``labels`` must be right-padded, i.e. each row ``labelpaddings[b, :]`` must be repetition of zeroes, followed by repetition of ones. - blank_id: Id for blank token. ``logits[b, :, blank_id]`` are used as + blank_id : int + Id for blank token. ``logits[b, :, blank_id]`` are used as probabilities of blank symbols. - log_epsilon: Numerically-stable approximation of log(+0). + log_epsilon : float + Numerically-stable approximation of log(+0). - Returns: - (B,)-array containing loss values for each sequence in the batch. + Returns + ------- + (B,)-array containing loss values for each sequence in the batch. """ return _bt_metric.ctc_loss( logits, logit_paddings, labels, label_paddings, @@ -1028,12 +1112,18 @@ def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'): and :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}` and :math:`i \neq y`. - Args: - predicts: :math:`(N, C)` where `C = number of classes`. - target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. - margin (float, optional): Has a default value of :math:`1`. - p (float, optional): Has a default value of :math:`1`. - reduction (str, optional): Specifies the reduction to apply to the output: + Parameters + ---------- + predicts + :math:`(N, C)` where `C = number of classes`. + target + :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`. + margin : float, optional + Has a default value of :math:`1`. + p : float, optional + Has a default value of :math:`1`. + reduction : str, optional + Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. @@ -1041,8 +1131,9 @@ def multi_margin_loss(predicts, targets, margin=1.0, p=1, reduction='mean'): and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` - Returns: - a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`. + Returns + ------- + a scalar representing the multi-class margin loss. If `reduction` is ``'none'``, then :math:`(N)`. """ assert p == 1 or p == 2, 'p should be 1 or 2' # Convert to plain JAX arrays: under JAX >= 0.9 implicit __jax_array__ diff --git a/brainpy/losses/regularization.py b/brainpy/losses/regularization.py index 5bb1b1842..2d62258a9 100644 --- a/brainpy/losses/regularization.py +++ b/brainpy/losses/regularization.py @@ -32,11 +32,14 @@ def l2_norm(x, axis=None): """Computes the L2 loss. - Args: - x: n-dimensional tensor of floats. - - Returns: - scalar tensor containing the l2 loss of x. + Parameters + ---------- + x + n-dimensional tensor of floats. + + Returns + ------- + scalar tensor containing the l2 loss of x. """ leaves, _ = tree_flatten(x) return jnp.sqrt(jnp.sum(jnp.asarray([jnp.vdot(x, x) for x in leaves]), axis=axis)) @@ -45,8 +48,9 @@ def l2_norm(x, axis=None): def mean_absolute(outputs, axis=None): r"""Computes the mean absolute error between x and y. - Returns: - tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. + Returns + ------- + tensor of shape (d_i, ..., for i in keep_axis) containing the mean absolute error. """ r = tree_map(lambda a: _bt_metric.absolute_error(a, None, axis=axis, reduction='mean'), outputs, is_leaf=_is_leaf) @@ -64,12 +68,19 @@ def log_cosh(errors): log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` for large x. It is a twice differentiable alternative to the Huber loss. - References: - [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) - Args: - errors: a vector of arbitrary shape. - Returns: - the log-cosh loss. + + Parameters + ---------- + errors + a vector of arbitrary shape. + + Returns + ------- + the log-cosh loss. + + References + ---------- + [Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym) """ r = tree_map(lambda a: _bt_metric.log_cosh(a), errors, is_leaf=_is_leaf) @@ -81,14 +92,22 @@ def smooth_labels(labels, alpha: float) -> jnp.ndarray: Label smoothing is often used in combination with a cross-entropy loss. Smoothed labels favour small logit gaps, and it has been shown that this can provide better model calibration by preventing overconfident predictions. - References: - [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) - Args: - labels: one hot labels to be smoothed. - alpha: the smoothing factor, the greedy category with be assigned + + Parameters + ---------- + labels + one hot labels to be smoothed. + alpha : float + the smoothing factor, the greedy category with be assigned probability `(1-alpha) + alpha / num_categories` - Returns: - a smoothed version of the one hot input labels. + + Returns + ------- + a smoothed version of the one hot input labels. + + References + ---------- + [Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf) """ r = tree_map(lambda tar: _bt_metric.smooth_labels(tar, alpha), labels, is_leaf=lambda x: isinstance(x, bm.Array)) diff --git a/brainpy/math/activations.py b/brainpy/math/activations.py index 37287ff84..3d6c11d9e 100644 --- a/brainpy/math/activations.py +++ b/brainpy/math/activations.py @@ -107,12 +107,12 @@ def celu(x, alpha=1.0): `Continuously Differentiable Exponential Linear Units `_. - Parameters:: - + Parameters + ---------- x : ArrayType - The input array. + The input array. alpha : ndarray, float - The default is 1.0. + The default is 1.0. """ x = x.value if isinstance(x, Array) else x alpha = alpha.value if isinstance(alpha, Array) else alpha @@ -130,12 +130,12 @@ def elu(x, alpha=1.0): \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases} - Parameters:: - - x: JaxArray, jnp.ndarray - The input array. + Parameters + ---------- + x : JaxArray, jnp.ndarray + The input array. alpha : scalar or Array - default: 1.0. + default: 1.0. """ x = x.value if isinstance(x, Array) else x alpha = alpha.value if isinstance(alpha, Array) else alpha @@ -161,12 +161,12 @@ def gelu(x, approximate=True): For more information, see `Gaussian Error Linear Units (GELUs) `_, section 2. - Parameters:: - - x: ArrayType - The input array. - approximate: bool - whether to use the approximate or exact formulation. + Parameters + ---------- + x : ArrayType + The input array. + approximate : bool + whether to use the approximate or exact formulation. """ x = x.value if isinstance(x, Array) else x # Promote integer / boolean inputs to a floating dtype before computing. @@ -190,12 +190,12 @@ def gelu(x, approximate=True): def glu(x, axis=-1): r"""Gated linear unit activation function. - Parameters:: - - x: ArrayType - The input array. - axis: int - The axis along which the split should be computed (default: -1) + Parameters + ---------- + x : ArrayType + The input array. + axis : int + The axis along which the split should be computed (default: -1) """ size = x.shape[axis] assert size % 2 == 0, "axis size must be divisible by 2" @@ -216,14 +216,14 @@ def hard_tanh(x, min_val=- 1.0, max_val=1.0): 1, & 1 < x \end{cases} - Parameters:: - - x: ArrayType - The input array. - min_val: float - minimum value of the linear region range. Default: -1 - max_val: float - maximum value of the linear region range. Default: 1 + Parameters + ---------- + x : ArrayType + The input array. + min_val : float + minimum value of the linear region range. Default: -1 + max_val : float + maximum value of the linear region range. Default: 1 """ x = x.value if isinstance(x, Array) else x return jnp.where(x > max_val, max_val, jnp.where(x < min_val, min_val, x)) @@ -237,10 +237,10 @@ def hard_sigmoid(x): .. math:: \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6} - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. """ return relu6(x + 3.) / 6. @@ -263,10 +263,10 @@ def hard_silu(x): .. math:: \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x) - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. """ return x * hard_sigmoid(x) @@ -287,9 +287,13 @@ def hard_shrink(x, lambd=0.5): 0, & \text{ otherwise } \end{cases} - Args: - lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 + Parameters + ---------- + lambd + the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5 + Notes + ----- Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. @@ -312,12 +316,12 @@ def leaky_relu(x, negative_slope=1e-2): where :math:`\alpha` = :code:`negative_slope`. - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. negative_slope : float - The scalar specifying the negative slope (default: 0.01) + The scalar specifying the negative slope (default: 0.01) """ x = x.value if isinstance(x, Array) else x return jnp.where(x >= 0, x, negative_slope * x) @@ -337,11 +341,14 @@ def softplus(x, beta: float = 1., threshold: float = 20.): For numerical stability the implementation reverts to the linear function when :math:`input \times \beta > threshold`. - Parameters:: - - x: The input array. - beta: the :math:`\beta` value for the Softplus formulation. Default: 1. - threshold: values above this revert to a linear function. Default: 20. + Parameters + ---------- + x + The input array. + beta : float + the :math:`\beta` value for the Softplus formulation. Default: 1. + threshold : float + values above this revert to a linear function. Default: 20. """ x = x.value if isinstance(x, Array) else x @@ -356,10 +363,10 @@ def log_sigmoid(x): .. math:: \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x}) - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. """ return -softplus(-x) @@ -375,9 +382,13 @@ def soft_shrink(x, lambd=0.5): 0, & \text{ otherwise } \end{cases} - Args: - lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 + Parameters + ---------- + lambd + the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5 + Notes + ----- Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. @@ -396,13 +407,13 @@ def log_softmax(x, axis=-1): \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right) - Parameters:: - - x: ArrayType - The input array. - axis: int - The axis or axes along which the :code:`log_softmax` should be - computed. Either an integer or a tuple of integers. + Parameters + ---------- + x : ArrayType + The input array. + axis : int + The axis or axes along which the :code:`log_softmax` should be + computed. Either an integer or a tuple of integers. """ x = x.value if isinstance(x, Array) else x shifted = x - jax.lax.stop_gradient(x.max(axis, keepdims=True)) @@ -442,12 +453,17 @@ def one_hot(x, num_classes, *, dtype=None, axis=-1): Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) - Args: - x: A tensor of indices. - num_classes: Number of classes in the one-hot dimension. - dtype: optional, a float dtype for the returned values (default float64 if + Parameters + ---------- + x + A tensor of indices. + num_classes + Number of classes in the one-hot dimension. + dtype + optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). - axis: the axis or axes along which the function should be + axis + the axis or axes along which the function should be computed. """ num_classes = jax.core.concrete_or_error( @@ -511,8 +527,10 @@ def _relu(x: Array) -> Array: `Numerical influence of ReLU’(0) on backpropagation `_. - Args: - x : input array + Parameters + ---------- + x + input array """ return jnp.maximum(x, 0) @@ -529,10 +547,10 @@ def relu6(x): .. math:: \mathrm{relu6}(x) = \min(\max(x, 0), 6) - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. """ x = x.value if isinstance(x, Array) else x return jnp.minimum(jnp.maximum(x, 0), 6.) @@ -558,10 +576,15 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333, ): See: https://arxiv.org/pdf/1505.00853.pdf - Args: - lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` - upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + Parameters + ---------- + lower + lower bound of the uniform distribution. Default: :math:`\frac{1}{8}` + upper + upper bound of the uniform distribution. Default: :math:`\frac{1}{3}` + Notes + ----- Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. @@ -605,10 +628,10 @@ def sigmoid(x): .. math:: \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. """ x = x.value if isinstance(x, Array) else x return jax.scipy.special.expit(x) @@ -622,10 +645,10 @@ def soft_sign(x): .. math:: \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1} - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. """ x = x.value if isinstance(x, Array) else x return x / (jnp.abs(x) + 1) @@ -640,14 +663,14 @@ def softmax(x, axis=-1): .. math :: \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} - Parameters:: - - x: ArrayType - The input array. - axis: int - The axis or axes along which the softmax should be computed. The - softmax output summed across these dimensions should sum to :math:`1`. - Either an integer or a tuple of integers. + Parameters + ---------- + x : ArrayType + The input array. + axis : int + The axis or axes along which the softmax should be computed. The + softmax output summed across these dimensions should sum to :math:`1`. + Either an integer or a tuple of integers. """ x = x.value if isinstance(x, Array) else x unnormalized = jnp.exp(x - jax.lax.stop_gradient(x.max(axis, keepdims=True))) @@ -664,14 +687,18 @@ def softmin(x, axis=-1): .. math:: \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)} + Parameters + ---------- + axis : int + A dimension along which Softmin will be computed (so every slice + along dim will sum to 1). + + Notes + ----- Shape: - Input: :math:`(*)` where `*` means, any number of additional dimensions - Output: :math:`(*)`, same shape as the input - - Args: - axis (int): A dimension along which Softmin will be computed (so every slice - along dim will sum to 1). """ x = x.value if isinstance(x, Array) else x neg_x = -x @@ -690,10 +717,10 @@ def silu(x): .. math:: \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}} - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. """ x = x.value if isinstance(x, Array) else x return x * sigmoid(x) @@ -713,6 +740,8 @@ def mish(x): .. note:: See `Mish: A Self Regularized Non-Monotonic Neural Activation Function `_ + Notes + ----- Shape: - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. @@ -739,10 +768,10 @@ def selu(x): `Self-Normalizing Neural Networks `_. - Parameters:: - - x: ArrayType - The input array. + Parameters + ---------- + x : ArrayType + The input array. """ alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 diff --git a/brainpy/math/compat_numpy.py b/brainpy/math/compat_numpy.py index 02e5de9f3..d8355311e 100644 --- a/brainpy/math/compat_numpy.py +++ b/brainpy/math/compat_numpy.py @@ -287,21 +287,25 @@ def msort(a): """ Return a copy of an array sorted along the first axis. - Parameters:: + Parameters + ---------- a : array_like Array to be sorted. - Returns:: + Returns + ------- sorted_array : ndarray Array of the same type and shape as `a`. - See Also:: + See Also + -------- sort - Notes:: + Notes + ----- ``brainpy.math.msort(a)`` is equivalent to ``brainpy.math.sort(a, axis=0)``. @@ -466,24 +470,28 @@ def shape(a): """ Return the shape of an array. - Parameters:: + Parameters + ---------- a : array_like Input array. - Returns:: + Returns + ------- shape : tuple of ints The elements of the shape tuple give the lengths of the corresponding array dimensions. - See Also:: + See Also + -------- len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with ``N>=1``. ndarray.shape : Equivalent array method. - Examples:: + Examples + -------- >>> import brainpy >>> brainpy.math.shape(brainpy.math.eye(3)) @@ -506,7 +514,8 @@ def size(a, axis=None): """ Return the number of elements along a given axis. - Parameters:: + Parameters + ---------- a : array_like Input data. @@ -514,18 +523,21 @@ def size(a, axis=None): Axis along which the elements are counted. By default, give the total number of elements. - Returns:: + Returns + ------- element_count : int Number of elements along the specified axis. - See Also:: + See Also + -------- shape : dimensions of array Array.shape : dimensions of array Array.size : number of elements in array - Examples:: + Examples + -------- >>> import brainpy >>> a = brainpy.math.array([[1,2,3], [4,5,6]]) diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index 71e8671e6..8dab8add9 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -66,18 +66,18 @@ def flatten(input: Union[jax.Array, Array], .. note:: Flattening a zero-dimensional tensor will return a one-dimensional view. - Parameters:: - - input: Array - The input array. - start_dim: int - the first dim to flatten - end_dim: int - the last dim to flatten - - Returns:: - - out: Array + Parameters + ---------- + input : Array + The input array. + start_dim : int + the first dim to flatten + end_dim : int + the last dim to flatten + + Returns + ------- + out : Array """ input = _as_jax_array_(input) shape = input.shape @@ -105,17 +105,22 @@ def unflatten(x: Union[jax.Array, Array], dim: int, sizes: Sequence[int]) -> Arr """ Expands a dimension of the input tensor over multiple dimensions. - Args: - x: input tensor. - dim: Dimension to be unflattened, specified as an index into ``x.shape``. - sizes: New shape of the unflattened dimension. One of its elements can be -1 - in which case the corresponding output dimension is inferred. - Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``. - - Returns: - A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions. - The returned tensor has one more dimension than the input tensor. - The returned tensor shares the same underlying data with this tensor. + Parameters + ---------- + x : Union[jax.Array, Array] + input tensor. + dim : int + Dimension to be unflattened, specified as an index into ``x.shape``. + sizes : Sequence[int] + New shape of the unflattened dimension. One of its elements can be -1 + in which case the corresponding output dimension is inferred. + Otherwise, the product of ``sizes`` must equal ``input.shape[dim]``. + + Returns + ------- + A tensor with the same data as ``input``, but with ``dim`` split into multiple dimensions. + The returned tensor has one more dimension than the input tensor. + The returned tensor shares the same underlying data with this tensor. """ x = _as_jax_array_(x) ndim = x.ndim @@ -140,16 +145,16 @@ def unsqueeze(x: Union[jax.Array, Array], dim: int) -> Array: A dim value within the range ``[-input.dim() - 1, input.dim() + 1)`` can be used. Negative dim will correspond to unsqueeze() applied at ``dim = dim + input.dim() + 1``. - Parameters:: - - x: Array - The input Array - dim: int - The index at which to insert the singleton dimension - - Returns:: + Parameters + ---------- + x : Array + The input Array + dim : int + The index at which to insert the singleton dimension - out: Array + Returns + ------- + out : Array """ x = _as_jax_array_(x) r = jnp.expand_dims(x, dim) diff --git a/brainpy/math/compat_tensorflow.py b/brainpy/math/compat_tensorflow.py index 89fe406ad..fdd580a2e 100644 --- a/brainpy/math/compat_tensorflow.py +++ b/brainpy/math/compat_tensorflow.py @@ -66,15 +66,20 @@ def reduce_logsumexp(input_tensor, axis=None, keepdims=False): overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. - Args: - input_tensor: The tensor to reduce. Should have numeric type. - axis: The dimensions to reduce. If `None` (the default), reduces all + Parameters + ---------- + input_tensor + The tensor to reduce. Should have numeric type. + axis + The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. - keepdims: If true, retains reduced dimensions with length 1. + keepdims + If true, retains reduced dimensions with length 1. - Returns: - The reduced tensor. + Returns + ------- + The reduced tensor. """ r = jax.scipy.special.logsumexp(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims) return _return(r) @@ -91,15 +96,20 @@ def reduce_euclidean_norm(input_tensor, axis=None, keepdims=False): If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. - Args: - input_tensor: The tensor to reduce. Should have numeric type. - axis: The dimensions to reduce. If `None` (the default), reduces all + Parameters + ---------- + input_tensor + The tensor to reduce. Should have numeric type. + axis + The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. - keepdims: If true, retains reduced dimensions with length 1. + keepdims + If true, retains reduced dimensions with length 1. - Returns: - The reduced tensor, of the same dtype as the input_tensor. + Returns + ------- + The reduced tensor, of the same dtype as the input_tensor. """ r = jnp.linalg.norm(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims) return _return(r) @@ -118,15 +128,20 @@ def reduce_max(input_tensor, axis=None, keepdims=False): If `axis` is None, all dimensions are reduced, and a tensor with a single element is returned. - Args: - input_tensor: The tensor to reduce. Should have real numeric type. - axis: The dimensions to reduce. If `None` (the default), reduces all + Parameters + ---------- + input_tensor + The tensor to reduce. Should have real numeric type. + axis + The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. - keepdims: If true, retains reduced dimensions with length 1. + keepdims + If true, retains reduced dimensions with length 1. - Returns: - The reduced tensor. + Returns + ------- + The reduced tensor. """ return _return(jnp.max(_as_jax_array_(input_tensor), axis=axis, keepdims=keepdims)) @@ -239,39 +254,39 @@ def segment_sum(data: Union[Array, jnp.ndarray], mode: Optional[lax.GatherScatterMode] = None) -> Array: """``segment_sum`` operator for brainpy `Array` and `Variable`. - Parameters:: - - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns:: - - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. + Parameters + ---------- + data : Array + An array with the values to be reduced. + segment_ids : Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments : Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted : bool + whether ``segment_ids`` is known to be sorted. + unique_indices : bool + whether `segment_ids` is known to be free of duplicates. + bucket_size : int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode : lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns + ------- + output : Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. """ return _return(jax.ops.segment_sum(as_jax(data), as_jax(segment_ids), @@ -291,39 +306,39 @@ def segment_prod(data: Union[Array, jnp.ndarray], mode: Optional[lax.GatherScatterMode] = None) -> Array: """``segment_prod`` operator for brainpy `Array` and `Variable`. - Parameters:: - - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns:: - - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. + Parameters + ---------- + data : Array + An array with the values to be reduced. + segment_ids : Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments : Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted : bool + whether ``segment_ids`` is known to be sorted. + unique_indices : bool + whether `segment_ids` is known to be free of duplicates. + bucket_size : int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode : lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns + ------- + output : Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. """ return _return(jax.ops.segment_prod(as_jax(data), as_jax(segment_ids), @@ -343,39 +358,39 @@ def segment_max(data: Union[Array, jnp.ndarray], mode: Optional[lax.GatherScatterMode] = None) -> Array: """``segment_max`` operator for brainpy `Array` and `Variable`. - Parameters:: - - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns:: - - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. + Parameters + ---------- + data : Array + An array with the values to be reduced. + segment_ids : Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments : Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted : bool + whether ``segment_ids`` is known to be sorted. + unique_indices : bool + whether `segment_ids` is known to be free of duplicates. + bucket_size : int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode : lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns + ------- + output : Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. """ return _return(jax.ops.segment_max(as_jax(data), as_jax(segment_ids), @@ -395,39 +410,39 @@ def segment_min(data: Union[Array, jnp.ndarray], mode: Optional[lax.GatherScatterMode] = None) -> Array: """``segment_min`` operator for brainpy `Array` and `Variable`. - Parameters:: - - data: Array - An array with the values to be reduced. - segment_ids: Array - An array with integer dtype that indicates the segments of - `data` (along its leading axis) to be summed. Values can be repeated and - need not be sorted. - num_segments: Optional, int - An int with nonnegative value indicating the number - of segments. The default is set to be the minimum number of segments that - would support all indices in ``segment_ids``, calculated as - ``max(segment_ids) + 1``. - Since `num_segments` determines the size of the output, a static value - must be provided to use ``segment_sum`` in a ``jit``-compiled function. - indices_are_sorted: bool - whether ``segment_ids`` is known to be sorted. - unique_indices: bool - whether `segment_ids` is known to be free of duplicates. - bucket_size: int - Size of bucket to group indices into. ``segment_sum`` is - performed on each bucket separately to improve numerical stability of - addition. Default ``None`` means no bucketing. - mode: lax.GatherScatterMode - A :class:`jax.lax.GatherScatterMode` value describing how - out-of-bounds indices should be handled. By default, values outside of the - range [0, num_segments) are dropped and do not contribute to the sum. - - Returns:: - - output: Array - An array with shape :code:`(num_segments,) + data.shape[1:]` representing the - segment sums. + Parameters + ---------- + data : Array + An array with the values to be reduced. + segment_ids : Array + An array with integer dtype that indicates the segments of + `data` (along its leading axis) to be summed. Values can be repeated and + need not be sorted. + num_segments : Optional, int + An int with nonnegative value indicating the number + of segments. The default is set to be the minimum number of segments that + would support all indices in ``segment_ids``, calculated as + ``max(segment_ids) + 1``. + Since `num_segments` determines the size of the output, a static value + must be provided to use ``segment_sum`` in a ``jit``-compiled function. + indices_are_sorted : bool + whether ``segment_ids`` is known to be sorted. + unique_indices : bool + whether `segment_ids` is known to be free of duplicates. + bucket_size : int + Size of bucket to group indices into. ``segment_sum`` is + performed on each bucket separately to improve numerical stability of + addition. Default ``None`` means no bucketing. + mode : lax.GatherScatterMode + A :class:`jax.lax.GatherScatterMode` value describing how + out-of-bounds indices should be handled. By default, values outside of the + range [0, num_segments) are dropped and do not contribute to the sum. + + Returns + ------- + output : Array + An array with shape :code:`(num_segments,) + data.shape[1:]` representing the + segment sums. """ return _return(jax.ops.segment_min(as_jax(data), as_jax(segment_ids), @@ -455,14 +470,19 @@ def cast(x, dtype): Note casting nan and inf values to integral types has undefined behavior. - Args: - x: A `Array`. It could be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, + Parameters + ---------- + x + A `Array`. It could be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`. - dtype: The destination type. The list of supported dtypes is the same as + dtype + The destination type. The list of supported dtypes is the same as `x`. - Returns: - A `Array` with same shape as `x` and same type as `dtype`. + + Returns + ------- + A `Array` with same shape as `x` and same type as `dtype`. """ return asarray(x, dtype=dtype) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index 10d69b420..09250402c 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -111,25 +111,26 @@ class TimeDelay(AbstractDelay): [[-0.8] [-0.8]]] - Parameters:: + Parameters + ---------- - delay_target: ArrayType + delay_target : ArrayType The initial delay data. - t0: float, int + t0 : float, int The zero time. - delay_len: float, int + delay_len : float, int The maximum delay length. - dt: float, int + dt : float, int The time precesion. - before_t0: callable, bm.ndarray, jnp.ndarray, float, int + before_t0 : callable, bm.ndarray, jnp.ndarray, float, int The delay data before ::math`t_0`. - when `before_t0` is a function, it should receive a time argument `t` - when `before_to` is a tensor, it should be a tensor with shape of :math:`(num_delay, ...)`, where the longest delay data is aranged in the first index. - name: str + name : str The delay instance name. - interp_method: str + interp_method : str The way to deal with the delay at the time which is not integer times of the time step. For exameple, if the time step ``dt=0.1``, the time delay length ``delay\_len=1.``, when users require the delay data at ``t-0.53``, we can deal this situation with @@ -143,7 +144,8 @@ class TimeDelay(AbstractDelay): .. versionadded:: 2.1.1 - See Also:: + See Also + -------- LengthDelay """ @@ -215,15 +217,16 @@ def reset(self, before_t0=None): """Reset the delay variable. - Parameters:: + Parameters + ---------- - delay_target: ArrayType + delay_target : ArrayType The delay target. - delay_len: float, int + delay_len : float, int The maximum delay length. The unit is the time. - t0: int, float + t0 : int, float The zero time. - before_t0: callable, int, float, ArrayType + before_t0 : callable, int, float, ArrayType The data before t0. - when ``before_t0`` is a function, it should receive a time argument ``t`` (mirroring the behaviour of ``__init__``). @@ -324,13 +327,14 @@ class NeuTimeDelay(TimeDelay): class LengthDelay(AbstractDelay): """Delay variable which has a fixed delay length. - Parameters:: + Parameters + ---------- - delay_target: int, sequence of int + delay_target : int, sequence of int The initial delay data. - delay_len: int + delay_len : int The maximum delay length. - initial_delay_data: Any + initial_delay_data : Any The delay data. It can be a Python number, like float, int, boolean values. It can also be arrays. Or a callable function or instance of ``Connector``. A callable will be invoked as ``initial_delay_data(shape, dtype=...)`` when its @@ -358,14 +362,15 @@ class LengthDelay(AbstractDelay): delay = 1 data ] - name: str + name : str The delay object name. - batch_axis: int + batch_axis : int The batch axis. If not provided, it will be inferred from the `delay_target`. - update_method: str + update_method : str The method used for updating delay. - See Also:: + See Also + -------- TimeDelay """ @@ -481,9 +486,10 @@ def __call__(self, delay_len, *indices): def retrieve(self, delay_len, *indices): """Retrieve the delay data acoording to the delay length. - Parameters:: + Parameters + ---------- - delay_len: int, ArrayType + delay_len : int, ArrayType The delay length used to retrieve the data. """ if check.is_checking(): @@ -511,9 +517,10 @@ def retrieve(self, delay_len, *indices): def update(self, value: Union[numbers.Number, Array, jax.Array] = None): """Update delay variable with the new data. - Parameters:: + Parameters + ---------- - value: Any + value : Any The value of the latest data, used to update this delay variable. """ if value is None: diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index 3bf28d251..ba4c5b268 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -365,28 +365,28 @@ def set( ): """Set the default computation environment. - Parameters:: - - mode: Mode - The computing mode. - membrane_scaling: Scaling - The numerical membrane_scaling. - dt: float - The numerical integration precision. - x64: bool - Enable x64 computation. - complex_: type - The complex data type. - float_ - The floating data type. - int_ - The integer data type. - bool_ - The bool data type. - bp_object_as_pytree: bool - Whether to register brainpy object as pytree. - numpy_func_return: str - The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. + Parameters + ---------- + mode : Mode + The computing mode. + membrane_scaling : Scaling + The numerical membrane_scaling. + dt : float + The numerical integration precision. + x64 : bool + Enable x64 computation. + complex_ : type + The complex data type. + float_ : type + The floating data type. + int_ : type + The integer data type. + bool_ : type + The bool data type. + bp_object_as_pytree : bool + Whether to register brainpy object as pytree. + numpy_func_return : str + The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. """ # Validate all arguments BEFORE mutating any global state, so that an # invalid argument cannot leave the environment in a half-updated state. @@ -475,10 +475,10 @@ def dftype(): def set_float(dtype: type): """Set global default float type. - Parameters:: - - dtype: type - The float type. + Parameters + ---------- + dtype : type + The float type. """ defaults.float_ = dtype @@ -486,10 +486,10 @@ def set_float(dtype: type): def get_float(): """Get the default float data type. - Returns:: - - dftype: type - The default float data type. + Returns + ------- + dftype : type + The default float data type. """ return defaults.float_ @@ -497,10 +497,10 @@ def get_float(): def set_int(dtype: type): """Set global default integer type. - Parameters:: - - dtype: type - The integer type. + Parameters + ---------- + dtype : type + The integer type. """ defaults.int_ = dtype @@ -508,10 +508,10 @@ def set_int(dtype: type): def get_int(): """Get the default int data type. - Returns:: - - dftype: type - The default int data type. + Returns + ------- + dftype : type + The default int data type. """ return defaults.int_ @@ -519,10 +519,10 @@ def get_int(): def set_bool(dtype: type): """Set global default boolean type. - Parameters:: - - dtype: type - The bool type. + Parameters + ---------- + dtype : type + The bool type. """ defaults.bool_ = dtype @@ -530,10 +530,10 @@ def set_bool(dtype: type): def get_bool(): """Get the default boolean data type. - Returns:: - - dftype: type - The default bool data type. + Returns + ------- + dftype : type + The default bool data type. """ return defaults.bool_ @@ -541,10 +541,10 @@ def get_bool(): def set_complex(dtype: type): """Set global default complex type. - Parameters:: - - dtype: type - The complex type. + Parameters + ---------- + dtype : type + The complex type. """ defaults.complex_ = dtype @@ -552,10 +552,10 @@ def set_complex(dtype: type): def get_complex(): """Get the default complex data type. - Returns:: - - dftype: type - The default complex data type. + Returns + ------- + dftype : type + The default complex data type. """ return defaults.complex_ @@ -566,8 +566,8 @@ def get_complex(): def set_dt(dt): """Set the default numerical integrator precision. - Parameters:: - + Parameters + ---------- dt : float Numerical integration precision. """ @@ -578,8 +578,8 @@ def set_dt(dt): def get_dt(): """Get the numerical integrator precision. - Returns:: - + Returns + ------- dt : float Numerical integration precision. """ @@ -589,10 +589,10 @@ def get_dt(): def set_mode(mode: modes.Mode): """Set the default computing mode. - Parameters:: - - mode: Mode - The instance of :py:class:`~.Mode`. + Parameters + ---------- + mode : Mode + The instance of :py:class:`~.Mode`. """ if not isinstance(mode, modes.Mode): raise TypeError(f'Must be instance of brainpy.math.Mode. ' @@ -603,10 +603,10 @@ def set_mode(mode: modes.Mode): def get_mode() -> modes.Mode: """Get the default computing mode. - References:: - - mode: Mode - The default computing mode. + References + ---------- + mode : Mode + The default computing mode. """ return defaults.mode @@ -614,10 +614,10 @@ def get_mode() -> modes.Mode: def set_membrane_scaling(membrane_scaling: scales.Scaling): """Set the default computing membrane_scaling. - Parameters:: - - scaling: Scaling - The instance of :py:class:`~.Scaling`. + Parameters + ---------- + scaling : Scaling + The instance of :py:class:`~.Scaling`. """ if not isinstance(membrane_scaling, scales.Scaling): raise TypeError(f'Must be instance of brainpy.math.Scaling. ' @@ -628,10 +628,10 @@ def set_membrane_scaling(membrane_scaling: scales.Scaling): def get_membrane_scaling() -> scales.Scaling: """Get the default computing membrane_scaling. - Returns:: - - membrane_scaling: Scaling - The default computing membrane_scaling. + Returns + ------- + membrane_scaling : Scaling + The default computing membrane_scaling. """ return defaults.membrane_scaling @@ -684,10 +684,10 @@ def set_platform(platform: str): def get_platform() -> str: """Get the computing platform. - Returns:: - - platform: str - Either 'cpu', 'gpu' or 'tpu'. + Returns + ------- + platform : str + Either 'cpu', 'gpu' or 'tpu'. """ return devices()[0].platform @@ -709,7 +709,10 @@ def set_host_device_count(n): know through our issue or forum page. More information is available in this `JAX issue `_. - :param int n: number of devices to use. + Parameters + ---------- + n : int + number of devices to use. """ xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split() @@ -733,18 +736,18 @@ def clear_buffer_memory( This operation may cause errors when you use a deleted buffer. Therefore, regenerate data always. - Parameters:: - - platform: str - The device to clear its memory. - array: bool - Clear all buffer array. Default is True. - compilation: bool - Clear compilation cache. Default is False. - transform: bool - Clear transform cache. Default is True. - object_name: bool - Clear name cache. Default is True. + Parameters + ---------- + platform : str + The device to clear its memory. + array : bool + Clear all buffer array. Default is True. + compilation : bool + Clear compilation cache. Default is False. + transform : bool + Clear transform cache. Default is True. + object_name : bool + Clear name cache. Default is True. """ from brainstate._compatible_import import get_backend @@ -769,8 +772,10 @@ def disable_gpu_memory_preallocation(release_memory: bool = True): GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory may OOM with preallocation disabled. - Args: - release_memory: bool. Whether we release memory during the computation. + Parameters + ---------- + release_memory : bool + Whether we release memory during the computation. """ os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' if release_memory: diff --git a/brainpy/math/event/csr_matmat.py b/brainpy/math/event/csr_matmat.py index a8084328d..02303da5b 100644 --- a/brainpy/math/event/csr_matmat.py +++ b/brainpy/math/event/csr_matmat.py @@ -36,18 +36,20 @@ def csrmm( ): """Product of CSR sparse matrix and a dense event matrix. - Args: - data : array of shape ``(nse,)``, float. - indices : array of shape ``(nse,)`` - indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` - B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + Parameters + ---------- + data : array of shape ``(nse,)``, float. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and dtype ``data.dtype`` - shape : length-2 tuple representing the matrix shape - transpose : boolean specifying whether to transpose the sparse matrix + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix before computing. - Returns: - C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + Returns + ------- + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product product. """ if isinstance(data, Array): diff --git a/brainpy/math/event/csr_matvec.py b/brainpy/math/event/csr_matvec.py index 8be452447..dcac8974e 100644 --- a/brainpy/math/event/csr_matvec.py +++ b/brainpy/math/event/csr_matvec.py @@ -49,26 +49,28 @@ def csrmv( This function supports JAX transformations, including `jit()`, `grad()`, `vmap()` and `pmap()`. - Parameters:: + Parameters + ---------- - data: ndarray, float + data : ndarray, float An array of shape ``(nse,)``. - indices: ndarray + indices : ndarray An array of shape ``(nse,)``. - indptr: ndarray + indptr : ndarray An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray + events : ndarray An array of shape ``(shape[0] if transpose else shape[1],)`` and dtype ``data.dtype``. - shape: tuple + shape : tuple A length-2 tuple representing the matrix shape. - transpose: bool + transpose : bool A boolean specifying whether to transpose the sparse matrix before computing. If ``transpose=True``, the operator will compute based on the event-driven property of the ``events`` vector. - Returns:: + Returns + ------- y : Array The array of shape ``(shape[1] if transpose else shape[0],)`` representing diff --git a/brainpy/math/interoperability.py b/brainpy/math/interoperability.py index c5287e342..824322af7 100644 --- a/brainpy/math/interoperability.py +++ b/brainpy/math/interoperability.py @@ -39,16 +39,18 @@ def is_bp_array(x): def as_device_array(tensor, dtype=None): """Convert the input to a ``jax.numpy.DeviceArray``. - Parameters:: + Parameters + ---------- - tensor: array_like + tensor : array_like Input data, in any form that can be converted to an array. This includes lists, lists of tuples, tuples, tuples of tuples, tuples of lists, ArrayType. - dtype: data-type, optional + dtype : data-type, optional By default, the data-type is inferred from the input data. - Returns:: + Returns + ------- out : ArrayType Array interpretation of `tensor`. No copy is performed if the input @@ -70,16 +72,18 @@ def as_device_array(tensor, dtype=None): def as_ndarray(tensor, dtype=None): """Convert the input to a ``numpy.ndarray``. - Parameters:: + Parameters + ---------- - tensor: array_like + tensor : array_like Input data, in any form that can be converted to an array. This includes lists, lists of tuples, tuples, tuples of tuples, tuples of lists, ArrayType. - dtype: data-type, optional + dtype : data-type, optional By default, the data-type is inferred from the input data. - Returns:: + Returns + ------- out : ndarray Array interpretation of `tensor`. No copy is performed if the input @@ -97,16 +101,18 @@ def as_ndarray(tensor, dtype=None): def as_variable(tensor, dtype=None): """Convert the input to a ``brainpy.math.Variable``. - Parameters:: + Parameters + ---------- - tensor: array_like + tensor : array_like Input data, in any form that can be converted to an array. This includes lists, lists of tuples, tuples, tuples of tuples, tuples of lists, ArrayType. - dtype: data-type, optional + dtype : data-type, optional By default, the data-type is inferred from the input data. - Returns:: + Returns + ------- out : ndarray Array interpretation of `tensor`. No copy is performed if the input diff --git a/brainpy/math/jitconn/matvec.py b/brainpy/math/jitconn/matvec.py index 89442333c..deaa33d05 100644 --- a/brainpy/math/jitconn/matvec.py +++ b/brainpy/math/jitconn/matvec.py @@ -65,17 +65,18 @@ def mv_prob_homo( matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of the speed compared with ``outdim_parallel=False``. - Parameters:: + Parameters + ---------- - vector: Array, ndarray + vector : Array, ndarray The vector. - weight: float + weight : float The value of the random matrix. - conn_prob: float + conn_prob : float The connection probability. - shape: tuple of int + shape : tuple of int The matrix shape. - seed: int + seed : int The random number generation seed. .. warning:: @@ -88,16 +89,17 @@ def mv_prob_homo( jitted call reuses that single seed. For reproducible and correct behaviour under JAX transformations, always pass an explicit integer ``seed``. - transpose: bool + transpose : bool Transpose the random matrix or not. - outdim_parallel: bool + outdim_parallel : bool Perform the parallel random generations along the out dimension or not. It can be used to set the just-in-time generated :math:M^T: is the same as the just-in-time generated :math:`M` when ``transpose=True``. - Returns:: + Returns + ------- - out: Array, ndarray + out : Array, ndarray The output of :math:`y = M @ v`. """ if seed is None: @@ -149,19 +151,20 @@ def mv_prob_uniform( matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of the speed compared with ``outdim_parallel=False``. - Parameters:: + Parameters + ---------- - vector: Array, ndarray + vector : Array, ndarray The vector. - w_low: float + w_low : float Lower boundary of the output interval. - w_high: float + w_high : float Upper boundary of the output interval. - conn_prob: float + conn_prob : float The connection probability. - shape: tuple of int + shape : tuple of int The matrix shape. - seed: int + seed : int The random number generation seed. .. warning:: @@ -174,16 +177,17 @@ def mv_prob_uniform( jitted call reuses that single seed. For reproducible and correct behaviour under JAX transformations, always pass an explicit integer ``seed``. - transpose: bool + transpose : bool Transpose the random matrix or not. - outdim_parallel: bool + outdim_parallel : bool Perform the parallel random generations along the out dimension or not. It can be used to set the just-in-time generated :math:M^T: is the same as the just-in-time generated :math:`M` when ``transpose=True``. - Returns:: + Returns + ------- - out: Array, ndarray + out : Array, ndarray The output of :math:`y = M @ v`. """ if seed is None: @@ -237,19 +241,20 @@ def mv_prob_normal( matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of the speed compared with ``outdim_parallel=False``. - Parameters:: + Parameters + ---------- - vector: Array, ndarray + vector : Array, ndarray The vector. - w_mu: float + w_mu : float Mean (centre) of the distribution. - w_sigma: float + w_sigma : float Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float + conn_prob : float The connection probability. - shape: tuple of int + shape : tuple of int The matrix shape. - seed: int + seed : int The random number generation seed. .. warning:: @@ -262,16 +267,17 @@ def mv_prob_normal( jitted call reuses that single seed. For reproducible and correct behaviour under JAX transformations, always pass an explicit integer ``seed``. - transpose: bool + transpose : bool Transpose the random matrix or not. - outdim_parallel: bool + outdim_parallel : bool Perform the parallel random generations along the out dimension or not. It can be used to set the just-in-time generated :math:M^T: is the same as the just-in-time generated :math:`M` when ``transpose=True``. - Returns:: + Returns + ------- - out: Array, ndarray + out : Array, ndarray The output of :math:`y = M @ v`. """ if seed is None: @@ -301,13 +307,14 @@ def get_homo_weight_matrix( ) -> jax.Array: r"""Get the connection matrix :math:`M` with a connection probability `conn_prob`. - Parameters:: + Parameters + ---------- - conn_prob: float + conn_prob : float The connection probability. - shape: tuple of int + shape : tuple of int The matrix shape. - seed: int + seed : int The random number generation seed. .. warning:: @@ -320,16 +327,17 @@ def get_homo_weight_matrix( jitted call reuses that single seed. For reproducible and correct behaviour under JAX transformations, always pass an explicit integer ``seed``. - transpose: bool + transpose : bool Transpose the random matrix or not. - outdim_parallel: bool + outdim_parallel : bool Perform the parallel random generations along the out dimension or not. It can be used to set the just-in-time generated :math:M^T: is the same as the just-in-time generated :math:`M` when ``transpose=True``. - Returns:: + Returns + ------- - out: Array, ndarray + out : Array, ndarray The connection matrix :math:`M`. """ if seed is None: @@ -352,17 +360,18 @@ def get_uniform_weight_matrix( ) -> jax.Array: r"""Get the weight matrix :math:`M` with a uniform distribution for its value. - Parameters:: + Parameters + ---------- - w_low: float + w_low : float Lower boundary of the output interval. - w_high: float + w_high : float Upper boundary of the output interval. - conn_prob: float + conn_prob : float The connection probability. - shape: tuple of int + shape : tuple of int The matrix shape. - seed: int + seed : int The random number generation seed. .. warning:: @@ -375,16 +384,17 @@ def get_uniform_weight_matrix( jitted call reuses that single seed. For reproducible and correct behaviour under JAX transformations, always pass an explicit integer ``seed``. - transpose: bool + transpose : bool Transpose the random matrix or not. - outdim_parallel: bool + outdim_parallel : bool Perform the parallel random generations along the out dimension or not. It can be used to set the just-in-time generated :math:M^T: is the same as the just-in-time generated :math:`M` when ``transpose=True``. - Returns:: + Returns + ------- - out: Array, ndarray + out : Array, ndarray The weight matrix :math:`M`. """ if seed is None: @@ -412,15 +422,16 @@ def get_normal_weight_matrix( ) -> jax.Array: r"""Get the weight matrix :math:`M` with a normal distribution for its value. - Parameters:: + Parameters + ---------- - w_mu: float + w_mu : float Mean (centre) of the distribution. - w_sigma: float + w_sigma : float Standard deviation (spread or “width”) of the distribution. Must be non-negative. - shape: tuple of int + shape : tuple of int The matrix shape. - seed: int + seed : int The random number generation seed. .. warning:: @@ -433,16 +444,17 @@ def get_normal_weight_matrix( jitted call reuses that single seed. For reproducible and correct behaviour under JAX transformations, always pass an explicit integer ``seed``. - transpose: bool + transpose : bool Transpose the random matrix or not. - outdim_parallel: bool + outdim_parallel : bool Perform the parallel random generations along the out dimension or not. It can be used to set the just-in-time generated :math:M^T: is the same as the just-in-time generated :math:`M` when ``transpose=True``. - Returns:: + Returns + ------- - out: Array, ndarray + out : Array, ndarray The weight matrix :math:`M`. """ if seed is None: diff --git a/brainpy/math/ndarray.py b/brainpy/math/ndarray.py index 1d4e66f35..cdafb34cc 100644 --- a/brainpy/math/ndarray.py +++ b/brainpy/math/ndarray.py @@ -200,8 +200,10 @@ def device_buffer(self): def fill_(self, fill_value): """Fill the array with a scalar value. - Args: - fill_value: the scalar value to fill the array. + Parameters + ---------- + fill_value + the scalar value to fill the array. """ if isinstance(fill_value, Array): fill_value = fill_value.value @@ -231,10 +233,14 @@ class ShardedArray(Array): A drawback of sharding is that the data may not be evenly distributed on shards. - Args: - value: the array value. - dtype: the array type. - keep_sharding: keep the array sharding information using ``jax.lax.with_sharding_constraint``. Default True. + Parameters + ---------- + value + the array value. + dtype : Any + the array type. + keep_sharding : bool + keep the array sharding information using ``jax.lax.with_sharding_constraint``. Default True. """ __slots__ = ('_value', '_keep_sharding') @@ -265,8 +271,9 @@ def tree_unflatten(cls, aux_data, flat_contents): def value(self): """The value stored in this array. - Returns: - The stored data. + Returns + ------- + The stored data. """ v = self._value # Keep sharding constraints, but only for genuinely multi-device diff --git a/brainpy/math/object_transform/_utils.py b/brainpy/math/object_transform/_utils.py index c4b32d71a..f49cac96d 100644 --- a/brainpy/math/object_transform/_utils.py +++ b/brainpy/math/object_transform/_utils.py @@ -69,8 +69,10 @@ def fn(inputs, states): where `inputs` is a list of all input arguments, and `states` is a list of all state arguments. - Args: - fn: The function to be decorated. + Parameters + ---------- + fn + The function to be decorated. """ if isinstance(fn, brainstate.typing.Missing): diff --git a/brainpy/math/object_transform/autograd.py b/brainpy/math/object_transform/autograd.py index a64b675c8..089649b91 100644 --- a/brainpy/math/object_transform/autograd.py +++ b/brainpy/math/object_transform/autograd.py @@ -100,7 +100,8 @@ def grad( >>> f_grad = bm.grad(f, grad_vars=f.x, argnums=(0, 1)) - Examples:: + Examples + -------- Grad for a pure function: @@ -109,7 +110,8 @@ def grad( >>> print(grad_tanh(0.2)) 0.961043 - Parameters:: + Parameters + ---------- func : callable, function, BrainPyObject Function to be differentiated. Its arguments at positions specified by @@ -121,21 +123,22 @@ def grad( The variables in ``func`` to take their gradients. argnums : optional, integer or sequence of integers Specifies which positional argument(s) to differentiate with respect to (default 0). - has_aux: optional, bool + has_aux : optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. - holomorphic: optional, bool + holomorphic : optional, bool Indicates whether ``fun`` is promised to be holomorphic. If True, inputs and outputs must be complex. Default False. - allow_int: optional, bool + allow_int : optional, bool Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. - Returns:: + Returns + ------- func : GradientTransform A function with the same arguments as ``fun``, that evaluates the gradient @@ -201,31 +204,33 @@ def jacrev( - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. - Parameters:: + Parameters + ---------- - func: Function whose Jacobian is to be computed. + func : Function whose Jacobian is to be computed. grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. - has_aux: optional, bool + has_aux : optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. - argnums: Optional, integer or sequence of integers. + argnums : Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default ``0``). - holomorphic: Optional, bool. + holomorphic : Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. Default False. - allow_int: Optional, bool. + allow_int : Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. - Returns:: + Returns + ------- - fun: GradientTransform + fun : GradientTransform The transformed object. """ @@ -277,25 +282,27 @@ def jacfwd( - "has_aux=False" + "return_value=True" => ``((var_grads, arg_grads), loss_value)``. - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. - Parameters:: + Parameters + ---------- - func: Function whose Jacobian is to be computed. + func : Function whose Jacobian is to be computed. grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. - has_aux: optional, bool + has_aux : optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. - argnums: Optional, integer or sequence of integers. Specifies which + argnums : Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default ``0``). - holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be + holomorphic : Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. Default False. - Returns:: + Returns + ------- - obj: GradientTransform + obj : GradientTransform The transformed object. """ return brainstate.transform.jacfwd( @@ -318,7 +325,8 @@ def hessian( ): """Hessian of ``func`` as a dense array. - Parameters:: + Parameters + ---------- func : callable, function Function whose Hessian is to be computed. Its arguments at positions @@ -327,7 +335,7 @@ def hessian( containers thereof. grad_vars : optional, ArrayCollector, sequence of ArrayType The variables required to compute their gradients. - argnums: Optional, integer or sequence of integers + argnums : Optional, integer or sequence of integers Specifies which positional argument(s) to differentiate with respect to (default ``0``). holomorphic : bool Indicates whether ``fun`` is promised to be holomorphic. Default False. @@ -336,9 +344,10 @@ def hessian( considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. - Returns:: + Returns + ------- - obj: ObjectTransform + obj : ObjectTransform The transformed object. """ @@ -383,22 +392,24 @@ def vector_grad( - "has_aux=True" + "return_value=True" => ``((var_grads, arg_grads), loss_value, aux_data)``. - Parameters:: + Parameters + ---------- - func: Callable + func : Callable Function whose gradient is to be computed. grad_vars : optional, ArrayType, sequence of ArrayType, dict The variables in ``func`` to take their gradients. - has_aux: optional, bool + has_aux : optional, bool Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. return_value : bool Whether return the loss value. - argnums: Optional, integer or sequence of integers. Specifies which + argnums : Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default ``0``). - Returns:: + Returns + ------- func : GradientTransform The vector gradient function. diff --git a/brainpy/math/object_transform/base.py b/brainpy/math/object_transform/base.py index 4c61425fd..df80b7880 100644 --- a/brainpy/math/object_transform/base.py +++ b/brainpy/math/object_transform/base.py @@ -154,17 +154,27 @@ def tracing_variable( handled by ``brainstate`` directly. Calling this method always raises :class:`NotImplementedError`. - Args: - name: str. The variable name. - init: callable, Array. The data to be initialized as a ``Variable``. - shape: int, sequence of int. The shape of the variable. - batch_or_mode: int, bool, Mode. The batch size of this variable. - batch_axis: int. The batch axis, if batch size is given. - axis_names: sequence of str. The name for each axis. - batch_axis_name: str. The name for the batch axis. - - Raises: - NotImplementedError: Always, because this feature is unsupported since 3.0.0. + Parameters + ---------- + name : str + The variable name. + init : callable, Array + The data to be initialized as a ``Variable``. + shape : int, sequence of int + The shape of the variable. + batch_or_mode : int, bool, Mode + The batch size of this variable. + batch_axis : int + The batch axis, if batch size is given. + axis_names : sequence of str + The name for each axis. + batch_axis_name : str + The name for the batch axis. + + Raises + ------ + NotImplementedError + Always, because this feature is unsupported since 3.0.0. """ raise NotImplementedError( 'Since 3.0.0, brainpy is rewritten with brainstate. The feature tracing_variable is no longer supported. ' @@ -175,9 +185,12 @@ def __setattr__(self, key: str, value: Any) -> None: .. versionadded:: 2.3.1 - Args: - key: str. The attribute. - value: Any. The value. + Parameters + ---------- + key : str + The attribute. + value : Any + The value. """ if key in self.__dict__: val = self.__dict__[key] @@ -193,10 +206,10 @@ def tree_flatten(self): .. versionadded:: 2.3.1 - Returns:: - - res: tuple - A tuple of dynamical values and static values. + Returns + ------- + res : tuple + A tuple of dynamical values and static values. """ dynamic_names = [] dynamic_values = [] @@ -309,21 +322,21 @@ def vars( ): """Collect all variables in this node and the children nodes. - Parameters:: - + Parameters + ---------- method : str - The method to access the variables. - level: int - The hierarchy level to find variables. - include_self: bool - Whether include the variables in the self. - exclude_types: tuple of type - The type to exclude. - - Returns:: - + The method to access the variables. + level : int + The hierarchy level to find variables. + include_self : bool + Whether include the variables in the self. + exclude_types : tuple of type + The type to exclude. + + Returns + ------- gather : ArrayCollector - The collection contained (the path, the variable). + The collection contained (the path, the variable). """ if exclude_types is None: exclude_types = (VariableView,) @@ -351,19 +364,19 @@ def vars( def train_vars(self, method='absolute', level=-1, include_self=True): """The shortcut for retrieving all trainable variables. - Parameters:: - + Parameters + ---------- method : str - The method to access the variables. Support 'absolute' and 'relative'. - level: int - The hierarchy level to find TrainVar instances. - include_self: bool - Whether include the TrainVar instances in the self. - - Returns:: - + The method to access the variables. Support 'absolute' and 'relative'. + level : int + The hierarchy level to find TrainVar instances. + include_self : bool + Whether include the TrainVar instances in the self. + + Returns + ------- gather : ArrayCollector - The collection contained (the path, the trainable variable). + The collection contained (the path, the trainable variable). """ return self.vars(method=method, level=level, include_self=include_self).subset(TrainVar) @@ -439,37 +452,37 @@ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _p def nodes(self, method='absolute', level=-1, include_self=True): """Collect all children nodes. - Parameters:: - + Parameters + ---------- method : str - The method to access the nodes. - level: int - The hierarchy level to find nodes. - include_self: bool - Whether include the self. - - Returns:: - + The method to access the nodes. + level : int + The hierarchy level to find nodes. + include_self : bool + Whether include the self. + + Returns + ------- gather : Collector - The collection contained (the path, the node). + The collection contained (the path, the node). """ return self._find_nodes(method=method, level=level, include_self=include_self) def unique_name(self, name=None, type_=None): """Get the unique name for this object. - Parameters:: - + Parameters + ---------- name : str, optional - The expected name. If None, the default unique name will be returned. - Otherwise, the provided name will be checked to guarantee its uniqueness. + The expected name. If None, the default unique name will be returned. + Otherwise, the provided name will be checked to guarantee its uniqueness. type_ : str, optional - The name of this class, used for object naming. - - Returns:: + The name of this class, used for object naming. + Returns + ------- name : str - The unique name for this object. + The unique name for this object. """ if name is None: if type_ is None: @@ -506,10 +519,10 @@ def __load_state__(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[ def state_dict(self, **kwargs) -> dict: """Returns a dictionary containing a whole state of the module. - Returns:: - - out: dict - A dictionary containing a whole state of the module. + Returns + ------- + out : dict + A dictionary containing a whole state of the module. """ nodes = self.nodes() # retrieve all nodes return {key: node.save_state(**kwargs) for key, node in nodes.items()} @@ -524,22 +537,22 @@ def load_state_dict( """Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. - Parameters:: - - state_dict: dict - A dict containing parameters and persistent buffers. - warn: bool - Warnings when there are missing keys or unexpected keys in the external ``state_dict``. - compatible: bool - The version of API for compatibility. - - Returns:: - - out: StateLoadResult - ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: - - * **missing_keys** is a list of str containing the missing keys - * **unexpected_keys** is a list of str containing the unexpected keys + Parameters + ---------- + state_dict : dict + A dict containing parameters and persistent buffers. + warn : bool + Warnings when there are missing keys or unexpected keys in the external ``state_dict``. + compatible : bool + The version of API for compatibility. + + Returns + ------- + out : StateLoadResult + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys """ if compatible == 'v1': variables = self.vars().unique() @@ -571,8 +584,10 @@ def load_state_dict( def to(self, device: Optional[Any]): """Moves all variables into the given device. - Args: - device: The device. + Parameters + ---------- + device : Optional[Any] + The device. """ # Iterate over the actual ``Variable`` instances (not the nested # ``state_dict`` mapping). Iterating ``state_dict()`` would yield @@ -617,16 +632,16 @@ def _add_node1(self, k, v, _paths, gather, nodes): class FunAsObject(BrainPyObject): """Transform a Python function as a :py:class:`~.BrainPyObject`. - Parameters:: - + Parameters + ---------- target : callable - The function to wrap. + The function to wrap. child_objs : optional, BrainPyObject, sequence of BrainPyObject, dict - The nodes in the defined function ``f``. + The nodes in the defined function ``f``. dyn_vars : optional, Variable, sequence of Variable, dict - The dynamically changed variables. + The dynamically changed variables. name : optional, str - The function name. + The function name. """ def __init__( diff --git a/brainpy/math/object_transform/collectors.py b/brainpy/math/object_transform/collectors.py index 7ca342e22..411e6f42b 100644 --- a/brainpy/math/object_transform/collectors.py +++ b/brainpy/math/object_transform/collectors.py @@ -59,14 +59,16 @@ def update(self, other, **kwargs): def __add__(self, other): """Merging two dicts. - Parameters:: + Parameters + ---------- - other: dict + other : dict The other dict instance. - Returns:: + Returns + ------- - gather: Collector + gather : Collector The new collector. """ gather = type(self)(self) @@ -76,14 +78,16 @@ def __add__(self, other): def __sub__(self, other: Union[Dict, Sequence]): """Remove other item in the collector. - Parameters:: + Parameters + ---------- - other: dict, sequence + other : dict, sequence The items to remove. - Returns:: + Returns + ------- - gather: Collector + gather : Collector The new collector. """ if not isinstance(other, (dict, tuple, list)): @@ -152,7 +156,8 @@ def subset(self, var_type): >>> # get all ODE integrators >>> some_collector.subset(bp.ode.ODEIntegrator) - Parameters:: + Parameters + ---------- var_type : type The type/class to match. diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py index dd3f0aeee..2fda05e16 100644 --- a/brainpy/math/object_transform/controls.py +++ b/brainpy/math/object_transform/controls.py @@ -115,36 +115,38 @@ def cond( >>> a, b Variable([1., 1.], dtype=float32), Variable([0., 0.], dtype=float32) - Parameters:: + Parameters + ---------- - pred: bool + pred : bool Boolean scalar type, indicating which branch function to apply. - true_fun: callable, ArrayType, float, int, bool + true_fun : callable, ArrayType, float, int, bool Function to be applied if ``pred`` is True. This function must receive one arguement for ``operands``. - false_fun: callable, ArrayType, float, int, bool + false_fun : callable, ArrayType, float, int, bool Function to be applied if ``pred`` is False. This function must receive one arguement for ``operands``. - operands: Any + operands : Any Operands (A) input to branching function depending on ``pred``. The type can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof. - dyn_vars: optional, Variable, sequence of Variable, dict + dyn_vars : optional, Variable, sequence of Variable, dict The dynamically changed variables. .. deprecated:: 2.4.0 No longer need to provide ``dyn_vars``. This function is capable of automatically collecting the dynamical variables used in the target ``func``. - child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject + child_objs : optional, dict, sequence of BrainPyObject, BrainPyObject The children objects used in the target function. .. deprecated:: 2.4.0 No longer need to provide ``dyn_vars``. This function is capable of automatically collecting the dynamical variables used in the target ``func``. - Returns:: + Returns + ------- - res: Any + res : Any The conditional results. """ if not isinstance(operands, (tuple, list)): @@ -177,7 +179,8 @@ def ifelse( ): """``If-else`` control flows looks like native Pythonic programming. - Examples:: + Examples + -------- >>> import brainpy.math as bm >>> def f(a): @@ -196,35 +199,37 @@ def ifelse( >>> f(3) 3 - Parameters:: + Parameters + ---------- - conditions: bool, sequence of bool + conditions : bool, sequence of bool The boolean conditions. - branches: Any + branches : Any The branches, at least has two elements. Elements can be functions, arrays, or numbers. The number of ``branches`` and ``conditions`` has the relationship of `len(branches) == len(conditions) + 1`. Each branch should receive one arguement for ``operands``. - operands: optional, Any + operands : optional, Any The operands for each branch. - show_code: bool + show_code : bool Whether show the formatted code. - dyn_vars: Variable, sequence of Variable, dict + dyn_vars : Variable, sequence of Variable, dict The dynamically changed variables. .. deprecated:: 2.4.0 No longer need to provide ``dyn_vars``. This function is capable of automatically collecting the dynamical variables used in the target ``func``. - child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject + child_objs : optional, dict, sequence of BrainPyObject, BrainPyObject The children objects used in the target function. .. deprecated:: 2.4.0 No longer need to provide ``dyn_vars``. This function is capable of automatically collecting the dynamical variables used in the target ``func``. - Returns:: + Returns + ------- - res: Any + res : Any The results of the control flow. """ if operands is None: @@ -338,28 +343,29 @@ def for_loop( [16.] [20.]] - Parameters:: + Parameters + ---------- - body_fun: callable + body_fun : callable A Python function to be scanned. This function accepts one argument and returns one output. The argument denotes a slice of ``operands`` along its leading axis, and that output represents a slice of the return value. - operands: Any + operands : Any The value over which to scan along the leading axis, where ``operands`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. If body function `body_func` receives multiple arguments, `operands` should be a tuple/list whose length is equal to the number of arguments. - reverse: bool + reverse : bool Optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both ``xs`` and in ``ys``. - unroll: int + unroll : int Optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. - jit: bool + jit : bool Whether to just-in-time compile the function. Set to ``False`` to disable JIT compilation. .. note:: @@ -367,7 +373,7 @@ def for_loop( manager. Consequently it has no effect when ``for_loop`` is called inside an enclosing trace (e.g. within another jitted/scanned function): JAX is already tracing, so the loop runs as a compiled ``scan`` regardless of this flag. - progress_bar: bool, ProgressBar, int + progress_bar : bool, ProgressBar, int Whether and how to display a progress bar during execution: - ``False`` (default): No progress bar @@ -396,13 +402,13 @@ def for_loop( .. versionadded:: 2.4.2 .. versionchanged:: 2.7.3 Now accepts ProgressBar instances and integers for advanced customization. - dyn_vars: Variable, sequence of Variable, dict + dyn_vars : Variable, sequence of Variable, dict The instances of :py:class:`~.Variable`. .. deprecated:: 2.4.0 No longer need to provide ``dyn_vars``. This function is capable of automatically collecting the dynamical variables used in the target ``func``. - child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject + child_objs : optional, dict, sequence of BrainPyObject, BrainPyObject The children objects used in the target function. .. versionadded:: 2.3.1 @@ -411,9 +417,10 @@ def for_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. - Returns:: + Returns + ------- - outs: Any + outs : Any The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs. """ if not isinstance(operands, (tuple, list)): @@ -479,35 +486,36 @@ def scan( All returns in body function will be gathered as the return of the whole loop. - Parameters:: + Parameters + ---------- - body_fun: callable + body_fun : callable A Python function to be scanned. This function accepts one argument and returns one output. The argument denotes a slice of ``operands`` along its leading axis, and that output represents a slice of the return value. - init: Any + init : Any An initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by ``body_fun``. - operands: Any + operands : Any The value over which to scan along the leading axis, where ``operands`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. If body function `body_func` receives multiple arguments, `operands` should be a tuple/list whose length is equal to the number of arguments. - remat: bool + remat : bool Make ``fun`` recompute internal linearization points when differentiated. - reverse: bool + reverse : bool Optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both ``xs`` and in ``ys``. - unroll: int + unroll : int Optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. - progress_bar: bool, ProgressBar, int + progress_bar : bool, ProgressBar, int Whether and how to display a progress bar during execution: - ``False`` (default): No progress bar @@ -521,9 +529,10 @@ def scan( .. versionchanged:: 2.7.3 Now accepts ProgressBar instances and integers for advanced customization. - Returns:: + Returns + ------- - outs: tuple + outs : tuple A two-element tuple ``(final_carry, stacked_ys)``: - ``final_carry``: the loop carry value returned by the last iteration of @@ -586,22 +595,23 @@ def while_loop( .. versionadded:: 2.1.11 - Parameters:: + Parameters + ---------- - body_fun: callable + body_fun : callable A function which define the updating logic. It receives one argument for ``operands``, without returns. - cond_fun: callable + cond_fun : callable A function which define the stop condition. It receives one argument for ``operands``, with one boolean value return. - operands: Any + operands : Any The operands for ``body_fun`` and ``cond_fun`` functions. - dyn_vars: Variable, sequence of Variable, dict + dyn_vars : Variable, sequence of Variable, dict The dynamically changed variables. .. deprecated:: 2.4.0 No longer need to provide ``dyn_vars``. This function is capable of automatically collecting the dynamical variables used in the target ``func``. - child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject + child_objs : optional, dict, sequence of BrainPyObject, BrainPyObject The children objects used in the target function. .. deprecated:: 2.4.0 diff --git a/brainpy/math/object_transform/function.py b/brainpy/math/object_transform/function.py index 631295879..46c5dc05b 100644 --- a/brainpy/math/object_transform/function.py +++ b/brainpy/math/object_transform/function.py @@ -90,20 +90,22 @@ def to_object( ): """Transform a Python function to :py:class:`~.BrainPyObject`. - Parameters:: + Parameters + ---------- - f: function, callable + f : function, callable The python function. - child_objs: Callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject + child_objs : Callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject The children objects used in this Python function. - dyn_vars: Variable, sequence of Variable, dict of Variable + dyn_vars : Variable, sequence of Variable, dict of Variable The `Variable` instance used in the Python function. - name: str + name : str The name of the created ``BrainPyObject``. - Returns:: + Returns + ------- - func: FunAsObject + func : FunAsObject The instance of ``BrainPyObject``. """ @@ -130,20 +132,22 @@ def function( .. deprecated:: 2.3.0 Using :py:func:`~.to_object` instead. - Parameters:: + Parameters + ---------- - f: function, callable + f : function, callable The python function. - nodes: Callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject + nodes : Callable, BrainPyObject, sequence of BrainPyObject, dict of BrainPyObject The children objects used in this Python function. - dyn_vars: Variable, sequence of Variable, dict of Variable + dyn_vars : Variable, sequence of Variable, dict of Variable The `Variable` instance used in the Python function. - name: str + name : str The name of the created ``BrainPyObject``. - Returns:: + Returns + ------- - func: FunAsObject + func : FunAsObject The instance of ``BrainPyObject``. """ warnings.warn('`brainpy.math.function()` is deprecated; use `brainpy.math.to_object()` instead. ' diff --git a/brainpy/math/object_transform/jit.py b/brainpy/math/object_transform/jit.py index 0861000a3..800e115e6 100644 --- a/brainpy/math/object_transform/jit.py +++ b/brainpy/math/object_transform/jit.py @@ -99,7 +99,8 @@ def jit( but it can also JIT compile a :py:class:`brainpy.DynamicalSystem`, or a :py:class:`brainpy.BrainPyObject` object. - Examples:: + Examples + -------- You can JIT any object in which all dynamical variables are defined as :py:class:`~.Variable`. @@ -122,11 +123,13 @@ def jit( >>> return lmbda * bp.math.where(x > 0, x, alpha * bp.math.exp(x) - alpha) - Parameters:: + Parameters + ---------- {jit_par} - Returns:: + Returns + ------- func : JITTransform A callable jitted function, set up for just-in-time compilation. @@ -169,7 +172,8 @@ def cls_jit( ) -> Callable: """Just-in-time compile a function and then the jitted function as the bound method for a class. - Examples:: + Examples + -------- This transformation can be put on any class function. For example, @@ -191,11 +195,13 @@ def cls_jit( >>> program = SomeProgram() >>> program() - Parameters:: + Parameters + ---------- {jit_pars} - Returns:: + Returns + ------- func : JITTransform A callable jitted function, set up for just-in-time compilation. diff --git a/brainpy/math/object_transform/variables.py b/brainpy/math/object_transform/variables.py index 74543abfe..797c4e1f7 100644 --- a/brainpy/math/object_transform/variables.py +++ b/brainpy/math/object_transform/variables.py @@ -52,11 +52,16 @@ class Variable(brainstate.State, Array): Note that when initializing a `Variable` by the data shape, all values in this `Variable` will be initialized as zeros. - Args: - value_or_size: Shape, Array, int. The value or the size of the value. - dtype: Any. The type of the data. - batch_axis: optional, int. The batch axis. - axis_names: sequence of str. The name for each axis. + Parameters + ---------- + value_or_size : Shape, Array, int + The value or the size of the value. + dtype : Any + The type of the data. + batch_axis : optional, int + The batch axis. + axis_names : sequence of str + The name for each axis. """ def __init__( diff --git a/brainpy/math/others.py b/brainpy/math/others.py index 63a349eb4..ffc09f959 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -45,23 +45,23 @@ def shared_args_over_time(num_step: Optional[int] = None, include_dt: bool = True): """Form a shared argument over time for the inference of a :py:class:`~.DynamicalSystem`. - Parameters:: - - num_step: int - The number of time step. Provide either ``duration`` or ``num_step``. - duration: float - The total duration. Provide either ``duration`` or ``num_step``. - dt: float - The duration for each time step. - t0: float - The start time. - include_dt: bool - Produce the time steps at every time step. - - Returns:: - - shared: DotDict - The shared arguments over the given time. + Parameters + ---------- + num_step : int + The number of time step. Provide either ``duration`` or ``num_step``. + duration : float + The total duration. Provide either ``duration`` or ``num_step``. + dt : float + The duration for each time step. + t0 : float + The start time. + include_dt : bool + Produce the time steps at every time step. + + Returns + ------- + shared : DotDict + The shared arguments over the given time. """ dt = get_dt() if dt is None else dt check.is_float(dt, 'dt', allow_none=False) @@ -80,15 +80,15 @@ def shared_args_over_time(num_step: Optional[int] = None, def remove_diag(arr): """Remove the diagonal of the matrix. - Parameters:: + Parameters + ---------- + arr : ArrayType + The matrix with the shape of `(M, N)`. - arr: ArrayType - The matrix with the shape of `(M, N)`. - - Returns:: - - arr: Array - The matrix without diagonal which has the shape of `(M, N-1)`. + Returns + ------- + arr : Array + The matrix without diagonal which has the shape of `(M, N-1)`. """ if arr.ndim != 2: raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') @@ -141,12 +141,16 @@ def exprel(x, threshold: float = None): suffer from catastrophic loss of precision. ``exprel(x)`` is implemented to avoid the loss of precision that occurs when ``x`` is near zero. - Args: - x: ndarray. Input array. ``x`` must contain real numbers. - threshold: float. + Parameters + ---------- + x + ndarray. Input array. ``x`` must contain real numbers. + threshold : float + float. - Returns: - ``(exp(x) - 1)/x``, computed element-wise. + Returns + ------- + ``(exp(x) - 1)/x``, computed element-wise. """ x = as_jax(x) if threshold is None: diff --git a/brainpy/math/pre_syn_post.py b/brainpy/math/pre_syn_post.py index 0b6bcff80..40056206d 100644 --- a/brainpy/math/pre_syn_post.py +++ b/brainpy/math/pre_syn_post.py @@ -85,20 +85,22 @@ def pre2post_event_sum(events, post_val[post_ids[j]] += values[j] - Parameters:: + Parameters + ---------- - events: ArrayType + events : ArrayType The events, must be bool. - pre2post: tuple of ArrayType, tuple of ArrayType + pre2post : tuple of ArrayType, tuple of ArrayType A tuple contains the connection information of pre-to-post. - post_num: int + post_num : int The number of post-synaptic group. - values: float, ArrayType + values : float, ArrayType The value to make summation. - Returns:: + Returns + ------- - out: ArrayType + out : ArrayType A tensor with the shape of ``post_num``. """ indices, idnptr = pre2post @@ -126,20 +128,22 @@ def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None): for i, j in zip(pre_ids, post_ids): post_val[j] += pre_values[pre_ids[i]] - Parameters:: + Parameters + ---------- - pre_values: float, ArrayType + pre_values : float, ArrayType The pre-synaptic values. - post_ids: ArrayType + post_ids : ArrayType The connected post-synaptic neuron ids. - post_num: int + post_num : int Output dimension. The number of post-synaptic neurons. - pre_ids: optional, ArrayType + pre_ids : optional, ArrayType The connected pre-synaptic neuron ids. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) @@ -164,20 +168,22 @@ def pre2post_prod(pre_values, post_num, post_ids, pre_ids=None): for i, j in zip(pre_ids, post_ids): post_val[j] *= pre_values[pre_ids[i]] - Parameters:: + Parameters + ---------- - pre_values: float, ArrayType + pre_values : float, ArrayType The pre-synaptic values. - pre_ids: ArrayType + pre_ids : ArrayType The connected pre-synaptic neuron ids. - post_ids: ArrayType + post_ids : ArrayType The connected post-synaptic neuron ids. - post_num: int + post_num : int Output dimension. The number of post-synaptic neurons. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) @@ -202,20 +208,22 @@ def pre2post_min(pre_values, post_num, post_ids, pre_ids=None): for i, j in zip(pre_ids, post_ids): post_val[j] = np.minimum(post_val[j], pre_values[pre_ids[i]]) - Parameters:: + Parameters + ---------- - pre_values: float, ArrayType + pre_values : float, ArrayType The pre-synaptic values. - pre_ids: ArrayType + pre_ids : ArrayType The connected pre-synaptic neuron ids. - post_ids: ArrayType + post_ids : ArrayType The connected post-synaptic neuron ids. - post_num: int + post_num : int Output dimension. The number of post-synaptic neurons. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) @@ -240,20 +248,22 @@ def pre2post_max(pre_values, post_num, post_ids, pre_ids=None): for i, j in zip(pre_ids, post_ids): post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]]) - Parameters:: + Parameters + ---------- - pre_values: float, ArrayType + pre_values : float, ArrayType The pre-synaptic values. - pre_ids: ArrayType + pre_ids : ArrayType The connected pre-synaptic neuron ids. - post_ids: ArrayType + post_ids : ArrayType The connected post-synaptic neuron ids. - post_num: int + post_num : int Output dimension. The number of post-synaptic neurons. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) @@ -269,23 +279,26 @@ def pre2post_max(pre_values, post_num, post_ids, pre_ids=None): def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic mean computation. - Parameters:: + Parameters + ---------- - pre_values: float, ArrayType + pre_values : float, ArrayType The pre-synaptic values. - pre_ids: ArrayType + pre_ids : ArrayType The connected pre-synaptic neuron ids. - post_ids: ArrayType + post_ids : ArrayType The connected post-synaptic neuron ids. - post_num: int + post_num : int Output dimension. The number of post-synaptic neurons. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The value with the size of post-synaptic neurons. - Notes:: + Notes + ----- When ``pre_values`` is a scalar, every connection carries the same constant value, so the per-post mean is simply that constant. In this case the function @@ -326,16 +339,18 @@ def pre2syn(pre_values, pre_ids): for syn_i, pre_i in enumerate(pre_ids): syn_val[i] = pre_values[pre_i] - Parameters:: + Parameters + ---------- - pre_values: float, ArrayType + pre_values : float, ArrayType The pre-synaptic value. - pre_ids: ArrayType + pre_ids : ArrayType The pre-synaptic neuron index. - Returns:: + Returns + ------- - syn_val: ArrayType + syn_val : ArrayType The synaptic value. """ pre_values = as_jax(pre_values) @@ -364,18 +379,20 @@ def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=False): for syn_i, post_i in enumerate(post_ids): post_val[post_i] += syn_values[syn_i] - Parameters:: + Parameters + ---------- - syn_values: ArrayType + syn_values : ArrayType The synaptic values. - post_ids: ArrayType + post_ids : ArrayType The post-synaptic neuron ids. - post_num: int + post_num : int The number of the post-synaptic neurons. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) @@ -400,22 +417,24 @@ def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=False) for syn_i, post_i in enumerate(post_ids): post_val[post_i] *= syn_values[syn_i] - Parameters:: + Parameters + ---------- - syn_values: ArrayType + syn_values : ArrayType The synaptic values. - post_ids: ArrayType + post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. - post_num: int + post_num : int The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. + indices_are_sorted : whether ``post_ids`` is known to be sorted. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) @@ -437,22 +456,24 @@ def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=False): for syn_i, post_i in enumerate(post_ids): post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i]) - Parameters:: + Parameters + ---------- - syn_values: ArrayType + syn_values : ArrayType The synaptic values. - post_ids: ArrayType + post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. - post_num: int + post_num : int The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. + indices_are_sorted : whether ``post_ids`` is known to be sorted. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) @@ -474,22 +495,24 @@ def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=False): for syn_i, post_i in enumerate(post_ids): post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) - Parameters:: + Parameters + ---------- - syn_values: ArrayType + syn_values : ArrayType The synaptic values. - post_ids: ArrayType + post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. - post_num: int + post_num : int The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. + indices_are_sorted : whether ``post_ids`` is known to be sorted. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) @@ -502,22 +525,24 @@ def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=False): def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post mean computation. - Parameters:: + Parameters + ---------- - syn_values: ArrayType + syn_values : ArrayType The synaptic values. - post_ids: ArrayType + post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. - post_num: int + post_num : int The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. + indices_are_sorted : whether ``post_ids`` is known to be sorted. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) @@ -534,22 +559,24 @@ def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False) def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=False): """The syn-to-post softmax computation. - Parameters:: + Parameters + ---------- - syn_values: ArrayType + syn_values : ArrayType The synaptic values. - post_ids: ArrayType + post_ids : ArrayType The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. - post_num: int + post_num : int The number of the post-synaptic neurons. - indices_are_sorted: whether ``post_ids`` is known to be sorted. + indices_are_sorted : whether ``post_ids`` is known to be sorted. - Returns:: + Returns + ------- - post_val: ArrayType + post_val : ArrayType The post-synaptic value. """ post_ids = as_jax(post_ids) diff --git a/brainpy/math/scales.py b/brainpy/math/scales.py index 5fd931906..2ec82f521 100644 --- a/brainpy/math/scales.py +++ b/brainpy/math/scales.py @@ -34,12 +34,16 @@ def transform( ) -> 'Scaling': """Transform the membrane potential range to a ``Scaling`` instance. - Args: - V_range: [V_min, V_max] - scaled_V_range: [scaled_V_min, scaled_V_max] - - Returns: - The instanced scaling object. + Parameters + ---------- + V_range : Sequence[Union[float, int]] + [V_min, V_max] + scaled_V_range : Sequence[Union[float, int]] + [scaled_V_min, scaled_V_max] + + Returns + ------- + The instanced scaling object. """ V_min, V_max = V_range scaled_V_min, scaled_V_max = scaled_V_range diff --git a/brainpy/math/sharding.py b/brainpy/math/sharding.py index 3ef62e2c6..b1fad5b89 100644 --- a/brainpy/math/sharding.py +++ b/brainpy/math/sharding.py @@ -86,12 +86,16 @@ def _device_put(x: Union[Array, jax.Array, np.ndarray], Note that this function can only transfer ``brainpy.math.Array``, ``jax.Array``, and ``numpy.ndarray``. Other value will be directly returned. - Args: - x: The input array. - device: The given device. - - Returns: - A copy of ``x`` that resides on ``device``. + Parameters + ---------- + x : Union[Array, jax.Array, np.ndarray] + The input array. + device : Union[None, jax.Device, Sharding] + The given device. + + Returns + ------- + A copy of ``x`` that resides on ``device``. """ if isinstance(x, Array): x.value = jax.device_put(x.value, device=device) @@ -110,12 +114,16 @@ def get_sharding( ) -> Optional[NamedSharding]: """Get sharding according to the given axes information. - Args: - axis_names: list of str, or tuple of str. The name for each axis in the array. - mesh: Mesh. The given device mesh. + Parameters + ---------- + axis_names : Optional[Sequence[str]] + list of str, or tuple of str. The name for each axis in the array. + mesh : Optional[Mesh] + Mesh. The given device mesh. - Returns: - The instance of NamedSharding. + Returns + ------- + The instance of NamedSharding. """ if axis_names is None: return None @@ -147,13 +155,18 @@ def partition_by_axname( ): """Put the given arrays into the mesh devices. - Args: - x: any. Any array. - axis_names: sequence of str. The name for each axis in the array. - mesh: Mesh. The given device mesh. - - Returns: - The re-sharded arrays. + Parameters + ---------- + x : Any + any. Any array. + axis_names : Optional[Sequence[str]] + sequence of str. The name for each axis in the array. + mesh : Optional[Mesh] + Mesh. The given device mesh. + + Returns + ------- + The re-sharded arrays. """ if axis_names is None: return x @@ -181,12 +194,16 @@ def partition_by_sharding( ): """Partition inputs with the given sharding strategy. - Args: - x: The input arrays. It can be a pyTree of arrays. - sharding: The `jax.sharding.Sharding` instance. + Parameters + ---------- + x : Any + The input arrays. It can be a pyTree of arrays. + sharding : Optional[Sharding] + The `jax.sharding.Sharding` instance. - Returns: - The sharded ``x``, which has been partitioned by the given sharding stragety. + Returns + ------- + The sharded ``x``, which has been partitioned by the given sharding stragety. """ if sharding is None: return x @@ -204,13 +221,17 @@ def partition( ): """Partition the input arrays onto devices by the given sharding strategies. - Args: - x: Any input arrays. It can also be a PyTree of arrays. - sharding: The sharding strategy. - - Returns: - The partitioned arrays. - Notably, the + Parameters + ---------- + x : Any + Any input arrays. It can also be a PyTree of arrays. + sharding : Optional[Union[Sequence[str], jax.Device, Sharding]] + The sharding strategy. + + Returns + ------- + The partitioned arrays. + Notably, the """ if sharding is None: return x @@ -239,10 +260,14 @@ def _keep_constraint(x: Any): def keep_constraint(x: Any): """Keep the sharding constraint of the given inputs during computation. - Args: - x: Any. + Parameters + ---------- + x : Any + Any. - Returns: - constraint_x: Same as ``x``. + Returns + ------- + constraint_x + Same as ``x``. """ return jax.tree_util.tree_map(_keep_constraint, x, is_leaf=is_bp_array) diff --git a/brainpy/math/sparse/csr_mm.py b/brainpy/math/sparse/csr_mm.py index 97c212f47..4da5d54d0 100644 --- a/brainpy/math/sparse/csr_mm.py +++ b/brainpy/math/sparse/csr_mm.py @@ -37,18 +37,20 @@ def csrmm( """ Product of CSR sparse matrix and a dense matrix. - Args: - data : array of shape ``(nse,)``. - indices : array of shape ``(nse,)`` - indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` - B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + Parameters + ---------- + data : array of shape ``(nse,)``. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and dtype ``data.dtype`` - shape : length-2 tuple representing the matrix shape - transpose : boolean specifying whether to transpose the sparse matrix + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix before computing. - Returns: - C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + Returns + ------- + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product. """ if isinstance(data, Array): diff --git a/brainpy/math/sparse/csr_mv.py b/brainpy/math/sparse/csr_mv.py index 039a7229f..31590284e 100644 --- a/brainpy/math/sparse/csr_mv.py +++ b/brainpy/math/sparse/csr_mv.py @@ -39,24 +39,26 @@ def csrmv( This function supports JAX transformations, including `jit()`, `grad()`, `vmap()` and `pmap()`. - Parameters:: + Parameters + ---------- - data: ndarray, float + data : ndarray, float An array of shape ``(nse,)``. - indices: ndarray + indices : ndarray An array of shape ``(nse,)``. - indptr: ndarray + indptr : ndarray An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray + vector : ndarray An array of shape ``(shape[0] if transpose else shape[1],)`` and dtype ``data.dtype``. - shape: tuple of int + shape : tuple of int A length-2 tuple representing the matrix shape. - transpose: bool + transpose : bool A boolean specifying whether to transpose the sparse matrix before computing. - Returns:: + Returns + ------- y : ndarry The array of shape ``(shape[1] if transpose else shape[0],)`` representing diff --git a/brainpy/math/sparse/jax_prim.py b/brainpy/math/sparse/jax_prim.py index 0960f6b8b..4ac9df38e 100644 --- a/brainpy/math/sparse/jax_prim.py +++ b/brainpy/math/sparse/jax_prim.py @@ -35,14 +35,16 @@ def _matmul_with_left_sparse( Y = M_{\mathrm{sparse}} @ M_{\mathrm{dense}} - Parameters:: + Parameters + ---------- - sparse: dict + sparse : dict The sparse matrix with shape of :math:`(N, M)`. - dense: ArrayType + dense : ArrayType The dense matrix with the shape of :math:`(M, K)`. - Returns:: + Returns + ------- matrix A tensor the the shape of :math:`(N, K)`. @@ -75,14 +77,16 @@ def _matmul_with_right_sparse( Y = M_{\mathrm{dense}} @ M_{\mathrm{sparse}} - Parameters:: + Parameters + ---------- - dense: ArrayType + dense : ArrayType The dense matrix with the shape of :math:`(N, M)`. - sparse: dict + sparse : dict The sparse matrix with shape of :math:`(M, K)`. - Returns:: + Returns + ------- matrix A tensor the the shape of :math:`(N, K)`. @@ -117,7 +121,8 @@ def seg_matmul(A, B): where :math:`A` or :math:`B` is a sparse matrix. :math:`A` and :math:`B` cannot be both sparse. - Examples:: + Examples + -------- >>> import brainpy.math as bm @@ -150,16 +155,18 @@ def seg_matmul(A, B): ArrayType([[0.438388 , 1.4346815 , 0. , 2.361964 ], [0.9171978 , 1.1214957 , 0. , 0.90534496]], dtype=float32) - Parameters:: + Parameters + ---------- - A: tensor, sequence + A : tensor, sequence The dense or sparse matrix with the shape of :math:`(N, M)`. - B: tensor, sequence + B : tensor, sequence The dense or sparse matrix with the shape of :math:`(M, K)`. - Returns:: + Returns + ------- - results: ArrayType + results : ArrayType The tensor with the shape of :math:`(N, K)`. """ if isinstance(A, dict): diff --git a/brainpy/measure.py b/brainpy/measure.py index 4afd2a3a5..c3611a1be 100644 --- a/brainpy/measure.py +++ b/brainpy/measure.py @@ -36,14 +36,16 @@ def raster_plot(sp_matrix, times): """Get spike raster plot which displays the spiking activity of a group of neurons over time. - Parameters:: + Parameters + ---------- sp_matrix : bnp.ndarray The matrix which record spiking activities. times : bnp.ndarray The time steps. - Returns:: + Returns + ------- raster_plot : tuple Include (neuron index, spike time). @@ -68,7 +70,8 @@ def firing_rate(spikes, width, dt=None, numpy=True): v_k = {n_k^{sp} \over T} - Parameters:: + Parameters + ---------- spikes : ndarray The spike matrix which record spiking activities. @@ -76,11 +79,12 @@ def firing_rate(spikes, width, dt=None, numpy=True): The width of the ``window`` in millisecond. dt : float, optional The sample rate. - numpy: bool + numpy : bool Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled. - Returns:: + Returns + ------- rate : ndarray The population rate in Hz, smoothed with the given window. diff --git a/brainpy/mixin.py b/brainpy/mixin.py index d9afc05d5..88de4229d 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -318,8 +318,10 @@ def add_elem(self, *elems, **elements): >>> obj = Container() >>> obj.add_elem(a=1.) - Args: - elements: children objects. + Parameters + ---------- + elements + children objects. """ self.children.update(self.format_elements(object, *elems, **elements)) @@ -373,12 +375,17 @@ class SupportInputProj(MixIn): def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'): """Add an input function. - Args: - key: str. The dict key. - fun: Callable. The function to generate inputs. - label: str. The input label. - category: str. The input category, should be ``current`` (the current) or - ``delta`` (the delta synapse, indicating the delta function). + Parameters + ---------- + key : str + The dict key. + fun : Callable + The function to generate inputs. + label : str + The input label. + category : str + The input category, should be ``current`` (the current) or + ``delta`` (the delta synapse, indicating the delta function). """ if not callable(fun): raise TypeError('Must be a function.') @@ -398,11 +405,14 @@ def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, cate def get_inp_fun(self, key: str): """Get the input function. - Args: - key: str. The key. + Parameters + ---------- + key : str + The key. - Returns: - The input function which generates currents. + Returns + ------- + The input function which generates currents. """ if key in self.current_inputs: return self.current_inputs[key] @@ -414,14 +424,20 @@ def get_inp_fun(self, key: str): def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): """Summarize all current inputs by the defined input functions ``.current_inputs``. - Args: - *args: The arguments for input functions. - init: The initial input data. - label: str. The input label. - **kwargs: The arguments for input functions. + Parameters + ---------- + *args + The arguments for input functions. + init + The initial input data. + label : str + The input label. + **kwargs + The arguments for input functions. - Returns: - The total currents. + Returns + ------- + The total currents. """ if label is None: for key, out in self.current_inputs.items(): @@ -436,14 +452,20 @@ def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): """Summarize all delta inputs by the defined input functions ``.delta_inputs``. - Args: - *args: The arguments for input functions. - init: The initial input data. - label: str. The input label. - **kwargs: The arguments for input functions. + Parameters + ---------- + *args + The arguments for input functions. + init + The initial input data. + label : str + The input label. + **kwargs + The arguments for input functions. - Returns: - The total currents. + Returns + ------- + The total currents. """ if label is None: for key, out in self.delta_inputs.items(): diff --git a/brainpy/optim/optimizer.py b/brainpy/optim/optimizer.py index 3a6871cd7..25f9505f3 100644 --- a/brainpy/optim/optimizer.py +++ b/brainpy/optim/optimizer.py @@ -42,9 +42,10 @@ class Optimizer(BrainPyObject): """Base Optimizer Class. - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. """ @@ -110,9 +111,10 @@ class SGD(CommonOpt): \theta = \theta - \eta \cdot \nabla_\theta J(\theta; x; y) - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. """ @@ -165,12 +167,14 @@ class Momentum(CommonOpt): \end{split} \end{align} - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - References:: + References + ---------- .. [1] Qian, N. (1999). On the momentum term in gradient descent learning algorithms. Neural Networks : The Official Journal of the International @@ -231,12 +235,14 @@ class MomentumNesterov(CommonOpt): \end{split} \end{align} - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - References:: + References + ---------- .. [2] Nesterov, Y. (1983). A method for unconstrained convex minimization problem with the rate of convergence o(1/k2). Doklady ANSSSR (translated as Soviet.Math.Docl.), vol. 269, pp. 543– 547. @@ -303,12 +309,14 @@ class Adagrad(CommonOpt): This in turn causes the learning rate to shrink and eventually become infinitesimally small, at which point the algorithm is no longer able to acquire additional knowledge. - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - References:: + References + ---------- .. [3] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. Journal of Machine Learning Research, 12, 2121–2159. Retrieved from http://jmlr.org/papers/v12/duchi11a.html @@ -387,12 +395,14 @@ class Adadelta(CommonOpt): keep it at this value. epsilon is important for the very first update (so the numerator does not become 0). - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - References:: + References + ---------- .. [4] Zeiler, M. D. (2012). ADADELTA: An Adaptive Learning Rate Method. Retrieved from http://arxiv.org/abs/1212.5701 @@ -466,12 +476,14 @@ class RMSProp(CommonOpt): The centered version additionally maintains a moving average of the gradients, and uses that average to estimate the variance. - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - References:: + References + ---------- .. [5] Tieleman, T. and Hinton, G. (2012): Neural Networks for Machine Learning, Lecture 6.5 - rmsprop. @@ -530,23 +542,25 @@ class Adam(CommonOpt): individual adaptive learning rates for different parameters from estimates of first- and second-order moments of the gradients. - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - beta1: optional, float + beta1 : optional, float A positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9). - beta2: optional, float + beta2 : optional, float A positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999). - eps: optional, float + eps : optional, float A positive scalar value for epsilon, a small constant for numerical stability (default 1e-8). name : optional, str The optimizer name. - References:: + References + ---------- .. [6] Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980. """ @@ -632,20 +646,22 @@ class LARS(CommonOpt): m_{t} = \beta_{1}m_{t-1} + \left(1-\beta_{1}\right)\left(g_{t} + \lambda{x_{t}}\right) \\ x_{t+1}^{\left(i\right)} = x_{t}^{\left(i\right)} - \eta_{t}\frac{\phi\left(|| x_{t}^{\left(i\right)} ||\right)}{|| m_{t}^{\left(i\right)} || }m_{t}^{\left(i\right)} - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - momentum: float + momentum : float coefficient used for the moving average of the gradient. - weight_decay: float + weight_decay : float weight decay coefficient. - tc: float + tc : float trust coefficient eta ( < 1) for trust ratio computation. - eps: float + eps : float epsilon used for trust ratio computation. - References:: + References + ---------- .. [1] You, Yang, Igor Gitman and Boris Ginsburg. “Large Batch Training of Convolutional Networks.” arXiv: Computer Vision and Pattern Recognition (2017): n. pag. """ @@ -712,9 +728,10 @@ class Adan(CommonOpt): \end{aligned} \end{equation} - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. Can be much higher than Adam, up to 5-10x. (default: 1e-3) betas : tuple Coefficients used for computing running averages of gradient and its norm. (default: (0.02, 0.08, 0.01)) @@ -722,7 +739,7 @@ class Adan(CommonOpt): The term added to the denominator to improve numerical stability. (default: 1e-8) weight_decay : float decoupled weight decay (L2 penalty) (default: 0) - no_prox: bool + no_prox : bool how to perform the decoupled weight decay (default: False). It determines the update rule of parameters with weight decay. By default, Adan updates the parameters in the way presented in Algorithm 1 in the paper: @@ -735,7 +752,8 @@ class Adan(CommonOpt): .. math:: \boldsymbol{\theta}_{k+1} = ( 1-\lambda \eta)\boldsymbol{\theta}_k - \boldsymbol{\eta}_k \circ (\mathbf{m}_k+(1-{\color{blue}\beta_2})\mathbf{v}_k). - References:: + References + ---------- .. [1] Xie, Xingyu, Pan Zhou, Huan Li, Zhouchen Lin and Shuicheng Yan. “Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing @@ -893,29 +911,31 @@ class AdamW(CommonOpt): \end{aligned} - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - beta1: optional, float + beta1 : optional, float A positive scalar value for beta_1, the exponential decay rate for the first moment estimates. Generally close to 1. - beta2: optional, float + beta2 : optional, float A positive scalar value for beta_2, the exponential decay rate for the second moment estimates. Generally close to 1. - eps: optional, float + eps : optional, float A positive scalar value for epsilon, a small constant for numerical stability. - weight_decay: float + weight_decay : float Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. - amsgrad: bool + amsgrad : bool whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`. name : optional, str The optimizer name. - References:: + References + ---------- .. [1] Loshchilov, Ilya and Frank Hutter. “Decoupled Weight Decay Regularization.” International Conference on Learning Representations (2019). @@ -1031,21 +1051,23 @@ class SM3(CommonOpt): momentum, SM3 will use just over half as much memory as Adam, and a bit more than Adagrad. - Parameters:: + Parameters + ---------- - lr: float, Scheduler + lr : float, Scheduler learning rate. - momentum: float + momentum : float coefficient used to scale prior updates before adding. This drastically increases memory usage if `momentum > 0.0`. (default: 0.0) - beta: float + beta : float coefficient used for exponential moving averages (default: 0.0) - eps: float + eps : float Term added to square-root in denominator to improve numerical stability (default: 1e-30). - References:: + References + ---------- .. [1] Anil, Rohan, Vineet Gupta, Tomer Koren and Yoram Singer. “Memory Efficient Adaptive Optimization.” Neural Information Processing Systems (2019). diff --git a/brainpy/optim/scheduler.py b/brainpy/optim/scheduler.py index 2ae0d307e..99898beb9 100644 --- a/brainpy/optim/scheduler.py +++ b/brainpy/optim/scheduler.py @@ -88,16 +88,17 @@ class StepLR(Scheduler): """Decays the learning rate of each parameter group by gamma every `step_size` epochs. - Parameters:: + Parameters + ---------- - lr: float + lr : float Initial learning rate. - step_size: int + step_size : int Period of learning rate decay. - gamma: float + gamma : float Multiplicative factor of learning rate decay. Default: 0.1. - last_epoch: int + last_epoch : int The index of last epoch. Default: -1. """ @@ -129,16 +130,17 @@ class MultiStepLR(Scheduler): happen simultaneously with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. - Parameters:: + Parameters + ---------- - lr: float + lr : float Initial learning rate. - milestones: sequence of int + milestones : sequence of int List of epoch indices. Must be increasing. - gamma: float + gamma : float Multiplicative factor of learning rate decay. Default: 0.1. - last_epoch: int + last_epoch : int The index of last epoch. Default: -1. """ @@ -201,15 +203,16 @@ class CosineAnnealingLR(Scheduler): `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only implements the cosine annealing part of SGDR, and not the restarts. - Parameters:: + Parameters + ---------- - lr: float + lr : float Initial learning rate. - T_max: int + T_max : int Maximum number of iterations. - eta_min: float + eta_min : float Minimum learning rate. Default: 0. - last_epoch: int + last_epoch : int The index of last epoch. Default: -1. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: @@ -250,20 +253,21 @@ class CosineAnnealingWarmRestarts(CallBasedScheduler): It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. - Parameters:: + Parameters + ---------- - lr: float + lr : float Initial learning rate. - num_call_per_epoch: int + num_call_per_epoch : int The number the scheduler to call in each epoch. This usually means the number of batch in each epoch training. - T_0: int + T_0 : int Number of iterations for the first restart. - T_mult: int + T_mult : int A factor increases :math:`T_{i}` after a restart. Default: 1. - eta_min: float + eta_min : float Minimum learning rate. Default: 0. - last_call: int + last_call : int The index of last call. Default: -1. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: @@ -322,13 +326,14 @@ class ExponentialLR(Scheduler): """Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr. - Parameters:: + Parameters + ---------- - lr: float + lr : float Initial learning rate. - gamma: float + gamma : float Multiplicative factor of learning rate decay. - last_epoch: int + last_epoch : int The index of last epoch. Default: -1. """ diff --git a/brainpy/runners.py b/brainpy/runners.py index 6ff421d83..9418a20a3 100644 --- a/brainpy/runners.py +++ b/brainpy/runners.py @@ -60,14 +60,16 @@ def _is_brainpy_array(x): def check_and_format_inputs(host, inputs): """Check inputs and get the formatted inputs for the given population. - Parameters:: + Parameters + ---------- host : DynamicalSystem The host which contains all data. inputs : tuple, list The inputs of the population. - Returns:: + Returns + ------- formatted_inputs : tuple, list The formatted inputs of the population. @@ -210,7 +212,8 @@ def _f_ops(ops, var, data): class DSRunner(Runner): """The runner for :py:class:`~.DynamicalSystem`. - Parameters:: + Parameters + ---------- target : DynamicalSystem The target model to run. @@ -249,14 +252,14 @@ class DSRunner(Runner): .. versionchanged:: 2.3.1 ``fun_inputs`` are merged into ``inputs``. - fun_inputs: callable + fun_inputs : callable The functional inputs. Manually specify the inputs for the target variables. This input function should receive one argument ``shared`` which contains the shared arguments like time ``t``, time step ``dt``, and index ``i``. .. deprecated:: 2.3.1 Will be removed since version 2.4.0. - monitors: Optional, sequence of str, dict, Monitor + monitors : Optional, sequence of str, dict, Monitor Variables to monitor. - A list of string. Like ``monitors=['a', 'b', 'c']``. @@ -267,7 +270,7 @@ class DSRunner(Runner): .. versionchanged:: 2.3.1 ``fun_monitors`` are merged into ``monitors``. - fun_monitors: dict + fun_monitors : dict Monitoring variables by a dict of callable functions. The dict ``key`` should be a string for the later retrieval by ``runner.mon[key]``. The dict ``value`` should be a callable function which receives two arguments: ``t`` and ``dt``. @@ -277,15 +280,15 @@ class DSRunner(Runner): .. deprecated:: 2.3.1 Will be removed since version 2.4.0. - jit: bool, dict + jit : bool, dict The JIT settings. Using dict is able to set the jit mode at different phase, for instance, ``jit={'predict': True, 'fit': False}``. - progress_bar: bool + progress_bar : bool Use progress bar to report the running progress or not? - dyn_vars: Optional, dict + dyn_vars : Optional, dict The dynamically changed variables. Instance of :py:class:`~.Variable`. These variables together with variable retrieved from the ``target`` constitute all dynamical variables in this runner. @@ -293,7 +296,7 @@ class DSRunner(Runner): numpy_mon_after_run : bool When finishing the network running, transform the JAX arrays into numpy ndarray or not? - data_first_axis: str + data_first_axis : str Set the default data dimension arrangement. To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the time length (``data_first_axis='T'``). @@ -301,7 +304,7 @@ class DSRunner(Runner): .. versionadded:: 2.3.1 - memory_efficient: bool + memory_efficient : bool Whether using the memory-efficient way to just-in-time compile the given target. Default is False. @@ -412,12 +415,13 @@ def predict( Moreover, it can automatically monitor the node variables, states, inputs, and its output. - Parameters:: + Parameters + ---------- - duration: float + duration : float The simulation time length. If you have provided ``inputs``, there is no longer need to provide ``duration``. - inputs: ArrayType, dict of ArrayType, sequence of ArrayType + inputs : ArrayType, dict of ArrayType, sequence of ArrayType The input data. - If the mode of ``target`` is instance of :py:class:`~.BatchingMode`, @@ -427,22 +431,23 @@ def predict( - If the mode of ``target`` is instance of :py:class:`~.NonBatchingMode`, the ``inputs`` should be a PyTree of data with one dimension: ``(time, ...)``. - inputs_are_batching: bool + inputs_are_batching : bool Whether the ``inputs`` are batching. If `True`, the batching axis is the first dimension. .. deprecated:: 2.3.1 Will be removed after version 2.4.0. - reset_state: bool + reset_state : bool Whether reset the model states. - eval_time: bool + eval_time : bool Whether ro evaluate the running time. - shared_args: optional, dict + shared_args : optional, dict The shared arguments across different layers. - Returns:: + Returns + ------- - output: ArrayType, dict, sequence + output : ArrayType, dict, sequence The model output. """ @@ -534,15 +539,17 @@ def __call__(self, *args, **kwargs) -> Union[Output, Tuple[float, Output]]: def _predict(self, indices, *xs, shared_args=None) -> Union[Output, Monitor]: """Predict the output according to the inputs. - Parameters:: + Parameters + ---------- - xs: sequence + xs : sequence If `inputs` is not None, it should be a tensor with the shape of :math:`(num_time, ...)`. - shared_args: optional, dict + shared_args : optional, dict The shared keyword arguments. - Returns:: + Returns + ------- outputs, hists A tuple of pair of (outputs, hists). diff --git a/brainpy/running/jax_multiprocessing.py b/brainpy/running/jax_multiprocessing.py index fb18dd12c..e134c917d 100644 --- a/brainpy/running/jax_multiprocessing.py +++ b/brainpy/running/jax_multiprocessing.py @@ -40,20 +40,22 @@ def jax_vectorize_map( suitable to be used in GPU backends. This is because ``jax.vmap`` can parallelize the mapped axis on GPU devices. - Parameters:: + Parameters + ---------- - func: callable, function + func : callable, function The function to be mapped. - arguments: sequence, dict + arguments : sequence, dict The function arguments, used to define tasks. - num_parallel: int + num_parallel : int The number of batch size. - clear_buffer: bool + clear_buffer : bool Clear the buffer memory after running each batch data. - Returns:: + Returns + ------- - results: Any + results : Any The running results. """ if not isinstance(arguments, (dict, tuple, list)): @@ -105,20 +107,22 @@ def jax_parallelize_map( If you are using it in a single CPU, please set host device count by ``brainpy.math.set_host_device_count(n)`` before. - Parameters:: + Parameters + ---------- - func: callable, function + func : callable, function The function to be mapped. - arguments: sequence, dict + arguments : sequence, dict The function arguments, used to define tasks. - num_parallel: int + num_parallel : int The number of batch size. - clear_buffer: bool + clear_buffer : bool Clear the buffer memory after running each batch data. - Returns:: + Returns + ------- - results: Any + results : Any The running results. """ if not isinstance(arguments, (dict, tuple, list)): diff --git a/brainpy/running/native_multiprocessing.py b/brainpy/running/native_multiprocessing.py index 11f043134..a9f2ffa63 100644 --- a/brainpy/running/native_multiprocessing.py +++ b/brainpy/running/native_multiprocessing.py @@ -30,7 +30,8 @@ def process_pool(func: callable, .. Note:: This multiprocessing function should be called within a `if __main__ == '__main__':` syntax. - Parameters:: + Parameters + ---------- func : callable The function to run model. @@ -40,7 +41,8 @@ def process_pool(func: callable, num_process : int The number of the processes. - Returns:: + Returns + ------- results : list Process results. @@ -84,16 +86,18 @@ def some_func(..., lock, ...): .. Note:: This multiprocessing function should be called within a `if __main__ == '__main__':` syntax. - Parameters:: + Parameters + ---------- - func: callable + func : callable The function to run model. all_params : list, tuple, dict The parameters of the function arguments. num_process : int The number of the processes. - Returns:: + Returns + ------- results : list Process results. diff --git a/brainpy/running/pathos_multiprocessing.py b/brainpy/running/pathos_multiprocessing.py index 045af1cdd..8d55ffc4e 100644 --- a/brainpy/running/pathos_multiprocessing.py +++ b/brainpy/running/pathos_multiprocessing.py @@ -56,26 +56,28 @@ def _parallel( ) -> Generator: """Perform a parallel map with a progress bar. - Parameters:: + Parameters + ---------- - ordered: bool + ordered : bool True for an ordered map, false for an unordered map. - function: callable, function + function : callable, function The function to apply to each element of the given Iterables. - arguments: sequence of Iterable, dict + arguments : sequence of Iterable, dict One or more Iterables containing the data to be mapped. - num_process: int, float + num_process : int, float Number of threads used for parallel running. If `int`, it is the number of threads to be used; if `float`, it is the fraction of total threads to be used for running. - num_task: int + num_task : int The total number of tasks in this parallel running. - tqdm_kwargs: Any + tqdm_kwargs : Any The setting for the progress bar. - Returns:: + Returns + ------- - results: Iterable + results : Iterable A generator which will apply the function to each element of the given Iterables in parallel in order with a progress bar. """ @@ -140,7 +142,8 @@ def cpu_ordered_parallel( ) -> List[Any]: """Performs a parallel ordered map with a progress bar. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -160,24 +163,26 @@ def cpu_ordered_parallel( >>> results = bp.running.cpu_unordered_parallel(simulate, [np.arange(1, 10, 100)], num_process=10) >>> print(results) - Parameters:: + Parameters + ---------- - func: callable, function + func : callable, function The function to apply to each element of the given Iterables. - arguments: sequence of Iterable, dict + arguments : sequence of Iterable, dict One or more Iterables containing the data to be mapped. - num_process: int, float + num_process : int, float Number of threads used for parallel running. If `int`, it is the number of threads to be used; if `float`, it is the fraction of total threads to be used for running. - num_task: int + num_task : int The total number of tasks in this parallel running. - tqdm_kwargs: Any + tqdm_kwargs : Any The setting for the progress bar. - Returns:: + Returns + ------- - results: list + results : list A list which will apply the function to each element of the given tasks. """ generator = _parallel(True, @@ -198,7 +203,8 @@ def cpu_unordered_parallel( ) -> List[Any]: """Performs a parallel unordered map with a progress bar. - Examples:: + Examples + -------- >>> import brainpy as bp >>> import brainpy.math as bm @@ -218,24 +224,26 @@ def cpu_unordered_parallel( >>> results = bp.running.cpu_unordered_parallel(simulate, [np.arange(1, 10, 100)], num_process=10) >>> print(results) - Parameters:: + Parameters + ---------- - func: callable, function + func : callable, function The function to apply to each element of the given Iterables. - arguments: sequence of Iterable, dict + arguments : sequence of Iterable, dict One or more Iterables containing the data to be mapped. - num_process: int, float + num_process : int, float Number of threads used for parallel running. If `int`, it is the number of threads to be used; if `float`, it is the fraction of total threads to be used for running. - num_task: int + num_task : int The total number of tasks in this parallel running. - tqdm_kwargs: Any + tqdm_kwargs : Any The setting for the progress bar. - Returns:: + Returns + ------- - results: list + results : list A list which will apply the function to each element of the given tasks. """ generator = _parallel(False, diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py index 5f18c464b..cb3ef118a 100644 --- a/brainpy/running/runner.py +++ b/brainpy/running/runner.py @@ -33,12 +33,13 @@ class Runner(BrainPyObject): """Base Runner. - Parameters:: + Parameters + ---------- - target: Any + target : Any The target model. - monitors: None, sequence of str, dict, Monitor + monitors : None, sequence of str, dict, Monitor Variables to monitor. - A list of string. Like ``monitors=['a', 'b', 'c']`` @@ -50,20 +51,20 @@ class Runner(BrainPyObject): .. versionchanged:: 2.3.1 ``func_monitors`` are merged into ``monitors``. - fun_monitors: dict + fun_monitors : dict Monitoring variables by a dict of callable functions. The `key` should be a string for later retrieval by `runner.mon[key]`. The `value` should be a callable function which receives two arguments: `t` and `dt`. .. deprecated:: 2.3.1 Use ``monitors`` instead. - jit: bool, dict + jit : bool, dict The JIT settings. - progress_bar: bool + progress_bar : bool Use progress bar to report the running progress or not? - dyn_vars: Optional, Variable, sequence of Variable, dict + dyn_vars : Optional, Variable, sequence of Variable, dict The dynamically changed variables. Instance of :py:class:`~.Variable`. numpy_mon_after_run : bool diff --git a/brainpy/tools/codes.py b/brainpy/tools/codes.py index bd0a5a809..a854403b6 100644 --- a/brainpy/tools/codes.py +++ b/brainpy/tools/codes.py @@ -128,19 +128,22 @@ def get_identifiers(expr: str, include_numbers: bool = False) -> Set[str]: that matches a programming language variable like expression, which is here implemented as the regexp ``\\b[A-Za-z_][A-Za-z0-9_]*\\b``. - Parameters:: + Parameters + ---------- expr : str The string to analyze include_numbers : bool, optional Whether to include number literals in the output. Defaults to ``False``. - Returns:: + Returns + ------- identifiers : set A set of all the identifiers (and, optionally, numbers) in `expr`. - Examples:: + Examples + -------- >>> expr = '3-a*_b+c5+8+f(A - .3e-10, tau_2)*17' >>> ids = get_identifiers(expr) @@ -203,7 +206,8 @@ def word_replace(expr: str, substitutions: Dict[str, Any], exclude_dot: bool = T word ``word`` appearing in ``expr`` is replaced by ``rep``. Here a 'word' means anything matching the regexp ``\\bword\\b``. - Examples:: + Examples + -------- >>> expr = 'a*_b+c5+8+f(A)' >>> print(word_replace(expr, {'a':'banana', 'f':'func'})) @@ -231,12 +235,14 @@ def is_lambda_function(func: Any) -> bool: """Check whether the function is a ``lambda`` function. Comes from https://stackoverflow.com/questions/23852423/how-to-check-that-variable-is-a-lambda-function - Parameters:: + Parameters + ---------- func : callable function The function. - Returns:: + Returns + ------- bool True of False. @@ -260,11 +266,13 @@ def get_main_code(func: Optional[Callable[..., Any]], codes: Optional[str] = Non For lambda function, return the - Parameters:: + Parameters + ---------- func : callable, Optional, int, float - Returns:: + Returns + ------- """ if func is None: diff --git a/brainpy/tools/dicts.py b/brainpy/tools/dicts.py index 4fb543f40..35f94ef10 100644 --- a/brainpy/tools/dicts.py +++ b/brainpy/tools/dicts.py @@ -75,14 +75,16 @@ def update(self, *args: Any, **kwargs: Any) -> 'DotDict': # type: ignore[overri def __add__(self, other: Mapping[Any, Any]) -> 'DotDict': """Merging two dicts. - Parameters:: + Parameters + ---------- - other: dict + other : dict The other dict instance. - Returns:: + Returns + ------- - gather: Collector + gather : Collector The new collector. """ gather = type(self)(self) @@ -92,14 +94,16 @@ def __add__(self, other: Mapping[Any, Any]) -> 'DotDict': def __sub__(self, other: Union[Dict[Any, Any], Sequence[Any]]) -> 'DotDict': """Remove other item in the collector. - Parameters:: + Parameters + ---------- - other: dict, sequence + other : dict, sequence The items to remove. - Returns:: + Returns + ------- - gather: Collector + gather : Collector The new collector. """ if not isinstance(other, (dict, tuple, list)): @@ -160,7 +164,8 @@ def subset(self, var_type: type) -> 'DotDict': >>> # get all ODE integrators >>> some_collector.subset(bp.ode.ODEIntegrator) - Parameters:: + Parameters + ---------- var_type : type The type/class to match. diff --git a/brainpy/tools/functions.py b/brainpy/tools/functions.py index 9cfa29be1..3b9aa5a29 100644 --- a/brainpy/tools/functions.py +++ b/brainpy/tools/functions.py @@ -96,7 +96,8 @@ def __reduce__(self) -> Tuple[Any, ...]: class Compose(object): """ A composition of functions - See Also: + See Also + -------- compose """ __slots__ = 'first', 'funcs' diff --git a/brainpy/tools/others.py b/brainpy/tools/others.py index d20f28248..ecc962d51 100644 --- a/brainpy/tools/others.py +++ b/brainpy/tools/others.py @@ -108,12 +108,14 @@ def to_size(x: Union[int, Sequence[int], None]) -> Optional[Tuple[int, ...]]: def timeout(s: float) -> Callable[[Callable[..., T]], Callable[..., T]]: """Add a timeout parameter to a function and return it. - Parameters:: + Parameters + ---------- s : float Time limit in seconds. - Returns:: + Returns + ------- func : callable Functional results. Or, raise an error of KeyboardInterrupt. diff --git a/brainpy/tools/progress.py b/brainpy/tools/progress.py index a3c3bb8a1..e08cfeb01 100644 --- a/brainpy/tools/progress.py +++ b/brainpy/tools/progress.py @@ -34,11 +34,14 @@ def func_dump(func: python_types.FunctionType) -> Tuple[str, Optional[Tuple[Any, ...]], Optional[Tuple[Any, ...]]]: """Serializes a user defined function. - Args: - func: the function to serialize. - - Returns: - A tuple `(code, defaults, closure)`. + Parameters + ---------- + func : python_types.FunctionType + the function to serialize. + + Returns + ------- + A tuple `(code, defaults, closure)`. """ if os.name == "nt": raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") @@ -58,14 +61,20 @@ def func_load(code: Any, defaults: Any = None, closure: Any = None, globs: Optional[Dict[str, Any]] = None) -> python_types.FunctionType: """Deserializes a user defined function. - Args: - code: bytecode of the function. - defaults: defaults of the function. - closure: closure of the function. - globs: dictionary of global objects. - - Returns: - A function object. + Parameters + ---------- + code : Any + bytecode of the function. + defaults : Any + defaults of the function. + closure : Any + closure of the function. + globs : Optional[Dict[str, Any]] + dictionary of global objects. + + Returns + ------- + A function object. """ if isinstance(code, (tuple, list)): # unpack previous dump code, defaults, closure = code @@ -75,11 +84,14 @@ def func_load(code: Any, defaults: Any = None, closure: Any = None, def ensure_value_to_cell(value: Any) -> Any: """Ensures that a value is converted to a python cell object. - Args: - value: Any value that needs to be casted to the cell type + Parameters + ---------- + value : Any + Any value that needs to be casted to the cell type - Returns: - A value wrapped as a cell object (see function "func_load") + Returns + ------- + A value wrapped as a cell object (see function "func_load") """ def dummy_fn() -> None: @@ -107,15 +119,22 @@ def dummy_fn() -> None: class Progbar: """Displays a progress bar. - Args: - target: Total number of steps expected, None if unknown. - width: Progress bar width on screen. - verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) - stateful_metrics: Iterable of string names of metrics that should *not* - be averaged over time. Metrics in this list will be displayed as-is. - All others will be averaged by the progbar before display. - interval: Minimum visual progress update interval (in seconds). - unit_name: Display name for step counts (usually "step" or "sample"). + Parameters + ---------- + target : Optional[int] + Total number of steps expected, None if unknown. + width : int + Progress bar width on screen. + verbose : int + Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics : Optional[Iterable[str]] + Iterable of string names of metrics that should *not* + be averaged over time. Metrics in this list will be displayed as-is. + All others will be averaged by the progbar before display. + interval : float + Minimum visual progress update interval (in seconds). + unit_name : str + Display name for step counts (usually "step" or "sample"). """ def __init__( @@ -159,14 +178,18 @@ def update(self, current: int, values: Optional[Sequence[Tuple[str, float]]] = N finalize: Optional[bool] = None) -> None: """Updates the progress bar. - Args: - current: Index of current step. - values: List of tuples: `(name, value_for_last_step)`. If `name` is - in `stateful_metrics`, `value_for_last_step` will be displayed - as-is. Else, an average of the metric over time will be - displayed. - finalize: Whether this is the last update for the progress bar. If - `None`, uses `current >= self.target`. Defaults to `None`. + Parameters + ---------- + current : int + Index of current step. + values : Optional[Sequence[Tuple[str, float]]] + List of tuples: `(name, value_for_last_step)`. If `name` is + in `stateful_metrics`, `value_for_last_step` will be displayed + as-is. Else, an average of the metric over time will be + displayed. + finalize : Optional[bool] + Whether this is the last update for the progress bar. If + `None`, uses `current >= self.target`. Defaults to `None`. """ if finalize is None: if self.target is None: @@ -316,11 +339,17 @@ def _format_time(self, time_per_unit: float, unit_name: str) -> str: Given the duration, this function formats it in either milliseconds or seconds and displays the unit (i.e. ms/step or s/epoch) - Args: - time_per_unit: the duration to display - unit_name: the name of the unit to display - Returns: - a string with the correctly formatted duration and units + + Parameters + ---------- + time_per_unit : float + the duration to display + unit_name : str + the name of the unit to display + + Returns + ------- + a string with the correctly formatted duration and units """ formatted = "" if time_per_unit >= 1 or time_per_unit == 0: @@ -341,11 +370,16 @@ def _estimate_step_duration(self, current: int, now: float) -> float: of the (assumed to be non-representative) first step for estimates when more steps are available (i.e. `current>1`). - Args: - current: Index of current step. - now: The current time. + Parameters + ---------- + current : int + Index of current step. + now : float + The current time. - Returns: Estimate of the duration of a single step. + Returns + ------- + Estimate of the duration of a single step. """ if current: # there are a few special scenarios here: @@ -374,12 +408,16 @@ def _update_stateful_metrics(self, stateful_metrics: Iterable[str]) -> None: def make_batches(size: int, batch_size: int) -> List[Tuple[int, int]]: """Returns a list of batch indices (tuples of indices). - Args: - size: Integer, total size of the data to slice into batches. - batch_size: Integer, batch size. + Parameters + ---------- + size : int + Integer, total size of the data to slice into batches. + batch_size : int + Integer, batch size. - Returns: - A list of tuples of array indices. + Returns + ------- + A list of tuples of array indices. """ num_batches = int(np.ceil(size / float(batch_size))) return [ @@ -398,16 +436,23 @@ def slice_arrays(arrays: Any, start: Any = None, stop: Any = None) -> Any: Can also work on list/array of indices: `slice_arrays(x, indices)` - Args: - arrays: Single array or list of arrays. - start: can be an integer index (start index) or a list/array of indices - stop: integer (stop index); should be None if `start` was a list. - - Returns: - A slice of the array(s). - - Raises: - ValueError: If the value of start is a list and stop is not None. + Parameters + ---------- + arrays : Any + Single array or list of arrays. + start : Any + can be an integer index (start index) or a list/array of indices + stop : Any + integer (stop index); should be None if `start` was a list. + + Returns + ------- + A slice of the array(s). + + Raises + ------ + ValueError + If the value of start is a list and stop is not None. """ if arrays is None: return [None] @@ -446,11 +491,14 @@ def to_list(x: Any) -> List[Any]: If a tensor is passed, we return a list of size 1 containing the tensor. - Args: - x: target object to be normalized. + Parameters + ---------- + x : Any + target object to be normalized. - Returns: - A list. + Returns + ------- + A list. """ if isinstance(x, list): return x diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py index 332881e55..627e53f6e 100644 --- a/brainpy/train/back_propagation.py +++ b/brainpy/train/back_propagation.py @@ -51,33 +51,34 @@ class BPTrainer(DSTrainer): For more parameters, users should refer to :py:class:`~.DSRunner`. - Parameters:: + Parameters + ---------- - target: DynamicalSystem + target : DynamicalSystem The target model to train. - loss_fun: str, callable + loss_fun : str, callable The loss function. If it is a string, it should be the function chosen from ``brainpy.losses`` module. Otherwise, a callable function which receives argument of `(predicts, targets)` should be provided. - loss_has_aux: bool + loss_has_aux : bool To indicate whether the `loss_fun` returns auxiliary data. - loss_auto_run: bool + loss_auto_run : bool pass - optimizer: optim.Optimizer + optimizer : optim.Optimizer The optimizer used for training. - numpy_mon_after_run: bool + numpy_mon_after_run : bool Make the monitored results as NumPy arrays. - logger: Any + logger : Any A file-like object (stream); defaults to the current `sys.stdout`. - shuffle_data: bool + shuffle_data : bool .. deprecated:: 2.2.4.1 Control the data shuffling by user self. - seed: int + seed : int .. deprecated:: 2.2.4.1 Control the data shuffling by user self. - kwargs: Any + kwargs : Any Other general parameters please see :py:class:`~.DSRunner`. """ @@ -201,9 +202,10 @@ def fit( ): """Fit the target model according to the given training data. - Parameters:: + Parameters + ---------- - train_data: callable, iterable + train_data : callable, iterable It can be a callable function, or a tuple/list representing `(X, Y)` data. - Callable. This function should return a pair of `(X, Y)` data. - Iterable. It should be a pair of `(X, Y)` train set. @@ -217,18 +219,18 @@ def fit( then we will only fit the model with the only last output. - If the shape of each tensor is `(num_sample, num_time, num_feature)`, then the fitting happens on the whole data series. - test_data: callable, iterable, optional + test_data : callable, iterable, optional Same as ``train_data``. - num_epoch: int + num_epoch : int The number of training epoch. Default 100. - num_report: int + num_report : int The number of step to report the progress. If `num_report=-1`, it will report the training progress each epoch. - reset_state: bool + reset_state : bool Whether reset the initial states of the target model. - shared_args: dict + shared_args : dict The shared keyword arguments for the target models. - fun_after_report: optional, Callable + fun_after_report : optional, Callable The function to call after each report of `fit` phase or `test` phase. The function should receive three arguments: - ``idx`` for the indicator the current the running index. (If ``report=-1``, @@ -238,7 +240,7 @@ def fit( - ``phase``: to indicate the phase of 'fit' or 'test'. .. versionadded:: 2.3.1 - batch_size: int + batch_size : int .. deprecated:: 2.2.4.1 Please set batch size in your dataset. @@ -480,12 +482,13 @@ class BPTT(BPTrainer): For more parameters, users should refer to :py:class:`~.DSRunner`. - Parameters:: + Parameters + ---------- - target: DynamicalSystem + target : DynamicalSystem The target model to train. - loss_fun: str, callable + loss_fun : str, callable The loss function. - If it is a string, it should be the function chosen from ``brainpy.losses`` module. @@ -497,7 +500,7 @@ class BPTT(BPTrainer): parts: the network history prediction outputs, and the monitored values. see BrainPy examples for more information. - loss_has_aux: bool + loss_has_aux : bool To indicate whether the loss function returns auxiliary data expect the loss. Moreover, all auxiliary data should be a dict, whose key is used for logging item name and its data is used for the corresponding value. @@ -507,13 +510,13 @@ class BPTT(BPTrainer): def loss_fun(predicts, targets): return loss, {'acc': acc, 'spike_num': spike_num} - optimizer: Optimizer + optimizer : Optimizer The optimizer used for training. Should be an instance of :py:class:`~.Optimizer`. - numpy_mon_after_run: bool + numpy_mon_after_run : bool Make the monitored results as NumPy arrays. - logger: Any + logger : Any A file-like object (stream). Used to output the running results. Default is the current `sys.stdout`. - data_first_axis: str + data_first_axis : str To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the time length (``data_first_axis='T'``). """ @@ -596,21 +599,23 @@ def predict( Moreover, it can automatically monitor the node variables, states, inputs, feedbacks and its output. - Parameters:: + Parameters + ---------- - inputs: ArrayType, dict + inputs : ArrayType, dict The feedforward input data. It must be a 3-dimensional data which has the shape of `(num_sample, num_time, num_feature)`. - reset_state: bool + reset_state : bool Whether reset the model states. - shared_args: optional, dict + shared_args : optional, dict The shared arguments across different layers. - eval_time: bool + eval_time : bool Evaluate the time used for running. - Returns:: + Returns + ------- - output: ArrayType, dict + output : ArrayType, dict The model output. """ if shared_args is None: diff --git a/brainpy/train/base.py b/brainpy/train/base.py index 4709fe6f3..b420da6af 100644 --- a/brainpy/train/base.py +++ b/brainpy/train/base.py @@ -32,12 +32,13 @@ class DSTrainer(DSRunner): For more parameters, users should refer to :py:class:`~.DSRunner`. - Parameters:: + Parameters + ---------- - target: DynamicalSystem + target : DynamicalSystem The training target. - kwargs: Any + kwargs : Any Other general parameters in :py:class:`~.DSRunner`. """ @@ -81,20 +82,22 @@ def predict( ) -> Output: """Prediction function. - Parameters:: + Parameters + ---------- - inputs: ArrayType, sequence of ArrayType, dict of ArrayType + inputs : ArrayType, sequence of ArrayType, dict of ArrayType The input values. - reset_state: bool + reset_state : bool Reset the target state before running. - eval_time: bool + eval_time : bool Whether we evaluate the running time or not? - shared_args: dict + shared_args : dict The shared arguments across nodes. - Returns:: + Returns + ------- - output: ArrayType, sequence of ArrayType, dict of ArrayType + output : ArrayType, sequence of ArrayType, dict of ArrayType The running output. """ if shared_args is None: diff --git a/brainpy/train/offline.py b/brainpy/train/offline.py index 78863ec16..945780151 100644 --- a/brainpy/train/offline.py +++ b/brainpy/train/offline.py @@ -42,11 +42,12 @@ class OfflineTrainer(DSTrainer): For more parameters, users should refer to :py:class:`~.DSRunner`. - Parameters:: + Parameters + ---------- - target: DynamicalSystem + target : DynamicalSystem The target model to train. - fit_method: OfflineAlgorithm, Callable, dict, str + fit_method : OfflineAlgorithm, Callable, dict, str The fitting method applied to the target model. - It can be a string, which specify the shortcut name of the training algorithm. Like, ``fit_method='ridge'`` means using the Ridge regression method. @@ -60,7 +61,7 @@ class OfflineTrainer(DSTrainer): - It can also be a callable function, which receives three arguments "targets", "x" and "y". For example, ``fit_method=lambda targets, x, y: numpy.linalg.lstsq(x, targets)[0]``. - kwargs: Any + kwargs : Any Other general parameters please see :py:class:`~.DSRunner`. """ @@ -120,20 +121,22 @@ def predict( What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that the `inputs_are_batching` is default `True`. - Parameters:: + Parameters + ---------- - inputs: ArrayType + inputs : ArrayType The input values. - reset_state: bool + reset_state : bool Reset the target state before running. - eval_time: bool + eval_time : bool Whether we evaluate the running time or not? - shared_args: dict + shared_args : dict The shared arguments across nodes. - Returns:: + Returns + ------- - output: ArrayType + output : ArrayType The running output. """ outs = super().predict(inputs=inputs, reset_state=reset_state, @@ -150,9 +153,10 @@ def fit( ) -> Output: """Fit the target model according to the given training and testing data. - Parameters:: + Parameters + ---------- - train_data: sequence of data + train_data : sequence of data It should be a pair of `(X, Y)` train set. - ``X``: should be a tensor or a dict of tensors with the shape of `(num_sample, num_time, num_feature)`, where `num_sample` is @@ -163,9 +167,9 @@ def fit( then we will only fit the model with the only last output. - If the shape of each tensor is `(num_sample, num_time, num_feature)`, then the fitting happens on the whole data series. - reset_state: bool + reset_state : bool Whether reset the initial states of the target model. - shared_args: dict + shared_args : dict The shared keyword arguments for the target models. """ with brainstate.environ.context(fit=True): @@ -276,11 +280,12 @@ class RidgeTrainer(OfflineTrainer): For more parameters, users should refer to :py:class:`~.DSRunner`. - Parameters:: + Parameters + ---------- - target: TrainingSystem, DynamicalSystem + target : TrainingSystem, DynamicalSystem The target model. - alpha: float + alpha : float The regularization coefficient. """ diff --git a/brainpy/train/online.py b/brainpy/train/online.py index e655a1355..d2c920225 100644 --- a/brainpy/train/online.py +++ b/brainpy/train/online.py @@ -44,12 +44,13 @@ class OnlineTrainer(DSTrainer): For more parameters, users should refer to :py:class:`~.DSRunner`. - Parameters:: + Parameters + ---------- - target: DynamicalSystem + target : DynamicalSystem The target model to train. - fit_method: OnlineAlgorithm, Callable, dict, str + fit_method : OnlineAlgorithm, Callable, dict, str The fitting method applied to the target model. - It can be a string, which specify the shortcut name of the training algorithm. @@ -63,7 +64,7 @@ class OnlineTrainer(DSTrainer): For example, ``fit_meth=bp.algorithms.RLS(alpha=1e-5)``. - It can also be a callable function. - kwargs: Any + kwargs : Any Other general parameters please see :py:class:`~.DSRunner`. """ @@ -128,20 +129,22 @@ def predict( What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that the `inputs_are_batching` is default `True`. - Parameters:: + Parameters + ---------- - inputs: ArrayType + inputs : ArrayType The input values. - reset_state: bool + reset_state : bool Reset the target state before running. - shared_args: dict + shared_args : dict The shared arguments across nodes. - eval_time: bool + eval_time : bool Whether we evaluate the running time or not? - Returns:: + Returns + ------- - output: ArrayType + output : ArrayType The running output. """ outs = super().predict(inputs=inputs, @@ -225,16 +228,18 @@ def _fit(self, shared_args: Dict = None): """Predict the output according to the inputs. - Parameters:: + Parameters + ---------- - indices: ArrayType + indices : ArrayType The running indices. - ys: dict + ys : dict Each tensor should have the shape of `(num_time, num_batch, num_feature)`. - shared_args: optional, dict + shared_args : optional, dict The shared keyword arguments. - Returns:: + Returns + ------- outputs, hists A tuple of pair of (outputs, hists). diff --git a/brainpy/transform.py b/brainpy/transform.py index 4422c6f53..8c1885da8 100644 --- a/brainpy/transform.py +++ b/brainpy/transform.py @@ -46,7 +46,8 @@ class LoopOverTime(DynamicalSystem): For more flexible customization, we recommend users to use :py:func:`~.for_loop`, or :py:class:`~.DSRunner`. - Examples:: + Examples + -------- This model can be used for network training: @@ -96,11 +97,12 @@ class LoopOverTime(DynamicalSystem): >>> plt.show() - Parameters:: + Parameters + ---------- - target: DynamicalSystem + target : DynamicalSystem The target to transform. - no_state: bool + no_state : bool Denoting whether the `target` has the shared argument or not. - For ANN layers which are no_state, like :py:class:`~.Dense` or :py:class:`~.Conv2d`, @@ -110,23 +112,23 @@ class LoopOverTime(DynamicalSystem): send data to the object, and reshape output to `shape = [T, N, *]`. In this way, the calculation over different time is parralelized. - out_vars: PyTree + out_vars : PyTree The variables to monitor over the time loop. - t0: float, optional + t0 : float, optional The start time to run the system. If None, ``t`` will be no longer generated in the loop. - i0: int, optional + i0 : int, optional The start index to run the system. If None, ``i`` will be no longer generated in the loop. - dt: float + dt : float The time step. - shared_arg: dict + shared_arg : dict The shared arguments across the nodes. For instance, `shared_arg={'fit': False}` for the prediction phase. - data_first_axis: str + data_first_axis : str Denoting the type of the first axis of input data. If ``'T'``, we treat the data as `(time, ...)`. If ``'B'``, we treat the data as `(batch, time, ...)` when the `target` is in Batching mode. Default is ``'T'``. - name: str + name : str The transformed object name. """ @@ -194,15 +196,17 @@ def __call__( ): """Forward propagation along the time or inputs. - Parameters:: + Parameters + ---------- - duration_or_xs: float, PyTree + duration_or_xs : float, PyTree If `float`, it indicates a running duration. If a PyTree, it is the given inputs. - Returns:: + Returns + ------- - out: PyTree + out : PyTree The accumulated outputs over time. """ # inputs diff --git a/docs/conf.py b/docs/conf.py index 9331e1428..f5c069e3f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -135,3 +135,16 @@ autodoc_default_options = { 'exclude-members': '....,default_rng', } + +# -- Options for napoleon (docstring style) ---------------------------------- +# BrainPy standardizes on NumPy-style docstrings (see CLAUDE.md). Parse only the +# NumPy style so that any stray Google-style docstring renders incorrectly and is +# caught in review, keeping the convention enforced. +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = False +napoleon_include_private_with_doc = False +napoleon_use_param = True +napoleon_use_rtype = True +napoleon_preprocess_types = True +napoleon_use_ivar = True diff --git a/pyproject.toml b/pyproject.toml index a58c60ab9..8afcf9e65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,16 @@ exclude = [ version = { attr = "brainpy.__version__" } +# --------------------------------------------------------------------------- +# Docstring style +# --------------------------------------------------------------------------- +# BrainPy standardizes on NumPy-style docstrings (see CLAUDE.md and the napoleon +# settings in docs/conf.py). This records the convention as a single source of +# truth so ``pydocstyle`` / ``ruff --select D`` enforce the NumPy style on demand. +[tool.pydocstyle] +convention = "numpy" + + [tool.coverage.run] # Measure coverage of the library source only -- never of the test files # themselves (a test file's lines are not "product code", and several