Skip to content

Commit

Permalink
Utilize ChainRulesCore thunks (#966)
Browse files Browse the repository at this point in the history
* Don't force unthunking of ChainRulesCore thunks

Introduces @_adjoint_keepthunks to mark adjoints that should pass
chunks through.

* Use @_adjoint_keepthunks where appropriate

* Use wrap_chainrules_output in unthunk_tangent

* Fix unthunk_tangent for array of thunks

* Don't unthunk explicitly in unbroadcast

* Define unthunk_tangent for IdDict to support Params

* Make unthunk_tangent for IdDict non-differentiable

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>

* Revert "Don't unthunk explicitly in unbroadcast"

This reverts commit 34865ea.

* Fix problems related to unthunk_tangent for IdDict

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>

* Resolve duplicate rrule for unthunk_tangent with IdDict

* Make unthunk_tangent recurse into arrays

* Fix tests

* Unthunk in collect(::Generator)

* Update deps

* Disable thunks for 2nd order AD

* Temporary use fork of CRC

* Remove hook

* Fix

* Cleanup

* Up deps

* Cleanup

* Remove extra unthunk_tangent

---------

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
Co-authored-by: Anton Smirnov <tonysmn97@gmail.com>
  • Loading branch information
3 people authored Jan 4, 2025
1 parent bc6cd09 commit d1aa910
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 42 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ ZygoteTrackerExt = "Tracker"

[compat]
AbstractFFTs = "1.3.1"
ChainRules = "1.44.1"
ChainRulesCore = "1.9"
ChainRules = "1.72.2"
ChainRulesCore = "1.25.1"
ChainRulesTestUtils = "1"
Colors = "0.12, 0.13"
DiffRules = "1.4"
Expand Down
3 changes: 2 additions & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ module Zygote
using LinearAlgebra, Statistics
using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield, unthunk_tangent

using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
using ChainRules: ChainRules, AbstractThunk, rrule, unthunk, canonicalize
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand Down
16 changes: 14 additions & 2 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from
# Zygote rules here?
function unthunk_tangent end
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
@non_differentiable unthunk_tangent(::IdDict)


struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForwardsMode}}
context::CTX
end
Expand Down Expand Up @@ -107,7 +118,6 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally.
"""
@inline wrap_chainrules_output(x) = x
@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
Expand Down Expand Up @@ -261,7 +271,9 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs
_pullback(config.context, f_args...)
end

ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
ad_pullback(Δ) = zygote2differential(
pb(wrap_chainrules_output(unthunk_tangent(Δ))),
f_args)
return y, ad_pullback
end

Expand Down
15 changes: 12 additions & 3 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ end
_pullback(f, args...) = _pullback(Context(), f, args...)

tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)
tailmemaybe(x::Tuple) = unthunk_tangent(Base.tail(x))

# unthunking is essentially an identity operation on a lazy value, but
# `@adjoint unthunk_tangent(x) = unthunk_tangent(x), ȳ -> (ȳ,)` is not enough to make
# nested AD work, so define
@adjoint tailmemaybe(xs::Tuple) = tailmemaybe(xs), x̄s -> ((nothing, x̄s...),)


"""
pullback(f, args...)
Expand Down Expand Up @@ -351,6 +357,9 @@ function copy!(x::AbstractVector, ps::Params)
x
end

_maybe_unthunk(x::AbstractThunk) = unthunk(x)
_maybe_unthunk(x) = x

"""
Grads(...)
Expand Down Expand Up @@ -385,7 +394,7 @@ end

function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
return gs.grads[x]
return _maybe_unthunk(gs.grads[x])
end

"""
Expand Down Expand Up @@ -468,7 +477,7 @@ function pullback(f, ps::Params)
cache(cx)[p] = nothing
end
back(Δ)
Grads(cx.cache, ps) # TODO make a copy
Grads(_maybe_unthunk(cx.cache), ps)
end
end

Expand Down
12 changes: 12 additions & 0 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
insertafter!, finish, expand!, prune!, substitute!, substitute,
block, block!, branch!, return!, stmt, meta


