@@ -178,3 +178,71 @@ function jacobian(f, x::AbstractVector)
178178end
179179
180180hessian (f, x) = jacobian (x -> gradient (f, x, nest= true )[1 ], x)
181+
182+ using Functors: fmap, functor
183+ using Optimisers: _trainable, isnumeric
184+
185+ """
186+ withgradient(f, xs...)
187+
188+ This computes the value `f(xs...)` and the gradient with respect to `xs`.
189+ However, it differs from `gradient` in several other respects:
190+ * It will recurse into `xs` using `fmap`, and thus like Zygote's "explicit mode" it
191+ returns a tree-like gradient matching the shape of a Flux model.
192+ This recursion obeys restrictions imposed by `Optimisers.trainable`, if defined.
193+ * Only objects satisfying `Optimisers.isnumeric` are regarded as parameters,
194+ thus in particular integers are ignored.
195+ * Returns plain arrays, not tracked. Uses `nothing` as a strong zero gradient, like Zygote.
196+
197+ # Examples
198+ ```
199+ julia> nt = (vec = [1.0, 2.0], mat = [4.0;;], fun = sin);
200+
201+ julia> withgradient(nt, 2) do x, p
202+ sum(abs2, x.vec) ^ p
203+ end
204+ (val = 25.0, grad = ((vec = [20.0, 40.0], mat = [0.0;;], fun = nothing), nothing))
205+
206+ julia> using Flux
207+
208+ julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1, bias=false));
209+
210+ julia> withgradient(model, rand(Float32, 2)) do m, x
211+ sum(abs2, m(x))
212+ end
213+ (val = 0.035716165f0, grad = ((layers = ((weight = Float32[-0.4241869 -0.16741231], bias = Float32[-0.5529184], σ = nothing), (weight = Float32[-0.04804218;;], bias = nothing, σ = nothing)),), Float32[0.12706584, -0.08858479]))
214+ ```
215+ """
216+ function withgradient (f, xs... )
217+ pxs = fmap (param, xs; exclude = isnumeric, walk = _trainable_walk)
218+ l = f (pxs... )
219+ losscheck (l)
220+ l isa TrackedReal || return (val = l, grad = nothing )
221+ @interrupts back! (l)
222+ (val = data (l), grad = rec_grad (pxs))
223+ end
224+
225+ function _trainable_walk (f, x)
226+ func, re = functor (x)
227+ isempty (func) && return x
228+ done = map (f, _trainable (x)) # recurse only into trainable fields, this contains `nothing` elsewhere
229+ map (func, merge (func, done)) do n, t
230+ isnothing (t) ? n : t
231+ end |> re # reconstruct the whole thing
232+ end
233+ _trainable_walk (f, x:: Tuple ) = map (f, x)
234+
235+ # Easier to write the recursion to extract the gradients without using fmap:
236+ rec_grad (x:: TrackedArray ) = grad (x)
237+ rec_grad (x:: TrackedReal ) = grad (x)
238+ rec_grad (x:: AbstractArray{<:Number} ) = nothing
239+ rec_grad (x:: Number ) = nothing
240+
241+ rec_grad (x:: Union{Tuple,NamedTuple,AbstractArray} ) = map (rec_grad, x)
242+ rec_grad (:: Tuple{} ) = nothing
243+ rec_grad (:: NamedTuple{(), Tuple{}} ) = nothing
244+ function rec_grad (x:: T ) where {T}
245+ F = fieldnames (T)
246+ isempty (F) && return nothing
247+ map (f -> rec_grad (getfield (x, f)), NamedTuple {F} (F))
248+ end
0 commit comments