-
Notifications
You must be signed in to change notification settings - Fork 155
Open
Description
In Zygote.jl, we can take the gradient with respect to all fields of a struct foo passed through a function bar via
g = Zygote.gradient(f -> bar(f), foo)Can this be done in ForwardDiff as well?
Reproducer:
using Zygote
using ForwardDiff
struct Foo
x::Number
t::Number
c::Number
end
function bar(f::Foo)
return f.x - f.c*f.t
end
foo = Foo(2, 3, 3e8)
println(foo)
g = Zygote.gradient(f -> bar(f), foo)
println(g)
g = ForwardDiff.gradient(f -> bar(f), foo)
println(g)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels