@@ -24,7 +24,7 @@ function setup(rule, x; seen = Base.IdSet())
2424 end
2525end
2626
27- subtract! (x, x̄) = iswriteable (x) ? (x .= x .- x̄) : eltype (x).(x .- x̄)
27+ subtract! (x, x̄) = maywrite (x) ? (x .= x .- x̄) : eltype (x).(x .- x̄)
2828
2929update! (:: Nothing , x, :: Zero , :: Zero... ) = nothing , x
3030update! (:: Nothing , x, x̄s... ) = nothing , x
@@ -44,8 +44,8 @@ function update!(tree, x, x̄s...)
4444end
4545
4646function update (tree, x, x̄s... )
47- t′ = fmap (copy, tree; exclude = iswriteable )
48- x′ = fmap (copy, x; exclude = iswriteable )
47+ t′ = fmap (copy, tree; exclude = maywrite )
48+ x′ = fmap (copy, x; exclude = maywrite )
4949 update! (t′, x′, x̄s... )
5050end
5151
@@ -56,8 +56,17 @@ isnumeric(x::AbstractArray{<:Number}) = isleaf(x) # isleaf to allow for e.g. tr
5656isnumeric (x:: AbstractArray{<:Integer} ) = false
5757isnumeric (x) = false
5858
59- iswriteable (:: DenseArray ) = true # more elaborate versions are possible, wait until needed?
60- iswriteable (_) = false
59+ """
60+ maywrite(x) -> Bool
61+
62+ Should return `true` if we are completely sure that `update!` can write new
63+ values into `x`. Otherwise `false`, indicating a non-mutating path.
64+ For now, simply `x isa DenseArray` allowing `Array`, `CuArray`, etc.
65+ """
66+ maywrite (:: DenseArray ) = true # see https://github.com/FluxML/Optimisers.jl/issues/99 for discussion
67+ maywrite (_) = false
68+
69+ @deprecate iswriteable maywrite false # remove when releasing Optimisers@0.3
6170
6271"""
6372 trainable(x::Layer) -> NamedTuple
8493 @.. x = x + y
8594
8695Sometimes in-place broadcasting macro, for use in `apply!` rules.
87- If `iswriteable (x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
96+ If `maywrite (x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
8897"""
8998macro var".." (ex)
9099 Meta. isexpr (ex, :(= )) || throw (" the macro @.. only accepts assignment, like @.. x = y + z" )
91100 dst = esc (ex. args[1 ])
92101 src = esc (Broadcast. __dot__ (ex. args[2 ]))
93- :($ dst = if $ iswriteable ($ dst)
102+ :($ dst = if $ maywrite ($ dst)
94103 $ dst .= $ src
95104 else
96105 $ src
0 commit comments