Skip to content

Commit 2391791

Browse files
committed
sparsejacobian, jacobian_sparsity
1 parent 000c2da commit 2391791

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

src/direct.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,55 @@ function jacobian(ops::AbstractVector{<:Expression}, vars::AbstractVector{<:Expr
2222
[expand_derivatives(Differential(v)(O),simplify) for O in ops, v in vars]
2323
end
2424

25+
"""
26+
```julia
27+
sparsejacobian(ops::AbstractVector{<:Expression}, vars::AbstractVector{<:Expression}; simplify = true)
28+
```
29+
30+
A helper function for computing the sparse Jacobian of an array of expressions with respect to
31+
an array of variable expressions.
32+
"""
33+
function sparsejacobian(ops::AbstractVector{<:Expression}, vars::AbstractVector{<:Expression}; simplify = true)
34+
I = Int[]
35+
J = Int[]
36+
du = Expression[]
37+
38+
sp = jacobian_sparsity(ops, vars)
39+
I,J,_ = findnz(sp)
40+
41+
exprs = Expression[]
42+
43+
for (i,j) in zip(I, J)
44+
push!(exprs, expand_derivatives(Differential(vars[j])(ops[i])))
45+
end
46+
sparse(I,J, exprs, length(ops), length(vars))
47+
end
48+
49+
using SymbolicUtils: @rule, Rewriters
50+
51+
function jacobian_sparsity(du, u)
52+
dict = Dict(zip(to_symbolic.(u), 1:length(u)))
53+
54+
i = Ref(1)
55+
I = Int[]
56+
J = Int[]
57+
58+
# This rewriter notes down which u's appear in a
59+
# given du (whose index is stored in the `i` Ref)
60+
r = [@rule ~x::(x->haskey(dict, x)) => begin
61+
push!(I, i[])
62+
push!(J, dict[~x])
63+
nothing
64+
end] |> Rewriters.Chain |> Rewriters.Postwalk
65+
66+
for ii = 1:length(du)
67+
i[] = ii
68+
r(to_symbolic(du[ii]))
69+
end
70+
71+
sparse(I, J, true, length(du), length(u))
72+
end
73+
2574
"""
2675
```julia
2776
hessian(O::Expression, vars::AbstractVector{<:Expression}; simplify = true)

0 commit comments

Comments
 (0)