-
-
Notifications
You must be signed in to change notification settings - Fork 167
Closed
Labels
questionUser queriesUser queries
Description
I'm looking to integrate a sympy expression, but ODE term isn't accepting my rhs_func. Not sure if this should go in diffrax or sympy2jax.
Here's a MWE of a simple harmonic osciallator:
import sympy
import jax
import jax.numpy as jnp
import diffrax
import sympy2jax
x = sympy.symbols("x")
y = sympy.symbols("y")
f1 = 1 * y - .1 * x
f2 = -1 * x - .1 * y
mod = sympy2jax.SymbolicModule([f1, f2])
y0 = .5 * jax.numpy.ones(2)
symbols = [x, y]
# Verify call signature is correct
out = mod(**{str(symb): val for symb, val in zip(symbols, y0)})
def rhs_func(t, y, args):
return mod(**{str(symb): val for symb, val in zip(symbols, y)})
# Verify call signature is correct
rhs_func(0, y0, None)
term = diffrax.ODETerm(rhs_func)
solver = diffrax.Tsit5()
save_at = diffrax.SaveAt(dense=True)
y0_jax = jnp.asarray([1])
sol = diffrax.diffeqsolve(
term,
solver,
t0=0,
t1=1,
dt0=.1, # Initial step size
y0=y0_jax,
args=(),
saveat=save_at,
max_steps = int(100)
)
ValueError: Terms are not compatible with solver!
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:245, in _Func.__call__(self, memodict)
244 try:
--> 245 arg_call = memodict[arg]
246 except KeyError:
KeyError: _Symbol(_name='y')
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:137, in _Symbol.__call__(self, memodict)
136 try:
--> 137 return memodict[self._name]
138 except KeyError as e:
KeyError: 'y'
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:170, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
169 try:
--> 170 vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
171 except Exception as e:
File ~/github/gen-experiments/env/lib/python3.12/site-packages/equinox/_eval_shape.py:38, in filter_eval_shape(fun, *args, **kwargs)
37 dynamic, static = partition((fun, args, kwargs), _filter)
---> 38 dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
39 return combine(dynamic_out, static_out.value)
[... skipping hidden 1 frame]
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/api.py:3012, in eval_shape(fun, *args, **kwargs)
3011 except TypeError: fun = partial(fun)
-> 3012 return jit(fun).trace(*args, **kwargs).out_info
[... skipping hidden 1 frame]
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:303, in jit_trace(jit_func, *args, **kwargs)
301 @api_boundary
302 def jit_trace(jit_func, *args, **kwargs) -> stages.Traced:
--> 303 p, args_flat = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
304 arg_types = map(convert_to_metaty, args_flat)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:616, in _infer_params(fun, ji, args, kwargs)
615 else:
--> 616 return _infer_params_internal(fun, ji, args, kwargs)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:636, in _infer_params_internal(fun, ji, args, kwargs)
635 dbg = dbg_fn()
--> 636 p, args_flat = _infer_params_impl(
637 fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
638 if p.params['jaxpr'].jaxpr.is_high:
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:535, in _infer_params_impl(***failed resolving arguments***)
533 qdd_token = _qdd_cache_index(flat_fun, in_type)
--> 535 jaxpr, consts, out_avals = _create_pjit_jaxpr(
536 flat_fun, in_type, qdd_token, IgnoreKey(ji.inline))
538 if config.mutable_array_checks.value:
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/linear_util.py:460, in cache.<locals>.memoized_fun(fun, *args)
459 start = time.time()
--> 460 ans = call(fun, *args)
461 if do_explain:
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/pjit.py:1118, in _create_pjit_jaxpr(***failed resolving arguments***)
1115 with dispatch.log_elapsed_time(
1116 "Finished tracing + transforming {fun_name} for pjit in {elapsed_time:.9f} sec",
1117 fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
-> 1118 jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_type)
1120 if config.debug_key_reuse.value:
1121 # Import here to avoid circular imports
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/profiler.py:359, in annotate_function.<locals>.wrapper(*args, **kwargs)
358 with TraceAnnotation(name, **decorator_kwargs):
--> 359 return func(*args, **kwargs)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2348, in trace_to_jaxpr_dynamic(fun, in_avals, keep_inputs, lower, auto_dce)
2347 with core.set_current_trace(trace):
-> 2348 ans = fun.call_wrapped(*in_tracers)
2349 _check_returned_jaxtypes(fun.debug_info, ans)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/linear_util.py:212, in WrappedFun.call_wrapped(self, *args, **kwargs)
211 """Calls the transformed function"""
--> 212 return self.f_transformed(*args, **kwargs)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/api_util.py:83, in flatten_fun3(f, store, in_tree, *args_flat)
82 py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 83 ans = f(*py_args, **py_kwargs)
84 paths_and_ans, out_tree = tree_flatten_with_path(ans)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/equinox/_eval_shape.py:33, in filter_eval_shape.<locals>._fn(_static, _dynamic)
32 _fun, _args, _kwargs = combine(_static, _dynamic)
---> 33 _out = _fun(*_args, **_kwargs)
34 _dynamic_out, _static_out = partition(_out, _filter)
[... skipping hidden 1 frame]
File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_term.py:194, in ODETerm.vf(self, t, y, args)
193 def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF:
--> 194 out = self.vector_field(t, y, args)
195 if jtu.tree_structure(out) != jtu.tree_structure(y):
Cell In[96], line 2, in rhs_func(t, y, args)
1 def rhs_func(t, y, args):
----> 2 return jnp.stack(mod(**{str(symb): val for symb, val in zip(symbols, y)}))
File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:326, in SymbolicModule.__call__(self, **symbols)
325 memodict = symbols
--> 326 return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in tree_map(f, tree, is_leaf, *rest)
361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in <genexpr>(.0)
361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:326, in SymbolicModule.__call__.<locals>.<lambda>(n)
325 memodict = symbols
--> 326 return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:247, in _Func.__call__(self, memodict)
246 except KeyError:
--> 247 arg_call = arg(memodict)
248 memodict[arg] = arg_call
File ~/github/gen-experiments/env/lib/python3.12/site-packages/sympy2jax/sympy_module.py:139, in _Symbol.__call__(self, memodict)
138 except KeyError as e:
--> 139 raise KeyError(f"Missing input for symbol {self._name}") from e
KeyError: 'Missing input for symbol y'
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:200, in _assert_term_compatible(t, y, args, terms, term_structure, contr_kwargs)
199 with jax.numpy_dtype_promotion("standard"):
--> 200 jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
201 except ValueError as e:
202 # ValueError may also arise from mismatched tree structures
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in tree_map(f, tree, is_leaf, *rest)
361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/github/gen-experiments/env/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in <genexpr>(.0)
361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:172, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
171 except Exception as e:
--> 172 raise ValueError(f"Error while tracing {term}.vf: " + str(e))
173 vf_type_compatible = eqx.filter_eval_shape(
174 better_isinstance, vf_type, vf_type_expected
175 )
ValueError: Error while tracing ODETerm(vector_field=<function rhs_func>).vf: 'Missing input for symbol y'
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[96], line 11
8 save_at = diffrax.SaveAt(dense=True)
9 y0_jax = jnp.asarray([1])
---> 11 sol = diffrax.diffeqsolve(
12 term,
13 solver,
14 t0=0,
15 t1=1,
16 dt0=.1, # Initial step size
17 y0=y0_jax,
18 args=(),
19 saveat=save_at,
20 max_steps = int(100)
21 )
[... skipping hidden 17 frame]
File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:1103, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
1100 terms = MultiTerm(*terms)
1102 # Error checking for term compatibility
-> 1103 _assert_term_compatible(
1104 t0,
1105 y0,
1106 args,
1107 terms,
1108 solver.term_structure,
1109 solver.term_compatible_contr_kwargs,
1110 )
1112 if is_sde(terms):
1113 if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
File ~/github/gen-experiments/env/lib/python3.12/site-packages/diffrax/_integrate.py:205, in _assert_term_compatible(t, y, args, terms, term_structure, contr_kwargs)
203 pretty_term = wl.pformat(terms)
204 pretty_expected = wl.pformat(term_structure)
--> 205 raise ValueError(
206 f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
207 f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
208 "scroll up you may find a root-cause error that is more specific."
209 ) from e
ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<function rhs_func>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.
I read #248, but I'm not sure where my MWE differs, other than unpacking the state in a comprehension. Is that the problem?
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries