Skip to content

Kvaerno, KenCarp solvers fail jax tracing for a multicomponent state and jax 0.8.2 #717

@ds283

Description

@ds283

I am encountering an issue with the Kvaerno and KenCarp steppers (possibly others) when the state has more than one component. This issue seems to have appeared with jax 0.8.2.

diffrax and jax versions
diffrax 0.7.0
jax 0.8.2
jaxlib 0.8.2
optimistix 0.0.11
equinox 0.13.2

Platform
M1 and Darwin 25.1.0

Minimum (non-)working example

from diffrax import diffeqsolve, ODETerm, SaveAt, PIDController, Kvaerno5

vector_field = lambda t, y, args: (-5 * y[0], 0.5 * y[1])
term = ODETerm(vector_field)
solver = Kvaerno5()
saveat = SaveAt(dense=True)
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)

sol = diffeqsolve(
    term,
    solver,
    t0=0,
    t1=3,
    dt0=0.1,
    y0=(1, 1),
    saveat=saveat,
    stepsize_controller=stepsize_controller,
)

Expected behaviour
Solver runs to completion and solution is evaluable.

Observed behaviour
JAX JIT tracing fails with traceback

$ python /Users/ds283/Documents/Code/diffrax-issue/test-diffrax.py 
Traceback (most recent call last):
  File "/Users/ds283/Documents/Code/diffrax-issue/test-diffrax.py", line 9, in <module>
    sol = diffeqsolve(
        term,
    ...<6 lines>...
        stepsize_controller=stepsize_controller,
    )
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_jit.py", line 209, in __call__
    return _call(self, False, args, kwargs)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_jit.py", line 263, in _call
    marker, _, _ = out = jit_wrapper._cached(
                         ~~~~~~~~~~~~~~~~~~~^
        dynamic_donate, dynamic_nodonate, static
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/diffrax/_integrate.py", line 1281, in diffeqsolve
    _, _, dense_info_struct, _, _ = eqx.filter_eval_shape(
                                    ~~~~~~~~~~~~~~~~~~~~~^
        solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/diffrax/_solver/runge_kutta.py", line 1149, in step
    final_val = eqxi.while_loop(
        cond_stage,
    ...<6 lines>...
        base=num_stages,
    )
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/loop.py", line 107, in while_loop
    return checkpointed_while_loop(
        cond_fun,
    ...<4 lines>...
        checkpoints=checkpoints,
    )
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/checkpointed.py", line 247, in checkpointed_while_loop
    body_fun_ = filter_closure_convert(body_fun_, init_val_)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/common.py", line 511, in new_body_fun
    buffer_val2 = body_fun(buffer_val)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/diffrax/_solver/runge_kutta.py", line 984, in rk_stage
    nonlinear_sol = optx.root_find(
        _implicit_relation_f,
    ...<5 lines>...
        max_steps=self.root_find_max_steps,  # pyright: ignore
    )
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_root_find.py", line 218, in root_find
    return iterative_solve(
        fn,
    ...<10 lines>...
        rewrite_fn=_rewrite_fn,
    )
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_iterate.py", line 344, in iterative_solve
    ) = adjoint.apply(_iterate, rewrite_fn, inputs, tags)
        ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_adjoint.py", line 133, in apply
    return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_ad.py", line 60, in implicit_jvp
    root, residual = _implicit_impl(fn_primal, fn_rewrite, inputs, tags, linear_solver)
                     ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_ad.py", line 67, in _implicit_impl
    return jtu.tree_map(jnp.asarray, fn_primal(inputs))
                                     ~~~~~~~~~^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_iterate.py", line 240, in _iterate
    final_carry = while_loop(cond_fun, body_fun, init_carry, max_steps=max_steps)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/loop.py", line 103, in while_loop
    _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
                         ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_loop/common.py", line 511, in new_body_fun
    buffer_val2 = body_fun(buffer_val)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/optimistix/_iterate.py", line 230, in body_fun
    new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
                            ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/diffrax/_root_finder/_verychord.py", line 127, in step
    sol = lx.linear_solve(
        jac, fx, self.linear_solver, state=linear_state, throw=False
    )
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solve.py", line 820, in linear_solve
    solution, result, stats = eqxi.filter_primitive_bind(
                              ~~~~~~~~~~~~~~~~~~~~~~~~~~^
        linear_solve_p, operator, state, vector, options, solver, throw
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_primitive.py", line 271, in filter_primitive_bind
    flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/internal/_primitive.py", line 156, in _wrapper
    out = rule(*args)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solve.py", line 126, in _linear_solve_abstract_eval
    out = eqx.filter_eval_shape(
        _linear_solve_impl,
    ...<6 lines>...
        check_closure=False,
    )
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solve.py", line 87, in _linear_solve_impl
    out = solver.compute(state, vector, options)
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solve.py", line 648, in compute
    solution, result, _ = solver.compute(state, vector, options)
                          ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_module/_prebuilt.py", line 33, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/lineax/_solver/lu.py", line 61, in compute
    trans = 1 if transpose else 0
                 ^^^^^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function _fn at /Users/ds283/Documents/Code/diffrax-issue/.venv/lib/python3.13/site-packages/equinox/_eval_shape.py:31 for jit. This concrete value was not available in Python because it depends on the value of the argument _dynamic[1][1][1][2].
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Notes:

  • Everything works even with jax 0.8.2 if the state vector is changed to have a single component
  • Everything works if Kvaerno5() is swapped out for Dopri8() or another explicit method
  • Everything works, for all solvers and whether the state has a single component or multiple components, on jax 0.8.1 and before (I checked back at least to the 0.6.x series)
  • The same error occurs at least for Kvaerno3(), Kvaerno4(), KenCarp3(), KenCarp4(), KenCarp5()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions