|
9 | 9 | from pytensor.raise_op import Assert |
10 | 10 | from pytensor.tensor import TensorVariable |
11 | 11 | from pytensor.tensor.nlinalg import matrix_dot |
12 | | -from pytensor.tensor.slinalg import SolveTriangular |
| 12 | +from pytensor.tensor.slinalg import solve_triangular |
13 | 13 |
|
14 | 14 | from pymc_experimental.statespace.filters.utilities import ( |
15 | 15 | quad_form_sym, |
|
22 | 22 | MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) |
23 | 23 | PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] |
24 | 24 |
|
25 | | -solve_lower_triangular = SolveTriangular(lower=True) |
26 | 25 | assert_data_is_1d = Assert("UnivariateTimeSeries filter requires data be at most 1-dimensional") |
27 | 26 | assert_time_varying_dim_correct = Assert( |
28 | 27 | "The first dimension of a time varying matrix (the time dimension) must be " |
@@ -684,13 +683,13 @@ def update(self, a, P, y, c, d, Z, H, all_nan_flag): |
684 | 683 | F_chol = pt.linalg.cholesky(F) |
685 | 684 |
|
686 | 685 | # If everything is missing, K = 0, IKZ = I |
687 | | - K = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, PZT.T)).T |
| 686 | + K = solve_triangular(F_chol.T, solve_triangular(F_chol, PZT.T)).T |
688 | 687 | I_KZ = self.eye_states - K.dot(Z) |
689 | 688 |
|
690 | 689 | a_filtered = a + K.dot(v) |
691 | 690 | P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) |
692 | 691 |
|
693 | | - inner_term = solve_lower_triangular(F_chol.T, solve_lower_triangular(F_chol, v)) |
| 692 | + inner_term = solve_triangular(F_chol.T, solve_triangular(F_chol, v)) |
694 | 693 | n = y.shape[0] |
695 | 694 |
|
696 | 695 | ll = pt.switch( |
|
0 commit comments