@@ -92,28 +92,10 @@ function ChainRulesCore.rrule(::typeof(get_identity_vecs), M::Int)
9292 get_identity_vecs (M), _ -> (NoTangent (), NoTangent ())
9393end
9494reduce_hcat (vs) = reduce (hcat, vs)
95- # function ChainRulesCore.rrule(::typeof(reduce_hcat), vs::Vector{<:Vector})
96- # return reduce_hcat(vs), Δ -> begin
97- # return NoTangent(), [Δ[:, i] for i in 1:size(Δ, 2)]
98- # end
99- # end
10095
10196const fdm = FiniteDifferences. central_fdm (5 , 1 )
10297
10398function _jacobian (f, x1, x2)
104- # val, pb = Zygote.pullback(f, x1, x2)
105- # if val isa Vector
106- # M = length(val)
107- # vecs = get_identity_vecs(M)
108- # cotangents = map(pb, vecs)
109- # Jt = reduce_hcat(map(last, cotangents))
110- # return copy(Jt')
111- # elseif val isa Real
112- # Jt = last(pb(1.0))
113- # return copy(Jt')
114- # else
115- # throw(ArgumentError("Output type not supported."))
116- # end
11799 ẏs = map (eachindex (x2)) do n
118100 return fdm (zero (eltype (x2))) do ε
119101 xn = x2[n]
@@ -124,49 +106,6 @@ function _jacobian(f, x1, x2)
124106 end
125107 return reduce (hcat, ẏs)
126108end
127- # function ChainRulesCore.rrule(::typeof(_jacobian), f, x1, x2)
128- # (val, pb), _pb_pb = Zygote.pullback(Zygote.pullback, f, x1, x2)
129- # M = length(val)
130- # if val isa Vector
131- # vecs = get_identity_vecs(M)
132- # _pb = (pb, v) -> last(pb(v))
133- # co1, pb_pb = Zygote.pullback(_pb, pb, first(vecs))
134- # cotangents = vcat([co1], last.(map(pb, @view(vecs[2:end]))))
135- # Jt, hcat_pb = Zygote.pullback(reduce_hcat, cotangents)
136- # return copy(Jt'), Δ -> begin
137- # temp = hcat_pb(Δ')[1]
138- # co_pb = map(temp) do t
139- # first(pb_pb(t))
140- # end
141- # co_f_x = _pb_pb.(tuple.(Ref(nothing), co_pb))
142- # co_f = sum(getindex.(co_f_x, 1))
143- # co_x1 = sum(getindex.(co_f_x, 2))
144- # co_x2 = sum(getindex.(co_f_x, 3))
145- # return NoTangent(), co_f, co_x1, co_x2
146- # end
147- # elseif val isa Real
148- # println(1)
149- # _pb = (pb, v) -> pb(v)[end]
150- # println(2)
151- # @show _pb(pb, 1.0)
152- # Jt, pb_pb = Zygote.pullback(_pb, pb, 1.0)
153- # println(3)
154- # return copy(Jt'), Δ -> begin
155- # println(4)
156- # @show vec(Δ)
157- # @show Δ
158- # @show pb_pb(Δ')
159- # co_pb = first(pb_pb(vec(Δ)))
160- # co_f_x = _pb_pb((nothing, co_pb))
161- # co_f = co_f_x[1]
162- # co_x1 = co_f_x[2]
163- # co_x2 = co_f_x[3]
164- # return NoTangent(), co_f, co_x1, co_x2
165- # end
166- # else
167- # throw(ArgumentError("Output type not supported."))
168- # end
169- # end
170109
171110function (f:: RandomFunction )(x)
172111 mup = mean (f. p)
0 commit comments