@@ -38,7 +38,7 @@ to adjust the model:
3838
3939``` julia
4040
41- using Flux, Metalhead, Optimisers
41+ using Flux, Metalhead, Zygote, Optimisers
4242
4343model = Metalhead. ResNet (18 ) |> gpu # define a model to train
4444image = rand (Float32, 224 , 224 , 3 , 1 ) |> gpu; # dummy data
@@ -52,7 +52,7 @@ state = Optimisers.setup(rule, model); # initialise this optimiser's momentum e
5252end ;
5353
5454state, model = Optimisers. update (state, model, ∇model);
55- @show sum (model (image));
55+ @show sum (model (image)); # reduced
5656
5757```
5858
@@ -62,8 +62,14 @@ tree formed by the model and update the parameters using the gradients.
6262
6363There is also [ ` Optimisers.update! ` ] ( @ref ) which similarly returns a new model and new state,
6464but is free to mutate arrays within the old one for efficiency.
65- The method of ` apply! ` for each rule is likewise free to mutate arrays within its state;
66- they are defensively copied when this rule is used with ` update ` .
65+ (The method of ` apply! ` above is likewise free to mutate arrays within its state;
66+ they are defensively copied when this rule is used with ` update ` .)
67+ For ` Adam() ` , there are two momenta per parameter, thus ` state ` is about twice the size of ` model ` :
68+
69+ ``` julia
70+ Base. summarysize (model) / 1024 ^ 2 # about 45MB
71+ Base. summarysize (state) / 1024 ^ 2 # about 90MB
72+ ```
6773
6874Optimisers.jl does not depend on any one automatic differentiation package,
6975but for now the most likely source of gradients is [ Zygote.jl] ( https://fluxml.ai/Zygote.jl ) .
@@ -72,14 +78,34 @@ This `∇model` is another tree structure, rather than the dictionary-like objec
7278Zygote's "implicit" mode ` gradient(() -> loss(...), Flux.params(model)) ` -- see
7379[ Zygote's documentation] ( https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1 ) for more about this difference.
7480
81+
82+ ## Usage with [ Yota.jl] ( https://github.com/dfdx/Yota.jl )
83+
84+ Yota is another modern automatic differentiation package, an alternative to Zygote.
85+
86+ Its main function is ` Yota.grad ` , which returns the loss as well as the gradient (like ` Zygote.withgradient ` )
87+ but also returns a gradient component for the loss function.
88+ To extract what Optimisers.jl needs, you can write (for the Flux model above):
89+
90+ ``` julia
91+ using Yota
92+
93+ loss, (∇function , ∇model, ∇image) = Yota. grad (model, image) do m, x
94+ sum (m (x)
95+ end ;
96+
97+ # Or else, this may save computing ∇image:
98+ loss, (_, ∇model) = grad (m -> sum (m (image)), model);
99+ ```
100+
75101## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
76102
77- The main design difference of Lux is that the tree of parameters is separate from
103+ The main design difference of Lux from Flux is that the tree of parameters is separate from
78104the layer structure. It is these parameters which `setup` and `update` need to know about.
79105
80106Lux describes this separation of parameter storage from model description as "explicit" parameters.
81107Beware that it has nothing to do with Zygote's notion of "explicit" gradients.
82- (If the same model is written in Flux and Lux, ` ∇model ` above and ` ∇params ` below will often be
108+ (If the same model is written in Flux and Lux, `∇model` above and `∇params` below will be nearly
83109identical trees of nested `NamedTuple`s.)
84110
85111``` julia
@@ -88,27 +114,47 @@ using Lux, Boltz, Zygote, Optimisers
88114
89115lux_model, params, lux_state = Boltz. resnet (:resnet18 ) |> gpu; # define and initialise model
90116images = rand (Float32, 224 , 224 , 3 , 4 ) |> gpu; # batch of dummy data
91- y, _ = Lux. apply (lux_model, images, params, lux_state); # run the model
92- @show sum (y) # initial dummy loss
117+ y, lux_state = Lux. apply (lux_model, images, params, lux_state); # run the model
118+ @show sum (y); # initial dummy loss
93119
94120rule = Optimisers. Adam ()
95121opt_state = Optimisers. setup (rule, params); # optimiser state based on model parameters
96122
97- ∇params, _ = gradient (params, images) do p, x # gradient with respect to parameter tree
98- y, _ = Lux. apply (lux_model, x, p, lux_state)
99- sum (y)
123+ (loss, lux_state), back = Zygote . pullback (params, images) do p, x
124+ y, st = Lux. apply (lux_model, x, p, lux_state)
125+ sum (y), st # return both the loss, and the updated lux_state
100126end ;
127+ ∇params, _ = back ((one .(loss), nothing )); # gradient of only the loss, with respect to parameter tree
128+ loss == sum (y) # not yet changed
101129
102130opt_state, params = Optimisers. update! (opt_state, params, ∇params);
103131
104- y, _ = Lux. apply (lux_model, images, params, lux_state);
105- @show sum (y)
132+ y, lux_state = Lux. apply (lux_model, images, params, lux_state);
133+ @show sum (y); # now reduced
106134
107135```
108136
109137Besides the parameters stored in `params` and gradually optimised, any other model state
110- is stored in ` lux_state ` . For simplicity this example does not show how to propagate the
111- updated ` lux_state ` to the next iteration, see Lux's documentation.
138+ is stored in `lux_state`, and updated by `Lux.apply`. (In this example, BatchNorm has state.)
139+ This is completely unrelated to Optimisers.jl's state, although designed in a similar spirit.
140+
141+ ``` julia
142+ Base. summarysize (lux_model) / 1024 # just 2KB
143+ Base. summarysize (params) / 1024 ^ 2 # about 45MB, same as Flux model
144+ Base. summarysize (lux_state) / 1024 # 40KB
145+ Base. summarysize (opt_state) / 1024 ^ 2 # about 90MB, with Adam
146+ ```
147+
148+ If you are certain there is no model state, then the gradient calculation can
149+ be simplified to use `Zygote.gradient` instead of `Zygote.pullback`:
150+
151+ ``` julia
152+ ∇params, _ = gradient (params, images) do p, x
153+ y, _ = Lux. apply (lux_model, x, p, lux_state) # discards new lux_state
154+ sum (y)
155+ end ;
156+ ```
157+
112158
113159## Non-`trainable` Parameters
114160
0 commit comments