@@ -2,6 +2,7 @@ module ReliabilityOptimization
22
33using ImplicitDifferentiation, Zygote, LinearAlgebra, ChainRulesCore, SparseArrays
44using UnPack, NonconvexIpopt, Statistics, Distributions, Reexport, DistributionsAD
5+ using FiniteDifferences, StaticArraysCore
56@reexport using LinearAlgebra
67@reexport using Statistics
78export RandomFunction, FORM, RIA, MvNormal
1516struct FORM{M}
1617 method:: M
1718end
18- struct RIA end
19+ struct RIA{A, O}
20+ optim_alg:: A
21+ optim_options:: O
22+ end
23+ RIA () = RIA (IpoptAlg (), IpoptOptions (print_level = 0 , max_wall_time = 1.0 ))
1924
20- function get_forward (f, p, :: FORM{<:RIA} )
25+ function get_forward (f, p, method:: FORM{<:RIA} )
26+ alg, options = method. method. optim_alg, method. method. optim_options
2127 function forward (x)
2228 # gets an objective function of p
2329 obj = pc -> begin
@@ -39,9 +45,9 @@ function get_forward(f, p, ::FORM{<:RIA})
3945 add_eq_constraint! (innerOptModel, constr)
4046 result = optimize (
4147 innerOptModel,
42- IpoptAlg () ,
48+ alg ,
4349 [mean (p); 0.0 ],
44- options = IpoptOptions (print_level = 0 ) ,
50+ options = options ,
4551 )
4652 return vcat (result. minimizer, result. problem. mult_g[1 ])
4753 end
@@ -54,7 +60,7 @@ function get_conditions(f, ::FORM{<:RIA})
5460 c = pcmult[end - 1 ]
5561 mult = pcmult[end ]
5662 return vcat (
57- 2 * p + Zygote . pullback (p -> f (x, p) , p)[ 2 ](mult)[ 1 ] ,
63+ 2 * p + vec ( _jacobian (f, x , p)) * mult ,
5864 2 c - mult,
5965 f (x, p) .- c,
6066 )
7985_vec (x:: Real ) = [x]
8086_vec (x) = x
8187
82- function _jacobian (f, x)
83- val, pb = Zygote. pullback (f, x)
84- M = length (val)
85- vecs = [Vector (sparsevec ([i], [true ], M)) for i in 1 : M]
86- Jt = reduce (hcat, first .(pb .(vecs)))
87- return copy (Jt' )
88+ function get_identity_vecs (M)
89+ return [Vector (sparsevec ([i], [1.0 ], M)) for i in 1 : M]
90+ end
91+ function ChainRulesCore. rrule (:: typeof (get_identity_vecs), M:: Int )
92+ get_identity_vecs (M), _ -> (NoTangent (), NoTangent ())
93+ end
94+ reduce_hcat (vs) = reduce (hcat, vs)
95+ # function ChainRulesCore.rrule(::typeof(reduce_hcat), vs::Vector{<:Vector})
96+ # return reduce_hcat(vs), Δ -> begin
97+ # return NoTangent(), [Δ[:, i] for i in 1:size(Δ, 2)]
98+ # end
99+ # end
100+
101+ const fdm = FiniteDifferences. central_fdm (5 , 1 )
102+
103+ function _jacobian (f, x1, x2)
104+ # val, pb = Zygote.pullback(f, x1, x2)
105+ # if val isa Vector
106+ # M = length(val)
107+ # vecs = get_identity_vecs(M)
108+ # cotangents = map(pb, vecs)
109+ # Jt = reduce_hcat(map(last, cotangents))
110+ # return copy(Jt')
111+ # elseif val isa Real
112+ # Jt = last(pb(1.0))
113+ # return copy(Jt')
114+ # else
115+ # throw(ArgumentError("Output type not supported."))
116+ # end
117+ ẏs = map (eachindex (x2)) do n
118+ return fdm (zero (eltype (x2))) do ε
119+ xn = x2[n]
120+ xcopy = vcat (x2[1 : n- 1 ], xn + ε, x2[n+ 1 : end ])
121+ ret = copy (f (x1, xcopy)) # copy required incase `f(x)` returns something that aliases `x`
122+ return ret
123+ end
124+ end
125+ return reduce (hcat, ẏs)
88126end
127+ # function ChainRulesCore.rrule(::typeof(_jacobian), f, x1, x2)
128+ # (val, pb), _pb_pb = Zygote.pullback(Zygote.pullback, f, x1, x2)
129+ # M = length(val)
130+ # if val isa Vector
131+ # vecs = get_identity_vecs(M)
132+ # _pb = (pb, v) -> last(pb(v))
133+ # co1, pb_pb = Zygote.pullback(_pb, pb, first(vecs))
134+ # cotangents = vcat([co1], last.(map(pb, @view(vecs[2:end]))))
135+ # Jt, hcat_pb = Zygote.pullback(reduce_hcat, cotangents)
136+ # return copy(Jt'), Δ -> begin
137+ # temp = hcat_pb(Δ')[1]
138+ # co_pb = map(temp) do t
139+ # first(pb_pb(t))
140+ # end
141+ # co_f_x = _pb_pb.(tuple.(Ref(nothing), co_pb))
142+ # co_f = sum(getindex.(co_f_x, 1))
143+ # co_x1 = sum(getindex.(co_f_x, 2))
144+ # co_x2 = sum(getindex.(co_f_x, 3))
145+ # return NoTangent(), co_f, co_x1, co_x2
146+ # end
147+ # elseif val isa Real
148+ # println(1)
149+ # _pb = (pb, v) -> pb(v)[end]
150+ # println(2)
151+ # @show _pb(pb, 1.0)
152+ # Jt, pb_pb = Zygote.pullback(_pb, pb, 1.0)
153+ # println(3)
154+ # return copy(Jt'), Δ -> begin
155+ # println(4)
156+ # @show vec(Δ)
157+ # @show Δ
158+ # @show pb_pb(Δ')
159+ # co_pb = first(pb_pb(vec(Δ)))
160+ # co_f_x = _pb_pb((nothing, co_pb))
161+ # co_f = co_f_x[1]
162+ # co_x1 = co_f_x[2]
163+ # co_x2 = co_f_x[3]
164+ # return NoTangent(), co_f, co_x1, co_x2
165+ # end
166+ # else
167+ # throw(ArgumentError("Output type not supported."))
168+ # end
169+ # end
89170
90171function (f:: RandomFunction )(x)
91172 mup = mean (f. p)
92173 covp = cov (f. p)
93174 p0 = getp0 (f. f, x, f. p, f. method)
94- dfdp0 = _jacobian (p -> f. f (x, p) , p0)
175+ dfdp0 = _jacobian (f. f, x , p0)
95176 fp0 = f. f (x, p0)
96- muf = _vec (fp0) + dfdp0 * (mup - p0)
177+ muf = _vec (fp0) . + dfdp0 * (mup - p0)
97178 covf = dfdp0 * covp * dfdp0'
98179 return MvNormal (muf, covf)
99180end
100181
182+ # necessary type piracy FiniteDifferences._estimate_magnitudes uses this constructor which Zygote struggles to differentiate on its own
183+ function ChainRulesCore. rrule (:: typeof (StaticArraysCore. SVector{3 }), x1:: T , x2:: T , x3:: T ) where {T}
184+ StaticArraysCore. SVector {3} (x1, x2, x3), Δ -> (NoTangent (), Δ[1 ], Δ[2 ], Δ[3 ])
101185end
186+
187+ end
0 commit comments