Skip to content

Using sympy2jax in an ODETerm #719

@Jacob-Stevens-Haas

Description

@Jacob-Stevens-Haas

I'm looking to integrate a sympy expression, but ODE term isn't accepting my rhs_func. Not sure if this should go in diffrax or sympy2jax.

Here's a MWE of a simple harmonic osciallator:

    import sympy
    import jax
    import jax.numpy as jnp
    import diffrax
    import sympy2jax

    x = sympy.symbols("x")
    y = sympy.symbols("y")
    f1 = 1 * y - .1 * x
    f2 = -1 * x - .1 * y
    mod = sympy2jax.SymbolicModule([f1, f2])

    y0 = .5 * jax.numpy.ones(2)
    symbols = [x, y]
    # Verify call signature is correct
    out = mod(**{str(symb): val for symb, val in zip(symbols, y0)})

    def rhs_func(t, y, args):
        return mod(**{str(symb): val for symb, val in zip(symbols, y)})

    # Verify call signature is correct
    rhs_func(0, y0, None)

    term = diffrax.ODETerm(rhs_func)
    solver = diffrax.Tsit5()
    save_at = diffrax.SaveAt(dense=True)
    y0_jax = jnp.asarray([1])

    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=0,
        t1=1,
        dt0=.1,  # Initial step size
        y0=y0_jax,
        args=(),
        saveat=save_at,
        max_steps = int(100)
    )
ValueError: Terms are not compatible with solver!
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:245, in _Func.__call__(self, memodict)
    244 try:
--> 245     arg_call = memodict[arg]
    246 except KeyError:

KeyError: _Symbol(_name='y')

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:137, in _Symbol.__call__(self, memodict)
    136 try:
--> 137     return memodict[self._name]
    138 except KeyError as e:

KeyError: 'y'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:170, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
    169 try:
--> 170     vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
    171 except Exception as e:

File ~/github/gen-experiments/env/lib/python3.12/site-packages/equinox/_eval_shape.py:38, in filter_eval_shape(fun, *args, **kwargs)
     37 dynamic, static = partition((fun, args, kwargs), _filter)
---> 38 dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
     39 return combine(dynamic_out, static_out.value)

    [... skipping hidden 1 frame]

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/api.py:3012, in eval_shape(fun, *args, **kwargs)
   3011 except TypeError: fun = partial(fun)
-> 3012 return jit(fun).trace(*args, **kwargs).out_info

    [... skipping hidden 1 frame]

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:303, in jit_trace(jit_func, *args, **kwargs)
    301 @api_boundary
    302 def jit_trace(jit_func, *args, **kwargs) -> stages.Traced:
--> 303   p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
    304   arg_types = map(convert_to_metaty, args_flat)

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:616, in _infer_params(fun, ji, args, kwargs)
    615 else:
--> 616   return _infer_params_internal(fun, ji, args, kwargs)

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:636, in _infer_params_internal(fun, ji, args, kwargs)
    635 dbg = dbg_fn()