# TODO: Temporary, to be removed when ChainRulesCore rrules are required to
# support thunks as an input and all instances of _adjoint_keepthunks in
# Zygote have been replaces by rrules:
macro _adjoint_keepthunks(ex)
ZygoteRules.gradm(ex, false, true)
end
macro _adjoint_keepthunks!(ex)
ZygoteRules.gradm(ex, true, true)
end


@inline tuple_va(N, xs) = xs
@inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...)
@inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N))
Expand Down
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
= reconstruct_if_dict(x̄, _keys) # return a dictionary if needed
(nothing, (f = f̄, iter = x̄),)
end
y, collect_pullback
y, collect_pullback unthunk_tangent
end

collect_if_dict(x::Dict) = collect(x), collect(keys(x))
Expand Down
3 changes: 2 additions & 1 deletion src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end

function unbroadcast(x::AbstractArray, x̄)
function unbroadcast(x::AbstractArray, maybethunked_x̄)
= unthunk_tangent(maybethunked_x̄)
N = ndims(x̄)
if length(x) == length(x̄)
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
Expand Down
52 changes: 29 additions & 23 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,23 @@ function accum(x::RefValue, y::RefValue)
return x
end

accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y))
accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y)

accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y)))
accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y))
accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y)))

# Core functions
@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
@_adjoint_keepthunks deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)

@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing
@_adjoint_keepthunks (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing

@adjoint ifelse(cond::Bool, t, f) =
@_adjoint_keepthunks ifelse(cond::Bool, t, f) =
ifelse(cond, t, f),
Δ -> cond ? (nothing, Δ, zero(Δ)) : (nothing, zero(Δ), Δ)

@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)
@_adjoint_keepthunks Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)

accum_param(::Context{false}, _, Δ) = Δ
@generated function accum_param(cx::Context, x, Δ)
Expand All @@ -70,11 +77,11 @@ end

unwrap(x) = x

@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)
@_adjoint_keepthunks unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)

unwrap(ref, x) = x

@adjoint unwrap(ref, x) = unwrap(x), function (x̄)
@_adjoint_keepthunks unwrap(ref, x) = unwrap(x), function (x̄)
accum_global(__context__, ref, x̄)
(accum_param(__context__, x, x̄),)
end
Expand All @@ -88,7 +95,7 @@ function global_set(ref, val)
end
end

@adjoint! function global_set(ref, x)
@_adjoint_keepthunks! function global_set(ref, x)
global_set(ref, x), function (x̄)
gs = cache(__context__)
= accum(get(gs, ref, nothing), x̄)
Expand All @@ -101,9 +108,9 @@ end

using Base: tail

@adjoint tuple(xs...) = xs, identity
@_adjoint_keepthunks tuple(xs...) = xs, identity

@adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
@_adjoint_keepthunks function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i}
val = xs[i]
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand All @@ -112,7 +119,7 @@ using Base: tail
val, back
end

@adjoint function getindex(xs::NTuple{N,Any}, i::Integer) where N
@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, i::Integer) where N
val = xs[i]
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand All @@ -121,10 +128,10 @@ end
return val, back
end

@adjoint getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N =
@_adjoint_keepthunks getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N =
(xs[r], Δ -> (ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : nothing, Val(N)), nothing))

@adjoint function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N
@_adjoint_keepthunks function getindex(xs::NTuple{N,Any}, r::AbstractVector) where N
val = xs[r]
function back(Δ)
dxs = ntuple(Val(length(xs))) do x
Expand Down Expand Up @@ -155,18 +162,18 @@ function _pullback(cx::AContext, ::typeof(literal_indexed_iterate), xs::Tuple, :
end

# Needed for iteration lowering
@adjoint Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
@_adjoint_keepthunks Core.getfield(xs::NTuple{N,Any}, i::Int) where N =
(xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing))

@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
@_adjoint_keepthunks Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Int) where {K,N} =
(xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing))

@adjoint function Base.first(xs::Tuple)
@_adjoint_keepthunks function Base.first(xs::Tuple)
drest = map(_->nothing, tail(xs))
first(xs), Δ -> ((Δ, drest...),)
end

@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)
@_adjoint_keepthunks Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),)

_empty(x) = length(x)
_empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x)
Expand All @@ -188,7 +195,7 @@ end

