From 506c5be6b09429feb0a340e2a2c4295a69f38138 Mon Sep 17 00:00:00 2001 From: LuggiStruggi Date: Tue, 27 May 2025 11:42:04 +0200 Subject: [PATCH 1/2] start root finder at time of previous ode_solver_step than final ode_solver_step --- diffrax/_integrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 0dc09800..37292124 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -727,7 +727,7 @@ def _call_real_impl(): _event_root_find = optx.root_find( _to_root_find, event.root_finder, - y0=final_state.event_tnext, + y0=final_state.event_tprev, options=_options, throw=False, ) From 1039524b47163a77db0a6244f2a817592ec8ea27 Mon Sep 17 00:00:00 2001 From: LuggiStruggi Date: Wed, 28 May 2025 12:28:14 +0200 Subject: [PATCH 2/2] run root finders over all potential first events --- diffrax/_integrate.py | 124 ++++++++++++++++++------------------------ 1 file changed, 54 insertions(+), 70 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 37292124..be481243 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -582,12 +582,7 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): jtu.tree_structure((0, 0)), event_values__mask, ) - had_event = False - event_mask_leaves = [] - for event_mask_i in jtu.tree_leaves(event_mask): - event_mask_leaves.append(event_mask_i & jnp.invert(had_event)) - had_event = event_mask_i | had_event - event_mask = jtu.tree_unflatten(event_structure, event_mask_leaves) + had_event = jnp.any(jnp.stack(jtu.tree_leaves(event_mask), axis=0)) result = RESULTS.where( had_event, RESULTS.event_occurred, @@ -653,6 +648,7 @@ def body_fun(state): if event is None or event.root_finder is None: tfinal = final_state.tprev yfinal = final_state.y + first_event_mask = final_state.event_mask else: # If we're on this branch, it means that an event may have triggered, and now we # may need to do a root find, in order to locate the event time. @@ -663,19 +659,19 @@ def body_fun(state): event_happened = jnp.max(float_mask) > 0.0 def _root_find(): - _interpolator = solver.interpolation_cls( + interp = solver.interpolation_cls( t0=final_state.event_tprev, t1=final_state.event_tnext, **final_state.event_dense_info, ) - def _to_root_find(_t, _): - _distance_from_t_end = final_state.event_tnext - _t + flat_fns, fn_tree = jtu.tree_flatten(event.cond_fn) + flat_masks, _ = jtu.tree_flatten(event_mask) - def _call_real(_event_mask_i, _cond_fn_i): - def _call_real_impl(): - # First evaluate the triggered event. - _y = _interpolator.evaluate(_t) + def _call_real(_event_mask_i, _cond_fn_i): + def _find(): + def f(_t, _): + _y = interp.evaluate(_t) _value = _cond_fn_i( t=_t, y=_y, @@ -689,67 +685,56 @@ def _call_real_impl(): stepsize_controller=stepsize_controller, max_steps=max_steps, ) - # Second: if this is a boolean event, then normalise to a - # floating point number by having the root occur at the end of - # the last step, i.e. `event_tnext`. - _value_dtype = jnp.result_type(_value) - if jnp.issubdtype(_value_dtype, jnp.bool_): - _value = _distance_from_t_end - else: - assert jnp.issubdtype(_value_dtype, jnp.floating) - return _value - - # Only the triggered event actually gets to the decide what time the - # event occurs; everything else is zeroed out to automatically give - # a root. - # - # We allow this `lax.cond` to be inefficiently transformed into a - # `lax.select` when `_event_mask_i` is batched. There isn't any way - # to avoid this, I think. - _value = lax.cond(_event_mask_i, _call_real_impl, lambda: 0.0) - - # Third: if no events triggered at all, then have the root occur at - # the end of the last step (which will be the `t1` of the overall - # solve). - _value = jnp.where(event_happened, _value, _distance_from_t_end) - return _value - - return jtu.tree_map( - _call_real, - event_mask, - event.cond_fn, - ) + return ( + (final_state.event_tnext - _t) + if jnp.issubdtype(_value.dtype, jnp.bool_) + else _value + ) + + opts = { + "lower": final_state.event_tprev, + "upper": final_state.event_tnext, + } + res = optx.root_find( + f, + event.root_finder, + y0=final_state.event_tnext, + options=opts, + throw=False, + ) + return res.value - _options = { - "lower": final_state.event_tprev, - "upper": final_state.event_tnext, - } - _event_root_find = optx.root_find( - _to_root_find, - event.root_finder, - y0=final_state.event_tprev, - options=_options, - throw=False, + return lax.cond(_event_mask_i, _find, lambda: jnp.inf) + + candidates = jnp.stack( + [_call_real(m, fn) for m, fn in zip(flat_masks, flat_fns)] ) - _tfinal = _event_root_find.value - # TODO: we might need to change the way we evaluate `_yfinal` in order to - # get more accurate derivatives? - _yfinal = _interpolator.evaluate(_tfinal) - _result = RESULTS.where( - _event_root_find.result == optx.RESULTS.successful, + + t_event = jnp.min(candidates) + t_event = jnp.where(jnp.isfinite(t_event), t_event, final_state.event_tnext) + + y_event = interp.evaluate(t_event) + + first_idx = jnp.argmin(candidates) + first_mask_arr = jnp.arange(candidates.shape[0]) == first_idx + first_event_mask = jtu.tree_unflatten(fn_tree, list(first_mask_arr)) + + new_result = RESULTS.where( + jnp.any(jnp.stack(flat_masks)), + RESULTS.event_occurred, result, - RESULTS.promote(_event_root_find.result), ) - return _tfinal, _yfinal, _result + + return t_event, y_event, new_result, first_event_mask # Fastpath: if no event happened anywhere at all, then skip the root-find # altogether. # Note that `_root_find` might still be called on batch elements which did not # have an event, so we still need to access `event_happened` inside of it. - tfinal, yfinal, result = lax.cond( + tfinal, yfinal, result, first_event_mask = lax.cond( eqxi.unvmap_any(event_happened), _root_find, - lambda: (final_state.tprev, final_state.y, result), + lambda: (final_state.tprev, final_state.y, result, final_state.event_mask), ) # We delete all the saved values after the event time. @@ -824,9 +809,13 @@ def _save_t1(subsaveat, save_state): final_state = eqx.tree_at( lambda s: s.save_state, final_state, save_state, is_leaf=_is_none ) + final_state = _handle_static(final_state) result = RESULTS.where(cond_fun(final_state), RESULTS.max_steps_reached, result) aux_stats = dict() # TODO: put something in here? + + # override event mask with first found event + final_state = eqx.tree_at(lambda s: s.event_mask, final_state, first_event_mask) return eqx.tree_at(lambda s: s.result, final_state, result), aux_stats @@ -1339,18 +1328,13 @@ def _outer_cond_fn(cond_fn_i): jtu.tree_structure((0, 0)), event_values__mask, ) - had_event = False - event_mask_leaves = [] - for event_mask_i in jtu.tree_leaves(event_mask): - event_mask_leaves.append(event_mask_i & jnp.invert(had_event)) - had_event = event_mask_i | had_event - event_mask = jtu.tree_unflatten(event_structure, event_mask_leaves) + had_event = jnp.any(jnp.stack(jtu.tree_leaves(event_mask), axis=0)) result = RESULTS.where( had_event, RESULTS.event_occurred, result, ) - del had_event, event_structure, event_mask_leaves, event_values__mask + del had_event, event_structure, event_values__mask # Initialise state init_state = State(