Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #62 #70

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
29 changes: 18 additions & 11 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,20 @@ end
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
Base.length(re::Restructure) = re.length

struct Offset
i::Int
end

# This flattens a model, and returns a web of offsets for later use:
function _flatten(x)
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
isnumeric(x) && return vcat(_vec(x)), Offset(0), length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
o
Offset(o)
end
reduce(vcat, arrays), off, len[]
end
Expand All @@ -85,16 +89,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trai
end
end

_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
_getat(y::Number, off::Offset, flat::AbstractVector) = ProjectTo(y)(flat[off.i + 1])
_getat(y::AbstractArray, off::Offset, flat::AbstractVector) =
ProjectTo(y)(reshape(flat[off.i .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes

function _trainable_biwalk(f, x, aux)
ch, re = functor(typeof(x), x)
au, _ = functor(typeof(x), aux)
au = _aux_children(aux)
_trainmap(f, ch, _trainable(x), au) |> re
end

_aux_children(off) = functor(off)[1]
jondeuce marked this conversation as resolved.
Show resolved Hide resolved

function _trainmap(f, ch, tr, aux)
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c)
isnothing(t) ? c : f(t, a)
Expand All @@ -103,13 +109,14 @@ end

function _Tangent_biwalk(f, x, aux) # use with prune = NoT
ch, re = functor(typeof(x), x)
au, _ = functor(typeof(x), aux)
au = _aux_children(aux)
y = _trainmap(f, ch, _trainable(x), au)
y isa Tuple{} && return NoT
p = ProjectTo(x)
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self, this I need to think about. Some of this complication was working around things that are now fixed in CRC.jl, if I remember right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, admittedly this line took some trial and error and is a little bit above my pay-grade. I managed to convince myself, but perhaps there's something cleaner.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I finally understand what's going on. Sorry it took a while.

re constructs another Skip containing the gradient, and backing turns that into a NamedTuple with the same field names, which is what Tangent wants.

The only way I can see this failing is this: If the primal type's constructor is fussy about what types it can accept, then it may not be happy to accept something which is valid as its gradient. E.g. if there is only Skip(::AbstractLayer), and re tries to make one with a Tangent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! Yes, I struggled with that edge case too. Unfortunately I think it's quite tricky to work around. For example, suppose you have a user-defined functor(m::MyModel) = (m.w,), w -> .... Then:

  1. In general there's no way to reconstruct MyModel (or even a NamedTuple of fields/values) without re, as you do not know the corresponding field name given only (m.w,), but
  2. As you say, if the primal constructor isn't sufficiently generic then it won't be able to store Tangent/Nothing/etc. values in it's fields and will error before backing can unpack it again

Avoiding re would be ideal, but I think that would require functor to always return NamedTuples on custom structs. I noticed that this is the default in @functor, though, so maybe it's not such a painful requirement? In the mean time I can at least add a branch that would avoid re for structs that are functored to NamedTuples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact there's another problem I didn't spot before, what a mess:

julia> ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0]);  # from tests: a,c are functor-ed, and only a is trainable

julia> v2, re2 = destructure(ac)
([1.0, 2.0], Restructure(TwoThirds, ..., 2))

julia> gradient(ac) do x  # with Tangent{typeof(x), typeof(y)}(y)
             w2, _ = destructure(x)
             w2[2]^2
           end
((a = [0.0, 4.0], b = nothing, c = [4.0, 5.0]),) 

# Same, with z = backing(re(y)) :
julia> gradient(ac) do x
             w2, _ = destructure(x)
             w2[2]^2
           end
┌ Info: last case
│   x = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])
│   y = (a = [0.0, 4.0], c = [4.0, 5.0])
└   z = NamedTuple{(:a, :b, :c), Tuple{Any, Any, Any}}(([0.0, 4.0], [3.0], [4.0, 5.0]))
((a = [0.0, 4.0], b = [3.0], c = [4.0, 5.0]),)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yikes. That's a good example, hits all the pain points at once. If I'm understanding correctly, the gradient should be ((a = [0.0, 4.0], b = nothing, c = nothing),), right?

I think the problem is the _trainmap above; it populates the nothing values from _trainable (non-trainable fields) with the primal values, when they should be NoT. That's how the b and/or c values get back in there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think _trainmap needs to do something isnothing(t) ? NoT : f(t, a) here. That's where c = [4.0, 5.0] is coming from.

But b = [3.0] is coming from this PR's trick of calling the reconstructor made by @functor:

julia> ch, re = Functors.functor(ac)
((a = [1.0, 2.0], c = [4.0, 5.0]), var"#1#2"{TwoThirds}(TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])))

julia> re((a = [10, 20], c = nothing))
TwoThirds([10, 20], [3.0], nothing)

Copy link
Contributor Author

