1- from collections .abc import Sequence
1+ from collections .abc import Callable , Sequence
22from functools import wraps
33from itertools import zip_longest
44from types import ModuleType
5- from typing import TYPE_CHECKING , Literal
5+ from typing import TYPE_CHECKING
66
77import numpy as np
88
2222 from pytensor .tensor .random .op import RandomVariable
2323
2424
25- def params_broadcast_shapes (param_shapes , ndims_params , use_pytensor = True ):
25+ def params_broadcast_shapes (
26+ param_shapes : Sequence , ndims_params : Sequence [int ], use_pytensor : bool = True
27+ ) -> list [tuple [int , ...]]:
2628 """Broadcast parameters that have different dimensions.
2729
2830 Parameters
@@ -36,12 +38,12 @@ def params_broadcast_shapes(param_shapes, ndims_params, use_pytensor=True):
3638
3739 Returns
3840 =======
39- bcast_shapes : list of ndarray
41+ bcast_shapes : list of tuples of ints
4042 The broadcasted values of `params`.
4143 """
4244 max_fn = maximum if use_pytensor else max
4345
44- rev_extra_dims = []
46+ rev_extra_dims : list [ int ] = []
4547 for ndim_param , param_shape in zip (ndims_params , param_shapes ):
4648 # We need this in order to use `len`
4749 param_shape = tuple (param_shape )
@@ -71,7 +73,9 @@ def max_bcast(x, y):
7173 return bcast_shapes
7274
7375
74- def broadcast_params (params , ndims_params ):
76+ def broadcast_params (
77+ params : Sequence [np .ndarray | TensorVariable ], ndims_params : Sequence [int ]
78+ ) -> list [np .ndarray ]:
7579 """Broadcast parameters that have different dimensions.
7680
7781 >>> ndims_params = [1, 2]
@@ -215,7 +219,9 @@ def __init__(
215219 self ,
216220 seed : int | None = None ,
217221 namespace : ModuleType | None = None ,
218- rng_ctor : Literal [np .random .Generator ] = np .random .default_rng ,
222+ rng_ctor : Callable [
223+ [np .random .SeedSequence ], np .random .Generator
224+ ] = np .random .default_rng ,
219225 ):
220226 if namespace is None :
221227 from pytensor .tensor .random import basic # pylint: disable=import-self
0 commit comments