@@ -177,47 +177,39 @@ Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1),
177177A = rand (n, n);
178178dA = zeros (n, n);
179179b1 = rand (n);
180- for alg in (
180+
181+ function fnice (A, b, alg)
182+ prob = LinearProblem (A, b)
183+ sol1 = solve (prob, alg)
184+ return sum (sol1. u)
185+ end
186+
187+ @testset for alg in (
181188 LUFactorization (),
182189 RFLUFactorization () # KrylovJL_GMRES(), fails
183190)
184- @show alg
185- function fb (b)
186- prob = LinearProblem (A, b)
187-
188- sol1 = solve (prob, alg)
191+ fb_closure = b -> fnice (A, b, alg)
189192
190- sum (sol1. u)
191- end
192- fb (b1)
193-
194- fd_jac = FiniteDiff. finite_difference_jacobian (fb, b1) |> vec
193+ fd_jac = FiniteDiff. finite_difference_jacobian (fb_closure, b1) |> vec
195194 @show fd_jac
196195
197196 en_jac = map (onehot (b1)) do db1
198- eres = Enzyme. autodiff (Forward, fb, Duplicated ( copy (b1 ), db1))
199- eres[ 1 ]
197+ return only ( Enzyme. autodiff (set_runtime_activity (Forward ), fnice,
198+ Const (A), Duplicated (b1, db1), Const (alg)))
200199 end |> collect
201200 @show en_jac
202201
203202 @test en_jac≈ fd_jac rtol= 1e-4
204203
205- function fA (A)
206- prob = LinearProblem (A, b1)
207-
208- sol1 = solve (prob, alg)
204+ fA_closure = A -> fnice (A, b1, alg)
209205
210- sum (sol1. u)
211- end
212- fA (A)
213-
214- fd_jac = FiniteDiff. finite_difference_jacobian (fA, A) |> vec
206+ fd_jac = FiniteDiff. finite_difference_jacobian (fA_closure, A) |> vec
215207 @show fd_jac
216208
217209 en_jac = map (onehot (A)) do dA
218- eres = Enzyme. autodiff (Forward, fA, Duplicated ( copy (A ), dA))
219- eres[ 1 ]
220- end |> collect
210+ return only ( Enzyme. autodiff (set_runtime_activity (Forward ), fnice,
211+ Duplicated (A, dA), Const (b1), Const (alg)))
212+ end |> collect |> (x -> reshape (x, n, n))
221213 @show en_jac
222214
223215 @test en_jac≈ fd_jac rtol= 1e-4
0 commit comments