--> 636 p, args_flat = _infer_params_impl(
    637     fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
    638 if p.params['jaxpr'].jaxpr.is_high:

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:535, in _infer_params_impl(***failed resolving arguments***)
    533 qdd_token = _qdd_cache_index(flat_fun, in_type)
--> 535 jaxpr, consts, out_avals = _create_pjit_jaxpr(
    536     flat_fun, in_type, qdd_token, IgnoreKey(ji.inline))
    538 if config.mutable_array_checks.value:

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/linear_util.py:460, in cache.<locals>.memoized_fun(fun, *args)
    459   start = time.time()
--> 460 ans = call(fun, *args)
    461 if do_explain:

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:1118, in _create_pjit_jaxpr(***failed resolving arguments***)
   1115 with dispatch.log_elapsed_time(
   1116     "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
   1117     fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
-> 1118   jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
   1120 if config.debug_key_reuse.value:
   1121   # Import here to avoid circular imports

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/profiler.py:359, in annotate_function.<locals>.wrapper(*args, **kwargs)
    358 with TraceAnnotation(name, **decorator_kwargs):
--> 359   return func(*args, **kwargs)

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2348, in trace_to_jaxpr_dynamic(fun, in_avals, keep_inputs, lower, auto_dce)
   2347 with core.set_current_trace(trace):
-> 2348   ans = fun.call_wrapped(*in_tracers)
   2349 _check_returned_jaxtypes(fun.debug_info, ans)

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/linear_util.py:212, in WrappedFun.call_wrapped(self, *args, **kwargs)
    211 """Calls the transformed function"""
--> 212 return self.f_transformed(*args, **kwargs)

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/api_util.py:83, in flatten_fun3(f, store, in_tree, *args_flat)
     82 py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 83 ans = f(*py_args, **py_kwargs)
     84 paths_and_ans, out_tree = tree_flatten_with_path(ans)

File ~/github/gen-experiments/env/lib/python3.12/site-packages/equinox/_eval_shape.py:33, in filter_eval_shape.<locals>._fn(_static, _dynamic)
     32 _fun, _args, _kwargs = combine(_static, _dynamic)
---> 33 _out = _fun(*_args, **_kwargs)
     34 _dynamic_out, _static_out = partition(_out, _filter)

    [... skipping hidden 1 frame]

File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_term.py:194, in ODETerm.vf(self, t, y, args)
    193 def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF:
--> 194     out = self.vector_field(t, y, args)
    195     if jtu.tree_structure(out) != jtu.tree_structure(y):

Cell In[96], line 2, in rhs_func(t, y, args)
      1 def rhs_func(t, y, args):
----> 2     return jnp.stack(mod(**{str(symb): val for symb, val in zip(symbols, y)}))

File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:326, in SymbolicModule.__call__(self, **symbols)
    325 memodict = symbols
--> 326 return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in tree_map(f, tree, is_leaf, *rest)
    361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in <genexpr>(.0)
    361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:326, in SymbolicModule.__call__.<locals>.<lambda>(n)
    325 memodict = symbols
--> 326 return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)

File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:247, in _Func.__call__(self, memodict)
    246 except KeyError:
--> 247     arg_call = arg(memodict)
    248     memodict[arg] = arg_call

File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:139, in _Symbol.__call__(self, memodict)
    138 except KeyError as e:
--> 139     raise KeyError(f"Missing input for symbol {self._name}") from e

KeyError: 'Missing input for symbol y'

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:200, in _assert_term_compatible(t, y, args, terms, term_structure, contr_kwargs)
    199     with jax.numpy_dtype_promotion("standard"):
--> 200         jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
    201 except ValueError as e:
    202     # ValueError may also arise from mismatched tree structures

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in tree_map(f, tree, is_leaf, *rest)
    361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in <genexpr>(.0)
    361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:172, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
    171 except Exception as e:
--> 172     raise ValueError(f"Error while tracing {term}.vf: " + str(e))
    173 vf_type_compatible = eqx.filter_eval_shape(
    174     better_isinstance, vf_type, vf_type_expected
    175 )

ValueError: Error while tracing ODETerm(vector_field=<function rhs_func>).vf: 'Missing input for symbol y'

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[96], line 11
      8 save_at = diffrax.SaveAt(dense=True)
      9 y0_jax = jnp.asarray([1])
---> 11 sol = diffrax.diffeqsolve(
     12     term,
     13     solver,
     14     t0=0,
     15     t1=1,
     16     dt0=.1,  # Initial step size
     17     y0=y0_jax,
     18     args=(),
     19     saveat=save_at,
     20     max_steps = int(100)
     21 )

    [... skipping hidden 17 frame]

File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:1103, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
   1100         terms = MultiTerm(*terms)
   1102 # Error checking for term compatibility
-> 1103 _assert_term_compatible(
   1104     t0,
   1105     y0,
   1106     args,
   1107     terms,
   1108     solver.term_structure,
   1109     solver.term_compatible_contr_kwargs,
   1110 )
   1112 if is_sde(terms):
   1113     if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):

File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:205, in _assert_term_compatible(t, y, args, terms, term_structure, contr_kwargs)
    203 pretty_term = wl.pformat(terms)
    204 pretty_expected = wl.pformat(term_structure)
--> 205 raise ValueError(
    206     f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
    207     f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
    208     "scroll up you may find a root-cause error that is more specific."
    209 ) from e

ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<function rhs_func>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.

I read #248, but I'm not sure where my MWE differs, other than unpacking the state in a comprehension. Is that the problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions