From 8e9d99f4799f4aa606d0c9aa7c8df8f4c5ab94ad Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 31 Aug 2025 13:47:19 +0200 Subject: [PATCH] Fixes 681 --- diffrax/_integrate.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index bc319d40..1855d4ae 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -161,10 +161,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): pass elif n_term_args == 2: vf_type_expected, control_type_expected = term_args - try: - vf_type = eqx.filter_eval_shape(term.vf, t, yi, args) - except Exception as e: - raise ValueError(f"Error while tracing {term}.vf: " + str(e)) + vf_type = eqx.filter_eval_shape(term.vf, t, yi, args) vf_type_compatible = eqx.filter_eval_shape( better_isinstance, vf_type, vf_type_expected ) @@ -173,10 +170,7 @@ def _check(term_cls, term, term_contr_kwargs, yi): contr = ft.partial(term.contr, **term_contr_kwargs) # Work around https://github.com/google/jax/issues/21825 - try: - control_type = eqx.filter_eval_shape(contr, t, t) - except Exception as e: - raise ValueError(f"Error while tracing {term}.contr: " + str(e)) + control_type = eqx.filter_eval_shape(contr, t, t) control_type_compatible = eqx.filter_eval_shape( better_isinstance, control_type, control_type_expected )