@@ -6,12 +6,11 @@ using Base: RefValue
66
77isintermediate (eq:: Equation ) = ! (isa (eq. lhs, Operation) && isa (eq. lhs. op, Differential))
88
9- function _unwrap_differenital (O)
10- isa (O, Operation) || return (O, nothing , 0 )
11- isa (O. op, Differential) || return (O, nothing , 0 )
12- (x, t, order) = _unwrap_differenital (O. args[1 ])
13- t === nothing && (t = O. op. x)
14- t == O. op. x || throw (ArgumentError (" non-matching differentials on lhs" ))
9+ function flatten_differential (O:: Operation )
10+ @assert is_derivative (O) " invalid differential: $O "
11+ is_derivative (O. args[1 ]) || return (O. args[1 ], O. op. x, 1 )
12+ (x, t, order) = flatten_differential (O. args[1 ])
13+ t == O. op. x || throw (ArgumentError (" non-matching differentials on lhs: $t , $(O. op. x) " ))
1514 return (x, t, order + 1 )
1615end
1716
@@ -24,7 +23,7 @@ struct DiffEq # dⁿx/dtⁿ = rhs
2423end
2524function Base. convert (:: Type{DiffEq} , eq:: Equation )
2625 isintermediate (eq) && throw (ArgumentError (" intermediate equation received" ))
27- (x, t, n) = _unwrap_differenital (eq. lhs)
26+ (x, t, n) = flatten_differential (eq. lhs)
2827 return DiffEq (x, t, n, eq. rhs)
2928end
3029Base.:(== )(a:: DiffEq , b:: DiffEq ) = (a. x, a. t, a. n, a. rhs) == (b. x, b. t, b. n, b. rhs)
0 commit comments