-
-
Notifications
You must be signed in to change notification settings - Fork 167
Open
Description
I am encountering an issue with the Kvaerno and KenCarp steppers (possibly others) when the state has more than one component. This issue seems to have appeared with jax 0.8.2.
diffrax and jax versions
diffrax 0.7.0
jax 0.8.2
jaxlib 0.8.2
optimistix 0.0.11
equinox 0.13.2
Platform
M1 and Darwin 25.1.0
Minimum (non-)working example
from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Kvaerno5
vector_field = lambda t, y, args: (-5 * y[0], 0.5 * y[1])
term = ODETerm(vector_field)
solver = Kvaerno5()
saveat = SaveAt(dense=True)
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol = diffeqsolve(
term,
solver,
t0=0,
t1=3,
dt0=0.1,
y0=(1, 1),
saveat=saveat,
stepsize_controller=stepsize_controller,
)Expected behaviour
Solver runs to completion and solution is evaluable.
Observed behaviour
JAX JIT tracing fails with traceback
$ python /Users/ds283/Documents/Code/diffrax-issue/test-diffrax.py
Traceback (most recent call last):
File "/Users/ds283/Documents/Code/diffrax-issue/test-diffrax.py", line 9, in <module>
sol = diffeqsolve(
term,
...<6 lines>...
stepsize_controller=stepsize_controller,
)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_jit.py", line 209, in __call__
return _call(self, False, args, kwargs)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_jit.py", line 263, in _call
marker, _, _ = out = jit_wrapper._cached(
~~~~~~~~~~~~~~~~~~~^
dynamic_donate, dynamic_nodonate, static
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/diffrax/_integrate.py", line 1281, in diffeqsolve
_, _, dense_info_struct, _, _ = eqx.filter_eval_shape(
~~~~~~~~~~~~~~~~~~~~~^
solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
return self.__func__(self.__self__, *args, **kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/diffrax/_solver/runge_kutta.py", line 1149, in step
final_val = eqxi.while_loop(
cond_stage,
...<6 lines>...
base=num_stages,
)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/loop.py", line 107, in while_loop
return checkpointed_while_loop(
cond_fun,
...<4 lines>...
checkpoints=checkpoints,
)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py", line 247, in checkpointed_while_loop
body_fun_ = filter_closure_convert(body_fun_, init_val_)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/common.py", line 511, in new_body_fun
buffer_val2 = body_fun(buffer_val)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/diffrax/_solver/runge_kutta.py", line 984, in rk_stage
nonlinear_sol = optx.root_find(
_implicit_relation_f,
...<5 lines>...
max_steps=self.root_find_max_steps, # pyright: ignore
)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_root_find.py", line 218, in root_find
return iterative_solve(
fn,
...<10 lines>...
rewrite_fn=_rewrite_fn,
)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_iterate.py", line 344, in iterative_solve
) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
return self.__func__(self.__self__, *args, **kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_adjoint.py", line 133, in apply
return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_ad.py", line 60, in implicit_jvp
root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_ad.py", line 67, in _implicit_impl
return jtu.tree_map(jnp.asarray, fn_primal(inputs))
~~~~~~~~~^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_iterate.py", line 240, in _iterate
final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/loop.py", line 103, in while_loop
_, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/common.py", line 511, in new_body_fun
buffer_val2 = body_fun(buffer_val)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_iterate.py", line 230, in body_fun
new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
return self.__func__(self.__self__, *args, **kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/diffrax/_root_finder/_verychord.py", line 127, in step
sol = lx.linear_solve(
jac, fx, self.linear_solver, state=linear_state, throw=False
)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solve.py", line 820, in linear_solve
solution, result, stats = eqxi.filter_primitive_bind(
~~~~~~~~~~~~~~~~~~~~~~~~~~^
linear_solve_p, operator, state, vector, options, solver, throw
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_primitive.py", line 271, in filter_primitive_bind
flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_primitive.py", line 156, in _wrapper
out = rule(*args)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solve.py", line 126, in _linear_solve_abstract_eval
out = eqx.filter_eval_shape(
_linear_solve_impl,
...<6 lines>...
check_closure=False,
)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solve.py", line 87, in _linear_solve_impl
out = solver.compute(state, vector, options)
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
return self.__func__(self.__self__, *args, **kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solve.py", line 648, in compute
solution, result, _ = solver.compute(state, vector, options)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
return self.__func__(self.__self__, *args, **kwargs)
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solver/lu.py", line 61, in compute
trans = 1 if transpose else 0
^^^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function _fn at /Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_eval_shape.py:31 for jit. This concrete value was not available in Python because it depends on the value of the argument _dynamic[1][1][1][2].
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Notes:
- Everything works even with jax 0.8.2 if the state vector is changed to have a single component
- Everything works if
Kvaerno5()is swapped out forDopri8()or another explicit method - Everything works, for all solvers and whether the state has a single component or multiple components, on jax 0.8.1 and before (I checked back at least to the 0.6.x series)
- The same error occurs at least for
Kvaerno3(),Kvaerno4(),KenCarp3(),KenCarp4(),KenCarp5()
Metadata
Metadata
Assignees
Labels
No labels