@jondeuce jondeuce May 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. So on top of the modified _trainmap to fix c, one would still have to filter backing(re(y)) to replace repopulated primal values which aren't functor-ed with NoT in order to fix b.

EDIT: But, based on the output of Tangent{typeof(x), typeof(y)}(y), maybe the modified _trainmap alone would be enough and backing(re(y)) isn't needed after all, as Tangent will assign NoT to omitted fields in y automatically.

EDIT 2: Never mind, that would still fail for children which aren't fields, like Skip.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright pushed something that works for both Skip and your TwoThirds example (modified _trainmap + filtering backing(re(y))). But since it uses re it would still fail for fussy constructors.

Tangent{typeof(x), typeof(y)}(y)
end
end
Expand All @@ -126,23 +133,23 @@ ChainRulesCore.@non_differentiable _zero(x)
function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
dx′, _ = functor(typeof(x), base(dx))
off′, _ = functor(typeof(x), off)
off′ = _aux_children(off)
for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′)
flat = _grad!(xᵢ, dxᵢ, oᵢ, flat)
end
flat
end
function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T
function _grad!(x, dx, off::Offset, flat::AbstractVector{T}) where T
dx_un = unthunk(dx)
T2 = promote_type(T, eltype(dx_un))
if T != T2 # then we must widen the type
flat = copyto!(similar(flat, T2), flat)
end
@views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes
@views flat[off.i .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes
flat
end
_grad!(x, dx::Zero, off, flat::AbstractVector) = flat
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity
_grad!(x, dx::Zero, off::Offset, flat::AbstractVector) = flat # ambiguity

# These are only needed for 2nd derivatives:
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
Expand Down
40 changes: 40 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,46 @@ end
end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],)
end

@testset "issue 62" begin
# Flux.Chain used to have children which aren't its own fields, which Skip immitates.

sk = Skip([1.0, 2.0], (x=3, y=[4.0, 5.0]))
@test fmap(identity, sk) == sk

gk = gradient(x -> sum(x[2].y), sk)[1]
@test fmap(Zygote.accum, sk, gk) isa Skip # this relies on functor(typeof(x), dx)

st = fmapstructure(identity, sk)
@test st isa Tuple{Vector, NamedTuple}
@test_throws Exception fmap(+, sk, st) # this fails because of functor(typeof(x), dx)

v, re = destructure(sk)
@test v == [1,2,4,5]
@test re(10v) isa Skip
@test re(10v)[1] == [10, 20]

@test gradient(zero(v)) do w
re(w)[2].y[1]
end == ([0,0,1,0],)

@test gradient(sk) do x
w, _ = destructure(x)
w[1] + w[4]
end == ((layers = ([1.0, 0.0], (x = nothing, y = [0.0, 1.0])),),)
#=

ERROR: ArgumentError: Tangent for the primal Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}} should be backed by a NamedTuple type, not by Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}.
Stacktrace:
[1] _backing_error(P::Type, G::Type, E::Type)
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:62
[2] ChainRulesCore.Tangent{Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}}}(backing::Tuple{Vector{Float64}, ChainRulesCore.Tangent{NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}, NamedTuple{(:x, :y), Tuple{ChainRulesCore.NoTangent, Vector{Float64}}}}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:36
[3] _Tangent_biwalk(f::Function, x::Skip{Tuple{Vector{Float64}, NamedTuple{(:x, :y), Tuple{Int64, Vector{Float64}}}}}, aux::Tuple{Int64, NamedTuple{(:x, :y), Tuple{Tuple{}, Int64}}})
@ Optimisers ~/.julia/dev/Optimisers/src/destructure.jl:116

=#
end

@testset "DiffEqFlux issue 699" begin
# The gradient of `re` is a vector into which we accumulate contributions, and the issue
# is that one contribution may have a wider type than `v`, especially for `Dual` numbers.
Expand Down
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Optimisers.trainable(x::TwoThirds) = (a = x.a,)

struct Skip{T} # like Flux 0.12's Chain
layers::T
Skip(ls...) = new{typeof(ls)}(ls)
end
Base.getindex(x::Skip, i::Integer) = x.layers[i]
Functors.functor(::Type{<:Skip}, x) = x.layers, ls -> Skip(ls...)

@testset verbose=true "Optimisers.jl" begin
@testset verbose=true "Features" begin

Expand Down Expand Up @@ -165,6 +172,16 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@test_throws ArgumentError Optimisers.setup(ADAMW(), m2)
end

@testset "issue 62" begin
m62 = (s = Skip([1.0, 2.0], Foo([3.0], false)), t = [4.0, 5.0])
s62 = Optimisers.setup(Descent(), m62)
g62 = gradient(m -> m.s[2].x[1] + 3 * m.t[2], m62)
s, m = Optimisers.update(s62, m62, g62...)
@test m.s isa Skip
@test m.s[2].x ≈ [2.9]
@test m.t ≈ [4, 4.7]
end

end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down