unapply(t, xs) = _unapply(t, xs)[1]

@adjoint! function Core._apply(f, args...)
@_adjoint_keepthunks! function Core._apply(f, args...)
y, back = Core._apply(_pullback, (__context__, f), args...)
st = map(_empty, args)
y, function (Δ)
Expand All @@ -198,7 +205,7 @@ unapply(t, xs) = _unapply(t, xs)[1]
end
end

@adjoint! function Core._apply_iterate(::typeof(iterate), f, args...)
@_adjoint_keepthunks! function Core._apply_iterate(::typeof(iterate), f, args...)
y, back = Core._apply(_pullback, (__context__, f), args...)
st = map(_empty, args)
y, function (Δ)
Expand All @@ -223,7 +230,7 @@ end
@generated pair(::Val{k}, v, _=nothing) where k = :($k = v,)
@generated pair(::Val{k}, v, ::NamedTuple{keys}) where {k,keys} = k isa Int ? :($(getfield(keys, k)) = v,) : :($k = v,)

@adjoint function literal_getfield(x, ::Val{f}) where f
@_adjoint_keepthunks function literal_getfield(x, ::Val{f}) where f
val = getfield(x, f)
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
Expand Down Expand Up @@ -273,8 +280,7 @@ function _get!(default::Base.Callable, ch, x)
end
end


@adjoint! function setfield!(x, f, val)
@_adjoint_keepthunks! function setfield!(x, f, val)
y = setfield!(x, f, val)
g = grad_mut(__context__, x)
y, function (_)
Expand All @@ -290,13 +296,13 @@ end

Jnew{T}(g) where T = Jnew{T,typeof(g)}(g)

@adjoint! function __new__(T, args...)
@_adjoint_keepthunks! function __new__(T, args...)
x = __new__(T, args...)
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
x, Jnew{T,typeof(g),false}(g)
end

@adjoint! function __splatnew__(T, args)
@_adjoint_keepthunks! function __splatnew__(T, args)
x = __splatnew__(T, args)
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
x, Jnew{T,typeof(g),true}(g)
Expand Down
16 changes: 8 additions & 8 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ function ngradient(f, xs::AbstractArray...)
return grads
end

function gradcheck(f, xs...)
function gradcheck(f, xs...; rtol = 1e-5, atol = 1e-5)
grad_zygote = gradient(f, xs...)
grad_finite_difference = ngradient(f, xs...)
return all(isapprox.(grad_zygote, grad_finite_difference; rtol = 1e-5, atol = 1e-5))
return all(isapprox.(grad_zygote, grad_finite_difference; rtol = rtol, atol = atol))
end

gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
gradtest(f, xs::AbstractArray...; kwargs...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...; kwargs...)
gradtest(f, dims...; kwargs...) = gradtest(f, rand.(Float64, dims)...; kwargs...)

# utilities for using gradcheck with complex matrices
_splitreim(A) = (real(A),)
Expand Down Expand Up @@ -160,8 +160,8 @@ end
@test gradient(y, x, z) == ([1, 1, 2], nothing)

# https://github.com/FluxML/Zygote.jl/issues/376
_, back = Zygote._pullback(x->x[1]*im, randn(2))
@test back(1.0)[2] == real([-im, 0]) == [0, 0]
_, back = Zygote.pullback(x -> x[1] * im, randn(2))
@test back(1.0)[1] == real([-im, 0]) == [0, 0]

# _droplike
@test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],)
Expand Down Expand Up @@ -949,8 +949,8 @@ end
_hermsymtype(::Type{<:Symmetric}) = Symmetric
_hermsymtype(::Type{<:Hermitian}) = Hermitian

function _gradtest_hermsym(f, ST, A)
gradtest(_splitreim(collect(A))...) do (args...)
function _gradtest_hermsym(f, ST, A; kwargs...)
gradtest(_splitreim(collect(A))...; kwargs...) do (args...)
B = f(ST(_joinreim(_dropimaggrad.(args)...)))
return sum(_splitreim(B))
end
Expand Down
1 change: 0 additions & 1 deletion test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,4 @@ end
@test sgs[d.b] fill(1.f0, size(d.b))
end


end

0 comments on commit d1aa910

Please sign in to comment.