-
-
Notifications
You must be signed in to change notification settings - Fork 167
Open
Description
I have a solve defined with a steady state event. In my example the steady state is triggered as soon as the simulation starts at t0. If I define the event as shown in the snippet below, with a root_finder specified, the sol.ys value returns as [[0.]], instead of the expected value of [[10.0]].
import jax.numpy as jnp
import optimistix as optx
import diffrax
controller = diffrax.PIDController(rtol=1e-6, atol=1e-6)
steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6)
root_finder = optx.Newton(atol=1e-4, rtol=1e-4)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)),
diffrax.Kvaerno5(),
t0=0.0,
t1=1.2,
dt0=None,
y0=jnp.array([10.0]),
stepsize_controller=controller,
event=diffrax.Event(
cond_fn=steady_state_event,
root_finder=root_finder,
),
saveat=diffrax.SaveAt(t1=True),
max_steps=100,
)
print(sol.ys) # Array([[0.]], dtype=float32)If I run the above example but do not specify root_finder, I get the expected outcome of sol.ys = [10.0]. This originally arose when I was defining an event with both a steady state event and another cond_fn e.g.
def another_cond_fn(t, y, args, **kw): return 0.98 - t
event=diffrax.Event(
cond_fn=(steady_state_event, another_cond_fn),
root_finder=root_finder,
direction=(None, True),
)I have a workaround, because I can check if the system is in a steady state from t0 without calling diffraxeqsolve, but the fact the solution is affected by whether root_finder is defined in the event seems like it could be a bug.
Metadata
Metadata
Assignees
Labels
No labels