@@ -64,6 +64,14 @@ There is also [`Optimisers.update!`](@ref) which similarly returns a new model a
6464but is free to mutate arrays within the old one for efficiency.
6565The method of ` apply! ` for each rule is likewise free to mutate arrays within its state;
6666they are defensively copied when this rule is used with ` update ` .
67+ (The method of ` apply! ` above is likewise free to mutate arrays within its state;
68+ they are defensively copied when this rule is used with ` update ` .)
69+ For ` Adam() ` , there are two momenta per parameter, thus ` state ` is about twice the size of ` model ` :
70+
71+ ``` julia
72+ Base. summarysize (model) / 1024 ^ 2 # about 45MB
73+ Base. summarysize (state) / 1024 ^ 2 # about 90MB
74+ ```
6775
6876Optimisers.jl does not depend on any one automatic differentiation package,
6977but for now the most likely source of gradients is [ Zygote.jl] ( https://fluxml.ai/Zygote.jl ) .
@@ -72,6 +80,7 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
7280Zygote's "implicit" mode ` gradient(() -> loss(...), Flux.params(model)) ` -- see
7381[ Zygote's documentation] ( https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1 ) for more about this difference.
7482
83+
7584## Usage with [ Yota.jl] ( https://github.com/dfdx/Yota.jl )
7685
7786Yota is another modern automatic differentiation package, an alternative to Zygote.
@@ -89,40 +98,6 @@ loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
8998end ;
9099```
91100
92- Unfortunately this example doesn't actually run right now. This is the error:
93- ```
94- julia> loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
95- sum(m(x))
96- end;
97- ┌ Error: Failed to compile rrule for #233(Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64, relu), Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64)),), extract details via:
98- │ (f, args) = Yota.RRULE_VIA_AD_STATE[]
99- └ @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:160
100- ERROR: No deriative rule found for op %3 = getfield(%1, :x)::Array{Float32, 4} , try defining it using
101-
102- ChainRulesCore.rrule(::typeof(getfield), ::Flux.var"#233#234"{Array{Float32, 4}}, ::Symbol) = ...
103-
104- Stacktrace:
105- [1] error(s::String)
106- @ Base ./error.jl:35
107- [2] step_back!(tape::Umlaut.Tape{Yota.GradCtx}, y::Umlaut.Variable)
108- @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:197
109- [3] back!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
110- @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:238
111- [4] gradtape!(tape::Umlaut.Tape{Yota.GradCtx}; seed::Symbol)
112- @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:249
113- [5] gradtape(f::Flux.var"#233#234"{Array{Float32, 4}}, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}; ctx::Yota.GradCtx, seed::Symbol)
114- @ Yota ~/.julia/packages/Yota/GIFMf/src/grad.jl:276
115- [6] make_rrule(f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
116- @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:109
117- [7] rrule_via_ad(#unused#::Yota.YotaRuleConfig, f::Function, args::Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}})
118- @ Yota ~/.julia/packages/Yota/GIFMf/src/cr_api.jl:153
119- ...
120-
121- (jl_GWa2lX) pkg> st
122- Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GWa2lX/Project.toml`
123- ⌃ [587475ba] Flux v0.13.4
124- [cd998857] Yota v0.7.4
125- ```
126101
127102## Usage with [ Lux.jl] ( https://github.com/avik-pal/Lux.jl )
128103
@@ -163,6 +138,14 @@ y, lux_state = Lux.apply(lux_model, images, params, lux_state);
163138Besides the parameters stored in ` params ` and gradually optimised, any other model state
164139is stored in ` lux_state ` , and updated by ` Lux.apply ` . (In this example, BatchNorm has state.)
165140This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
141+
142+ ``` julia
143+ Base. summarysize (lux_model) / 1024 # just 2KB
144+ Base. summarysize (params) / 1024 ^ 2 # about 45MB, same as Flux model
145+ Base. summarysize (lux_state) / 1024 # 40KB
146+ Base. summarysize (opt_state) / 1024 ^ 2 # about 90MB, with Adam
147+ ```
148+
166149If you are certain there is no model state, then the gradient calculation can
167150be simplified to use ` Zygote.gradient ` instead of ` Zygote.pullback ` :
168151
0 commit comments